diff --git a/README.md b/README.md index 8f4c04641..bca81c20b 100644 --- a/README.md +++ b/README.md @@ -60,8 +60,8 @@ https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2 -### NetBird on Lawrence Systems (Video) -[![Watch the video](https://img.youtube.com/vi/Kwrff6h0rEw/0.jpg)](https://www.youtube.com/watch?v=Kwrff6h0rEw) +### Self-Host NetBird (Video) +[![Watch the video](https://img.youtube.com/vi/bZAgpT6nzaQ/0.jpg)](https://youtu.be/bZAgpT6nzaQ) ### Key features diff --git a/client/internal/dnsfwd/forwarder.go b/client/internal/dnsfwd/forwarder.go index 1230a4e46..5c7cb31fc 100644 --- a/client/internal/dnsfwd/forwarder.go +++ b/client/internal/dnsfwd/forwarder.go @@ -190,50 +190,75 @@ func (f *DNSForwarder) Close(ctx context.Context) error { return nberrors.FormatErrorOrNil(result) } -func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, query *dns.Msg) *dns.Msg { +func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, query *dns.Msg, startTime time.Time) { if len(query.Question) == 0 { - return nil + return } question := query.Question[0] - logger.Tracef("received DNS request for DNS forwarder: domain=%s type=%s class=%s", - question.Name, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass]) + qname := strings.ToLower(question.Name) - domain := strings.ToLower(question.Name) + logger.Tracef("question: domain=%s type=%s class=%s", + qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass]) resp := query.SetReply(query) network := resutil.NetworkForQtype(question.Qtype) if network == "" { resp.Rcode = dns.RcodeNotImplemented - if err := w.WriteMsg(resp); err != nil { - logger.Errorf("failed to write DNS response: %v", err) - } - return nil + f.writeResponse(logger, w, resp, qname, startTime) + return } - mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, ".")) - // query doesn't match any configured domain + mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(qname, ".")) if mostSpecificResId == "" { resp.Rcode = dns.RcodeRefused - if err := w.WriteMsg(resp); err != nil { - logger.Errorf("failed to write DNS response: %v", err) - } - return nil + f.writeResponse(logger, w, resp, qname, startTime) + return } ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout) defer cancel() - result := resutil.LookupIP(ctx, f.resolver, network, domain, question.Qtype) + result := resutil.LookupIP(ctx, f.resolver, network, qname, question.Qtype) if result.Err != nil { - f.handleDNSError(ctx, logger, w, question, resp, domain, result) - return nil + f.handleDNSError(ctx, logger, w, question, resp, qname, result, startTime) + return } f.updateInternalState(result.IPs, mostSpecificResId, matchingEntries) - resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, result.IPs, f.ttl)...) - f.cache.set(domain, question.Qtype, result.IPs) + resp.Answer = append(resp.Answer, resutil.IPsToRRs(qname, result.IPs, f.ttl)...) + f.cache.set(qname, question.Qtype, result.IPs) - return resp + f.writeResponse(logger, w, resp, qname, startTime) +} + +func (f *DNSForwarder) writeResponse(logger *log.Entry, w dns.ResponseWriter, resp *dns.Msg, qname string, startTime time.Time) { + if err := w.WriteMsg(resp); err != nil { + logger.Errorf("failed to write DNS response: %v", err) + return + } + + logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s", + qname, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime)) +} + +// udpResponseWriter wraps a dns.ResponseWriter to handle UDP-specific truncation. +type udpResponseWriter struct { + dns.ResponseWriter + query *dns.Msg +} + +func (u *udpResponseWriter) WriteMsg(resp *dns.Msg) error { + opt := u.query.IsEdns0() + maxSize := dns.MinMsgSize + if opt != nil { + maxSize = int(opt.UDPSize()) + } + + if resp.Len() > maxSize { + resp.Truncate(maxSize) + } + + return u.ResponseWriter.WriteMsg(resp) } func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) { @@ -243,30 +268,7 @@ func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) { "dns_id": fmt.Sprintf("%04x", query.Id), }) - resp := f.handleDNSQuery(logger, w, query) - if resp == nil { - return - } - - opt := query.IsEdns0() - maxSize := dns.MinMsgSize - if opt != nil { - // client advertised a larger EDNS0 buffer - maxSize = int(opt.UDPSize()) - } - - // if our response is too big, truncate and set the TC bit - if resp.Len() > maxSize { - resp.Truncate(maxSize) - } - - if err := w.WriteMsg(resp); err != nil { - logger.Errorf("failed to write DNS response: %v", err) - return - } - - logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s", - query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime)) + f.handleDNSQuery(logger, &udpResponseWriter{ResponseWriter: w, query: query}, query, startTime) } func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) { @@ -276,18 +278,7 @@ func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) { "dns_id": fmt.Sprintf("%04x", query.Id), }) - resp := f.handleDNSQuery(logger, w, query) - if resp == nil { - return - } - - if err := w.WriteMsg(resp); err != nil { - logger.Errorf("failed to write DNS response: %v", err) - return - } - - logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s", - query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime)) + f.handleDNSQuery(logger, w, query, startTime) } func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) { @@ -334,6 +325,7 @@ func (f *DNSForwarder) handleDNSError( resp *dns.Msg, domain string, result resutil.LookupResult, + startTime time.Time, ) { qType := question.Qtype qTypeName := dns.TypeToString[qType] @@ -343,9 +335,7 @@ func (f *DNSForwarder) handleDNSError( // NotFound: cache negative result and respond if result.Rcode == dns.RcodeNameError || result.Rcode == dns.RcodeSuccess { f.cache.set(domain, question.Qtype, nil) - if writeErr := w.WriteMsg(resp); writeErr != nil { - logger.Errorf("failed to write failure DNS response: %v", writeErr) - } + f.writeResponse(logger, w, resp, domain, startTime) return } @@ -355,9 +345,7 @@ func (f *DNSForwarder) handleDNSError( logger.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName) resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, ips, f.ttl)...) resp.Rcode = dns.RcodeSuccess - if writeErr := w.WriteMsg(resp); writeErr != nil { - logger.Errorf("failed to write cached DNS response: %v", writeErr) - } + f.writeResponse(logger, w, resp, domain, startTime) return } @@ -365,9 +353,7 @@ func (f *DNSForwarder) handleDNSError( verifyResult := resutil.LookupIP(ctx, f.resolver, resutil.NetworkForQtype(qType), domain, qType) if verifyResult.Rcode == dns.RcodeNameError || verifyResult.Rcode == dns.RcodeSuccess { resp.Rcode = verifyResult.Rcode - if writeErr := w.WriteMsg(resp); writeErr != nil { - logger.Errorf("failed to write failure DNS response: %v", writeErr) - } + f.writeResponse(logger, w, resp, domain, startTime) return } } @@ -375,15 +361,12 @@ func (f *DNSForwarder) handleDNSError( // No cache or verification failed. Log with or without the server field for more context. var dnsErr *net.DNSError if errors.As(result.Err, &dnsErr) && dnsErr.Server != "" { - logger.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, result.Err) + logger.Warnf("upstream failure: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, result.Err) } else { logger.Warnf(errResolveFailed, domain, result.Err) } - // Write final failure response. - if writeErr := w.WriteMsg(resp); writeErr != nil { - logger.Errorf("failed to write failure DNS response: %v", writeErr) - } + f.writeResponse(logger, w, resp, domain, startTime) } // getMatchingEntries retrieves the resource IDs for a given domain. diff --git a/client/internal/dnsfwd/forwarder_test.go b/client/internal/dnsfwd/forwarder_test.go index 6416c2f21..7325ef8a7 100644 --- a/client/internal/dnsfwd/forwarder_test.go +++ b/client/internal/dnsfwd/forwarder_test.go @@ -318,8 +318,9 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) { query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA) mockWriter := &test.MockResponseWriter{} - resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query) + forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now()) + resp := mockWriter.GetLastResponse() if tt.shouldResolve { require.NotNil(t, resp, "Expected response for authorized domain") require.Equal(t, dns.RcodeSuccess, resp.Rcode, "Expected successful response") @@ -329,10 +330,9 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) { mockFirewall.AssertExpectations(t) mockResolver.AssertExpectations(t) } else { - if resp != nil { - assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess, - "Unauthorized domain should not return successful answers") - } + require.NotNil(t, resp, "Expected response") + assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess, + "Unauthorized domain should not return successful answers") mockFirewall.AssertNotCalled(t, "UpdateSet") mockResolver.AssertNotCalled(t, "LookupNetIP") } @@ -466,14 +466,16 @@ func TestDNSForwarder_FirewallSetUpdates(t *testing.T) { dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA) mockWriter := &test.MockResponseWriter{} - resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, dnsQuery) + forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, dnsQuery, time.Now()) // Verify response + resp := mockWriter.GetLastResponse() if tt.shouldResolve { require.NotNil(t, resp, "Expected response for authorized domain") require.Equal(t, dns.RcodeSuccess, resp.Rcode) require.NotEmpty(t, resp.Answer) - } else if resp != nil { + } else { + require.NotNil(t, resp, "Expected response") assert.True(t, resp.Rcode == dns.RcodeRefused || len(resp.Answer) == 0, "Unauthorized domain should be refused or have no answers") } @@ -528,9 +530,10 @@ func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) { query.SetQuestion("example.com.", dns.TypeA) mockWriter := &test.MockResponseWriter{} - resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query) + forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now()) // Verify response contains all IPs + resp := mockWriter.GetLastResponse() require.NotNil(t, resp) require.Equal(t, dns.RcodeSuccess, resp.Rcode) require.Len(t, resp.Answer, 3, "Should have 3 answer records") @@ -605,7 +608,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) { }, } - _ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query) + forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now()) // Check the response written to the writer require.NotNil(t, writtenResp, "Expected response to be written") @@ -675,7 +678,8 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) { q1 := &dns.Msg{} q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA) w1 := &test.MockResponseWriter{} - resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1) + forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1, time.Now()) + resp1 := w1.GetLastResponse() require.NotNil(t, resp1) require.Equal(t, dns.RcodeSuccess, resp1.Rcode) require.Len(t, resp1.Answer, 1) @@ -683,13 +687,13 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) { // Second query: serve from cache after upstream failure q2 := &dns.Msg{} q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA) - var writtenResp *dns.Msg - w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }} - _ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2) + w2 := &test.MockResponseWriter{} + forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2, time.Now()) - require.NotNil(t, writtenResp, "expected response to be written") - require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode) - require.Len(t, writtenResp.Answer, 1) + resp2 := w2.GetLastResponse() + require.NotNil(t, resp2, "expected response to be written") + require.Equal(t, dns.RcodeSuccess, resp2.Rcode) + require.Len(t, resp2.Answer, 1) mockResolver.AssertExpectations(t) } @@ -715,7 +719,8 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) { q1 := &dns.Msg{} q1.SetQuestion(mixedQuery+".", dns.TypeA) w1 := &test.MockResponseWriter{} - resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1) + forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1, time.Now()) + resp1 := w1.GetLastResponse() require.NotNil(t, resp1) require.Equal(t, dns.RcodeSuccess, resp1.Rcode) require.Len(t, resp1.Answer, 1) @@ -727,13 +732,13 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) { q2 := &dns.Msg{} q2.SetQuestion("EXAMPLE.COM", dns.TypeA) - var writtenResp *dns.Msg - w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }} - _ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2) + w2 := &test.MockResponseWriter{} + forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2, time.Now()) - require.NotNil(t, writtenResp) - require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode) - require.Len(t, writtenResp.Answer, 1) + resp2 := w2.GetLastResponse() + require.NotNil(t, resp2) + require.Equal(t, dns.RcodeSuccess, resp2.Rcode) + require.Len(t, resp2.Answer, 1) mockResolver.AssertExpectations(t) } @@ -784,8 +789,9 @@ func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) { query.SetQuestion("smtp.mail.example.com.", dns.TypeA) mockWriter := &test.MockResponseWriter{} - resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query) + forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now()) + resp := mockWriter.GetLastResponse() require.NotNil(t, resp) assert.Equal(t, dns.RcodeSuccess, resp.Rcode) @@ -897,26 +903,15 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) { query := &dns.Msg{} query.SetQuestion(dns.Fqdn("example.com"), tt.queryType) - var writtenResp *dns.Msg - mockWriter := &test.MockResponseWriter{ - WriteMsgFunc: func(m *dns.Msg) error { - writtenResp = m - return nil - }, - } + mockWriter := &test.MockResponseWriter{} + forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now()) - resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query) - - // If a response was returned, it means it should be written (happens in wrapper functions) - if resp != nil && writtenResp == nil { - writtenResp = resp - } - - require.NotNil(t, writtenResp, "Expected response to be written") - assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description) + resp := mockWriter.GetLastResponse() + require.NotNil(t, resp, "Expected response to be written") + assert.Equal(t, tt.expectedCode, resp.Rcode, tt.description) if tt.expectNoAnswer { - assert.Empty(t, writtenResp.Answer, "Response should have no answer records") + assert.Empty(t, resp.Answer, "Response should have no answer records") } mockResolver.AssertExpectations(t) @@ -931,15 +926,8 @@ func TestDNSForwarder_EmptyQuery(t *testing.T) { query := &dns.Msg{} // Don't set any question - writeCalled := false - mockWriter := &test.MockResponseWriter{ - WriteMsgFunc: func(m *dns.Msg) error { - writeCalled = true - return nil - }, - } - resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query) + mockWriter := &test.MockResponseWriter{} + forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now()) - assert.Nil(t, resp, "Should return nil for empty query") - assert.False(t, writeCalled, "Should not write response for empty query") + assert.Nil(t, mockWriter.GetLastResponse(), "Should not write response for empty query") } diff --git a/client/internal/engine.go b/client/internal/engine.go index 597ac7c2d..4dbd5f45e 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -828,6 +828,10 @@ func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdate } func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { + started := time.Now() + defer func() { + log.Infof("sync finished in %s", time.Since(started)) + }() e.syncMsgMux.Lock() defer e.syncMsgMux.Unlock() diff --git a/client/internal/peer/ice/agent.go b/client/internal/peer/ice/agent.go index 79f68d279..c74b46d10 100644 --- a/client/internal/peer/ice/agent.go +++ b/client/internal/peer/ice/agent.go @@ -2,6 +2,7 @@ package ice import ( "context" + "fmt" "sync" "time" @@ -32,24 +33,6 @@ type ThreadSafeAgent struct { once sync.Once } -func (a *ThreadSafeAgent) Close() error { - var err error - a.once.Do(func() { - done := make(chan error, 1) - go func() { - done <- a.Agent.Close() - }() - - select { - case err = <-done: - case <-time.After(iceAgentCloseTimeout): - log.Warnf("ICE agent close timed out after %v, proceeding with cleanup", iceAgentCloseTimeout) - err = nil - } - }) - return err -} - func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) { iceKeepAlive := iceKeepAlive() iceDisconnectedTimeout := iceDisconnectedTimeout() @@ -93,9 +76,41 @@ func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, c return nil, err } + if agent == nil { + return nil, fmt.Errorf("ice.NewAgent returned nil agent without error") + } + return &ThreadSafeAgent{Agent: agent}, nil } +func (a *ThreadSafeAgent) Close() error { + var err error + a.once.Do(func() { + // Defensive check to prevent nil pointer dereference + // This can happen during sleep/wake transitions or memory corruption scenarios + // github.com/netbirdio/netbird/client/internal/peer/ice.(*ThreadSafeAgent).Close(0x40006883f0?) + // [signal 0xc0000005 code=0x0 addr=0x0 pc=0x7ff7e73af83c] + agent := a.Agent + if agent == nil { + log.Warnf("ICE agent is nil during close, skipping") + return + } + + done := make(chan error, 1) + go func() { + done <- agent.Close() + }() + + select { + case err = <-done: + case <-time.After(iceAgentCloseTimeout): + log.Warnf("ICE agent close timed out after %v, proceeding with cleanup", iceAgentCloseTimeout) + err = nil + } + }) + return err +} + func GenerateICECredentials() (string, string, error) { ufrag, err := randutil.GenerateCryptoRandomString(lenUFrag, runesAlpha) if err != nil { diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index b6b9d2cf4..464f57bff 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -107,8 +107,10 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) { } w.log.Debugf("agent already exists, recreate the connection") w.agentDialerCancel() - if err := w.agent.Close(); err != nil { - w.log.Warnf("failed to close ICE agent: %s", err) + if w.agent != nil { + if err := w.agent.Close(); err != nil { + w.log.Warnf("failed to close ICE agent: %s", err) + } } sessionID, err := NewICESessionID() diff --git a/client/internal/profilemanager/config.go b/client/internal/profilemanager/config.go index f2fda84e0..8f3ff8b11 100644 --- a/client/internal/profilemanager/config.go +++ b/client/internal/profilemanager/config.go @@ -252,7 +252,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) { } if config.AdminURL == nil { - log.Infof("using default Admin URL %s", DefaultManagementURL) + log.Infof("using default Admin URL %s", DefaultAdminURL) config.AdminURL, err = parseURL("Admin URL", DefaultAdminURL) if err != nil { return false, err diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 2baa0e668..077b9521b 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -173,12 +173,21 @@ func (m *DefaultManager) setupAndroidRoutes(config ManagerConfig) { } func (m *DefaultManager) setupRefCounters(useNoop bool) { + var once sync.Once + var wgIface *net.Interface + toInterface := func() *net.Interface { + once.Do(func() { + wgIface = m.wgInterface.ToInterface() + }) + return wgIface + } + m.routeRefCounter = refcounter.New( func(prefix netip.Prefix, _ struct{}) (struct{}, error) { - return struct{}{}, m.sysOps.AddVPNRoute(prefix, m.wgInterface.ToInterface()) + return struct{}{}, m.sysOps.AddVPNRoute(prefix, toInterface()) }, func(prefix netip.Prefix, _ struct{}) error { - return m.sysOps.RemoveVPNRoute(prefix, m.wgInterface.ToInterface()) + return m.sysOps.RemoveVPNRoute(prefix, toInterface()) }, ) diff --git a/client/internal/routemanager/systemops/routeflags_bsd.go b/client/internal/routemanager/systemops/routeflags_bsd.go index ad32e5029..33280bfb3 100644 --- a/client/internal/routemanager/systemops/routeflags_bsd.go +++ b/client/internal/routemanager/systemops/routeflags_bsd.go @@ -4,16 +4,17 @@ package systemops import ( "strings" - "syscall" + + "golang.org/x/sys/unix" ) // filterRoutesByFlags returns true if the route message should be ignored based on its flags. func filterRoutesByFlags(routeMessageFlags int) bool { - if routeMessageFlags&syscall.RTF_UP == 0 { + if routeMessageFlags&unix.RTF_UP == 0 { return true } - if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE|syscall.RTF_WASCLONED) != 0 { + if routeMessageFlags&(unix.RTF_REJECT|unix.RTF_BLACKHOLE|unix.RTF_WASCLONED) != 0 { return true } @@ -24,42 +25,51 @@ func filterRoutesByFlags(routeMessageFlags int) bool { func formatBSDFlags(flags int) string { var flagStrs []string - if flags&syscall.RTF_UP != 0 { + if flags&unix.RTF_UP != 0 { flagStrs = append(flagStrs, "U") } - if flags&syscall.RTF_GATEWAY != 0 { + if flags&unix.RTF_GATEWAY != 0 { flagStrs = append(flagStrs, "G") } - if flags&syscall.RTF_HOST != 0 { + if flags&unix.RTF_HOST != 0 { flagStrs = append(flagStrs, "H") } - if flags&syscall.RTF_REJECT != 0 { + if flags&unix.RTF_REJECT != 0 { flagStrs = append(flagStrs, "R") } - if flags&syscall.RTF_DYNAMIC != 0 { + if flags&unix.RTF_DYNAMIC != 0 { flagStrs = append(flagStrs, "D") } - if flags&syscall.RTF_MODIFIED != 0 { + if flags&unix.RTF_MODIFIED != 0 { flagStrs = append(flagStrs, "M") } - if flags&syscall.RTF_STATIC != 0 { + if flags&unix.RTF_STATIC != 0 { flagStrs = append(flagStrs, "S") } - if flags&syscall.RTF_LLINFO != 0 { + if flags&unix.RTF_LLINFO != 0 { flagStrs = append(flagStrs, "L") } - if flags&syscall.RTF_LOCAL != 0 { + if flags&unix.RTF_LOCAL != 0 { flagStrs = append(flagStrs, "l") } - if flags&syscall.RTF_BLACKHOLE != 0 { + if flags&unix.RTF_BLACKHOLE != 0 { flagStrs = append(flagStrs, "B") } - if flags&syscall.RTF_CLONING != 0 { + if flags&unix.RTF_CLONING != 0 { flagStrs = append(flagStrs, "C") } - if flags&syscall.RTF_WASCLONED != 0 { + if flags&unix.RTF_WASCLONED != 0 { flagStrs = append(flagStrs, "W") } + if flags&unix.RTF_PROTO1 != 0 { + flagStrs = append(flagStrs, "1") + } + if flags&unix.RTF_PROTO2 != 0 { + flagStrs = append(flagStrs, "2") + } + if flags&unix.RTF_PROTO3 != 0 { + flagStrs = append(flagStrs, "3") + } if len(flagStrs) == 0 { return "-" diff --git a/client/internal/routemanager/systemops/routeflags_freebsd.go b/client/internal/routemanager/systemops/routeflags_freebsd.go index 2338fe5d8..a8c82b3ed 100644 --- a/client/internal/routemanager/systemops/routeflags_freebsd.go +++ b/client/internal/routemanager/systemops/routeflags_freebsd.go @@ -4,17 +4,18 @@ package systemops import ( "strings" - "syscall" + + "golang.org/x/sys/unix" ) // filterRoutesByFlags returns true if the route message should be ignored based on its flags. func filterRoutesByFlags(routeMessageFlags int) bool { - if routeMessageFlags&syscall.RTF_UP == 0 { + if routeMessageFlags&unix.RTF_UP == 0 { return true } - // NOTE: syscall.RTF_WASCLONED deprecated in FreeBSD 8.0 - if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE) != 0 { + // NOTE: RTF_WASCLONED deprecated in FreeBSD 8.0 + if routeMessageFlags&(unix.RTF_REJECT|unix.RTF_BLACKHOLE) != 0 { return true } @@ -25,37 +26,46 @@ func filterRoutesByFlags(routeMessageFlags int) bool { func formatBSDFlags(flags int) string { var flagStrs []string - if flags&syscall.RTF_UP != 0 { + if flags&unix.RTF_UP != 0 { flagStrs = append(flagStrs, "U") } - if flags&syscall.RTF_GATEWAY != 0 { + if flags&unix.RTF_GATEWAY != 0 { flagStrs = append(flagStrs, "G") } - if flags&syscall.RTF_HOST != 0 { + if flags&unix.RTF_HOST != 0 { flagStrs = append(flagStrs, "H") } - if flags&syscall.RTF_REJECT != 0 { + if flags&unix.RTF_REJECT != 0 { flagStrs = append(flagStrs, "R") } - if flags&syscall.RTF_DYNAMIC != 0 { + if flags&unix.RTF_DYNAMIC != 0 { flagStrs = append(flagStrs, "D") } - if flags&syscall.RTF_MODIFIED != 0 { + if flags&unix.RTF_MODIFIED != 0 { flagStrs = append(flagStrs, "M") } - if flags&syscall.RTF_STATIC != 0 { + if flags&unix.RTF_STATIC != 0 { flagStrs = append(flagStrs, "S") } - if flags&syscall.RTF_LLINFO != 0 { + if flags&unix.RTF_LLINFO != 0 { flagStrs = append(flagStrs, "L") } - if flags&syscall.RTF_LOCAL != 0 { + if flags&unix.RTF_LOCAL != 0 { flagStrs = append(flagStrs, "l") } - if flags&syscall.RTF_BLACKHOLE != 0 { + if flags&unix.RTF_BLACKHOLE != 0 { flagStrs = append(flagStrs, "B") } // Note: RTF_CLONING and RTF_WASCLONED deprecated in FreeBSD 8.0 + if flags&unix.RTF_PROTO1 != 0 { + flagStrs = append(flagStrs, "1") + } + if flags&unix.RTF_PROTO2 != 0 { + flagStrs = append(flagStrs, "2") + } + if flags&unix.RTF_PROTO3 != 0 { + flagStrs = append(flagStrs, "3") + } if len(flagStrs) == 0 { return "-" diff --git a/idp/dex/connector.go b/idp/dex/connector.go index cad682141..ba2bb1f00 100644 --- a/idp/dex/connector.go +++ b/idp/dex/connector.go @@ -327,6 +327,60 @@ func ensureLocalConnector(ctx context.Context, stor storage.Storage) error { return nil } +// HasNonLocalConnectors checks if there are any connectors other than the local connector. +func (p *Provider) HasNonLocalConnectors(ctx context.Context) (bool, error) { + connectors, err := p.storage.ListConnectors(ctx) + if err != nil { + return false, fmt.Errorf("failed to list connectors: %w", err) + } + + p.logger.Info("checking for non-local connectors", "total_connectors", len(connectors)) + for _, conn := range connectors { + p.logger.Info("found connector in storage", "id", conn.ID, "type", conn.Type, "name", conn.Name) + if conn.ID != "local" || conn.Type != "local" { + p.logger.Info("found non-local connector", "id", conn.ID) + return true, nil + } + } + p.logger.Info("no non-local connectors found") + return false, nil +} + +// DisableLocalAuth removes the local (password) connector. +// Returns an error if no other connectors are configured. +func (p *Provider) DisableLocalAuth(ctx context.Context) error { + hasOthers, err := p.HasNonLocalConnectors(ctx) + if err != nil { + return err + } + if !hasOthers { + return fmt.Errorf("cannot disable local authentication: no other identity providers configured") + } + + // Check if local connector exists + _, err = p.storage.GetConnector(ctx, "local") + if errors.Is(err, storage.ErrNotFound) { + // Already disabled + return nil + } + if err != nil { + return fmt.Errorf("failed to check local connector: %w", err) + } + + // Delete the local connector + if err := p.storage.DeleteConnector(ctx, "local"); err != nil { + return fmt.Errorf("failed to delete local connector: %w", err) + } + + p.logger.Info("local authentication disabled") + return nil +} + +// EnableLocalAuth creates the local (password) connector if it doesn't exist. +func (p *Provider) EnableLocalAuth(ctx context.Context) error { + return ensureLocalConnector(ctx, p.storage) +} + // ensureStaticConnectors creates or updates static connectors in storage func ensureStaticConnectors(ctx context.Context, stor storage.Storage, connectors []Connector) error { for _, conn := range connectors { diff --git a/management/internals/controllers/network_map/controller/controller.go b/management/internals/controllers/network_map/controller/controller.go index 01310a554..8a26568a1 100644 --- a/management/internals/controllers/network_map/controller/controller.go +++ b/management/internals/controllers/network_map/controller/controller.go @@ -250,7 +250,10 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort) c.metrics.CountToSyncResponseDuration(time.Since(start)) - c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{Update: update}) + c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{ + Update: update, + MessageType: network_map.MessageTypeNetworkMap, + }) }(peer) } @@ -374,7 +377,10 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion) update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort) - c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{Update: update}) + c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{ + Update: update, + MessageType: network_map.MessageTypeNetworkMap, + }) return nil } @@ -783,6 +789,7 @@ func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerI }, }, }, + MessageType: network_map.MessageTypeNetworkMap, }) c.peersUpdateManager.CloseChannel(ctx, peerID) diff --git a/management/internals/controllers/network_map/update_channel/updatechannel_test.go b/management/internals/controllers/network_map/update_channel/updatechannel_test.go index afc1e2c32..c73baf81f 100644 --- a/management/internals/controllers/network_map/update_channel/updatechannel_test.go +++ b/management/internals/controllers/network_map/update_channel/updatechannel_test.go @@ -25,11 +25,14 @@ func TestCreateChannel(t *testing.T) { func TestSendUpdate(t *testing.T) { peer := "test-sendupdate" peersUpdater := NewPeersUpdateManager(nil) - update1 := &network_map.UpdateMessage{Update: &proto.SyncResponse{ - NetworkMap: &proto.NetworkMap{ - Serial: 0, + update1 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{ + NetworkMap: &proto.NetworkMap{ + Serial: 0, + }, }, - }} + MessageType: network_map.MessageTypeNetworkMap, + } _ = peersUpdater.CreateChannel(context.Background(), peer) if _, ok := peersUpdater.peerChannels[peer]; !ok { t.Error("Error creating the channel") @@ -45,11 +48,14 @@ func TestSendUpdate(t *testing.T) { peersUpdater.SendUpdate(context.Background(), peer, update1) } - update2 := &network_map.UpdateMessage{Update: &proto.SyncResponse{ - NetworkMap: &proto.NetworkMap{ - Serial: 10, + update2 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{ + NetworkMap: &proto.NetworkMap{ + Serial: 10, + }, }, - }} + MessageType: network_map.MessageTypeNetworkMap, + } peersUpdater.SendUpdate(context.Background(), peer, update2) timeout := time.After(5 * time.Second) diff --git a/management/internals/controllers/network_map/update_message.go b/management/internals/controllers/network_map/update_message.go index 33643bcbd..0ffddf8b2 100644 --- a/management/internals/controllers/network_map/update_message.go +++ b/management/internals/controllers/network_map/update_message.go @@ -4,6 +4,19 @@ import ( "github.com/netbirdio/netbird/shared/management/proto" ) +// MessageType indicates the type of update message for debouncing strategy +type MessageType int + +const ( + // MessageTypeNetworkMap represents network map updates (peers, routes, DNS, firewall) + // These updates can be safely debounced - only the latest state matters + MessageTypeNetworkMap MessageType = iota + // MessageTypeControlConfig represents control/config updates (tokens, peer expiration) + // These updates should not be dropped as they contain time-sensitive information + MessageTypeControlConfig +) + type UpdateMessage struct { - Update *proto.SyncResponse + Update *proto.SyncResponse + MessageType MessageType } diff --git a/management/internals/server/modules.go b/management/internals/server/modules.go index e413439a0..0755c50cb 100644 --- a/management/internals/server/modules.go +++ b/management/internals/server/modules.go @@ -72,7 +72,14 @@ func (s *BaseServer) UsersManager() users.Manager { func (s *BaseServer) SettingsManager() settings.Manager { return Create(s, func() settings.Manager { extraSettingsManager := integrations.NewManager(s.EventStore()) - return settings.NewManager(s.Store(), s.UsersManager(), extraSettingsManager, s.PermissionsManager()) + + idpConfig := settings.IdpConfig{} + if s.Config.EmbeddedIdP != nil && s.Config.EmbeddedIdP.Enabled { + idpConfig.EmbeddedIdpEnabled = true + idpConfig.LocalAuthDisabled = s.Config.EmbeddedIdP.LocalAuthDisabled + } + + return settings.NewManager(s.Store(), s.UsersManager(), extraSettingsManager, s.PermissionsManager(), idpConfig) }) } diff --git a/management/internals/shared/grpc/server.go b/management/internals/shared/grpc/server.go index 219baaf6d..ff9d7ea05 100644 --- a/management/internals/shared/grpc/server.go +++ b/management/internals/shared/grpc/server.go @@ -77,8 +77,9 @@ type Server struct { oAuthConfigProvider idp.OAuthConfigProvider - syncSem atomic.Int32 - syncLim int32 + syncSem atomic.Int32 + syncLimEnabled bool + syncLim int32 } // NewServer creates a new Management server @@ -108,6 +109,7 @@ func NewServer( blockPeersWithSameConfig := strings.ToLower(os.Getenv(envBlockPeers)) == "true" syncLim := int32(defaultSyncLim) + syncLimEnabled := true if syncLimStr := os.Getenv(envConcurrentSyncs); syncLimStr != "" { syncLimParsed, err := strconv.Atoi(syncLimStr) if err != nil { @@ -115,6 +117,9 @@ func NewServer( } else { //nolint:gosec syncLim = int32(syncLimParsed) + if syncLim < 0 { + syncLimEnabled = false + } } } @@ -134,7 +139,8 @@ func NewServer( loginFilter: newLoginFilter(), - syncLim: syncLim, + syncLim: syncLim, + syncLimEnabled: syncLimEnabled, }, nil } @@ -212,7 +218,7 @@ func (s *Server) Job(srv proto.ManagementService_JobServer) error { // Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and // notifies the connected peer of any updates (e.g. new peers under the same account) func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error { - if s.syncSem.Load() >= s.syncLim { + if s.syncLimEnabled && s.syncSem.Load() >= s.syncLim { return status.Errorf(codes.ResourceExhausted, "too many concurrent sync requests, please try again later") } s.syncSem.Add(1) @@ -294,7 +300,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S metahash := metaHash(peerMeta, realIP.String()) s.loginFilter.addLogin(peerKey.String(), metahash) - peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP) + peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP, reqStart) if err != nil { log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err) s.syncSem.Add(-1) @@ -305,7 +311,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S if err != nil { log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err) s.syncSem.Add(-1) - s.cancelPeerRoutines(ctx, accountID, peer) + s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, reqStart) return err } @@ -313,7 +319,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S if err != nil { log.WithContext(ctx).Debugf("error while notify peer connected for %s: %v", peerKey.String(), err) s.syncSem.Add(-1) - s.cancelPeerRoutines(ctx, accountID, peer) + s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, reqStart) return err } @@ -330,7 +336,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S s.syncSem.Add(-1) - return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv) + return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv, reqStart) } func (s *Server) handleHandshake(ctx context.Context, srv proto.ManagementService_JobServer) (wgtypes.Key, error) { @@ -398,11 +404,20 @@ func (s *Server) sendJobsLoop(ctx context.Context, accountID string, peerKey wgt } // handleUpdates sends updates to the connected peer until the updates channel is closed. -func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error { +// It implements a backpressure mechanism that sends the first update immediately, +// then debounces subsequent rapid updates, ensuring only the latest update is sent +// after a quiet period. +func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *network_map.UpdateMessage, srv proto.ManagementService_SyncServer, streamStartTime time.Time) error { log.WithContext(ctx).Tracef("starting to handle updates for peer %s", peerKey.String()) + + // Create a debouncer for this peer connection + debouncer := NewUpdateDebouncer(1000 * time.Millisecond) + defer debouncer.Stop() + for { select { // condition when there are some updates + // todo set the updates channel size to 1 case update, open := <-updates: if s.appMetrics != nil { s.appMetrics.GRPCMetrics().UpdateChannelQueueLength(len(updates) + 1) @@ -410,20 +425,38 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg if !open { log.WithContext(ctx).Debugf("updates channel for peer %s was closed", peerKey.String()) - s.cancelPeerRoutines(ctx, accountID, peer) + s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime) return nil } + log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String()) - if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv); err != nil { - log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err) - return err + if debouncer.ProcessUpdate(update) { + // Send immediately (first update or after quiet period) + if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv, streamStartTime); err != nil { + log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err) + return err + } + } + + // Timer expired - quiet period reached, send pending updates if any + case <-debouncer.TimerChannel(): + pendingUpdates := debouncer.GetPendingUpdates() + if len(pendingUpdates) == 0 { + continue + } + log.WithContext(ctx).Debugf("sending %d debounced update(s) for peer %s", len(pendingUpdates), peerKey.String()) + for _, pendingUpdate := range pendingUpdates { + if err := s.sendUpdate(ctx, accountID, peerKey, peer, pendingUpdate, srv, streamStartTime); err != nil { + log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err) + return err + } } // condition when client <-> server connection has been terminated case <-srv.Context().Done(): // happens when connection drops, e.g. client disconnects log.WithContext(ctx).Debugf("stream of peer %s has been closed", peerKey.String()) - s.cancelPeerRoutines(ctx, accountID, peer) + s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime) return srv.Context().Err() } } @@ -431,16 +464,16 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg // sendUpdate encrypts the update message using the peer key and the server's wireguard key, // then sends the encrypted message to the connected peer via the sync server. -func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error { +func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *network_map.UpdateMessage, srv proto.ManagementService_SyncServer, streamStartTime time.Time) error { key, err := s.secretsManager.GetWGKey() if err != nil { - s.cancelPeerRoutines(ctx, accountID, peer) + s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime) return status.Errorf(codes.Internal, "failed processing update message") } encryptedResp, err := encryption.EncryptMessage(peerKey, key, update.Update) if err != nil { - s.cancelPeerRoutines(ctx, accountID, peer) + s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime) return status.Errorf(codes.Internal, "failed processing update message") } err = srv.Send(&proto.EncryptedMessage{ @@ -448,7 +481,7 @@ func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtyp Body: encryptedResp, }) if err != nil { - s.cancelPeerRoutines(ctx, accountID, peer) + s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime) return status.Errorf(codes.Internal, "failed sending update message") } log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String()) @@ -480,11 +513,15 @@ func (s *Server) sendJob(ctx context.Context, peerKey wgtypes.Key, job *job.Even return nil } -func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) { +func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer, streamStartTime time.Time) { unlock := s.acquirePeerLockByUID(ctx, peer.Key) defer unlock() - err := s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key) + s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, streamStartTime) +} + +func (s *Server) cancelPeerRoutinesWithoutLock(ctx context.Context, accountID string, peer *nbpeer.Peer, streamStartTime time.Time) { + err := s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key, streamStartTime) if err != nil { log.WithContext(ctx).Errorf("failed to disconnect peer %s properly: %v", peer.Key, err) } diff --git a/management/internals/shared/grpc/token_mgr.go b/management/internals/shared/grpc/token_mgr.go index ccb32202f..65e58ad41 100644 --- a/management/internals/shared/grpc/token_mgr.go +++ b/management/internals/shared/grpc/token_mgr.go @@ -242,7 +242,10 @@ func (m *TimeBasedAuthSecretsManager) pushNewTURNAndRelayTokens(ctx context.Cont m.extendNetbirdConfig(ctx, peerID, accountID, update) log.WithContext(ctx).Debugf("sending new TURN credentials to peer %s", peerID) - m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{Update: update}) + m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{ + Update: update, + MessageType: network_map.MessageTypeControlConfig, + }) } func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, accountID, peerID string) { @@ -266,7 +269,10 @@ func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, ac m.extendNetbirdConfig(ctx, peerID, accountID, update) log.WithContext(ctx).Debugf("sending new relay credentials to peer %s", peerID) - m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{Update: update}) + m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{ + Update: update, + MessageType: network_map.MessageTypeControlConfig, + }) } func (m *TimeBasedAuthSecretsManager) extendNetbirdConfig(ctx context.Context, peerID, accountID string, update *proto.SyncResponse) { diff --git a/management/internals/shared/grpc/update_debouncer.go b/management/internals/shared/grpc/update_debouncer.go new file mode 100644 index 000000000..8af9c2656 --- /dev/null +++ b/management/internals/shared/grpc/update_debouncer.go @@ -0,0 +1,103 @@ +package grpc + +import ( + "time" + + "github.com/netbirdio/netbird/management/internals/controllers/network_map" +) + +// UpdateDebouncer implements a backpressure mechanism that: +// - Sends the first update immediately +// - Coalesces rapid subsequent network map updates (only latest matters) +// - Queues control/config updates (all must be delivered) +// - Preserves the order of messages (important for control configs between network maps) +// - Ensures pending updates are sent after a quiet period +type UpdateDebouncer struct { + debounceInterval time.Duration + timer *time.Timer + pendingUpdates []*network_map.UpdateMessage // Queue that preserves order + timerC <-chan time.Time +} + +// NewUpdateDebouncer creates a new debouncer with the specified interval +func NewUpdateDebouncer(interval time.Duration) *UpdateDebouncer { + return &UpdateDebouncer{ + debounceInterval: interval, + } +} + +// ProcessUpdate handles an incoming update and returns whether it should be sent immediately +func (d *UpdateDebouncer) ProcessUpdate(update *network_map.UpdateMessage) bool { + if d.timer == nil { + // No active debounce timer, signal to send immediately + // and start the debounce period + d.startTimer() + return true + } + + // Already in debounce period, accumulate this update preserving order + // Check if we should coalesce with the last pending update + if len(d.pendingUpdates) > 0 && + update.MessageType == network_map.MessageTypeNetworkMap && + d.pendingUpdates[len(d.pendingUpdates)-1].MessageType == network_map.MessageTypeNetworkMap { + // Replace the last network map with this one (coalesce consecutive network maps) + d.pendingUpdates[len(d.pendingUpdates)-1] = update + } else { + // Append to the queue (preserves order for control configs and non-consecutive network maps) + d.pendingUpdates = append(d.pendingUpdates, update) + } + d.resetTimer() + return false +} + +// TimerChannel returns the timer channel for select statements +func (d *UpdateDebouncer) TimerChannel() <-chan time.Time { + if d.timer == nil { + return nil + } + return d.timerC +} + +// GetPendingUpdates returns and clears all pending updates after timer expiration. +// Updates are returned in the order they were received, with consecutive network maps +// already coalesced to only the latest one. +// If there were pending updates, it restarts the timer to continue debouncing. +// If there were no pending updates, it clears the timer (true quiet period). +func (d *UpdateDebouncer) GetPendingUpdates() []*network_map.UpdateMessage { + updates := d.pendingUpdates + d.pendingUpdates = nil + + if len(updates) > 0 { + // There were pending updates, so updates are still coming rapidly + // Restart the timer to continue debouncing mode + if d.timer != nil { + d.timer.Reset(d.debounceInterval) + } + } else { + // No pending updates means true quiet period - return to immediate mode + d.timer = nil + d.timerC = nil + } + + return updates +} + +// Stop stops the debouncer and cleans up resources +func (d *UpdateDebouncer) Stop() { + if d.timer != nil { + d.timer.Stop() + d.timer = nil + d.timerC = nil + } + d.pendingUpdates = nil +} + +func (d *UpdateDebouncer) startTimer() { + d.timer = time.NewTimer(d.debounceInterval) + d.timerC = d.timer.C +} + +func (d *UpdateDebouncer) resetTimer() { + d.timer.Stop() + d.timer.Reset(d.debounceInterval) +} diff --git a/management/internals/shared/grpc/update_debouncer_test.go b/management/internals/shared/grpc/update_debouncer_test.go new file mode 100644 index 000000000..075994a2d --- /dev/null +++ b/management/internals/shared/grpc/update_debouncer_test.go @@ -0,0 +1,587 @@ +package grpc + +import ( + "testing" + "time" + + "github.com/netbirdio/netbird/management/internals/controllers/network_map" + "github.com/netbirdio/netbird/shared/management/proto" +) + +func TestUpdateDebouncer_FirstUpdateSentImmediately(t *testing.T) { + debouncer := NewUpdateDebouncer(50 * time.Millisecond) + defer debouncer.Stop() + + update := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + + shouldSend := debouncer.ProcessUpdate(update) + + if !shouldSend { + t.Error("First update should be sent immediately") + } + + if debouncer.TimerChannel() == nil { + t.Error("Timer should be started after first update") + } +} + +func TestUpdateDebouncer_RapidUpdatesCoalesced(t *testing.T) { + debouncer := NewUpdateDebouncer(50 * time.Millisecond) + defer debouncer.Stop() + + update1 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + update2 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + update3 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + + // First update should be sent immediately + if !debouncer.ProcessUpdate(update1) { + t.Error("First update should be sent immediately") + } + + // Rapid subsequent updates should be coalesced + if debouncer.ProcessUpdate(update2) { + t.Error("Second rapid update should not be sent immediately") + } + + if debouncer.ProcessUpdate(update3) { + t.Error("Third rapid update should not be sent immediately") + } + + // Wait for debounce period + select { + case <-debouncer.TimerChannel(): + pendingUpdates := debouncer.GetPendingUpdates() + if len(pendingUpdates) != 1 { + t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates)) + } + if pendingUpdates[0] != update3 { + t.Error("Should get the last update (update3)") + } + case <-time.After(100 * time.Millisecond): + t.Error("Timer should have fired") + } +} + +func TestUpdateDebouncer_LastUpdateAlwaysSent(t *testing.T) { + debouncer := NewUpdateDebouncer(30 * time.Millisecond) + defer debouncer.Stop() + + update1 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + update2 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + + // Send first update + debouncer.ProcessUpdate(update1) + + // Send second update within debounce period + debouncer.ProcessUpdate(update2) + + // Wait for timer + select { + case <-debouncer.TimerChannel(): + pendingUpdates := debouncer.GetPendingUpdates() + if len(pendingUpdates) != 1 { + t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates)) + } + if pendingUpdates[0] != update2 { + t.Error("Should get the last update") + } + if pendingUpdates[0] == update1 { + t.Error("Should not get the first update") + } + case <-time.After(100 * time.Millisecond): + t.Error("Timer should have fired") + } +} + +func TestUpdateDebouncer_TimerResetOnNewUpdate(t *testing.T) { + debouncer := NewUpdateDebouncer(50 * time.Millisecond) + defer debouncer.Stop() + + update1 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + update2 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + update3 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + + // Send first update + debouncer.ProcessUpdate(update1) + + // Wait a bit, but not the full debounce period + time.Sleep(30 * time.Millisecond) + + // Send second update - should reset timer + debouncer.ProcessUpdate(update2) + + // Wait a bit more + time.Sleep(30 * time.Millisecond) + + // Send third update - should reset timer again + debouncer.ProcessUpdate(update3) + + // Now wait for the timer (should fire after last update's reset) + select { + case <-debouncer.TimerChannel(): + pendingUpdates := debouncer.GetPendingUpdates() + if len(pendingUpdates) != 1 { + t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates)) + } + if pendingUpdates[0] != update3 { + t.Error("Should get the last update (update3)") + } + // Timer should be restarted since there was a pending update + if debouncer.TimerChannel() == nil { + t.Error("Timer should be restarted after sending pending update") + } + case <-time.After(150 * time.Millisecond): + t.Error("Timer should have fired") + } +} + +func TestUpdateDebouncer_TimerRestartsAfterPendingUpdateSent(t *testing.T) { + debouncer := NewUpdateDebouncer(30 * time.Millisecond) + defer debouncer.Stop() + + update1 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + update2 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + update3 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + + // First update sent immediately + debouncer.ProcessUpdate(update1) + + // Second update coalesced + debouncer.ProcessUpdate(update2) + + // Wait for timer to expire + select { + case <-debouncer.TimerChannel(): + pendingUpdates := debouncer.GetPendingUpdates() + + if len(pendingUpdates) == 0 { + t.Fatal("Should have pending update") + } + + // After sending pending update, timer is restarted, so next update is NOT immediate + if debouncer.ProcessUpdate(update3) { + t.Error("Update after debounced send should not be sent immediately (timer restarted)") + } + + // Wait for the restarted timer and verify update3 is pending + select { + case <-debouncer.TimerChannel(): + finalUpdates := debouncer.GetPendingUpdates() + if len(finalUpdates) != 1 || finalUpdates[0] != update3 { + t.Error("Should get update3 as pending") + } + case <-time.After(100 * time.Millisecond): + t.Error("Timer should have fired for restarted timer") + } + case <-time.After(100 * time.Millisecond): + t.Error("Timer should have fired") + } +} + +func TestUpdateDebouncer_StopCleansUp(t *testing.T) { + debouncer := NewUpdateDebouncer(50 * time.Millisecond) + + update := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + + // Send update to start timer + debouncer.ProcessUpdate(update) + + // Stop should clean up + debouncer.Stop() + + // Multiple stops should be safe + debouncer.Stop() +} + +func TestUpdateDebouncer_HighFrequencyUpdates(t *testing.T) { + debouncer := NewUpdateDebouncer(50 * time.Millisecond) + defer debouncer.Stop() + + // Simulate high-frequency updates + var lastUpdate *network_map.UpdateMessage + sentImmediately := 0 + for i := 0; i < 100; i++ { + update := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{ + NetworkMap: &proto.NetworkMap{ + Serial: uint64(i), + }, + }, + MessageType: network_map.MessageTypeNetworkMap, + } + lastUpdate = update + if debouncer.ProcessUpdate(update) { + sentImmediately++ + } + time.Sleep(1 * time.Millisecond) // Very rapid updates + } + + // Only first update should be sent immediately + if sentImmediately != 1 { + t.Errorf("Expected only 1 update sent immediately, got %d", sentImmediately) + } + + // Wait for debounce period + select { + case <-debouncer.TimerChannel(): + pendingUpdates := debouncer.GetPendingUpdates() + if len(pendingUpdates) != 1 { + t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates)) + } + if pendingUpdates[0] != lastUpdate { + t.Error("Should get the very last update") + } + if pendingUpdates[0].Update.NetworkMap.Serial != 99 { + t.Errorf("Expected serial 99, got %d", pendingUpdates[0].Update.NetworkMap.Serial) + } + case <-time.After(200 * time.Millisecond): + t.Error("Timer should have fired") + } +} + +func TestUpdateDebouncer_NoUpdatesAfterFirst(t *testing.T) { + debouncer := NewUpdateDebouncer(30 * time.Millisecond) + defer debouncer.Stop() + + update := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + + // Send first update + if !debouncer.ProcessUpdate(update) { + t.Error("First update should be sent immediately") + } + + // Wait for timer to expire with no additional updates (true quiet period) + select { + case <-debouncer.TimerChannel(): + pendingUpdates := debouncer.GetPendingUpdates() + if len(pendingUpdates) != 0 { + t.Error("Should have no pending updates") + } + // After true quiet period, timer should be cleared + if debouncer.TimerChannel() != nil { + t.Error("Timer should be cleared after quiet period") + } + case <-time.After(100 * time.Millisecond): + t.Error("Timer should have fired") + } +} + +func TestUpdateDebouncer_IntermediateUpdatesDropped(t *testing.T) { + debouncer := NewUpdateDebouncer(50 * time.Millisecond) + defer debouncer.Stop() + + updates := make([]*network_map.UpdateMessage, 5) + for i := range updates { + updates[i] = &network_map.UpdateMessage{ + Update: &proto.SyncResponse{ + NetworkMap: &proto.NetworkMap{ + Serial: uint64(i), + }, + }, + MessageType: network_map.MessageTypeNetworkMap, + } + } + + // First update sent immediately + debouncer.ProcessUpdate(updates[0]) + + // Send updates 1, 2, 3, 4 rapidly - only last one should remain pending + debouncer.ProcessUpdate(updates[1]) + debouncer.ProcessUpdate(updates[2]) + debouncer.ProcessUpdate(updates[3]) + debouncer.ProcessUpdate(updates[4]) + + // Wait for debounce + <-debouncer.TimerChannel() + pendingUpdates := debouncer.GetPendingUpdates() + + if len(pendingUpdates) != 1 { + t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates)) + } + if pendingUpdates[0].Update.NetworkMap.Serial != 4 { + t.Errorf("Expected only the last update (serial 4), got serial %d", pendingUpdates[0].Update.NetworkMap.Serial) + } +} + +func TestUpdateDebouncer_TrueQuietPeriodResetsToImmediateMode(t *testing.T) { + debouncer := NewUpdateDebouncer(30 * time.Millisecond) + defer debouncer.Stop() + + update1 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + update2 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + + // First update sent immediately + if !debouncer.ProcessUpdate(update1) { + t.Error("First update should be sent immediately") + } + + // Wait for timer without sending any more updates (true quiet period) + <-debouncer.TimerChannel() + pendingUpdates := debouncer.GetPendingUpdates() + + if len(pendingUpdates) != 0 { + t.Error("Should have no pending updates during quiet period") + } + + // After true quiet period, next update should be sent immediately + if !debouncer.ProcessUpdate(update2) { + t.Error("Update after true quiet period should be sent immediately") + } +} + +func TestUpdateDebouncer_ContinuousHighFrequencyStaysInDebounceMode(t *testing.T) { + debouncer := NewUpdateDebouncer(50 * time.Millisecond) + defer debouncer.Stop() + + // Simulate continuous high-frequency updates + for i := 0; i < 10; i++ { + update := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{ + NetworkMap: &proto.NetworkMap{ + Serial: uint64(i), + }, + }, + MessageType: network_map.MessageTypeNetworkMap, + } + + if i == 0 { + // First one sent immediately + if !debouncer.ProcessUpdate(update) { + t.Error("First update should be sent immediately") + } + } else { + // All others should be coalesced (not sent immediately) + if debouncer.ProcessUpdate(update) { + t.Errorf("Update %d should not be sent immediately", i) + } + } + + // Wait a bit but send next update before debounce expires + time.Sleep(20 * time.Millisecond) + } + + // Now wait for final debounce + select { + case <-debouncer.TimerChannel(): + pendingUpdates := debouncer.GetPendingUpdates() + if len(pendingUpdates) == 0 { + t.Fatal("Should have the last update pending") + } + if pendingUpdates[0].Update.NetworkMap.Serial != 9 { + t.Errorf("Expected serial 9, got %d", pendingUpdates[0].Update.NetworkMap.Serial) + } + case <-time.After(200 * time.Millisecond): + t.Error("Timer should have fired") + } +} + +func TestUpdateDebouncer_ControlConfigMessagesQueued(t *testing.T) { + debouncer := NewUpdateDebouncer(50 * time.Millisecond) + defer debouncer.Stop() + + netmapUpdate := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 1}}, + MessageType: network_map.MessageTypeNetworkMap, + } + tokenUpdate1 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}}, + MessageType: network_map.MessageTypeControlConfig, + } + tokenUpdate2 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}}, + MessageType: network_map.MessageTypeControlConfig, + } + + // First update sent immediately + debouncer.ProcessUpdate(netmapUpdate) + + // Send multiple control config updates - they should all be queued + debouncer.ProcessUpdate(tokenUpdate1) + debouncer.ProcessUpdate(tokenUpdate2) + + // Wait for debounce period + select { + case <-debouncer.TimerChannel(): + pendingUpdates := debouncer.GetPendingUpdates() + // Should get both control config updates + if len(pendingUpdates) != 2 { + t.Errorf("Expected 2 control config updates, got %d", len(pendingUpdates)) + } + // Control configs should come first + if pendingUpdates[0] != tokenUpdate1 { + t.Error("First pending update should be tokenUpdate1") + } + if pendingUpdates[1] != tokenUpdate2 { + t.Error("Second pending update should be tokenUpdate2") + } + case <-time.After(200 * time.Millisecond): + t.Error("Timer should have fired") + } +} + +func TestUpdateDebouncer_MixedMessageTypes(t *testing.T) { + debouncer := NewUpdateDebouncer(50 * time.Millisecond) + defer debouncer.Stop() + + netmapUpdate1 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 1}}, + MessageType: network_map.MessageTypeNetworkMap, + } + netmapUpdate2 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 2}}, + MessageType: network_map.MessageTypeNetworkMap, + } + tokenUpdate := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}}, + MessageType: network_map.MessageTypeControlConfig, + } + + // First update sent immediately + debouncer.ProcessUpdate(netmapUpdate1) + + // Send token update and network map update + debouncer.ProcessUpdate(tokenUpdate) + debouncer.ProcessUpdate(netmapUpdate2) + + // Wait for debounce period + select { + case <-debouncer.TimerChannel(): + pendingUpdates := debouncer.GetPendingUpdates() + // Should get 2 updates in order: token, then network map + if len(pendingUpdates) != 2 { + t.Errorf("Expected 2 pending updates, got %d", len(pendingUpdates)) + } + // Token update should come first (preserves order) + if pendingUpdates[0] != tokenUpdate { + t.Error("First pending update should be tokenUpdate") + } + // Network map update should come second + if pendingUpdates[1] != netmapUpdate2 { + t.Error("Second pending update should be netmapUpdate2") + } + case <-time.After(200 * time.Millisecond): + t.Error("Timer should have fired") + } +} + +func TestUpdateDebouncer_OrderPreservation(t *testing.T) { + debouncer := NewUpdateDebouncer(50 * time.Millisecond) + defer debouncer.Stop() + + // Simulate: 50 network maps -> 1 control config -> 50 network maps + // Expected result: 3 messages (netmap, controlConfig, netmap) + + // Send first network map immediately + firstNetmap := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 0}}, + MessageType: network_map.MessageTypeNetworkMap, + } + if !debouncer.ProcessUpdate(firstNetmap) { + t.Error("First update should be sent immediately") + } + + // Send 49 more network maps (will be coalesced to last one) + var lastNetmapBatch1 *network_map.UpdateMessage + for i := 1; i < 50; i++ { + lastNetmapBatch1 = &network_map.UpdateMessage{ + Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: uint64(i)}}, + MessageType: network_map.MessageTypeNetworkMap, + } + debouncer.ProcessUpdate(lastNetmapBatch1) + } + + // Send 1 control config + controlConfig := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}}, + MessageType: network_map.MessageTypeControlConfig, + } + debouncer.ProcessUpdate(controlConfig) + + // Send 50 more network maps (will be coalesced to last one) + var lastNetmapBatch2 *network_map.UpdateMessage + for i := 50; i < 100; i++ { + lastNetmapBatch2 = &network_map.UpdateMessage{ + Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: uint64(i)}}, + MessageType: network_map.MessageTypeNetworkMap, + } + debouncer.ProcessUpdate(lastNetmapBatch2) + } + + // Wait for debounce period + select { + case <-debouncer.TimerChannel(): + pendingUpdates := debouncer.GetPendingUpdates() + // Should get exactly 3 updates: netmap, controlConfig, netmap + if len(pendingUpdates) != 3 { + t.Errorf("Expected 3 pending updates, got %d", len(pendingUpdates)) + } + // First should be the last netmap from batch 1 + if pendingUpdates[0] != lastNetmapBatch1 { + t.Error("First pending update should be last netmap from batch 1") + } + if pendingUpdates[0].Update.NetworkMap.Serial != 49 { + t.Errorf("Expected serial 49, got %d", pendingUpdates[0].Update.NetworkMap.Serial) + } + // Second should be the control config + if pendingUpdates[1] != controlConfig { + t.Error("Second pending update should be control config") + } + // Third should be the last netmap from batch 2 + if pendingUpdates[2] != lastNetmapBatch2 { + t.Error("Third pending update should be last netmap from batch 2") + } + if pendingUpdates[2].Update.NetworkMap.Serial != 99 { + t.Errorf("Expected serial 99, got %d", pendingUpdates[2].Update.NetworkMap.Serial) + } + case <-time.After(200 * time.Millisecond): + t.Error("Timer should have fired") + } +} diff --git a/management/server/account.go b/management/server/account.go index 748290a31..a9f59773a 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -795,6 +795,19 @@ func IsEmbeddedIdp(i idp.Manager) bool { return ok } +// IsLocalAuthDisabled checks if local (email/password) authentication is disabled. +// Returns true only when using embedded IDP with local auth disabled in config. +func IsLocalAuthDisabled(ctx context.Context, i idp.Manager) bool { + if isNil(i) { + return false + } + embeddedIdp, ok := i.(*idp.EmbeddedIdPManager) + if !ok { + return false + } + return embeddedIdp.IsLocalAuthDisabled() +} + // addAccountIDToIDPAppMeta update user's app metadata in idp manager func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error { if !isNil(am.idpManager) && !IsEmbeddedIdp(am.idpManager) { @@ -1657,13 +1670,13 @@ func domainIsUpToDate(domain string, domainCategory string, userAuth auth.UserAu return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain } -func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { +func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID) if err != nil { return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err) } - err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, accountID) + err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, accountID, syncTime) if err != nil { log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err) } @@ -1671,8 +1684,20 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID return peer, netMap, postureChecks, dnsfwdPort, nil } -func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error { - err := am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID) +func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error { + peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerPubKey) + if err != nil { + log.WithContext(ctx).Warnf("failed to get peer %s for disconnect check: %v", peerPubKey, err) + return nil + } + + if peer.Status.LastSeen.After(streamStartTime) { + log.WithContext(ctx).Tracef("peer %s has newer activity (lastSeen=%s > streamStart=%s), skipping disconnect", + peerPubKey, peer.Status.LastSeen.Format(time.RFC3339), streamStartTime.Format(time.RFC3339)) + return nil + } + + err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID, time.Now().UTC()) if err != nil { log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err) } diff --git a/management/server/account/manager.go b/management/server/account/manager.go index 5e9bb42a2..1d25b0af7 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -58,7 +58,7 @@ type Manager interface { GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error) ListUsers(ctx context.Context, accountID string) ([]*types.User, error) GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) - MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error + MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error DeletePeer(ctx context.Context, accountID, peerID, userID string) error UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error @@ -114,8 +114,8 @@ type Manager interface { UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) - SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) - OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error + SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) + OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error) diff --git a/management/server/account_test.go b/management/server/account_test.go index 86cc69e8b..443e6344e 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1881,7 +1881,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID}) require.NoError(t, err, "unable to get the account") - err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID) + err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC()) require.NoError(t, err, "unable to mark peer connected") _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{ @@ -1952,7 +1952,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. require.NoError(t, err, "unable to get the account") // when we mark peer as connected, the peer login expiration routine should trigger - err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID) + err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC()) require.NoError(t, err, "unable to mark peer connected") failed := waitTimeout(wg, time.Second) @@ -1961,6 +1961,82 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. } } +func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) { + manager, _, err := createManager(t) + require.NoError(t, err, "unable to create account manager") + + accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID}) + require.NoError(t, err, "unable to create an account") + + key, err := wgtypes.GenerateKey() + require.NoError(t, err, "unable to generate WireGuard key") + peerPubKey := key.PublicKey().String() + + _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{ + Key: peerPubKey, + Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"}, + }, false) + require.NoError(t, err, "unable to add peer") + + t.Run("disconnect peer when streamStartTime is after LastSeen", func(t *testing.T) { + err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, time.Now().UTC()) + require.NoError(t, err, "unable to mark peer connected") + + peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey) + require.NoError(t, err, "unable to get peer") + require.True(t, peer.Status.Connected, "peer should be connected") + + streamStartTime := time.Now().UTC() + + err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, streamStartTime) + require.NoError(t, err) + + peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey) + require.NoError(t, err) + require.False(t, peer.Status.Connected, "peer should be disconnected") + }) + + t.Run("skip disconnect when LastSeen is after streamStartTime (zombie stream protection)", func(t *testing.T) { + err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, time.Now().UTC()) + require.NoError(t, err, "unable to mark peer connected") + + peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey) + require.NoError(t, err) + require.True(t, peer.Status.Connected, "peer should be connected") + + streamStartTime := peer.Status.LastSeen.Add(-1 * time.Hour) + + err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, streamStartTime) + require.NoError(t, err) + + peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey) + require.NoError(t, err) + require.True(t, peer.Status.Connected, + "peer should remain connected because LastSeen > streamStartTime (zombie stream protection)") + }) + + t.Run("skip stale connect when peer already has newer LastSeen (blocked goroutine protection)", func(t *testing.T) { + node2SyncTime := time.Now().UTC() + err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, node2SyncTime) + require.NoError(t, err, "node 2 should connect peer") + + peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey) + require.NoError(t, err) + require.True(t, peer.Status.Connected, "peer should be connected") + require.Equal(t, node2SyncTime.Unix(), peer.Status.LastSeen.Unix(), "LastSeen should be node2SyncTime") + + node1StaleSyncTime := node2SyncTime.Add(-1 * time.Minute) + err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, node1StaleSyncTime) + require.NoError(t, err, "stale connect should not return error") + + peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey) + require.NoError(t, err) + require.True(t, peer.Status.Connected, "peer should still be connected") + require.Equal(t, node2SyncTime.Unix(), peer.Status.LastSeen.Unix(), + "LastSeen should NOT be overwritten by stale syncTime from blocked goroutine") + }) +} + func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *testing.T) { manager, _, err := createManager(t) require.NoError(t, err, "unable to create account manager") @@ -1983,7 +2059,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test account, err := manager.Store.GetAccount(context.Background(), accountID) require.NoError(t, err, "unable to get the account") - err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID) + err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC()) require.NoError(t, err, "unable to mark peer connected") wg := &sync.WaitGroup{} @@ -3176,7 +3252,7 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) { b.ResetTimer() start := time.Now() for i := 0; i < b.N; i++ { - _, _, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1}) + _, _, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1}, time.Now().UTC()) assert.NoError(b, err) } diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 9d8dc3fe2..5049ac25c 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -140,15 +140,15 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks return nil, fmt.Errorf("register integrations endpoints: %w", err) } - // Check if embedded IdP is enabled + // Check if embedded IdP is enabled for instance manager embeddedIdP, embeddedIdpEnabled := idpManager.(*idpmanager.EmbeddedIdPManager) instanceManager, err := nbinstance.NewManager(ctx, accountManager.GetStore(), embeddedIdP) if err != nil { return nil, fmt.Errorf("failed to create instance manager: %w", err) } - accounts.AddEndpoints(accountManager, settingsManager, embeddedIdpEnabled, router) - peers.AddEndpoints(accountManager, router, networkMapController) + accounts.AddEndpoints(accountManager, settingsManager, router) + peers.AddEndpoints(accountManager, router, networkMapController, permissionsManager) users.AddEndpoints(accountManager, router) users.AddInvitesEndpoints(accountManager, router) users.AddPublicInvitesEndpoints(accountManager, router) diff --git a/management/server/http/handlers/accounts/accounts_handler.go b/management/server/http/handlers/accounts/accounts_handler.go index de778d59a..122c061ce 100644 --- a/management/server/http/handlers/accounts/accounts_handler.go +++ b/management/server/http/handlers/accounts/accounts_handler.go @@ -36,24 +36,22 @@ const ( // handler is a handler that handles the server.Account HTTP endpoints type handler struct { - accountManager account.Manager - settingsManager settings.Manager - embeddedIdpEnabled bool + accountManager account.Manager + settingsManager settings.Manager } -func AddEndpoints(accountManager account.Manager, settingsManager settings.Manager, embeddedIdpEnabled bool, router *mux.Router) { - accountsHandler := newHandler(accountManager, settingsManager, embeddedIdpEnabled) +func AddEndpoints(accountManager account.Manager, settingsManager settings.Manager, router *mux.Router) { + accountsHandler := newHandler(accountManager, settingsManager) router.HandleFunc("/accounts/{accountId}", accountsHandler.updateAccount).Methods("PUT", "OPTIONS") router.HandleFunc("/accounts/{accountId}", accountsHandler.deleteAccount).Methods("DELETE", "OPTIONS") router.HandleFunc("/accounts", accountsHandler.getAllAccounts).Methods("GET", "OPTIONS") } // newHandler creates a new handler HTTP handler -func newHandler(accountManager account.Manager, settingsManager settings.Manager, embeddedIdpEnabled bool) *handler { +func newHandler(accountManager account.Manager, settingsManager settings.Manager) *handler { return &handler{ - accountManager: accountManager, - settingsManager: settingsManager, - embeddedIdpEnabled: embeddedIdpEnabled, + accountManager: accountManager, + settingsManager: settingsManager, } } @@ -165,7 +163,7 @@ func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) { return } - resp := toAccountResponse(accountID, settings, meta, onboarding, h.embeddedIdpEnabled) + resp := toAccountResponse(accountID, settings, meta, onboarding) util.WriteJSONObject(r.Context(), w, []*api.Account{resp}) } @@ -292,7 +290,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) { return } - resp := toAccountResponse(accountID, updatedSettings, meta, updatedOnboarding, h.embeddedIdpEnabled) + resp := toAccountResponse(accountID, updatedSettings, meta, updatedOnboarding) util.WriteJSONObject(r.Context(), w, &resp) } @@ -321,7 +319,7 @@ func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } -func toAccountResponse(accountID string, settings *types.Settings, meta *types.AccountMeta, onboarding *types.AccountOnboarding, embeddedIdpEnabled bool) *api.Account { +func toAccountResponse(accountID string, settings *types.Settings, meta *types.AccountMeta, onboarding *types.AccountOnboarding) *api.Account { jwtAllowGroups := settings.JWTAllowGroups if jwtAllowGroups == nil { jwtAllowGroups = []string{} @@ -341,7 +339,8 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A LazyConnectionEnabled: &settings.LazyConnectionEnabled, DnsDomain: &settings.DNSDomain, AutoUpdateVersion: &settings.AutoUpdateVersion, - EmbeddedIdpEnabled: &embeddedIdpEnabled, + EmbeddedIdpEnabled: &settings.EmbeddedIdpEnabled, + LocalAuthDisabled: &settings.LocalAuthDisabled, } if settings.NetworkRange.IsValid() { diff --git a/management/server/http/handlers/accounts/accounts_handler_test.go b/management/server/http/handlers/accounts/accounts_handler_test.go index e455372c8..6cbd5908d 100644 --- a/management/server/http/handlers/accounts/accounts_handler_test.go +++ b/management/server/http/handlers/accounts/accounts_handler_test.go @@ -33,7 +33,6 @@ func initAccountsTestData(t *testing.T, account *types.Account) *handler { AnyTimes() return &handler{ - embeddedIdpEnabled: false, accountManager: &mock_server.MockAccountManager{ GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*types.Settings, error) { return account.Settings, nil @@ -124,6 +123,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { DnsDomain: sr(""), AutoUpdateVersion: sr(""), EmbeddedIdpEnabled: br(false), + LocalAuthDisabled: br(false), }, expectedArray: true, expectedID: accountID, @@ -148,6 +148,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { DnsDomain: sr(""), AutoUpdateVersion: sr(""), EmbeddedIdpEnabled: br(false), + LocalAuthDisabled: br(false), }, expectedArray: false, expectedID: accountID, @@ -172,6 +173,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { DnsDomain: sr(""), AutoUpdateVersion: sr("latest"), EmbeddedIdpEnabled: br(false), + LocalAuthDisabled: br(false), }, expectedArray: false, expectedID: accountID, @@ -196,6 +198,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { DnsDomain: sr(""), AutoUpdateVersion: sr(""), EmbeddedIdpEnabled: br(false), + LocalAuthDisabled: br(false), }, expectedArray: false, expectedID: accountID, @@ -220,6 +223,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { DnsDomain: sr(""), AutoUpdateVersion: sr(""), EmbeddedIdpEnabled: br(false), + LocalAuthDisabled: br(false), }, expectedArray: false, expectedID: accountID, @@ -244,6 +248,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { DnsDomain: sr(""), AutoUpdateVersion: sr(""), EmbeddedIdpEnabled: br(false), + LocalAuthDisabled: br(false), }, expectedArray: false, expectedID: accountID, diff --git a/management/server/http/handlers/instance/instance_handler.go b/management/server/http/handlers/instance/instance_handler.go index 5d8baaf8d..cd9fae6b8 100644 --- a/management/server/http/handlers/instance/instance_handler.go +++ b/management/server/http/handlers/instance/instance_handler.go @@ -46,7 +46,7 @@ func (h *handler) getInstanceStatus(w http.ResponseWriter, r *http.Request) { util.WriteErrorResponse("failed to check instance status", http.StatusInternalServerError, w) return } - + log.WithContext(r.Context()).Infof("instance setup status: %v", setupRequired) util.WriteJSONObject(r.Context(), w, api.InstanceStatus{ SetupRequired: setupRequired, }) diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go index 3abb1b6f2..2295470ec 100644 --- a/management/server/http/handlers/peers/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -17,6 +17,7 @@ import ( nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/groups" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" @@ -26,11 +27,12 @@ import ( // Handler is a handler that returns peers of the account type Handler struct { accountManager account.Manager + permissionsManager permissions.Manager networkMapController network_map.Controller } -func AddEndpoints(accountManager account.Manager, router *mux.Router, networkMapController network_map.Controller) { - peersHandler := NewHandler(accountManager, networkMapController) +func AddEndpoints(accountManager account.Manager, router *mux.Router, networkMapController network_map.Controller, permissionsManager permissions.Manager) { + peersHandler := NewHandler(accountManager, networkMapController, permissionsManager) router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS") router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer). Methods("GET", "PUT", "DELETE", "OPTIONS") @@ -42,10 +44,11 @@ func AddEndpoints(accountManager account.Manager, router *mux.Router, networkMap } // NewHandler creates a new peers Handler -func NewHandler(accountManager account.Manager, networkMapController network_map.Controller) *Handler { +func NewHandler(accountManager account.Manager, networkMapController network_map.Controller, permissionsManager permissions.Manager) *Handler { return &Handler{ accountManager: accountManager, networkMapController: networkMapController, + permissionsManager: permissionsManager, } } @@ -359,13 +362,19 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { return } - account, err := h.accountManager.GetAccountByID(r.Context(), accountID, activity.SystemInitiator) + user, err := h.accountManager.GetUserByID(r.Context(), userID) if err != nil { util.WriteError(r.Context(), err, w) return } - user, err := h.accountManager.GetUserByID(r.Context(), userID) + err = h.permissionsManager.ValidateAccountAccess(r.Context(), accountID, user, false) + if err != nil { + util.WriteError(r.Context(), status.NewPermissionDeniedError(), w) + return + } + + account, err := h.accountManager.GetAccountByID(r.Context(), accountID, activity.SystemInitiator) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/handlers/peers/peers_handler_test.go b/management/server/http/handlers/peers/peers_handler_test.go index 869a39b5e..786c144fc 100644 --- a/management/server/http/handlers/peers/peers_handler_test.go +++ b/management/server/http/handlers/peers/peers_handler_test.go @@ -13,13 +13,15 @@ import ( "testing" "time" + "github.com/golang/mock/gomock" "github.com/gorilla/mux" - "go.uber.org/mock/gomock" + ugomock "go.uber.org/mock/gomock" "golang.org/x/exp/maps" "github.com/netbirdio/netbird/management/internals/controllers/network_map" nbcontext "github.com/netbirdio/netbird/management/server/context" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" @@ -102,7 +104,7 @@ func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler { }, } - ctrl := gomock.NewController(t) + ctrl := ugomock.NewController(t) networkMapController := network_map.NewMockController(ctrl) networkMapController.EXPECT(). @@ -110,6 +112,10 @@ func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler { Return("domain"). AnyTimes() + ctrl2 := gomock.NewController(t) + permissionsManager := permissions.NewMockManager(ctrl2) + permissionsManager.EXPECT().ValidateAccountAccess(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + return &Handler{ accountManager: &mock_server.MockAccountManager{ UpdatePeerFunc: func(_ context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) { @@ -199,6 +205,7 @@ func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler { }, }, networkMapController: networkMapController, + permissionsManager: permissionsManager, } } diff --git a/management/server/http/handlers/users/invites_handler_test.go b/management/server/http/handlers/users/invites_handler_test.go index 80826b9d4..529ea24d6 100644 --- a/management/server/http/handlers/users/invites_handler_test.go +++ b/management/server/http/handlers/users/invites_handler_test.go @@ -205,6 +205,14 @@ func TestCreateInvite(t *testing.T) { return nil, status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider") }, }, + { + name: "local auth disabled", + requestBody: `{"email":"test@example.com","name":"Test User","role":"user","auto_groups":[]}`, + expectedStatus: http.StatusPreconditionFailed, + mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) { + return nil, status.Errorf(status.PreconditionFailed, "local user creation is disabled - use an external identity provider") + }, + }, { name: "invalid JSON", requestBody: `{invalid json}`, @@ -376,6 +384,15 @@ func TestAcceptInvite(t *testing.T) { return status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider") }, }, + { + name: "local auth disabled", + token: testInviteToken, + requestBody: `{"password":"SecurePass123!"}`, + expectedStatus: http.StatusPreconditionFailed, + mockFunc: func(ctx context.Context, token, password string) error { + return status.Errorf(status.PreconditionFailed, "local user creation is disabled - use an external identity provider") + }, + }, { name: "missing token", token: "", diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go index bdc68e85f..73d9310d4 100644 --- a/management/server/http/testing/testing_tools/channel/channel.go +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -73,7 +73,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee proxyController := integrations.NewController(store) userManager := users.NewManager(store) permissionsManager := permissions.NewManager(store) - settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager) + settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager, settings.IdpConfig{}) peersManager := peers.NewManager(store, permissionsManager) jobManager := job.NewJobManager(nil, store, peersManager) diff --git a/management/server/idp/embedded.go b/management/server/idp/embedded.go index e03ef0c27..b2ab94d69 100644 --- a/management/server/idp/embedded.go +++ b/management/server/idp/embedded.go @@ -43,6 +43,11 @@ type EmbeddedIdPConfig struct { Owner *OwnerConfig // SignKeyRefreshEnabled enables automatic key rotation for signing keys SignKeyRefreshEnabled bool + // LocalAuthDisabled disables the local (email/password) authentication connector. + // When true, users cannot authenticate via email/password, only via external identity providers. + // Existing local users are preserved and will be able to login again if re-enabled. + // Cannot be enabled if no external identity provider connectors are configured. + LocalAuthDisabled bool } // EmbeddedStorageConfig holds storage configuration for the embedded IdP. @@ -110,6 +115,8 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) { Issuer: "NetBird", Theme: "light", }, + // Always enable password DB initially - we disable the local connector after startup if needed. + // This ensures Dex has at least one connector during initialization. EnablePasswordDB: true, StaticClients: []storage.Client{ { @@ -197,11 +204,32 @@ func NewEmbeddedIdPManager(ctx context.Context, config *EmbeddedIdPConfig, appMe return nil, err } + log.WithContext(ctx).Debugf("initializing embedded Dex IDP with config: %+v", config) + provider, err := dex.NewProviderFromYAML(ctx, yamlConfig) if err != nil { return nil, fmt.Errorf("failed to create embedded IdP provider: %w", err) } + // If local auth is disabled, validate that other connectors exist + if config.LocalAuthDisabled { + hasOthers, err := provider.HasNonLocalConnectors(ctx) + if err != nil { + _ = provider.Stop(ctx) + return nil, fmt.Errorf("failed to check connectors: %w", err) + } + if !hasOthers { + _ = provider.Stop(ctx) + return nil, fmt.Errorf("cannot disable local authentication: no other identity providers configured") + } + // Ensure local connector is removed (it might exist from a previous run) + if err := provider.DisableLocalAuth(ctx); err != nil { + _ = provider.Stop(ctx) + return nil, fmt.Errorf("failed to disable local auth: %w", err) + } + log.WithContext(ctx).Info("local authentication disabled - only external identity providers can be used") + } + log.WithContext(ctx).Infof("embedded Dex IDP initialized with issuer: %s", yamlConfig.Issuer) return &EmbeddedIdPManager{ @@ -286,6 +314,8 @@ func (m *EmbeddedIdPManager) GetAllAccounts(ctx context.Context) (map[string][]* return nil, fmt.Errorf("failed to list users: %w", err) } + log.WithContext(ctx).Debugf("retrieved %d users from embedded IdP", len(users)) + indexedUsers := make(map[string][]*UserData) for _, user := range users { indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], &UserData{ @@ -295,11 +325,17 @@ func (m *EmbeddedIdPManager) GetAllAccounts(ctx context.Context) (map[string][]* }) } + log.WithContext(ctx).Debugf("retrieved %d users from embedded IdP", len(indexedUsers[UnsetAccountID])) + return indexedUsers, nil } // CreateUser creates a new user in the embedded IdP. func (m *EmbeddedIdPManager) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) { + if m.config.LocalAuthDisabled { + return nil, fmt.Errorf("local user creation is disabled") + } + if m.appMetrics != nil { m.appMetrics.IDPMetrics().CountCreateUser() } @@ -369,6 +405,10 @@ func (m *EmbeddedIdPManager) GetUserByEmail(ctx context.Context, email string) ( // Unlike CreateUser which auto-generates a password, this method uses the provided password. // This is useful for instance setup where the user provides their own password. func (m *EmbeddedIdPManager) CreateUserWithPassword(ctx context.Context, email, password, name string) (*UserData, error) { + if m.config.LocalAuthDisabled { + return nil, fmt.Errorf("local user creation is disabled") + } + if m.appMetrics != nil { m.appMetrics.IDPMetrics().CountCreateUser() } @@ -558,3 +598,13 @@ func (m *EmbeddedIdPManager) GetClientIDs() []string { func (m *EmbeddedIdPManager) GetUserIDClaim() string { return defaultUserIDClaim } + +// IsLocalAuthDisabled returns whether local authentication is disabled based on configuration. +func (m *EmbeddedIdPManager) IsLocalAuthDisabled() bool { + return m.config.LocalAuthDisabled +} + +// HasNonLocalConnectors checks if there are any identity provider connectors other than local. +func (m *EmbeddedIdPManager) HasNonLocalConnectors(ctx context.Context) (bool, error) { + return m.provider.HasNonLocalConnectors(ctx) +} diff --git a/management/server/idp/embedded_test.go b/management/server/idp/embedded_test.go index d8d3009dd..4dda483fb 100644 --- a/management/server/idp/embedded_test.go +++ b/management/server/idp/embedded_test.go @@ -370,3 +370,234 @@ func TestEmbeddedIdPManager_GetLocalKeysLocation(t *testing.T) { }) } } + +func TestEmbeddedIdPManager_LocalAuthDisabled(t *testing.T) { + ctx := context.Background() + + t.Run("cannot start with local auth disabled without other connectors", func(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + config := &EmbeddedIdPConfig{ + Enabled: true, + Issuer: "http://localhost:5556/dex", + LocalAuthDisabled: true, + Storage: EmbeddedStorageConfig{ + Type: "sqlite3", + Config: EmbeddedStorageTypeConfig{ + File: filepath.Join(tmpDir, "dex.db"), + }, + }, + } + + _, err = NewEmbeddedIdPManager(ctx, config, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "no other identity providers configured") + }) + + t.Run("local auth enabled by default", func(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + config := &EmbeddedIdPConfig{ + Enabled: true, + Issuer: "http://localhost:5556/dex", + Storage: EmbeddedStorageConfig{ + Type: "sqlite3", + Config: EmbeddedStorageTypeConfig{ + File: filepath.Join(tmpDir, "dex.db"), + }, + }, + } + + manager, err := NewEmbeddedIdPManager(ctx, config, nil) + require.NoError(t, err) + defer func() { _ = manager.Stop(ctx) }() + + // Verify local auth is enabled by default + assert.False(t, manager.IsLocalAuthDisabled()) + }) + + t.Run("start with local auth disabled when connector exists", func(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + dbFile := filepath.Join(tmpDir, "dex.db") + + // First, create a manager with local auth enabled and add a connector + config1 := &EmbeddedIdPConfig{ + Enabled: true, + Issuer: "http://localhost:5556/dex", + Storage: EmbeddedStorageConfig{ + Type: "sqlite3", + Config: EmbeddedStorageTypeConfig{ + File: dbFile, + }, + }, + } + + manager1, err := NewEmbeddedIdPManager(ctx, config1, nil) + require.NoError(t, err) + + // Create a user + userData, err := manager1.CreateUser(ctx, "preserved@example.com", "Preserved User", "account1", "admin@example.com") + require.NoError(t, err) + userID := userData.ID + + // Add an external connector (Google doesn't require OIDC discovery) + _, err = manager1.CreateConnector(ctx, &dex.ConnectorConfig{ + ID: "google-test", + Name: "Google Test", + Type: "google", + ClientID: "test-client-id", + ClientSecret: "test-client-secret", + }) + require.NoError(t, err) + + // Stop the first manager + err = manager1.Stop(ctx) + require.NoError(t, err) + + // Now create a new manager with local auth disabled + config2 := &EmbeddedIdPConfig{ + Enabled: true, + Issuer: "http://localhost:5556/dex", + LocalAuthDisabled: true, + Storage: EmbeddedStorageConfig{ + Type: "sqlite3", + Config: EmbeddedStorageTypeConfig{ + File: dbFile, + }, + }, + } + + manager2, err := NewEmbeddedIdPManager(ctx, config2, nil) + require.NoError(t, err) + defer func() { _ = manager2.Stop(ctx) }() + + // Verify local auth is disabled via config + assert.True(t, manager2.IsLocalAuthDisabled()) + + // Verify the user still exists in storage (just can't login via local) + lookedUp, err := manager2.GetUserDataByID(ctx, userID, AppMetadata{}) + require.NoError(t, err) + assert.Equal(t, "preserved@example.com", lookedUp.Email) + }) + + t.Run("CreateUser fails when local auth is disabled", func(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + dbFile := filepath.Join(tmpDir, "dex.db") + + // First, create a manager and add an external connector + config1 := &EmbeddedIdPConfig{ + Enabled: true, + Issuer: "http://localhost:5556/dex", + Storage: EmbeddedStorageConfig{ + Type: "sqlite3", + Config: EmbeddedStorageTypeConfig{ + File: dbFile, + }, + }, + } + + manager1, err := NewEmbeddedIdPManager(ctx, config1, nil) + require.NoError(t, err) + + _, err = manager1.CreateConnector(ctx, &dex.ConnectorConfig{ + ID: "google-test", + Name: "Google Test", + Type: "google", + ClientID: "test-client-id", + ClientSecret: "test-client-secret", + }) + require.NoError(t, err) + + err = manager1.Stop(ctx) + require.NoError(t, err) + + // Create manager with local auth disabled + config2 := &EmbeddedIdPConfig{ + Enabled: true, + Issuer: "http://localhost:5556/dex", + LocalAuthDisabled: true, + Storage: EmbeddedStorageConfig{ + Type: "sqlite3", + Config: EmbeddedStorageTypeConfig{ + File: dbFile, + }, + }, + } + + manager2, err := NewEmbeddedIdPManager(ctx, config2, nil) + require.NoError(t, err) + defer func() { _ = manager2.Stop(ctx) }() + + // Try to create a user - should fail + _, err = manager2.CreateUser(ctx, "newuser@example.com", "New User", "account1", "admin@example.com") + require.Error(t, err) + assert.Contains(t, err.Error(), "local user creation is disabled") + }) + + t.Run("CreateUserWithPassword fails when local auth is disabled", func(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + dbFile := filepath.Join(tmpDir, "dex.db") + + // First, create a manager and add an external connector + config1 := &EmbeddedIdPConfig{ + Enabled: true, + Issuer: "http://localhost:5556/dex", + Storage: EmbeddedStorageConfig{ + Type: "sqlite3", + Config: EmbeddedStorageTypeConfig{ + File: dbFile, + }, + }, + } + + manager1, err := NewEmbeddedIdPManager(ctx, config1, nil) + require.NoError(t, err) + + _, err = manager1.CreateConnector(ctx, &dex.ConnectorConfig{ + ID: "google-test", + Name: "Google Test", + Type: "google", + ClientID: "test-client-id", + ClientSecret: "test-client-secret", + }) + require.NoError(t, err) + + err = manager1.Stop(ctx) + require.NoError(t, err) + + // Create manager with local auth disabled + config2 := &EmbeddedIdPConfig{ + Enabled: true, + Issuer: "http://localhost:5556/dex", + LocalAuthDisabled: true, + Storage: EmbeddedStorageConfig{ + Type: "sqlite3", + Config: EmbeddedStorageTypeConfig{ + File: dbFile, + }, + }, + } + + manager2, err := NewEmbeddedIdPManager(ctx, config2, nil) + require.NoError(t, err) + defer func() { _ = manager2.Stop(ctx) }() + + // Try to create a user with password - should fail + _, err = manager2.CreateUserWithPassword(ctx, "newuser@example.com", "SecurePass123!", "New User") + require.Error(t, err) + assert.Contains(t, err.Error(), "local user creation is disabled") + }) +} diff --git a/management/server/instance/manager.go b/management/server/instance/manager.go index 6a0509ebd..19e3abdc0 100644 --- a/management/server/instance/manager.go +++ b/management/server/instance/manager.go @@ -104,13 +104,22 @@ func NewManager(ctx context.Context, store store.Store, idpManager idp.Manager) } func (m *DefaultManager) loadSetupRequired(ctx context.Context) error { + // Check if there are any accounts in the NetBird store + numAccounts, err := m.store.GetAccountsCounter(ctx) + if err != nil { + return err + } + hasAccounts := numAccounts > 0 + + // Check if there are any users in the embedded IdP (Dex) users, err := m.embeddedIdpManager.GetAllAccounts(ctx) if err != nil { return err } + hasLocalUsers := len(users) > 0 m.setupMu.Lock() - m.setupRequired = len(users) == 0 + m.setupRequired = !(hasAccounts || hasLocalUsers) m.setupMu.Unlock() return nil diff --git a/management/server/management_test.go b/management/server/management_test.go index 0864baadf..de02855bf 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -610,6 +610,7 @@ func TestSync10PeersGetUpdates(t *testing.T) { initialPeers := 10 additionalPeers := 10 + expectedPeerCount := initialPeers + additionalPeers - 1 // -1 because peer doesn't see itself var peers []wgtypes.Key for i := 0; i < initialPeers; i++ { @@ -618,8 +619,19 @@ func TestSync10PeersGetUpdates(t *testing.T) { peers = append(peers, key) } + // Track the maximum peer count each peer has seen + type peerState struct { + mu sync.Mutex + maxPeerCount int + done bool + } + peerStates := make(map[string]*peerState) + for _, pk := range peers { + peerStates[pk.PublicKey().String()] = &peerState{} + } + var wg sync.WaitGroup - wg.Add(initialPeers + initialPeers*additionalPeers) + wg.Add(initialPeers) // One completion per initial peer var syncClients []mgmtProto.ManagementService_SyncClient for _, pk := range peers { @@ -643,6 +655,9 @@ func TestSync10PeersGetUpdates(t *testing.T) { syncClients = append(syncClients, s) go func(pk wgtypes.Key, syncStream mgmtProto.ManagementService_SyncClient) { + pubKey := pk.PublicKey().String() + state := peerStates[pubKey] + for { encMsg := &mgmtProto.EncryptedMessage{} err := syncStream.RecvMsg(encMsg) @@ -651,19 +666,28 @@ func TestSync10PeersGetUpdates(t *testing.T) { } decryptedBytes, decErr := encryption.Decrypt(encMsg.Body, ts.serverPubKey, pk) if decErr != nil { - t.Errorf("failed to decrypt SyncResponse for peer %s: %v", pk.PublicKey().String(), decErr) + t.Errorf("failed to decrypt SyncResponse for peer %s: %v", pubKey, decErr) return } resp := &mgmtProto.SyncResponse{} umErr := pb.Unmarshal(decryptedBytes, resp) if umErr != nil { - t.Errorf("failed to unmarshal SyncResponse for peer %s: %v", pk.PublicKey().String(), umErr) + t.Errorf("failed to unmarshal SyncResponse for peer %s: %v", pubKey, umErr) return } - // We only count if there's a new peer update - if len(resp.GetRemotePeers()) > 0 { + + // Track the maximum peer count seen (due to debouncing, updates are coalesced) + peerCount := len(resp.GetRemotePeers()) + state.mu.Lock() + if peerCount > state.maxPeerCount { + state.maxPeerCount = peerCount + } + // Signal completion when this peer has seen all expected peers + if !state.done && state.maxPeerCount >= expectedPeerCount { + state.done = true wg.Done() } + state.mu.Unlock() } }(pk, s) } @@ -677,7 +701,30 @@ func TestSync10PeersGetUpdates(t *testing.T) { time.Sleep(time.Duration(n) * time.Millisecond) } - wg.Wait() + // Wait for debouncer to flush final updates (debounce interval is 1000ms) + time.Sleep(1500 * time.Millisecond) + + // Wait with timeout + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + // Success - all peers received expected peer count + case <-time.After(5 * time.Second): + // Timeout - report which peers didn't receive all updates + t.Error("Timeout waiting for all peers to receive updates") + for pubKey, state := range peerStates { + state.mu.Lock() + if state.maxPeerCount < expectedPeerCount { + t.Errorf("Peer %s only saw %d peers, expected %d", pubKey, state.maxPeerCount, expectedPeerCount) + } + state.mu.Unlock() + } + } for _, sc := range syncClients { err := sc.CloseSend() diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 026989898..8471d0a94 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -37,8 +37,8 @@ type MockAccountManager struct { GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error) ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error) GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) - MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error - SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) + MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP, syncTime time.Time) error + SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error) GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*types.Network, error) @@ -214,16 +214,15 @@ func (am *MockAccountManager) DeleteSetupKey(ctx context.Context, accountID, use return status.Errorf(codes.Unimplemented, "method DeleteSetupKey is not implemented") } -func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { +func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { if am.SyncAndMarkPeerFunc != nil { - return am.SyncAndMarkPeerFunc(ctx, accountID, peerPubKey, meta, realIP) + return am.SyncAndMarkPeerFunc(ctx, accountID, peerPubKey, meta, realIP, syncTime) } return nil, nil, nil, 0, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented") } -func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID string, peerPubKey string) error { - // TODO implement me - panic("implement me") +func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error { + return nil } func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) { @@ -323,9 +322,9 @@ func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userAuth } // MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface -func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error { +func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error { if am.MarkPeerConnectedFunc != nil { - return am.MarkPeerConnectedFunc(ctx, peerKey, connected, realIP) + return am.MarkPeerConnectedFunc(ctx, peerKey, connected, realIP, syncTime) } return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented") } diff --git a/management/server/peer.go b/management/server/peer.go index 71d3fbb0e..57afd8207 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -103,11 +103,13 @@ func (am *DefaultAccountManager) getUserAccessiblePeers(ctx context.Context, acc } // MarkPeerConnected marks peer as connected (true) or disconnected (false) -func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, accountID string) error { +// syncTime is used as the LastSeen timestamp and for stale request detection +func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error { var peer *nbpeer.Peer var settings *types.Settings var expired bool var err error + var skipped bool err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { peer, err = transaction.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, peerPubKey) @@ -115,9 +117,19 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK return err } - expired, err = updatePeerStatusAndLocation(ctx, am.geo, transaction, peer, connected, realIP, accountID) + if connected && !syncTime.After(peer.Status.LastSeen) { + log.WithContext(ctx).Tracef("peer %s has newer activity (lastSeen=%s >= syncTime=%s), skipping connect", + peer.ID, peer.Status.LastSeen.Format(time.RFC3339), syncTime.Format(time.RFC3339)) + skipped = true + return nil + } + + expired, err = updatePeerStatusAndLocation(ctx, am.geo, transaction, peer, connected, realIP, accountID, syncTime) return err }) + if skipped { + return nil + } if err != nil { return err } @@ -147,10 +159,10 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK return nil } -func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocation, transaction store.Store, peer *nbpeer.Peer, connected bool, realIP net.IP, accountID string) (bool, error) { +func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocation, transaction store.Store, peer *nbpeer.Peer, connected bool, realIP net.IP, accountID string, syncTime time.Time) (bool, error) { oldStatus := peer.Status.Copy() newStatus := oldStatus - newStatus.LastSeen = time.Now().UTC() + newStatus.LastSeen = syncTime newStatus.Connected = connected // whenever peer got connected that means that it logged in successfully if newStatus.Connected { diff --git a/management/server/settings/manager.go b/management/server/settings/manager.go index 2b2896572..74af0a3ef 100644 --- a/management/server/settings/manager.go +++ b/management/server/settings/manager.go @@ -24,19 +24,28 @@ type Manager interface { UpdateExtraSettings(ctx context.Context, accountID, userID string, extraSettings *types.ExtraSettings) (bool, error) } +// IdpConfig holds IdP-related configuration that is set at runtime +// and not stored in the database. +type IdpConfig struct { + EmbeddedIdpEnabled bool + LocalAuthDisabled bool +} + type managerImpl struct { store store.Store extraSettingsManager extra_settings.Manager userManager users.Manager permissionsManager permissions.Manager + idpConfig IdpConfig } -func NewManager(store store.Store, userManager users.Manager, extraSettingsManager extra_settings.Manager, permissionsManager permissions.Manager) Manager { +func NewManager(store store.Store, userManager users.Manager, extraSettingsManager extra_settings.Manager, permissionsManager permissions.Manager, idpConfig IdpConfig) Manager { return &managerImpl{ store: store, extraSettingsManager: extraSettingsManager, userManager: userManager, permissionsManager: permissionsManager, + idpConfig: idpConfig, } } @@ -74,6 +83,10 @@ func (m *managerImpl) GetSettings(ctx context.Context, accountID, userID string) settings.Extra.FlowDnsCollectionEnabled = extraSettings.FlowDnsCollectionEnabled } + // Fill in IdP-related runtime settings + settings.EmbeddedIdpEnabled = m.idpConfig.EmbeddedIdpEnabled + settings.LocalAuthDisabled = m.idpConfig.LocalAuthDisabled + return settings, nil } diff --git a/management/server/types/settings.go b/management/server/types/settings.go index 867e12bef..a94e01b78 100644 --- a/management/server/types/settings.go +++ b/management/server/types/settings.go @@ -55,6 +55,14 @@ type Settings struct { // AutoUpdateVersion client auto-update version AutoUpdateVersion string `gorm:"default:'disabled'"` + + // EmbeddedIdpEnabled indicates if the embedded identity provider is enabled. + // This is a runtime-only field, not stored in the database. + EmbeddedIdpEnabled bool `gorm:"-"` + + // LocalAuthDisabled indicates if local (email/password) authentication is disabled. + // This is a runtime-only field, not stored in the database. + LocalAuthDisabled bool `gorm:"-"` } // Copy copies the Settings struct @@ -76,6 +84,8 @@ func (s *Settings) Copy() *Settings { DNSDomain: s.DNSDomain, NetworkRange: s.NetworkRange, AutoUpdateVersion: s.AutoUpdateVersion, + EmbeddedIdpEnabled: s.EmbeddedIdpEnabled, + LocalAuthDisabled: s.LocalAuthDisabled, } if s.Extra != nil { settings.Extra = s.Extra.Copy() diff --git a/management/server/user.go b/management/server/user.go index 51da7a633..48005f325 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -191,6 +191,10 @@ func (am *DefaultAccountManager) createNewIdpUser(ctx context.Context, accountID // Unlike createNewIdpUser, this method fetches user data directly from the database // since the embedded IdP usage ensures the username and email are stored locally in the User table. func (am *DefaultAccountManager) createEmbeddedIdpUser(ctx context.Context, accountID string, inviterID string, invite *types.UserInfo) (*idp.UserData, error) { + if IsLocalAuthDisabled(ctx, am.idpManager) { + return nil, status.Errorf(status.PreconditionFailed, "local user creation is disabled - use an external identity provider") + } + inviter, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, inviterID) if err != nil { return nil, fmt.Errorf("failed to get inviter user: %w", err) @@ -1462,6 +1466,10 @@ func (am *DefaultAccountManager) CreateUserInvite(ctx context.Context, accountID return nil, status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider") } + if IsLocalAuthDisabled(ctx, am.idpManager) { + return nil, status.Errorf(status.PreconditionFailed, "local user creation is disabled - use an external identity provider") + } + if err := validateUserInvite(invite); err != nil { return nil, err } @@ -1621,6 +1629,10 @@ func (am *DefaultAccountManager) AcceptUserInvite(ctx context.Context, token, pa return status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider") } + if IsLocalAuthDisabled(ctx, am.idpManager) { + return status.Errorf(status.PreconditionFailed, "local user creation is disabled - use an external identity provider") + } + if password == "" { return status.Errorf(status.InvalidArgument, "password is required") } diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index be1cf9b76..7d3a95a5c 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -296,6 +296,11 @@ components: type: boolean readOnly: true example: false + local_auth_disabled: + description: Indicates whether local (email/password) authentication is disabled. When true, users can only authenticate via external identity providers. This is a read-only field. + type: boolean + readOnly: true + example: false required: - peer_login_expiration_enabled - peer_login_expiration diff --git a/shared/management/http/api/types.gen.go b/shared/management/http/api/types.gen.go index 89a61ad74..7dec2bb7a 100644 --- a/shared/management/http/api/types.gen.go +++ b/shared/management/http/api/types.gen.go @@ -446,6 +446,9 @@ type AccountSettings struct { // LazyConnectionEnabled Enables or disables experimental lazy connection LazyConnectionEnabled *bool `json:"lazy_connection_enabled,omitempty"` + // LocalAuthDisabled Indicates whether local (email/password) authentication is disabled. When true, users can only authenticate via external identity providers. This is a read-only field. + LocalAuthDisabled *bool `json:"local_auth_disabled,omitempty"` + // NetworkRange Allows to define a custom network range for the account in CIDR format NetworkRange *string `json:"network_range,omitempty"`