mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
Merge branch 'main' into prototype/reverse-proxy
This commit is contained in:
@@ -60,8 +60,8 @@
|
|||||||
|
|
||||||
https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2
|
https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2
|
||||||
|
|
||||||
### NetBird on Lawrence Systems (Video)
|
### Self-Host NetBird (Video)
|
||||||
[](https://www.youtube.com/watch?v=Kwrff6h0rEw)
|
[](https://youtu.be/bZAgpT6nzaQ)
|
||||||
|
|
||||||
### Key features
|
### Key features
|
||||||
|
|
||||||
|
|||||||
@@ -190,50 +190,75 @@ func (f *DNSForwarder) Close(ctx context.Context) error {
|
|||||||
return nberrors.FormatErrorOrNil(result)
|
return nberrors.FormatErrorOrNil(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, query *dns.Msg) *dns.Msg {
|
func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, query *dns.Msg, startTime time.Time) {
|
||||||
if len(query.Question) == 0 {
|
if len(query.Question) == 0 {
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
question := query.Question[0]
|
question := query.Question[0]
|
||||||
logger.Tracef("received DNS request for DNS forwarder: domain=%s type=%s class=%s",
|
qname := strings.ToLower(question.Name)
|
||||||
question.Name, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
|
|
||||||
|
|
||||||
domain := strings.ToLower(question.Name)
|
logger.Tracef("question: domain=%s type=%s class=%s",
|
||||||
|
qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
|
||||||
|
|
||||||
resp := query.SetReply(query)
|
resp := query.SetReply(query)
|
||||||
network := resutil.NetworkForQtype(question.Qtype)
|
network := resutil.NetworkForQtype(question.Qtype)
|
||||||
if network == "" {
|
if network == "" {
|
||||||
resp.Rcode = dns.RcodeNotImplemented
|
resp.Rcode = dns.RcodeNotImplemented
|
||||||
if err := w.WriteMsg(resp); err != nil {
|
f.writeResponse(logger, w, resp, qname, startTime)
|
||||||
logger.Errorf("failed to write DNS response: %v", err)
|
return
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, "."))
|
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(qname, "."))
|
||||||
// query doesn't match any configured domain
|
|
||||||
if mostSpecificResId == "" {
|
if mostSpecificResId == "" {
|
||||||
resp.Rcode = dns.RcodeRefused
|
resp.Rcode = dns.RcodeRefused
|
||||||
if err := w.WriteMsg(resp); err != nil {
|
f.writeResponse(logger, w, resp, qname, startTime)
|
||||||
logger.Errorf("failed to write DNS response: %v", err)
|
return
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
|
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
result := resutil.LookupIP(ctx, f.resolver, network, domain, question.Qtype)
|
result := resutil.LookupIP(ctx, f.resolver, network, qname, question.Qtype)
|
||||||
if result.Err != nil {
|
if result.Err != nil {
|
||||||
f.handleDNSError(ctx, logger, w, question, resp, domain, result)
|
f.handleDNSError(ctx, logger, w, question, resp, qname, result, startTime)
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
f.updateInternalState(result.IPs, mostSpecificResId, matchingEntries)
|
f.updateInternalState(result.IPs, mostSpecificResId, matchingEntries)
|
||||||
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, result.IPs, f.ttl)...)
|
resp.Answer = append(resp.Answer, resutil.IPsToRRs(qname, result.IPs, f.ttl)...)
|
||||||
f.cache.set(domain, question.Qtype, result.IPs)
|
f.cache.set(qname, question.Qtype, result.IPs)
|
||||||
|
|
||||||
return resp
|
f.writeResponse(logger, w, resp, qname, startTime)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *DNSForwarder) writeResponse(logger *log.Entry, w dns.ResponseWriter, resp *dns.Msg, qname string, startTime time.Time) {
|
||||||
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
|
logger.Errorf("failed to write DNS response: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
|
||||||
|
qname, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
|
||||||
|
}
|
||||||
|
|
||||||
|
// udpResponseWriter wraps a dns.ResponseWriter to handle UDP-specific truncation.
|
||||||
|
type udpResponseWriter struct {
|
||||||
|
dns.ResponseWriter
|
||||||
|
query *dns.Msg
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *udpResponseWriter) WriteMsg(resp *dns.Msg) error {
|
||||||
|
opt := u.query.IsEdns0()
|
||||||
|
maxSize := dns.MinMsgSize
|
||||||
|
if opt != nil {
|
||||||
|
maxSize = int(opt.UDPSize())
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Len() > maxSize {
|
||||||
|
resp.Truncate(maxSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
return u.ResponseWriter.WriteMsg(resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
|
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
|
||||||
@@ -243,30 +268,7 @@ func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
|
|||||||
"dns_id": fmt.Sprintf("%04x", query.Id),
|
"dns_id": fmt.Sprintf("%04x", query.Id),
|
||||||
})
|
})
|
||||||
|
|
||||||
resp := f.handleDNSQuery(logger, w, query)
|
f.handleDNSQuery(logger, &udpResponseWriter{ResponseWriter: w, query: query}, query, startTime)
|
||||||
if resp == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
opt := query.IsEdns0()
|
|
||||||
maxSize := dns.MinMsgSize
|
|
||||||
if opt != nil {
|
|
||||||
// client advertised a larger EDNS0 buffer
|
|
||||||
maxSize = int(opt.UDPSize())
|
|
||||||
}
|
|
||||||
|
|
||||||
// if our response is too big, truncate and set the TC bit
|
|
||||||
if resp.Len() > maxSize {
|
|
||||||
resp.Truncate(maxSize)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := w.WriteMsg(resp); err != nil {
|
|
||||||
logger.Errorf("failed to write DNS response: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
|
|
||||||
query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
|
func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
|
||||||
@@ -276,18 +278,7 @@ func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
|
|||||||
"dns_id": fmt.Sprintf("%04x", query.Id),
|
"dns_id": fmt.Sprintf("%04x", query.Id),
|
||||||
})
|
})
|
||||||
|
|
||||||
resp := f.handleDNSQuery(logger, w, query)
|
f.handleDNSQuery(logger, w, query, startTime)
|
||||||
if resp == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := w.WriteMsg(resp); err != nil {
|
|
||||||
logger.Errorf("failed to write DNS response: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
|
|
||||||
query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) {
|
func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) {
|
||||||
@@ -334,6 +325,7 @@ func (f *DNSForwarder) handleDNSError(
|
|||||||
resp *dns.Msg,
|
resp *dns.Msg,
|
||||||
domain string,
|
domain string,
|
||||||
result resutil.LookupResult,
|
result resutil.LookupResult,
|
||||||
|
startTime time.Time,
|
||||||
) {
|
) {
|
||||||
qType := question.Qtype
|
qType := question.Qtype
|
||||||
qTypeName := dns.TypeToString[qType]
|
qTypeName := dns.TypeToString[qType]
|
||||||
@@ -343,9 +335,7 @@ func (f *DNSForwarder) handleDNSError(
|
|||||||
// NotFound: cache negative result and respond
|
// NotFound: cache negative result and respond
|
||||||
if result.Rcode == dns.RcodeNameError || result.Rcode == dns.RcodeSuccess {
|
if result.Rcode == dns.RcodeNameError || result.Rcode == dns.RcodeSuccess {
|
||||||
f.cache.set(domain, question.Qtype, nil)
|
f.cache.set(domain, question.Qtype, nil)
|
||||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
f.writeResponse(logger, w, resp, domain, startTime)
|
||||||
logger.Errorf("failed to write failure DNS response: %v", writeErr)
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -355,9 +345,7 @@ func (f *DNSForwarder) handleDNSError(
|
|||||||
logger.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName)
|
logger.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName)
|
||||||
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, ips, f.ttl)...)
|
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, ips, f.ttl)...)
|
||||||
resp.Rcode = dns.RcodeSuccess
|
resp.Rcode = dns.RcodeSuccess
|
||||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
f.writeResponse(logger, w, resp, domain, startTime)
|
||||||
logger.Errorf("failed to write cached DNS response: %v", writeErr)
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -365,9 +353,7 @@ func (f *DNSForwarder) handleDNSError(
|
|||||||
verifyResult := resutil.LookupIP(ctx, f.resolver, resutil.NetworkForQtype(qType), domain, qType)
|
verifyResult := resutil.LookupIP(ctx, f.resolver, resutil.NetworkForQtype(qType), domain, qType)
|
||||||
if verifyResult.Rcode == dns.RcodeNameError || verifyResult.Rcode == dns.RcodeSuccess {
|
if verifyResult.Rcode == dns.RcodeNameError || verifyResult.Rcode == dns.RcodeSuccess {
|
||||||
resp.Rcode = verifyResult.Rcode
|
resp.Rcode = verifyResult.Rcode
|
||||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
f.writeResponse(logger, w, resp, domain, startTime)
|
||||||
logger.Errorf("failed to write failure DNS response: %v", writeErr)
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -375,15 +361,12 @@ func (f *DNSForwarder) handleDNSError(
|
|||||||
// No cache or verification failed. Log with or without the server field for more context.
|
// No cache or verification failed. Log with or without the server field for more context.
|
||||||
var dnsErr *net.DNSError
|
var dnsErr *net.DNSError
|
||||||
if errors.As(result.Err, &dnsErr) && dnsErr.Server != "" {
|
if errors.As(result.Err, &dnsErr) && dnsErr.Server != "" {
|
||||||
logger.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, result.Err)
|
logger.Warnf("upstream failure: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, result.Err)
|
||||||
} else {
|
} else {
|
||||||
logger.Warnf(errResolveFailed, domain, result.Err)
|
logger.Warnf(errResolveFailed, domain, result.Err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write final failure response.
|
f.writeResponse(logger, w, resp, domain, startTime)
|
||||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
|
||||||
logger.Errorf("failed to write failure DNS response: %v", writeErr)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// getMatchingEntries retrieves the resource IDs for a given domain.
|
// getMatchingEntries retrieves the resource IDs for a given domain.
|
||||||
|
|||||||
@@ -318,8 +318,9 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
|
|||||||
query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA)
|
query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA)
|
||||||
|
|
||||||
mockWriter := &test.MockResponseWriter{}
|
mockWriter := &test.MockResponseWriter{}
|
||||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||||
|
|
||||||
|
resp := mockWriter.GetLastResponse()
|
||||||
if tt.shouldResolve {
|
if tt.shouldResolve {
|
||||||
require.NotNil(t, resp, "Expected response for authorized domain")
|
require.NotNil(t, resp, "Expected response for authorized domain")
|
||||||
require.Equal(t, dns.RcodeSuccess, resp.Rcode, "Expected successful response")
|
require.Equal(t, dns.RcodeSuccess, resp.Rcode, "Expected successful response")
|
||||||
@@ -329,10 +330,9 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
|
|||||||
mockFirewall.AssertExpectations(t)
|
mockFirewall.AssertExpectations(t)
|
||||||
mockResolver.AssertExpectations(t)
|
mockResolver.AssertExpectations(t)
|
||||||
} else {
|
} else {
|
||||||
if resp != nil {
|
require.NotNil(t, resp, "Expected response")
|
||||||
assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess,
|
assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess,
|
||||||
"Unauthorized domain should not return successful answers")
|
"Unauthorized domain should not return successful answers")
|
||||||
}
|
|
||||||
mockFirewall.AssertNotCalled(t, "UpdateSet")
|
mockFirewall.AssertNotCalled(t, "UpdateSet")
|
||||||
mockResolver.AssertNotCalled(t, "LookupNetIP")
|
mockResolver.AssertNotCalled(t, "LookupNetIP")
|
||||||
}
|
}
|
||||||
@@ -466,14 +466,16 @@ func TestDNSForwarder_FirewallSetUpdates(t *testing.T) {
|
|||||||
dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA)
|
dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA)
|
||||||
|
|
||||||
mockWriter := &test.MockResponseWriter{}
|
mockWriter := &test.MockResponseWriter{}
|
||||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, dnsQuery)
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, dnsQuery, time.Now())
|
||||||
|
|
||||||
// Verify response
|
// Verify response
|
||||||
|
resp := mockWriter.GetLastResponse()
|
||||||
if tt.shouldResolve {
|
if tt.shouldResolve {
|
||||||
require.NotNil(t, resp, "Expected response for authorized domain")
|
require.NotNil(t, resp, "Expected response for authorized domain")
|
||||||
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||||
require.NotEmpty(t, resp.Answer)
|
require.NotEmpty(t, resp.Answer)
|
||||||
} else if resp != nil {
|
} else {
|
||||||
|
require.NotNil(t, resp, "Expected response")
|
||||||
assert.True(t, resp.Rcode == dns.RcodeRefused || len(resp.Answer) == 0,
|
assert.True(t, resp.Rcode == dns.RcodeRefused || len(resp.Answer) == 0,
|
||||||
"Unauthorized domain should be refused or have no answers")
|
"Unauthorized domain should be refused or have no answers")
|
||||||
}
|
}
|
||||||
@@ -528,9 +530,10 @@ func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
|
|||||||
query.SetQuestion("example.com.", dns.TypeA)
|
query.SetQuestion("example.com.", dns.TypeA)
|
||||||
|
|
||||||
mockWriter := &test.MockResponseWriter{}
|
mockWriter := &test.MockResponseWriter{}
|
||||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||||
|
|
||||||
// Verify response contains all IPs
|
// Verify response contains all IPs
|
||||||
|
resp := mockWriter.GetLastResponse()
|
||||||
require.NotNil(t, resp)
|
require.NotNil(t, resp)
|
||||||
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||||
require.Len(t, resp.Answer, 3, "Should have 3 answer records")
|
require.Len(t, resp.Answer, 3, "Should have 3 answer records")
|
||||||
@@ -605,7 +608,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||||
|
|
||||||
// Check the response written to the writer
|
// Check the response written to the writer
|
||||||
require.NotNil(t, writtenResp, "Expected response to be written")
|
require.NotNil(t, writtenResp, "Expected response to be written")
|
||||||
@@ -675,7 +678,8 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
|
|||||||
q1 := &dns.Msg{}
|
q1 := &dns.Msg{}
|
||||||
q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
||||||
w1 := &test.MockResponseWriter{}
|
w1 := &test.MockResponseWriter{}
|
||||||
resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1)
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1, time.Now())
|
||||||
|
resp1 := w1.GetLastResponse()
|
||||||
require.NotNil(t, resp1)
|
require.NotNil(t, resp1)
|
||||||
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
||||||
require.Len(t, resp1.Answer, 1)
|
require.Len(t, resp1.Answer, 1)
|
||||||
@@ -683,13 +687,13 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
|
|||||||
// Second query: serve from cache after upstream failure
|
// Second query: serve from cache after upstream failure
|
||||||
q2 := &dns.Msg{}
|
q2 := &dns.Msg{}
|
||||||
q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
||||||
var writtenResp *dns.Msg
|
w2 := &test.MockResponseWriter{}
|
||||||
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2, time.Now())
|
||||||
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2)
|
|
||||||
|
|
||||||
require.NotNil(t, writtenResp, "expected response to be written")
|
resp2 := w2.GetLastResponse()
|
||||||
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
|
require.NotNil(t, resp2, "expected response to be written")
|
||||||
require.Len(t, writtenResp.Answer, 1)
|
require.Equal(t, dns.RcodeSuccess, resp2.Rcode)
|
||||||
|
require.Len(t, resp2.Answer, 1)
|
||||||
|
|
||||||
mockResolver.AssertExpectations(t)
|
mockResolver.AssertExpectations(t)
|
||||||
}
|
}
|
||||||
@@ -715,7 +719,8 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
|
|||||||
q1 := &dns.Msg{}
|
q1 := &dns.Msg{}
|
||||||
q1.SetQuestion(mixedQuery+".", dns.TypeA)
|
q1.SetQuestion(mixedQuery+".", dns.TypeA)
|
||||||
w1 := &test.MockResponseWriter{}
|
w1 := &test.MockResponseWriter{}
|
||||||
resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1)
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1, time.Now())
|
||||||
|
resp1 := w1.GetLastResponse()
|
||||||
require.NotNil(t, resp1)
|
require.NotNil(t, resp1)
|
||||||
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
||||||
require.Len(t, resp1.Answer, 1)
|
require.Len(t, resp1.Answer, 1)
|
||||||
@@ -727,13 +732,13 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
|
|||||||
|
|
||||||
q2 := &dns.Msg{}
|
q2 := &dns.Msg{}
|
||||||
q2.SetQuestion("EXAMPLE.COM", dns.TypeA)
|
q2.SetQuestion("EXAMPLE.COM", dns.TypeA)
|
||||||
var writtenResp *dns.Msg
|
w2 := &test.MockResponseWriter{}
|
||||||
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2, time.Now())
|
||||||
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2)
|
|
||||||
|
|
||||||
require.NotNil(t, writtenResp)
|
resp2 := w2.GetLastResponse()
|
||||||
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
|
require.NotNil(t, resp2)
|
||||||
require.Len(t, writtenResp.Answer, 1)
|
require.Equal(t, dns.RcodeSuccess, resp2.Rcode)
|
||||||
|
require.Len(t, resp2.Answer, 1)
|
||||||
|
|
||||||
mockResolver.AssertExpectations(t)
|
mockResolver.AssertExpectations(t)
|
||||||
}
|
}
|
||||||
@@ -784,8 +789,9 @@ func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
|
|||||||
query.SetQuestion("smtp.mail.example.com.", dns.TypeA)
|
query.SetQuestion("smtp.mail.example.com.", dns.TypeA)
|
||||||
|
|
||||||
mockWriter := &test.MockResponseWriter{}
|
mockWriter := &test.MockResponseWriter{}
|
||||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||||
|
|
||||||
|
resp := mockWriter.GetLastResponse()
|
||||||
require.NotNil(t, resp)
|
require.NotNil(t, resp)
|
||||||
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||||
|
|
||||||
@@ -897,26 +903,15 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) {
|
|||||||
query := &dns.Msg{}
|
query := &dns.Msg{}
|
||||||
query.SetQuestion(dns.Fqdn("example.com"), tt.queryType)
|
query.SetQuestion(dns.Fqdn("example.com"), tt.queryType)
|
||||||
|
|
||||||
var writtenResp *dns.Msg
|
mockWriter := &test.MockResponseWriter{}
|
||||||
mockWriter := &test.MockResponseWriter{
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||||
WriteMsgFunc: func(m *dns.Msg) error {
|
|
||||||
writtenResp = m
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
resp := mockWriter.GetLastResponse()
|
||||||
|
require.NotNil(t, resp, "Expected response to be written")
|
||||||
// If a response was returned, it means it should be written (happens in wrapper functions)
|
assert.Equal(t, tt.expectedCode, resp.Rcode, tt.description)
|
||||||
if resp != nil && writtenResp == nil {
|
|
||||||
writtenResp = resp
|
|
||||||
}
|
|
||||||
|
|
||||||
require.NotNil(t, writtenResp, "Expected response to be written")
|
|
||||||
assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description)
|
|
||||||
|
|
||||||
if tt.expectNoAnswer {
|
if tt.expectNoAnswer {
|
||||||
assert.Empty(t, writtenResp.Answer, "Response should have no answer records")
|
assert.Empty(t, resp.Answer, "Response should have no answer records")
|
||||||
}
|
}
|
||||||
|
|
||||||
mockResolver.AssertExpectations(t)
|
mockResolver.AssertExpectations(t)
|
||||||
@@ -931,15 +926,8 @@ func TestDNSForwarder_EmptyQuery(t *testing.T) {
|
|||||||
query := &dns.Msg{}
|
query := &dns.Msg{}
|
||||||
// Don't set any question
|
// Don't set any question
|
||||||
|
|
||||||
writeCalled := false
|
mockWriter := &test.MockResponseWriter{}
|
||||||
mockWriter := &test.MockResponseWriter{
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||||
WriteMsgFunc: func(m *dns.Msg) error {
|
|
||||||
writeCalled = true
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
|
||||||
|
|
||||||
assert.Nil(t, resp, "Should return nil for empty query")
|
assert.Nil(t, mockWriter.GetLastResponse(), "Should not write response for empty query")
|
||||||
assert.False(t, writeCalled, "Should not write response for empty query")
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -828,6 +828,10 @@ func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdate
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||||
|
started := time.Now()
|
||||||
|
defer func() {
|
||||||
|
log.Infof("sync finished in %s", time.Since(started))
|
||||||
|
}()
|
||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package ice
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -32,24 +33,6 @@ type ThreadSafeAgent struct {
|
|||||||
once sync.Once
|
once sync.Once
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *ThreadSafeAgent) Close() error {
|
|
||||||
var err error
|
|
||||||
a.once.Do(func() {
|
|
||||||
done := make(chan error, 1)
|
|
||||||
go func() {
|
|
||||||
done <- a.Agent.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case err = <-done:
|
|
||||||
case <-time.After(iceAgentCloseTimeout):
|
|
||||||
log.Warnf("ICE agent close timed out after %v, proceeding with cleanup", iceAgentCloseTimeout)
|
|
||||||
err = nil
|
|
||||||
}
|
|
||||||
})
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) {
|
func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) {
|
||||||
iceKeepAlive := iceKeepAlive()
|
iceKeepAlive := iceKeepAlive()
|
||||||
iceDisconnectedTimeout := iceDisconnectedTimeout()
|
iceDisconnectedTimeout := iceDisconnectedTimeout()
|
||||||
@@ -93,9 +76,41 @@ func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, c
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if agent == nil {
|
||||||
|
return nil, fmt.Errorf("ice.NewAgent returned nil agent without error")
|
||||||
|
}
|
||||||
|
|
||||||
return &ThreadSafeAgent{Agent: agent}, nil
|
return &ThreadSafeAgent{Agent: agent}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *ThreadSafeAgent) Close() error {
|
||||||
|
var err error
|
||||||
|
a.once.Do(func() {
|
||||||
|
// Defensive check to prevent nil pointer dereference
|
||||||
|
// This can happen during sleep/wake transitions or memory corruption scenarios
|
||||||
|
// github.com/netbirdio/netbird/client/internal/peer/ice.(*ThreadSafeAgent).Close(0x40006883f0?)
|
||||||
|
// [signal 0xc0000005 code=0x0 addr=0x0 pc=0x7ff7e73af83c]
|
||||||
|
agent := a.Agent
|
||||||
|
if agent == nil {
|
||||||
|
log.Warnf("ICE agent is nil during close, skipping")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
done := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
done <- agent.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err = <-done:
|
||||||
|
case <-time.After(iceAgentCloseTimeout):
|
||||||
|
log.Warnf("ICE agent close timed out after %v, proceeding with cleanup", iceAgentCloseTimeout)
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func GenerateICECredentials() (string, string, error) {
|
func GenerateICECredentials() (string, string, error) {
|
||||||
ufrag, err := randutil.GenerateCryptoRandomString(lenUFrag, runesAlpha)
|
ufrag, err := randutil.GenerateCryptoRandomString(lenUFrag, runesAlpha)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -107,8 +107,10 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
|||||||
}
|
}
|
||||||
w.log.Debugf("agent already exists, recreate the connection")
|
w.log.Debugf("agent already exists, recreate the connection")
|
||||||
w.agentDialerCancel()
|
w.agentDialerCancel()
|
||||||
if err := w.agent.Close(); err != nil {
|
if w.agent != nil {
|
||||||
w.log.Warnf("failed to close ICE agent: %s", err)
|
if err := w.agent.Close(); err != nil {
|
||||||
|
w.log.Warnf("failed to close ICE agent: %s", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sessionID, err := NewICESessionID()
|
sessionID, err := NewICESessionID()
|
||||||
|
|||||||
@@ -252,7 +252,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if config.AdminURL == nil {
|
if config.AdminURL == nil {
|
||||||
log.Infof("using default Admin URL %s", DefaultManagementURL)
|
log.Infof("using default Admin URL %s", DefaultAdminURL)
|
||||||
config.AdminURL, err = parseURL("Admin URL", DefaultAdminURL)
|
config.AdminURL, err = parseURL("Admin URL", DefaultAdminURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
|
|||||||
@@ -173,12 +173,21 @@ func (m *DefaultManager) setupAndroidRoutes(config ManagerConfig) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *DefaultManager) setupRefCounters(useNoop bool) {
|
func (m *DefaultManager) setupRefCounters(useNoop bool) {
|
||||||
|
var once sync.Once
|
||||||
|
var wgIface *net.Interface
|
||||||
|
toInterface := func() *net.Interface {
|
||||||
|
once.Do(func() {
|
||||||
|
wgIface = m.wgInterface.ToInterface()
|
||||||
|
})
|
||||||
|
return wgIface
|
||||||
|
}
|
||||||
|
|
||||||
m.routeRefCounter = refcounter.New(
|
m.routeRefCounter = refcounter.New(
|
||||||
func(prefix netip.Prefix, _ struct{}) (struct{}, error) {
|
func(prefix netip.Prefix, _ struct{}) (struct{}, error) {
|
||||||
return struct{}{}, m.sysOps.AddVPNRoute(prefix, m.wgInterface.ToInterface())
|
return struct{}{}, m.sysOps.AddVPNRoute(prefix, toInterface())
|
||||||
},
|
},
|
||||||
func(prefix netip.Prefix, _ struct{}) error {
|
func(prefix netip.Prefix, _ struct{}) error {
|
||||||
return m.sysOps.RemoveVPNRoute(prefix, m.wgInterface.ToInterface())
|
return m.sysOps.RemoveVPNRoute(prefix, toInterface())
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -4,16 +4,17 @@ package systemops
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
// filterRoutesByFlags returns true if the route message should be ignored based on its flags.
|
// filterRoutesByFlags returns true if the route message should be ignored based on its flags.
|
||||||
func filterRoutesByFlags(routeMessageFlags int) bool {
|
func filterRoutesByFlags(routeMessageFlags int) bool {
|
||||||
if routeMessageFlags&syscall.RTF_UP == 0 {
|
if routeMessageFlags&unix.RTF_UP == 0 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE|syscall.RTF_WASCLONED) != 0 {
|
if routeMessageFlags&(unix.RTF_REJECT|unix.RTF_BLACKHOLE|unix.RTF_WASCLONED) != 0 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -24,42 +25,51 @@ func filterRoutesByFlags(routeMessageFlags int) bool {
|
|||||||
func formatBSDFlags(flags int) string {
|
func formatBSDFlags(flags int) string {
|
||||||
var flagStrs []string
|
var flagStrs []string
|
||||||
|
|
||||||
if flags&syscall.RTF_UP != 0 {
|
if flags&unix.RTF_UP != 0 {
|
||||||
flagStrs = append(flagStrs, "U")
|
flagStrs = append(flagStrs, "U")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_GATEWAY != 0 {
|
if flags&unix.RTF_GATEWAY != 0 {
|
||||||
flagStrs = append(flagStrs, "G")
|
flagStrs = append(flagStrs, "G")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_HOST != 0 {
|
if flags&unix.RTF_HOST != 0 {
|
||||||
flagStrs = append(flagStrs, "H")
|
flagStrs = append(flagStrs, "H")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_REJECT != 0 {
|
if flags&unix.RTF_REJECT != 0 {
|
||||||
flagStrs = append(flagStrs, "R")
|
flagStrs = append(flagStrs, "R")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_DYNAMIC != 0 {
|
if flags&unix.RTF_DYNAMIC != 0 {
|
||||||
flagStrs = append(flagStrs, "D")
|
flagStrs = append(flagStrs, "D")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_MODIFIED != 0 {
|
if flags&unix.RTF_MODIFIED != 0 {
|
||||||
flagStrs = append(flagStrs, "M")
|
flagStrs = append(flagStrs, "M")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_STATIC != 0 {
|
if flags&unix.RTF_STATIC != 0 {
|
||||||
flagStrs = append(flagStrs, "S")
|
flagStrs = append(flagStrs, "S")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_LLINFO != 0 {
|
if flags&unix.RTF_LLINFO != 0 {
|
||||||
flagStrs = append(flagStrs, "L")
|
flagStrs = append(flagStrs, "L")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_LOCAL != 0 {
|
if flags&unix.RTF_LOCAL != 0 {
|
||||||
flagStrs = append(flagStrs, "l")
|
flagStrs = append(flagStrs, "l")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_BLACKHOLE != 0 {
|
if flags&unix.RTF_BLACKHOLE != 0 {
|
||||||
flagStrs = append(flagStrs, "B")
|
flagStrs = append(flagStrs, "B")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_CLONING != 0 {
|
if flags&unix.RTF_CLONING != 0 {
|
||||||
flagStrs = append(flagStrs, "C")
|
flagStrs = append(flagStrs, "C")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_WASCLONED != 0 {
|
if flags&unix.RTF_WASCLONED != 0 {
|
||||||
flagStrs = append(flagStrs, "W")
|
flagStrs = append(flagStrs, "W")
|
||||||
}
|
}
|
||||||
|
if flags&unix.RTF_PROTO1 != 0 {
|
||||||
|
flagStrs = append(flagStrs, "1")
|
||||||
|
}
|
||||||
|
if flags&unix.RTF_PROTO2 != 0 {
|
||||||
|
flagStrs = append(flagStrs, "2")
|
||||||
|
}
|
||||||
|
if flags&unix.RTF_PROTO3 != 0 {
|
||||||
|
flagStrs = append(flagStrs, "3")
|
||||||
|
}
|
||||||
|
|
||||||
if len(flagStrs) == 0 {
|
if len(flagStrs) == 0 {
|
||||||
return "-"
|
return "-"
|
||||||
|
|||||||
@@ -4,17 +4,18 @@ package systemops
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
// filterRoutesByFlags returns true if the route message should be ignored based on its flags.
|
// filterRoutesByFlags returns true if the route message should be ignored based on its flags.
|
||||||
func filterRoutesByFlags(routeMessageFlags int) bool {
|
func filterRoutesByFlags(routeMessageFlags int) bool {
|
||||||
if routeMessageFlags&syscall.RTF_UP == 0 {
|
if routeMessageFlags&unix.RTF_UP == 0 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOTE: syscall.RTF_WASCLONED deprecated in FreeBSD 8.0
|
// NOTE: RTF_WASCLONED deprecated in FreeBSD 8.0
|
||||||
if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE) != 0 {
|
if routeMessageFlags&(unix.RTF_REJECT|unix.RTF_BLACKHOLE) != 0 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -25,37 +26,46 @@ func filterRoutesByFlags(routeMessageFlags int) bool {
|
|||||||
func formatBSDFlags(flags int) string {
|
func formatBSDFlags(flags int) string {
|
||||||
var flagStrs []string
|
var flagStrs []string
|
||||||
|
|
||||||
if flags&syscall.RTF_UP != 0 {
|
if flags&unix.RTF_UP != 0 {
|
||||||
flagStrs = append(flagStrs, "U")
|
flagStrs = append(flagStrs, "U")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_GATEWAY != 0 {
|
if flags&unix.RTF_GATEWAY != 0 {
|
||||||
flagStrs = append(flagStrs, "G")
|
flagStrs = append(flagStrs, "G")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_HOST != 0 {
|
if flags&unix.RTF_HOST != 0 {
|
||||||
flagStrs = append(flagStrs, "H")
|
flagStrs = append(flagStrs, "H")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_REJECT != 0 {
|
if flags&unix.RTF_REJECT != 0 {
|
||||||
flagStrs = append(flagStrs, "R")
|
flagStrs = append(flagStrs, "R")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_DYNAMIC != 0 {
|
if flags&unix.RTF_DYNAMIC != 0 {
|
||||||
flagStrs = append(flagStrs, "D")
|
flagStrs = append(flagStrs, "D")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_MODIFIED != 0 {
|
if flags&unix.RTF_MODIFIED != 0 {
|
||||||
flagStrs = append(flagStrs, "M")
|
flagStrs = append(flagStrs, "M")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_STATIC != 0 {
|
if flags&unix.RTF_STATIC != 0 {
|
||||||
flagStrs = append(flagStrs, "S")
|
flagStrs = append(flagStrs, "S")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_LLINFO != 0 {
|
if flags&unix.RTF_LLINFO != 0 {
|
||||||
flagStrs = append(flagStrs, "L")
|
flagStrs = append(flagStrs, "L")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_LOCAL != 0 {
|
if flags&unix.RTF_LOCAL != 0 {
|
||||||
flagStrs = append(flagStrs, "l")
|
flagStrs = append(flagStrs, "l")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_BLACKHOLE != 0 {
|
if flags&unix.RTF_BLACKHOLE != 0 {
|
||||||
flagStrs = append(flagStrs, "B")
|
flagStrs = append(flagStrs, "B")
|
||||||
}
|
}
|
||||||
// Note: RTF_CLONING and RTF_WASCLONED deprecated in FreeBSD 8.0
|
// Note: RTF_CLONING and RTF_WASCLONED deprecated in FreeBSD 8.0
|
||||||
|
if flags&unix.RTF_PROTO1 != 0 {
|
||||||
|
flagStrs = append(flagStrs, "1")
|
||||||
|
}
|
||||||
|
if flags&unix.RTF_PROTO2 != 0 {
|
||||||
|
flagStrs = append(flagStrs, "2")
|
||||||
|
}
|
||||||
|
if flags&unix.RTF_PROTO3 != 0 {
|
||||||
|
flagStrs = append(flagStrs, "3")
|
||||||
|
}
|
||||||
|
|
||||||
if len(flagStrs) == 0 {
|
if len(flagStrs) == 0 {
|
||||||
return "-"
|
return "-"
|
||||||
|
|||||||
@@ -327,6 +327,60 @@ func ensureLocalConnector(ctx context.Context, stor storage.Storage) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HasNonLocalConnectors checks if there are any connectors other than the local connector.
|
||||||
|
func (p *Provider) HasNonLocalConnectors(ctx context.Context) (bool, error) {
|
||||||
|
connectors, err := p.storage.ListConnectors(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("failed to list connectors: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
p.logger.Info("checking for non-local connectors", "total_connectors", len(connectors))
|
||||||
|
for _, conn := range connectors {
|
||||||
|
p.logger.Info("found connector in storage", "id", conn.ID, "type", conn.Type, "name", conn.Name)
|
||||||
|
if conn.ID != "local" || conn.Type != "local" {
|
||||||
|
p.logger.Info("found non-local connector", "id", conn.ID)
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
p.logger.Info("no non-local connectors found")
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DisableLocalAuth removes the local (password) connector.
|
||||||
|
// Returns an error if no other connectors are configured.
|
||||||
|
func (p *Provider) DisableLocalAuth(ctx context.Context) error {
|
||||||
|
hasOthers, err := p.HasNonLocalConnectors(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !hasOthers {
|
||||||
|
return fmt.Errorf("cannot disable local authentication: no other identity providers configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if local connector exists
|
||||||
|
_, err = p.storage.GetConnector(ctx, "local")
|
||||||
|
if errors.Is(err, storage.ErrNotFound) {
|
||||||
|
// Already disabled
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to check local connector: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete the local connector
|
||||||
|
if err := p.storage.DeleteConnector(ctx, "local"); err != nil {
|
||||||
|
return fmt.Errorf("failed to delete local connector: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
p.logger.Info("local authentication disabled")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// EnableLocalAuth creates the local (password) connector if it doesn't exist.
|
||||||
|
func (p *Provider) EnableLocalAuth(ctx context.Context) error {
|
||||||
|
return ensureLocalConnector(ctx, p.storage)
|
||||||
|
}
|
||||||
|
|
||||||
// ensureStaticConnectors creates or updates static connectors in storage
|
// ensureStaticConnectors creates or updates static connectors in storage
|
||||||
func ensureStaticConnectors(ctx context.Context, stor storage.Storage, connectors []Connector) error {
|
func ensureStaticConnectors(ctx context.Context, stor storage.Storage, connectors []Connector) error {
|
||||||
for _, conn := range connectors {
|
for _, conn := range connectors {
|
||||||
|
|||||||
@@ -250,7 +250,10 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
|||||||
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
|
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
|
||||||
c.metrics.CountToSyncResponseDuration(time.Since(start))
|
c.metrics.CountToSyncResponseDuration(time.Since(start))
|
||||||
|
|
||||||
c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{Update: update})
|
c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{
|
||||||
|
Update: update,
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
})
|
||||||
}(peer)
|
}(peer)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -374,7 +377,10 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
|
|||||||
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
|
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
|
||||||
|
|
||||||
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
|
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
|
||||||
c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{Update: update})
|
c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{
|
||||||
|
Update: update,
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
})
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -783,6 +789,7 @@ func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerI
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
})
|
})
|
||||||
c.peersUpdateManager.CloseChannel(ctx, peerID)
|
c.peersUpdateManager.CloseChannel(ctx, peerID)
|
||||||
|
|
||||||
|
|||||||
@@ -25,11 +25,14 @@ func TestCreateChannel(t *testing.T) {
|
|||||||
func TestSendUpdate(t *testing.T) {
|
func TestSendUpdate(t *testing.T) {
|
||||||
peer := "test-sendupdate"
|
peer := "test-sendupdate"
|
||||||
peersUpdater := NewPeersUpdateManager(nil)
|
peersUpdater := NewPeersUpdateManager(nil)
|
||||||
update1 := &network_map.UpdateMessage{Update: &proto.SyncResponse{
|
update1 := &network_map.UpdateMessage{
|
||||||
NetworkMap: &proto.NetworkMap{
|
Update: &proto.SyncResponse{
|
||||||
Serial: 0,
|
NetworkMap: &proto.NetworkMap{
|
||||||
|
Serial: 0,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}}
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
_ = peersUpdater.CreateChannel(context.Background(), peer)
|
_ = peersUpdater.CreateChannel(context.Background(), peer)
|
||||||
if _, ok := peersUpdater.peerChannels[peer]; !ok {
|
if _, ok := peersUpdater.peerChannels[peer]; !ok {
|
||||||
t.Error("Error creating the channel")
|
t.Error("Error creating the channel")
|
||||||
@@ -45,11 +48,14 @@ func TestSendUpdate(t *testing.T) {
|
|||||||
peersUpdater.SendUpdate(context.Background(), peer, update1)
|
peersUpdater.SendUpdate(context.Background(), peer, update1)
|
||||||
}
|
}
|
||||||
|
|
||||||
update2 := &network_map.UpdateMessage{Update: &proto.SyncResponse{
|
update2 := &network_map.UpdateMessage{
|
||||||
NetworkMap: &proto.NetworkMap{
|
Update: &proto.SyncResponse{
|
||||||
Serial: 10,
|
NetworkMap: &proto.NetworkMap{
|
||||||
|
Serial: 10,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}}
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
|
||||||
peersUpdater.SendUpdate(context.Background(), peer, update2)
|
peersUpdater.SendUpdate(context.Background(), peer, update2)
|
||||||
timeout := time.After(5 * time.Second)
|
timeout := time.After(5 * time.Second)
|
||||||
|
|||||||
@@ -4,6 +4,19 @@ import (
|
|||||||
"github.com/netbirdio/netbird/shared/management/proto"
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// MessageType indicates the type of update message for debouncing strategy
|
||||||
|
type MessageType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// MessageTypeNetworkMap represents network map updates (peers, routes, DNS, firewall)
|
||||||
|
// These updates can be safely debounced - only the latest state matters
|
||||||
|
MessageTypeNetworkMap MessageType = iota
|
||||||
|
// MessageTypeControlConfig represents control/config updates (tokens, peer expiration)
|
||||||
|
// These updates should not be dropped as they contain time-sensitive information
|
||||||
|
MessageTypeControlConfig
|
||||||
|
)
|
||||||
|
|
||||||
type UpdateMessage struct {
|
type UpdateMessage struct {
|
||||||
Update *proto.SyncResponse
|
Update *proto.SyncResponse
|
||||||
|
MessageType MessageType
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -72,7 +72,14 @@ func (s *BaseServer) UsersManager() users.Manager {
|
|||||||
func (s *BaseServer) SettingsManager() settings.Manager {
|
func (s *BaseServer) SettingsManager() settings.Manager {
|
||||||
return Create(s, func() settings.Manager {
|
return Create(s, func() settings.Manager {
|
||||||
extraSettingsManager := integrations.NewManager(s.EventStore())
|
extraSettingsManager := integrations.NewManager(s.EventStore())
|
||||||
return settings.NewManager(s.Store(), s.UsersManager(), extraSettingsManager, s.PermissionsManager())
|
|
||||||
|
idpConfig := settings.IdpConfig{}
|
||||||
|
if s.Config.EmbeddedIdP != nil && s.Config.EmbeddedIdP.Enabled {
|
||||||
|
idpConfig.EmbeddedIdpEnabled = true
|
||||||
|
idpConfig.LocalAuthDisabled = s.Config.EmbeddedIdP.LocalAuthDisabled
|
||||||
|
}
|
||||||
|
|
||||||
|
return settings.NewManager(s.Store(), s.UsersManager(), extraSettingsManager, s.PermissionsManager(), idpConfig)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -77,8 +77,9 @@ type Server struct {
|
|||||||
|
|
||||||
oAuthConfigProvider idp.OAuthConfigProvider
|
oAuthConfigProvider idp.OAuthConfigProvider
|
||||||
|
|
||||||
syncSem atomic.Int32
|
syncSem atomic.Int32
|
||||||
syncLim int32
|
syncLimEnabled bool
|
||||||
|
syncLim int32
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewServer creates a new Management server
|
// NewServer creates a new Management server
|
||||||
@@ -108,6 +109,7 @@ func NewServer(
|
|||||||
blockPeersWithSameConfig := strings.ToLower(os.Getenv(envBlockPeers)) == "true"
|
blockPeersWithSameConfig := strings.ToLower(os.Getenv(envBlockPeers)) == "true"
|
||||||
|
|
||||||
syncLim := int32(defaultSyncLim)
|
syncLim := int32(defaultSyncLim)
|
||||||
|
syncLimEnabled := true
|
||||||
if syncLimStr := os.Getenv(envConcurrentSyncs); syncLimStr != "" {
|
if syncLimStr := os.Getenv(envConcurrentSyncs); syncLimStr != "" {
|
||||||
syncLimParsed, err := strconv.Atoi(syncLimStr)
|
syncLimParsed, err := strconv.Atoi(syncLimStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -115,6 +117,9 @@ func NewServer(
|
|||||||
} else {
|
} else {
|
||||||
//nolint:gosec
|
//nolint:gosec
|
||||||
syncLim = int32(syncLimParsed)
|
syncLim = int32(syncLimParsed)
|
||||||
|
if syncLim < 0 {
|
||||||
|
syncLimEnabled = false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -134,7 +139,8 @@ func NewServer(
|
|||||||
|
|
||||||
loginFilter: newLoginFilter(),
|
loginFilter: newLoginFilter(),
|
||||||
|
|
||||||
syncLim: syncLim,
|
syncLim: syncLim,
|
||||||
|
syncLimEnabled: syncLimEnabled,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -212,7 +218,7 @@ func (s *Server) Job(srv proto.ManagementService_JobServer) error {
|
|||||||
// Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and
|
// Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and
|
||||||
// notifies the connected peer of any updates (e.g. new peers under the same account)
|
// notifies the connected peer of any updates (e.g. new peers under the same account)
|
||||||
func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error {
|
func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error {
|
||||||
if s.syncSem.Load() >= s.syncLim {
|
if s.syncLimEnabled && s.syncSem.Load() >= s.syncLim {
|
||||||
return status.Errorf(codes.ResourceExhausted, "too many concurrent sync requests, please try again later")
|
return status.Errorf(codes.ResourceExhausted, "too many concurrent sync requests, please try again later")
|
||||||
}
|
}
|
||||||
s.syncSem.Add(1)
|
s.syncSem.Add(1)
|
||||||
@@ -294,7 +300,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
|||||||
metahash := metaHash(peerMeta, realIP.String())
|
metahash := metaHash(peerMeta, realIP.String())
|
||||||
s.loginFilter.addLogin(peerKey.String(), metahash)
|
s.loginFilter.addLogin(peerKey.String(), metahash)
|
||||||
|
|
||||||
peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP)
|
peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP, reqStart)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err)
|
log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err)
|
||||||
s.syncSem.Add(-1)
|
s.syncSem.Add(-1)
|
||||||
@@ -305,7 +311,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err)
|
log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err)
|
||||||
s.syncSem.Add(-1)
|
s.syncSem.Add(-1)
|
||||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, reqStart)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -313,7 +319,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Debugf("error while notify peer connected for %s: %v", peerKey.String(), err)
|
log.WithContext(ctx).Debugf("error while notify peer connected for %s: %v", peerKey.String(), err)
|
||||||
s.syncSem.Add(-1)
|
s.syncSem.Add(-1)
|
||||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, reqStart)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -330,7 +336,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
|||||||
|
|
||||||
s.syncSem.Add(-1)
|
s.syncSem.Add(-1)
|
||||||
|
|
||||||
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv)
|
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv, reqStart)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleHandshake(ctx context.Context, srv proto.ManagementService_JobServer) (wgtypes.Key, error) {
|
func (s *Server) handleHandshake(ctx context.Context, srv proto.ManagementService_JobServer) (wgtypes.Key, error) {
|
||||||
@@ -398,11 +404,20 @@ func (s *Server) sendJobsLoop(ctx context.Context, accountID string, peerKey wgt
|
|||||||
}
|
}
|
||||||
|
|
||||||
// handleUpdates sends updates to the connected peer until the updates channel is closed.
|
// handleUpdates sends updates to the connected peer until the updates channel is closed.
|
||||||
func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
// It implements a backpressure mechanism that sends the first update immediately,
|
||||||
|
// then debounces subsequent rapid updates, ensuring only the latest update is sent
|
||||||
|
// after a quiet period.
|
||||||
|
func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *network_map.UpdateMessage, srv proto.ManagementService_SyncServer, streamStartTime time.Time) error {
|
||||||
log.WithContext(ctx).Tracef("starting to handle updates for peer %s", peerKey.String())
|
log.WithContext(ctx).Tracef("starting to handle updates for peer %s", peerKey.String())
|
||||||
|
|
||||||
|
// Create a debouncer for this peer connection
|
||||||
|
debouncer := NewUpdateDebouncer(1000 * time.Millisecond)
|
||||||
|
defer debouncer.Stop()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
// condition when there are some updates
|
// condition when there are some updates
|
||||||
|
// todo set the updates channel size to 1
|
||||||
case update, open := <-updates:
|
case update, open := <-updates:
|
||||||
if s.appMetrics != nil {
|
if s.appMetrics != nil {
|
||||||
s.appMetrics.GRPCMetrics().UpdateChannelQueueLength(len(updates) + 1)
|
s.appMetrics.GRPCMetrics().UpdateChannelQueueLength(len(updates) + 1)
|
||||||
@@ -410,20 +425,38 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg
|
|||||||
|
|
||||||
if !open {
|
if !open {
|
||||||
log.WithContext(ctx).Debugf("updates channel for peer %s was closed", peerKey.String())
|
log.WithContext(ctx).Debugf("updates channel for peer %s was closed", peerKey.String())
|
||||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String())
|
log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String())
|
||||||
if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv); err != nil {
|
if debouncer.ProcessUpdate(update) {
|
||||||
log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err)
|
// Send immediately (first update or after quiet period)
|
||||||
return err
|
if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv, streamStartTime); err != nil {
|
||||||
|
log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Timer expired - quiet period reached, send pending updates if any
|
||||||
|
case <-debouncer.TimerChannel():
|
||||||
|
pendingUpdates := debouncer.GetPendingUpdates()
|
||||||
|
if len(pendingUpdates) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
log.WithContext(ctx).Debugf("sending %d debounced update(s) for peer %s", len(pendingUpdates), peerKey.String())
|
||||||
|
for _, pendingUpdate := range pendingUpdates {
|
||||||
|
if err := s.sendUpdate(ctx, accountID, peerKey, peer, pendingUpdate, srv, streamStartTime); err != nil {
|
||||||
|
log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// condition when client <-> server connection has been terminated
|
// condition when client <-> server connection has been terminated
|
||||||
case <-srv.Context().Done():
|
case <-srv.Context().Done():
|
||||||
// happens when connection drops, e.g. client disconnects
|
// happens when connection drops, e.g. client disconnects
|
||||||
log.WithContext(ctx).Debugf("stream of peer %s has been closed", peerKey.String())
|
log.WithContext(ctx).Debugf("stream of peer %s has been closed", peerKey.String())
|
||||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
|
||||||
return srv.Context().Err()
|
return srv.Context().Err()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -431,16 +464,16 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg
|
|||||||
|
|
||||||
// sendUpdate encrypts the update message using the peer key and the server's wireguard key,
|
// sendUpdate encrypts the update message using the peer key and the server's wireguard key,
|
||||||
// then sends the encrypted message to the connected peer via the sync server.
|
// then sends the encrypted message to the connected peer via the sync server.
|
||||||
func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *network_map.UpdateMessage, srv proto.ManagementService_SyncServer, streamStartTime time.Time) error {
|
||||||
key, err := s.secretsManager.GetWGKey()
|
key, err := s.secretsManager.GetWGKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
|
||||||
return status.Errorf(codes.Internal, "failed processing update message")
|
return status.Errorf(codes.Internal, "failed processing update message")
|
||||||
}
|
}
|
||||||
|
|
||||||
encryptedResp, err := encryption.EncryptMessage(peerKey, key, update.Update)
|
encryptedResp, err := encryption.EncryptMessage(peerKey, key, update.Update)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
|
||||||
return status.Errorf(codes.Internal, "failed processing update message")
|
return status.Errorf(codes.Internal, "failed processing update message")
|
||||||
}
|
}
|
||||||
err = srv.Send(&proto.EncryptedMessage{
|
err = srv.Send(&proto.EncryptedMessage{
|
||||||
@@ -448,7 +481,7 @@ func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtyp
|
|||||||
Body: encryptedResp,
|
Body: encryptedResp,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
|
||||||
return status.Errorf(codes.Internal, "failed sending update message")
|
return status.Errorf(codes.Internal, "failed sending update message")
|
||||||
}
|
}
|
||||||
log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String())
|
log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String())
|
||||||
@@ -480,11 +513,15 @@ func (s *Server) sendJob(ctx context.Context, peerKey wgtypes.Key, job *job.Even
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) {
|
func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer, streamStartTime time.Time) {
|
||||||
unlock := s.acquirePeerLockByUID(ctx, peer.Key)
|
unlock := s.acquirePeerLockByUID(ctx, peer.Key)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
err := s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
|
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, streamStartTime)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) cancelPeerRoutinesWithoutLock(ctx context.Context, accountID string, peer *nbpeer.Peer, streamStartTime time.Time) {
|
||||||
|
err := s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key, streamStartTime)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("failed to disconnect peer %s properly: %v", peer.Key, err)
|
log.WithContext(ctx).Errorf("failed to disconnect peer %s properly: %v", peer.Key, err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -242,7 +242,10 @@ func (m *TimeBasedAuthSecretsManager) pushNewTURNAndRelayTokens(ctx context.Cont
|
|||||||
m.extendNetbirdConfig(ctx, peerID, accountID, update)
|
m.extendNetbirdConfig(ctx, peerID, accountID, update)
|
||||||
|
|
||||||
log.WithContext(ctx).Debugf("sending new TURN credentials to peer %s", peerID)
|
log.WithContext(ctx).Debugf("sending new TURN credentials to peer %s", peerID)
|
||||||
m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{Update: update})
|
m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{
|
||||||
|
Update: update,
|
||||||
|
MessageType: network_map.MessageTypeControlConfig,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, accountID, peerID string) {
|
func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, accountID, peerID string) {
|
||||||
@@ -266,7 +269,10 @@ func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, ac
|
|||||||
m.extendNetbirdConfig(ctx, peerID, accountID, update)
|
m.extendNetbirdConfig(ctx, peerID, accountID, update)
|
||||||
|
|
||||||
log.WithContext(ctx).Debugf("sending new relay credentials to peer %s", peerID)
|
log.WithContext(ctx).Debugf("sending new relay credentials to peer %s", peerID)
|
||||||
m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{Update: update})
|
m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{
|
||||||
|
Update: update,
|
||||||
|
MessageType: network_map.MessageTypeControlConfig,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *TimeBasedAuthSecretsManager) extendNetbirdConfig(ctx context.Context, peerID, accountID string, update *proto.SyncResponse) {
|
func (m *TimeBasedAuthSecretsManager) extendNetbirdConfig(ctx context.Context, peerID, accountID string, update *proto.SyncResponse) {
|
||||||
|
|||||||
103
management/internals/shared/grpc/update_debouncer.go
Normal file
103
management/internals/shared/grpc/update_debouncer.go
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
package grpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UpdateDebouncer implements a backpressure mechanism that:
|
||||||
|
// - Sends the first update immediately
|
||||||
|
// - Coalesces rapid subsequent network map updates (only latest matters)
|
||||||
|
// - Queues control/config updates (all must be delivered)
|
||||||
|
// - Preserves the order of messages (important for control configs between network maps)
|
||||||
|
// - Ensures pending updates are sent after a quiet period
|
||||||
|
type UpdateDebouncer struct {
|
||||||
|
debounceInterval time.Duration
|
||||||
|
timer *time.Timer
|
||||||
|
pendingUpdates []*network_map.UpdateMessage // Queue that preserves order
|
||||||
|
timerC <-chan time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewUpdateDebouncer creates a new debouncer with the specified interval
|
||||||
|
func NewUpdateDebouncer(interval time.Duration) *UpdateDebouncer {
|
||||||
|
return &UpdateDebouncer{
|
||||||
|
debounceInterval: interval,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProcessUpdate handles an incoming update and returns whether it should be sent immediately
|
||||||
|
func (d *UpdateDebouncer) ProcessUpdate(update *network_map.UpdateMessage) bool {
|
||||||
|
if d.timer == nil {
|
||||||
|
// No active debounce timer, signal to send immediately
|
||||||
|
// and start the debounce period
|
||||||
|
d.startTimer()
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Already in debounce period, accumulate this update preserving order
|
||||||
|
// Check if we should coalesce with the last pending update
|
||||||
|
if len(d.pendingUpdates) > 0 &&
|
||||||
|
update.MessageType == network_map.MessageTypeNetworkMap &&
|
||||||
|
d.pendingUpdates[len(d.pendingUpdates)-1].MessageType == network_map.MessageTypeNetworkMap {
|
||||||
|
// Replace the last network map with this one (coalesce consecutive network maps)
|
||||||
|
d.pendingUpdates[len(d.pendingUpdates)-1] = update
|
||||||
|
} else {
|
||||||
|
// Append to the queue (preserves order for control configs and non-consecutive network maps)
|
||||||
|
d.pendingUpdates = append(d.pendingUpdates, update)
|
||||||
|
}
|
||||||
|
d.resetTimer()
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// TimerChannel returns the timer channel for select statements
|
||||||
|
func (d *UpdateDebouncer) TimerChannel() <-chan time.Time {
|
||||||
|
if d.timer == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return d.timerC
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPendingUpdates returns and clears all pending updates after timer expiration.
|
||||||
|
// Updates are returned in the order they were received, with consecutive network maps
|
||||||
|
// already coalesced to only the latest one.
|
||||||
|
// If there were pending updates, it restarts the timer to continue debouncing.
|
||||||
|
// If there were no pending updates, it clears the timer (true quiet period).
|
||||||
|
func (d *UpdateDebouncer) GetPendingUpdates() []*network_map.UpdateMessage {
|
||||||
|
updates := d.pendingUpdates
|
||||||
|
d.pendingUpdates = nil
|
||||||
|
|
||||||
|
if len(updates) > 0 {
|
||||||
|
// There were pending updates, so updates are still coming rapidly
|
||||||
|
// Restart the timer to continue debouncing mode
|
||||||
|
if d.timer != nil {
|
||||||
|
d.timer.Reset(d.debounceInterval)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// No pending updates means true quiet period - return to immediate mode
|
||||||
|
d.timer = nil
|
||||||
|
d.timerC = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return updates
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop stops the debouncer and cleans up resources
|
||||||
|
func (d *UpdateDebouncer) Stop() {
|
||||||
|
if d.timer != nil {
|
||||||
|
d.timer.Stop()
|
||||||
|
d.timer = nil
|
||||||
|
d.timerC = nil
|
||||||
|
}
|
||||||
|
d.pendingUpdates = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *UpdateDebouncer) startTimer() {
|
||||||
|
d.timer = time.NewTimer(d.debounceInterval)
|
||||||
|
d.timerC = d.timer.C
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *UpdateDebouncer) resetTimer() {
|
||||||
|
d.timer.Stop()
|
||||||
|
d.timer.Reset(d.debounceInterval)
|
||||||
|
}
|
||||||
587
management/internals/shared/grpc/update_debouncer_test.go
Normal file
587
management/internals/shared/grpc/update_debouncer_test.go
Normal file
@@ -0,0 +1,587 @@
|
|||||||
|
package grpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUpdateDebouncer_FirstUpdateSentImmediately(t *testing.T) {
|
||||||
|
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
||||||
|
defer debouncer.Stop()
|
||||||
|
|
||||||
|
update := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
|
||||||
|
shouldSend := debouncer.ProcessUpdate(update)
|
||||||
|
|
||||||
|
if !shouldSend {
|
||||||
|
t.Error("First update should be sent immediately")
|
||||||
|
}
|
||||||
|
|
||||||
|
if debouncer.TimerChannel() == nil {
|
||||||
|
t.Error("Timer should be started after first update")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateDebouncer_RapidUpdatesCoalesced(t *testing.T) {
|
||||||
|
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
||||||
|
defer debouncer.Stop()
|
||||||
|
|
||||||
|
update1 := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
update2 := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
update3 := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
|
||||||
|
// First update should be sent immediately
|
||||||
|
if !debouncer.ProcessUpdate(update1) {
|
||||||
|
t.Error("First update should be sent immediately")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rapid subsequent updates should be coalesced
|
||||||
|
if debouncer.ProcessUpdate(update2) {
|
||||||
|
t.Error("Second rapid update should not be sent immediately")
|
||||||
|
}
|
||||||
|
|
||||||
|
if debouncer.ProcessUpdate(update3) {
|
||||||
|
t.Error("Third rapid update should not be sent immediately")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for debounce period
|
||||||
|
select {
|
||||||
|
case <-debouncer.TimerChannel():
|
||||||
|
pendingUpdates := debouncer.GetPendingUpdates()
|
||||||
|
if len(pendingUpdates) != 1 {
|
||||||
|
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
|
||||||
|
}
|
||||||
|
if pendingUpdates[0] != update3 {
|
||||||
|
t.Error("Should get the last update (update3)")
|
||||||
|
}
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Error("Timer should have fired")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateDebouncer_LastUpdateAlwaysSent(t *testing.T) {
|
||||||
|
debouncer := NewUpdateDebouncer(30 * time.Millisecond)
|
||||||
|
defer debouncer.Stop()
|
||||||
|
|
||||||
|
update1 := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
update2 := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send first update
|
||||||
|
debouncer.ProcessUpdate(update1)
|
||||||
|
|
||||||
|
// Send second update within debounce period
|
||||||
|
debouncer.ProcessUpdate(update2)
|
||||||
|
|
||||||
|
// Wait for timer
|
||||||
|
select {
|
||||||
|
case <-debouncer.TimerChannel():
|
||||||
|
pendingUpdates := debouncer.GetPendingUpdates()
|
||||||
|
if len(pendingUpdates) != 1 {
|
||||||
|
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
|
||||||
|
}
|
||||||
|
if pendingUpdates[0] != update2 {
|
||||||
|
t.Error("Should get the last update")
|
||||||
|
}
|
||||||
|
if pendingUpdates[0] == update1 {
|
||||||
|
t.Error("Should not get the first update")
|
||||||
|
}
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Error("Timer should have fired")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateDebouncer_TimerResetOnNewUpdate(t *testing.T) {
|
||||||
|
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
||||||
|
defer debouncer.Stop()
|
||||||
|
|
||||||
|
update1 := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
update2 := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
update3 := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send first update
|
||||||
|
debouncer.ProcessUpdate(update1)
|
||||||
|
|
||||||
|
// Wait a bit, but not the full debounce period
|
||||||
|
time.Sleep(30 * time.Millisecond)
|
||||||
|
|
||||||
|
// Send second update - should reset timer
|
||||||
|
debouncer.ProcessUpdate(update2)
|
||||||
|
|
||||||
|
// Wait a bit more
|
||||||
|
time.Sleep(30 * time.Millisecond)
|
||||||
|
|
||||||
|
// Send third update - should reset timer again
|
||||||
|
debouncer.ProcessUpdate(update3)
|
||||||
|
|
||||||
|
// Now wait for the timer (should fire after last update's reset)
|
||||||
|
select {
|
||||||
|
case <-debouncer.TimerChannel():
|
||||||
|
pendingUpdates := debouncer.GetPendingUpdates()
|
||||||
|
if len(pendingUpdates) != 1 {
|
||||||
|
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
|
||||||
|
}
|
||||||
|
if pendingUpdates[0] != update3 {
|
||||||
|
t.Error("Should get the last update (update3)")
|
||||||
|
}
|
||||||
|
// Timer should be restarted since there was a pending update
|
||||||
|
if debouncer.TimerChannel() == nil {
|
||||||
|
t.Error("Timer should be restarted after sending pending update")
|
||||||
|
}
|
||||||
|
case <-time.After(150 * time.Millisecond):
|
||||||
|
t.Error("Timer should have fired")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateDebouncer_TimerRestartsAfterPendingUpdateSent(t *testing.T) {
|
||||||
|
debouncer := NewUpdateDebouncer(30 * time.Millisecond)
|
||||||
|
defer debouncer.Stop()
|
||||||
|
|
||||||
|
update1 := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
update2 := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
update3 := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
|
||||||
|
// First update sent immediately
|
||||||
|
debouncer.ProcessUpdate(update1)
|
||||||
|
|
||||||
|
// Second update coalesced
|
||||||
|
debouncer.ProcessUpdate(update2)
|
||||||
|
|
||||||
|
// Wait for timer to expire
|
||||||
|
select {
|
||||||
|
case <-debouncer.TimerChannel():
|
||||||
|
pendingUpdates := debouncer.GetPendingUpdates()
|
||||||
|
|
||||||
|
if len(pendingUpdates) == 0 {
|
||||||
|
t.Fatal("Should have pending update")
|
||||||
|
}
|
||||||
|
|
||||||
|
// After sending pending update, timer is restarted, so next update is NOT immediate
|
||||||
|
if debouncer.ProcessUpdate(update3) {
|
||||||
|
t.Error("Update after debounced send should not be sent immediately (timer restarted)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for the restarted timer and verify update3 is pending
|
||||||
|
select {
|
||||||
|
case <-debouncer.TimerChannel():
|
||||||
|
finalUpdates := debouncer.GetPendingUpdates()
|
||||||
|
if len(finalUpdates) != 1 || finalUpdates[0] != update3 {
|
||||||
|
t.Error("Should get update3 as pending")
|
||||||
|
}
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Error("Timer should have fired for restarted timer")
|
||||||
|
}
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Error("Timer should have fired")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateDebouncer_StopCleansUp(t *testing.T) {
|
||||||
|
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
||||||
|
|
||||||
|
update := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send update to start timer
|
||||||
|
debouncer.ProcessUpdate(update)
|
||||||
|
|
||||||
|
// Stop should clean up
|
||||||
|
debouncer.Stop()
|
||||||
|
|
||||||
|
// Multiple stops should be safe
|
||||||
|
debouncer.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateDebouncer_HighFrequencyUpdates(t *testing.T) {
|
||||||
|
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
||||||
|
defer debouncer.Stop()
|
||||||
|
|
||||||
|
// Simulate high-frequency updates
|
||||||
|
var lastUpdate *network_map.UpdateMessage
|
||||||
|
sentImmediately := 0
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
update := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{
|
||||||
|
NetworkMap: &proto.NetworkMap{
|
||||||
|
Serial: uint64(i),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
lastUpdate = update
|
||||||
|
if debouncer.ProcessUpdate(update) {
|
||||||
|
sentImmediately++
|
||||||
|
}
|
||||||
|
time.Sleep(1 * time.Millisecond) // Very rapid updates
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only first update should be sent immediately
|
||||||
|
if sentImmediately != 1 {
|
||||||
|
t.Errorf("Expected only 1 update sent immediately, got %d", sentImmediately)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for debounce period
|
||||||
|
select {
|
||||||
|
case <-debouncer.TimerChannel():
|
||||||
|
pendingUpdates := debouncer.GetPendingUpdates()
|
||||||
|
if len(pendingUpdates) != 1 {
|
||||||
|
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
|
||||||
|
}
|
||||||
|
if pendingUpdates[0] != lastUpdate {
|
||||||
|
t.Error("Should get the very last update")
|
||||||
|
}
|
||||||
|
if pendingUpdates[0].Update.NetworkMap.Serial != 99 {
|
||||||
|
t.Errorf("Expected serial 99, got %d", pendingUpdates[0].Update.NetworkMap.Serial)
|
||||||
|
}
|
||||||
|
case <-time.After(200 * time.Millisecond):
|
||||||
|
t.Error("Timer should have fired")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateDebouncer_NoUpdatesAfterFirst(t *testing.T) {
|
||||||
|
debouncer := NewUpdateDebouncer(30 * time.Millisecond)
|
||||||
|
defer debouncer.Stop()
|
||||||
|
|
||||||
|
update := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send first update
|
||||||
|
if !debouncer.ProcessUpdate(update) {
|
||||||
|
t.Error("First update should be sent immediately")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for timer to expire with no additional updates (true quiet period)
|
||||||
|
select {
|
||||||
|
case <-debouncer.TimerChannel():
|
||||||
|
pendingUpdates := debouncer.GetPendingUpdates()
|
||||||
|
if len(pendingUpdates) != 0 {
|
||||||
|
t.Error("Should have no pending updates")
|
||||||
|
}
|
||||||
|
// After true quiet period, timer should be cleared
|
||||||
|
if debouncer.TimerChannel() != nil {
|
||||||
|
t.Error("Timer should be cleared after quiet period")
|
||||||
|
}
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Error("Timer should have fired")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateDebouncer_IntermediateUpdatesDropped(t *testing.T) {
|
||||||
|
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
||||||
|
defer debouncer.Stop()
|
||||||
|
|
||||||
|
updates := make([]*network_map.UpdateMessage, 5)
|
||||||
|
for i := range updates {
|
||||||
|
updates[i] = &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{
|
||||||
|
NetworkMap: &proto.NetworkMap{
|
||||||
|
Serial: uint64(i),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// First update sent immediately
|
||||||
|
debouncer.ProcessUpdate(updates[0])
|
||||||
|
|
||||||
|
// Send updates 1, 2, 3, 4 rapidly - only last one should remain pending
|
||||||
|
debouncer.ProcessUpdate(updates[1])
|
||||||
|
debouncer.ProcessUpdate(updates[2])
|
||||||
|
debouncer.ProcessUpdate(updates[3])
|
||||||
|
debouncer.ProcessUpdate(updates[4])
|
||||||
|
|
||||||
|
// Wait for debounce
|
||||||
|
<-debouncer.TimerChannel()
|
||||||
|
pendingUpdates := debouncer.GetPendingUpdates()
|
||||||
|
|
||||||
|
if len(pendingUpdates) != 1 {
|
||||||
|
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
|
||||||
|
}
|
||||||
|
if pendingUpdates[0].Update.NetworkMap.Serial != 4 {
|
||||||
|
t.Errorf("Expected only the last update (serial 4), got serial %d", pendingUpdates[0].Update.NetworkMap.Serial)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateDebouncer_TrueQuietPeriodResetsToImmediateMode(t *testing.T) {
|
||||||
|
debouncer := NewUpdateDebouncer(30 * time.Millisecond)
|
||||||
|
defer debouncer.Stop()
|
||||||
|
|
||||||
|
update1 := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
update2 := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
|
||||||
|
// First update sent immediately
|
||||||
|
if !debouncer.ProcessUpdate(update1) {
|
||||||
|
t.Error("First update should be sent immediately")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for timer without sending any more updates (true quiet period)
|
||||||
|
<-debouncer.TimerChannel()
|
||||||
|
pendingUpdates := debouncer.GetPendingUpdates()
|
||||||
|
|
||||||
|
if len(pendingUpdates) != 0 {
|
||||||
|
t.Error("Should have no pending updates during quiet period")
|
||||||
|
}
|
||||||
|
|
||||||
|
// After true quiet period, next update should be sent immediately
|
||||||
|
if !debouncer.ProcessUpdate(update2) {
|
||||||
|
t.Error("Update after true quiet period should be sent immediately")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateDebouncer_ContinuousHighFrequencyStaysInDebounceMode(t *testing.T) {
|
||||||
|
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
||||||
|
defer debouncer.Stop()
|
||||||
|
|
||||||
|
// Simulate continuous high-frequency updates
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
update := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{
|
||||||
|
NetworkMap: &proto.NetworkMap{
|
||||||
|
Serial: uint64(i),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
|
||||||
|
if i == 0 {
|
||||||
|
// First one sent immediately
|
||||||
|
if !debouncer.ProcessUpdate(update) {
|
||||||
|
t.Error("First update should be sent immediately")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// All others should be coalesced (not sent immediately)
|
||||||
|
if debouncer.ProcessUpdate(update) {
|
||||||
|
t.Errorf("Update %d should not be sent immediately", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait a bit but send next update before debounce expires
|
||||||
|
time.Sleep(20 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now wait for final debounce
|
||||||
|
select {
|
||||||
|
case <-debouncer.TimerChannel():
|
||||||
|
pendingUpdates := debouncer.GetPendingUpdates()
|
||||||
|
if len(pendingUpdates) == 0 {
|
||||||
|
t.Fatal("Should have the last update pending")
|
||||||
|
}
|
||||||
|
if pendingUpdates[0].Update.NetworkMap.Serial != 9 {
|
||||||
|
t.Errorf("Expected serial 9, got %d", pendingUpdates[0].Update.NetworkMap.Serial)
|
||||||
|
}
|
||||||
|
case <-time.After(200 * time.Millisecond):
|
||||||
|
t.Error("Timer should have fired")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateDebouncer_ControlConfigMessagesQueued(t *testing.T) {
|
||||||
|
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
||||||
|
defer debouncer.Stop()
|
||||||
|
|
||||||
|
netmapUpdate := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 1}},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
tokenUpdate1 := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}},
|
||||||
|
MessageType: network_map.MessageTypeControlConfig,
|
||||||
|
}
|
||||||
|
tokenUpdate2 := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}},
|
||||||
|
MessageType: network_map.MessageTypeControlConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
// First update sent immediately
|
||||||
|
debouncer.ProcessUpdate(netmapUpdate)
|
||||||
|
|
||||||
|
// Send multiple control config updates - they should all be queued
|
||||||
|
debouncer.ProcessUpdate(tokenUpdate1)
|
||||||
|
debouncer.ProcessUpdate(tokenUpdate2)
|
||||||
|
|
||||||
|
// Wait for debounce period
|
||||||
|
select {
|
||||||
|
case <-debouncer.TimerChannel():
|
||||||
|
pendingUpdates := debouncer.GetPendingUpdates()
|
||||||
|
// Should get both control config updates
|
||||||
|
if len(pendingUpdates) != 2 {
|
||||||
|
t.Errorf("Expected 2 control config updates, got %d", len(pendingUpdates))
|
||||||
|
}
|
||||||
|
// Control configs should come first
|
||||||
|
if pendingUpdates[0] != tokenUpdate1 {
|
||||||
|
t.Error("First pending update should be tokenUpdate1")
|
||||||
|
}
|
||||||
|
if pendingUpdates[1] != tokenUpdate2 {
|
||||||
|
t.Error("Second pending update should be tokenUpdate2")
|
||||||
|
}
|
||||||
|
case <-time.After(200 * time.Millisecond):
|
||||||
|
t.Error("Timer should have fired")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateDebouncer_MixedMessageTypes(t *testing.T) {
|
||||||
|
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
||||||
|
defer debouncer.Stop()
|
||||||
|
|
||||||
|
netmapUpdate1 := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 1}},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
netmapUpdate2 := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 2}},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
tokenUpdate := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}},
|
||||||
|
MessageType: network_map.MessageTypeControlConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
// First update sent immediately
|
||||||
|
debouncer.ProcessUpdate(netmapUpdate1)
|
||||||
|
|
||||||
|
// Send token update and network map update
|
||||||
|
debouncer.ProcessUpdate(tokenUpdate)
|
||||||
|
debouncer.ProcessUpdate(netmapUpdate2)
|
||||||
|
|
||||||
|
// Wait for debounce period
|
||||||
|
select {
|
||||||
|
case <-debouncer.TimerChannel():
|
||||||
|
pendingUpdates := debouncer.GetPendingUpdates()
|
||||||
|
// Should get 2 updates in order: token, then network map
|
||||||
|
if len(pendingUpdates) != 2 {
|
||||||
|
t.Errorf("Expected 2 pending updates, got %d", len(pendingUpdates))
|
||||||
|
}
|
||||||
|
// Token update should come first (preserves order)
|
||||||
|
if pendingUpdates[0] != tokenUpdate {
|
||||||
|
t.Error("First pending update should be tokenUpdate")
|
||||||
|
}
|
||||||
|
// Network map update should come second
|
||||||
|
if pendingUpdates[1] != netmapUpdate2 {
|
||||||
|
t.Error("Second pending update should be netmapUpdate2")
|
||||||
|
}
|
||||||
|
case <-time.After(200 * time.Millisecond):
|
||||||
|
t.Error("Timer should have fired")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateDebouncer_OrderPreservation(t *testing.T) {
|
||||||
|
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
||||||
|
defer debouncer.Stop()
|
||||||
|
|
||||||
|
// Simulate: 50 network maps -> 1 control config -> 50 network maps
|
||||||
|
// Expected result: 3 messages (netmap, controlConfig, netmap)
|
||||||
|
|
||||||
|
// Send first network map immediately
|
||||||
|
firstNetmap := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 0}},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
if !debouncer.ProcessUpdate(firstNetmap) {
|
||||||
|
t.Error("First update should be sent immediately")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send 49 more network maps (will be coalesced to last one)
|
||||||
|
var lastNetmapBatch1 *network_map.UpdateMessage
|
||||||
|
for i := 1; i < 50; i++ {
|
||||||
|
lastNetmapBatch1 = &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: uint64(i)}},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
debouncer.ProcessUpdate(lastNetmapBatch1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send 1 control config
|
||||||
|
controlConfig := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}},
|
||||||
|
MessageType: network_map.MessageTypeControlConfig,
|
||||||
|
}
|
||||||
|
debouncer.ProcessUpdate(controlConfig)
|
||||||
|
|
||||||
|
// Send 50 more network maps (will be coalesced to last one)
|
||||||
|
var lastNetmapBatch2 *network_map.UpdateMessage
|
||||||
|
for i := 50; i < 100; i++ {
|
||||||
|
lastNetmapBatch2 = &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: uint64(i)}},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
debouncer.ProcessUpdate(lastNetmapBatch2)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for debounce period
|
||||||
|
select {
|
||||||
|
case <-debouncer.TimerChannel():
|
||||||
|
pendingUpdates := debouncer.GetPendingUpdates()
|
||||||
|
// Should get exactly 3 updates: netmap, controlConfig, netmap
|
||||||
|
if len(pendingUpdates) != 3 {
|
||||||
|
t.Errorf("Expected 3 pending updates, got %d", len(pendingUpdates))
|
||||||
|
}
|
||||||
|
// First should be the last netmap from batch 1
|
||||||
|
if pendingUpdates[0] != lastNetmapBatch1 {
|
||||||
|
t.Error("First pending update should be last netmap from batch 1")
|
||||||
|
}
|
||||||
|
if pendingUpdates[0].Update.NetworkMap.Serial != 49 {
|
||||||
|
t.Errorf("Expected serial 49, got %d", pendingUpdates[0].Update.NetworkMap.Serial)
|
||||||
|
}
|
||||||
|
// Second should be the control config
|
||||||
|
if pendingUpdates[1] != controlConfig {
|
||||||
|
t.Error("Second pending update should be control config")
|
||||||
|
}
|
||||||
|
// Third should be the last netmap from batch 2
|
||||||
|
if pendingUpdates[2] != lastNetmapBatch2 {
|
||||||
|
t.Error("Third pending update should be last netmap from batch 2")
|
||||||
|
}
|
||||||
|
if pendingUpdates[2].Update.NetworkMap.Serial != 99 {
|
||||||
|
t.Errorf("Expected serial 99, got %d", pendingUpdates[2].Update.NetworkMap.Serial)
|
||||||
|
}
|
||||||
|
case <-time.After(200 * time.Millisecond):
|
||||||
|
t.Error("Timer should have fired")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -795,6 +795,19 @@ func IsEmbeddedIdp(i idp.Manager) bool {
|
|||||||
return ok
|
return ok
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsLocalAuthDisabled checks if local (email/password) authentication is disabled.
|
||||||
|
// Returns true only when using embedded IDP with local auth disabled in config.
|
||||||
|
func IsLocalAuthDisabled(ctx context.Context, i idp.Manager) bool {
|
||||||
|
if isNil(i) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
embeddedIdp, ok := i.(*idp.EmbeddedIdPManager)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return embeddedIdp.IsLocalAuthDisabled()
|
||||||
|
}
|
||||||
|
|
||||||
// addAccountIDToIDPAppMeta update user's app metadata in idp manager
|
// addAccountIDToIDPAppMeta update user's app metadata in idp manager
|
||||||
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error {
|
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error {
|
||||||
if !isNil(am.idpManager) && !IsEmbeddedIdp(am.idpManager) {
|
if !isNil(am.idpManager) && !IsEmbeddedIdp(am.idpManager) {
|
||||||
@@ -1657,13 +1670,13 @@ func domainIsUpToDate(domain string, domainCategory string, userAuth auth.UserAu
|
|||||||
return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain
|
return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||||
peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID)
|
peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err)
|
return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, accountID)
|
err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, accountID, syncTime)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
|
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
|
||||||
}
|
}
|
||||||
@@ -1671,8 +1684,20 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID
|
|||||||
return peer, netMap, postureChecks, dnsfwdPort, nil
|
return peer, netMap, postureChecks, dnsfwdPort, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error {
|
func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error {
|
||||||
err := am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID)
|
peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerPubKey)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Warnf("failed to get peer %s for disconnect check: %v", peerPubKey, err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if peer.Status.LastSeen.After(streamStartTime) {
|
||||||
|
log.WithContext(ctx).Tracef("peer %s has newer activity (lastSeen=%s > streamStart=%s), skipping disconnect",
|
||||||
|
peerPubKey, peer.Status.LastSeen.Format(time.RFC3339), streamStartTime.Format(time.RFC3339))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID, time.Now().UTC())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err)
|
log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ type Manager interface {
|
|||||||
GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
|
GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
|
||||||
ListUsers(ctx context.Context, accountID string) ([]*types.User, error)
|
ListUsers(ctx context.Context, accountID string) ([]*types.User, error)
|
||||||
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
||||||
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error
|
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error
|
||||||
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
|
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
|
||||||
UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||||
UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error
|
UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error
|
||||||
@@ -114,8 +114,8 @@ type Manager interface {
|
|||||||
UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error
|
UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error
|
||||||
GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error)
|
GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error)
|
||||||
GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error)
|
GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error)
|
||||||
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
||||||
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error
|
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error
|
||||||
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
||||||
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||||
GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error)
|
GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error)
|
||||||
|
|||||||
@@ -1881,7 +1881,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
|||||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
||||||
require.NoError(t, err, "unable to get the account")
|
require.NoError(t, err, "unable to get the account")
|
||||||
|
|
||||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID)
|
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC())
|
||||||
require.NoError(t, err, "unable to mark peer connected")
|
require.NoError(t, err, "unable to mark peer connected")
|
||||||
|
|
||||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
||||||
@@ -1952,7 +1952,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
|||||||
require.NoError(t, err, "unable to get the account")
|
require.NoError(t, err, "unable to get the account")
|
||||||
|
|
||||||
// when we mark peer as connected, the peer login expiration routine should trigger
|
// when we mark peer as connected, the peer login expiration routine should trigger
|
||||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID)
|
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC())
|
||||||
require.NoError(t, err, "unable to mark peer connected")
|
require.NoError(t, err, "unable to mark peer connected")
|
||||||
|
|
||||||
failed := waitTimeout(wg, time.Second)
|
failed := waitTimeout(wg, time.Second)
|
||||||
@@ -1961,6 +1961,82 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
|
||||||
|
manager, _, err := createManager(t)
|
||||||
|
require.NoError(t, err, "unable to create account manager")
|
||||||
|
|
||||||
|
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
||||||
|
require.NoError(t, err, "unable to create an account")
|
||||||
|
|
||||||
|
key, err := wgtypes.GenerateKey()
|
||||||
|
require.NoError(t, err, "unable to generate WireGuard key")
|
||||||
|
peerPubKey := key.PublicKey().String()
|
||||||
|
|
||||||
|
_, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||||
|
Key: peerPubKey,
|
||||||
|
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
|
||||||
|
}, false)
|
||||||
|
require.NoError(t, err, "unable to add peer")
|
||||||
|
|
||||||
|
t.Run("disconnect peer when streamStartTime is after LastSeen", func(t *testing.T) {
|
||||||
|
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, time.Now().UTC())
|
||||||
|
require.NoError(t, err, "unable to mark peer connected")
|
||||||
|
|
||||||
|
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||||
|
require.NoError(t, err, "unable to get peer")
|
||||||
|
require.True(t, peer.Status.Connected, "peer should be connected")
|
||||||
|
|
||||||
|
streamStartTime := time.Now().UTC()
|
||||||
|
|
||||||
|
err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, streamStartTime)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.False(t, peer.Status.Connected, "peer should be disconnected")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("skip disconnect when LastSeen is after streamStartTime (zombie stream protection)", func(t *testing.T) {
|
||||||
|
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, time.Now().UTC())
|
||||||
|
require.NoError(t, err, "unable to mark peer connected")
|
||||||
|
|
||||||
|
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, peer.Status.Connected, "peer should be connected")
|
||||||
|
|
||||||
|
streamStartTime := peer.Status.LastSeen.Add(-1 * time.Hour)
|
||||||
|
|
||||||
|
err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, streamStartTime)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, peer.Status.Connected,
|
||||||
|
"peer should remain connected because LastSeen > streamStartTime (zombie stream protection)")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("skip stale connect when peer already has newer LastSeen (blocked goroutine protection)", func(t *testing.T) {
|
||||||
|
node2SyncTime := time.Now().UTC()
|
||||||
|
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, node2SyncTime)
|
||||||
|
require.NoError(t, err, "node 2 should connect peer")
|
||||||
|
|
||||||
|
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, peer.Status.Connected, "peer should be connected")
|
||||||
|
require.Equal(t, node2SyncTime.Unix(), peer.Status.LastSeen.Unix(), "LastSeen should be node2SyncTime")
|
||||||
|
|
||||||
|
node1StaleSyncTime := node2SyncTime.Add(-1 * time.Minute)
|
||||||
|
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, node1StaleSyncTime)
|
||||||
|
require.NoError(t, err, "stale connect should not return error")
|
||||||
|
|
||||||
|
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, peer.Status.Connected, "peer should still be connected")
|
||||||
|
require.Equal(t, node2SyncTime.Unix(), peer.Status.LastSeen.Unix(),
|
||||||
|
"LastSeen should NOT be overwritten by stale syncTime from blocked goroutine")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *testing.T) {
|
func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *testing.T) {
|
||||||
manager, _, err := createManager(t)
|
manager, _, err := createManager(t)
|
||||||
require.NoError(t, err, "unable to create account manager")
|
require.NoError(t, err, "unable to create account manager")
|
||||||
@@ -1983,7 +2059,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
|||||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||||
require.NoError(t, err, "unable to get the account")
|
require.NoError(t, err, "unable to get the account")
|
||||||
|
|
||||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID)
|
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC())
|
||||||
require.NoError(t, err, "unable to mark peer connected")
|
require.NoError(t, err, "unable to mark peer connected")
|
||||||
|
|
||||||
wg := &sync.WaitGroup{}
|
wg := &sync.WaitGroup{}
|
||||||
@@ -3176,7 +3252,7 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) {
|
|||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
_, _, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1})
|
_, _, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1}, time.Now().UTC())
|
||||||
assert.NoError(b, err)
|
assert.NoError(b, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -140,15 +140,15 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
|
|||||||
return nil, fmt.Errorf("register integrations endpoints: %w", err)
|
return nil, fmt.Errorf("register integrations endpoints: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if embedded IdP is enabled
|
// Check if embedded IdP is enabled for instance manager
|
||||||
embeddedIdP, embeddedIdpEnabled := idpManager.(*idpmanager.EmbeddedIdPManager)
|
embeddedIdP, embeddedIdpEnabled := idpManager.(*idpmanager.EmbeddedIdPManager)
|
||||||
instanceManager, err := nbinstance.NewManager(ctx, accountManager.GetStore(), embeddedIdP)
|
instanceManager, err := nbinstance.NewManager(ctx, accountManager.GetStore(), embeddedIdP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create instance manager: %w", err)
|
return nil, fmt.Errorf("failed to create instance manager: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
accounts.AddEndpoints(accountManager, settingsManager, embeddedIdpEnabled, router)
|
accounts.AddEndpoints(accountManager, settingsManager, router)
|
||||||
peers.AddEndpoints(accountManager, router, networkMapController)
|
peers.AddEndpoints(accountManager, router, networkMapController, permissionsManager)
|
||||||
users.AddEndpoints(accountManager, router)
|
users.AddEndpoints(accountManager, router)
|
||||||
users.AddInvitesEndpoints(accountManager, router)
|
users.AddInvitesEndpoints(accountManager, router)
|
||||||
users.AddPublicInvitesEndpoints(accountManager, router)
|
users.AddPublicInvitesEndpoints(accountManager, router)
|
||||||
|
|||||||
@@ -36,24 +36,22 @@ const (
|
|||||||
|
|
||||||
// handler is a handler that handles the server.Account HTTP endpoints
|
// handler is a handler that handles the server.Account HTTP endpoints
|
||||||
type handler struct {
|
type handler struct {
|
||||||
accountManager account.Manager
|
accountManager account.Manager
|
||||||
settingsManager settings.Manager
|
settingsManager settings.Manager
|
||||||
embeddedIdpEnabled bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func AddEndpoints(accountManager account.Manager, settingsManager settings.Manager, embeddedIdpEnabled bool, router *mux.Router) {
|
func AddEndpoints(accountManager account.Manager, settingsManager settings.Manager, router *mux.Router) {
|
||||||
accountsHandler := newHandler(accountManager, settingsManager, embeddedIdpEnabled)
|
accountsHandler := newHandler(accountManager, settingsManager)
|
||||||
router.HandleFunc("/accounts/{accountId}", accountsHandler.updateAccount).Methods("PUT", "OPTIONS")
|
router.HandleFunc("/accounts/{accountId}", accountsHandler.updateAccount).Methods("PUT", "OPTIONS")
|
||||||
router.HandleFunc("/accounts/{accountId}", accountsHandler.deleteAccount).Methods("DELETE", "OPTIONS")
|
router.HandleFunc("/accounts/{accountId}", accountsHandler.deleteAccount).Methods("DELETE", "OPTIONS")
|
||||||
router.HandleFunc("/accounts", accountsHandler.getAllAccounts).Methods("GET", "OPTIONS")
|
router.HandleFunc("/accounts", accountsHandler.getAllAccounts).Methods("GET", "OPTIONS")
|
||||||
}
|
}
|
||||||
|
|
||||||
// newHandler creates a new handler HTTP handler
|
// newHandler creates a new handler HTTP handler
|
||||||
func newHandler(accountManager account.Manager, settingsManager settings.Manager, embeddedIdpEnabled bool) *handler {
|
func newHandler(accountManager account.Manager, settingsManager settings.Manager) *handler {
|
||||||
return &handler{
|
return &handler{
|
||||||
accountManager: accountManager,
|
accountManager: accountManager,
|
||||||
settingsManager: settingsManager,
|
settingsManager: settingsManager,
|
||||||
embeddedIdpEnabled: embeddedIdpEnabled,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -165,7 +163,7 @@ func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := toAccountResponse(accountID, settings, meta, onboarding, h.embeddedIdpEnabled)
|
resp := toAccountResponse(accountID, settings, meta, onboarding)
|
||||||
util.WriteJSONObject(r.Context(), w, []*api.Account{resp})
|
util.WriteJSONObject(r.Context(), w, []*api.Account{resp})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -292,7 +290,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := toAccountResponse(accountID, updatedSettings, meta, updatedOnboarding, h.embeddedIdpEnabled)
|
resp := toAccountResponse(accountID, updatedSettings, meta, updatedOnboarding)
|
||||||
|
|
||||||
util.WriteJSONObject(r.Context(), w, &resp)
|
util.WriteJSONObject(r.Context(), w, &resp)
|
||||||
}
|
}
|
||||||
@@ -321,7 +319,7 @@ func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request) {
|
|||||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||||
}
|
}
|
||||||
|
|
||||||
func toAccountResponse(accountID string, settings *types.Settings, meta *types.AccountMeta, onboarding *types.AccountOnboarding, embeddedIdpEnabled bool) *api.Account {
|
func toAccountResponse(accountID string, settings *types.Settings, meta *types.AccountMeta, onboarding *types.AccountOnboarding) *api.Account {
|
||||||
jwtAllowGroups := settings.JWTAllowGroups
|
jwtAllowGroups := settings.JWTAllowGroups
|
||||||
if jwtAllowGroups == nil {
|
if jwtAllowGroups == nil {
|
||||||
jwtAllowGroups = []string{}
|
jwtAllowGroups = []string{}
|
||||||
@@ -341,7 +339,8 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A
|
|||||||
LazyConnectionEnabled: &settings.LazyConnectionEnabled,
|
LazyConnectionEnabled: &settings.LazyConnectionEnabled,
|
||||||
DnsDomain: &settings.DNSDomain,
|
DnsDomain: &settings.DNSDomain,
|
||||||
AutoUpdateVersion: &settings.AutoUpdateVersion,
|
AutoUpdateVersion: &settings.AutoUpdateVersion,
|
||||||
EmbeddedIdpEnabled: &embeddedIdpEnabled,
|
EmbeddedIdpEnabled: &settings.EmbeddedIdpEnabled,
|
||||||
|
LocalAuthDisabled: &settings.LocalAuthDisabled,
|
||||||
}
|
}
|
||||||
|
|
||||||
if settings.NetworkRange.IsValid() {
|
if settings.NetworkRange.IsValid() {
|
||||||
|
|||||||
@@ -33,7 +33,6 @@ func initAccountsTestData(t *testing.T, account *types.Account) *handler {
|
|||||||
AnyTimes()
|
AnyTimes()
|
||||||
|
|
||||||
return &handler{
|
return &handler{
|
||||||
embeddedIdpEnabled: false,
|
|
||||||
accountManager: &mock_server.MockAccountManager{
|
accountManager: &mock_server.MockAccountManager{
|
||||||
GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*types.Settings, error) {
|
GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*types.Settings, error) {
|
||||||
return account.Settings, nil
|
return account.Settings, nil
|
||||||
@@ -124,6 +123,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
|||||||
DnsDomain: sr(""),
|
DnsDomain: sr(""),
|
||||||
AutoUpdateVersion: sr(""),
|
AutoUpdateVersion: sr(""),
|
||||||
EmbeddedIdpEnabled: br(false),
|
EmbeddedIdpEnabled: br(false),
|
||||||
|
LocalAuthDisabled: br(false),
|
||||||
},
|
},
|
||||||
expectedArray: true,
|
expectedArray: true,
|
||||||
expectedID: accountID,
|
expectedID: accountID,
|
||||||
@@ -148,6 +148,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
|||||||
DnsDomain: sr(""),
|
DnsDomain: sr(""),
|
||||||
AutoUpdateVersion: sr(""),
|
AutoUpdateVersion: sr(""),
|
||||||
EmbeddedIdpEnabled: br(false),
|
EmbeddedIdpEnabled: br(false),
|
||||||
|
LocalAuthDisabled: br(false),
|
||||||
},
|
},
|
||||||
expectedArray: false,
|
expectedArray: false,
|
||||||
expectedID: accountID,
|
expectedID: accountID,
|
||||||
@@ -172,6 +173,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
|||||||
DnsDomain: sr(""),
|
DnsDomain: sr(""),
|
||||||
AutoUpdateVersion: sr("latest"),
|
AutoUpdateVersion: sr("latest"),
|
||||||
EmbeddedIdpEnabled: br(false),
|
EmbeddedIdpEnabled: br(false),
|
||||||
|
LocalAuthDisabled: br(false),
|
||||||
},
|
},
|
||||||
expectedArray: false,
|
expectedArray: false,
|
||||||
expectedID: accountID,
|
expectedID: accountID,
|
||||||
@@ -196,6 +198,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
|||||||
DnsDomain: sr(""),
|
DnsDomain: sr(""),
|
||||||
AutoUpdateVersion: sr(""),
|
AutoUpdateVersion: sr(""),
|
||||||
EmbeddedIdpEnabled: br(false),
|
EmbeddedIdpEnabled: br(false),
|
||||||
|
LocalAuthDisabled: br(false),
|
||||||
},
|
},
|
||||||
expectedArray: false,
|
expectedArray: false,
|
||||||
expectedID: accountID,
|
expectedID: accountID,
|
||||||
@@ -220,6 +223,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
|||||||
DnsDomain: sr(""),
|
DnsDomain: sr(""),
|
||||||
AutoUpdateVersion: sr(""),
|
AutoUpdateVersion: sr(""),
|
||||||
EmbeddedIdpEnabled: br(false),
|
EmbeddedIdpEnabled: br(false),
|
||||||
|
LocalAuthDisabled: br(false),
|
||||||
},
|
},
|
||||||
expectedArray: false,
|
expectedArray: false,
|
||||||
expectedID: accountID,
|
expectedID: accountID,
|
||||||
@@ -244,6 +248,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
|||||||
DnsDomain: sr(""),
|
DnsDomain: sr(""),
|
||||||
AutoUpdateVersion: sr(""),
|
AutoUpdateVersion: sr(""),
|
||||||
EmbeddedIdpEnabled: br(false),
|
EmbeddedIdpEnabled: br(false),
|
||||||
|
LocalAuthDisabled: br(false),
|
||||||
},
|
},
|
||||||
expectedArray: false,
|
expectedArray: false,
|
||||||
expectedID: accountID,
|
expectedID: accountID,
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ func (h *handler) getInstanceStatus(w http.ResponseWriter, r *http.Request) {
|
|||||||
util.WriteErrorResponse("failed to check instance status", http.StatusInternalServerError, w)
|
util.WriteErrorResponse("failed to check instance status", http.StatusInternalServerError, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
log.WithContext(r.Context()).Infof("instance setup status: %v", setupRequired)
|
||||||
util.WriteJSONObject(r.Context(), w, api.InstanceStatus{
|
util.WriteJSONObject(r.Context(), w, api.InstanceStatus{
|
||||||
SetupRequired: setupRequired,
|
SetupRequired: setupRequired,
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import (
|
|||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
"github.com/netbirdio/netbird/management/server/groups"
|
"github.com/netbirdio/netbird/management/server/groups"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||||
@@ -26,11 +27,12 @@ import (
|
|||||||
// Handler is a handler that returns peers of the account
|
// Handler is a handler that returns peers of the account
|
||||||
type Handler struct {
|
type Handler struct {
|
||||||
accountManager account.Manager
|
accountManager account.Manager
|
||||||
|
permissionsManager permissions.Manager
|
||||||
networkMapController network_map.Controller
|
networkMapController network_map.Controller
|
||||||
}
|
}
|
||||||
|
|
||||||
func AddEndpoints(accountManager account.Manager, router *mux.Router, networkMapController network_map.Controller) {
|
func AddEndpoints(accountManager account.Manager, router *mux.Router, networkMapController network_map.Controller, permissionsManager permissions.Manager) {
|
||||||
peersHandler := NewHandler(accountManager, networkMapController)
|
peersHandler := NewHandler(accountManager, networkMapController, permissionsManager)
|
||||||
router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS")
|
router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS")
|
||||||
router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer).
|
router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer).
|
||||||
Methods("GET", "PUT", "DELETE", "OPTIONS")
|
Methods("GET", "PUT", "DELETE", "OPTIONS")
|
||||||
@@ -42,10 +44,11 @@ func AddEndpoints(accountManager account.Manager, router *mux.Router, networkMap
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewHandler creates a new peers Handler
|
// NewHandler creates a new peers Handler
|
||||||
func NewHandler(accountManager account.Manager, networkMapController network_map.Controller) *Handler {
|
func NewHandler(accountManager account.Manager, networkMapController network_map.Controller, permissionsManager permissions.Manager) *Handler {
|
||||||
return &Handler{
|
return &Handler{
|
||||||
accountManager: accountManager,
|
accountManager: accountManager,
|
||||||
networkMapController: networkMapController,
|
networkMapController: networkMapController,
|
||||||
|
permissionsManager: permissionsManager,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -359,13 +362,19 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, activity.SystemInitiator)
|
user, err := h.accountManager.GetUserByID(r.Context(), userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := h.accountManager.GetUserByID(r.Context(), userID)
|
err = h.permissionsManager.ValidateAccountAccess(r.Context(), accountID, user, false)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), status.NewPermissionDeniedError(), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, activity.SystemInitiator)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -13,13 +13,15 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
"go.uber.org/mock/gomock"
|
ugomock "go.uber.org/mock/gomock"
|
||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
"github.com/netbirdio/netbird/shared/auth"
|
"github.com/netbirdio/netbird/shared/auth"
|
||||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
@@ -102,7 +104,7 @@ func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
ctrl := gomock.NewController(t)
|
ctrl := ugomock.NewController(t)
|
||||||
|
|
||||||
networkMapController := network_map.NewMockController(ctrl)
|
networkMapController := network_map.NewMockController(ctrl)
|
||||||
networkMapController.EXPECT().
|
networkMapController.EXPECT().
|
||||||
@@ -110,6 +112,10 @@ func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler {
|
|||||||
Return("domain").
|
Return("domain").
|
||||||
AnyTimes()
|
AnyTimes()
|
||||||
|
|
||||||
|
ctrl2 := gomock.NewController(t)
|
||||||
|
permissionsManager := permissions.NewMockManager(ctrl2)
|
||||||
|
permissionsManager.EXPECT().ValidateAccountAccess(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||||
|
|
||||||
return &Handler{
|
return &Handler{
|
||||||
accountManager: &mock_server.MockAccountManager{
|
accountManager: &mock_server.MockAccountManager{
|
||||||
UpdatePeerFunc: func(_ context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
|
UpdatePeerFunc: func(_ context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
|
||||||
@@ -199,6 +205,7 @@ func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
networkMapController: networkMapController,
|
networkMapController: networkMapController,
|
||||||
|
permissionsManager: permissionsManager,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -205,6 +205,14 @@ func TestCreateInvite(t *testing.T) {
|
|||||||
return nil, status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
|
return nil, status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "local auth disabled",
|
||||||
|
requestBody: `{"email":"test@example.com","name":"Test User","role":"user","auto_groups":[]}`,
|
||||||
|
expectedStatus: http.StatusPreconditionFailed,
|
||||||
|
mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
|
||||||
|
return nil, status.Errorf(status.PreconditionFailed, "local user creation is disabled - use an external identity provider")
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "invalid JSON",
|
name: "invalid JSON",
|
||||||
requestBody: `{invalid json}`,
|
requestBody: `{invalid json}`,
|
||||||
@@ -376,6 +384,15 @@ func TestAcceptInvite(t *testing.T) {
|
|||||||
return status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
|
return status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "local auth disabled",
|
||||||
|
token: testInviteToken,
|
||||||
|
requestBody: `{"password":"SecurePass123!"}`,
|
||||||
|
expectedStatus: http.StatusPreconditionFailed,
|
||||||
|
mockFunc: func(ctx context.Context, token, password string) error {
|
||||||
|
return status.Errorf(status.PreconditionFailed, "local user creation is disabled - use an external identity provider")
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "missing token",
|
name: "missing token",
|
||||||
token: "",
|
token: "",
|
||||||
|
|||||||
@@ -73,7 +73,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
|||||||
proxyController := integrations.NewController(store)
|
proxyController := integrations.NewController(store)
|
||||||
userManager := users.NewManager(store)
|
userManager := users.NewManager(store)
|
||||||
permissionsManager := permissions.NewManager(store)
|
permissionsManager := permissions.NewManager(store)
|
||||||
settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager)
|
settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager, settings.IdpConfig{})
|
||||||
peersManager := peers.NewManager(store, permissionsManager)
|
peersManager := peers.NewManager(store, permissionsManager)
|
||||||
|
|
||||||
jobManager := job.NewJobManager(nil, store, peersManager)
|
jobManager := job.NewJobManager(nil, store, peersManager)
|
||||||
|
|||||||
@@ -43,6 +43,11 @@ type EmbeddedIdPConfig struct {
|
|||||||
Owner *OwnerConfig
|
Owner *OwnerConfig
|
||||||
// SignKeyRefreshEnabled enables automatic key rotation for signing keys
|
// SignKeyRefreshEnabled enables automatic key rotation for signing keys
|
||||||
SignKeyRefreshEnabled bool
|
SignKeyRefreshEnabled bool
|
||||||
|
// LocalAuthDisabled disables the local (email/password) authentication connector.
|
||||||
|
// When true, users cannot authenticate via email/password, only via external identity providers.
|
||||||
|
// Existing local users are preserved and will be able to login again if re-enabled.
|
||||||
|
// Cannot be enabled if no external identity provider connectors are configured.
|
||||||
|
LocalAuthDisabled bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// EmbeddedStorageConfig holds storage configuration for the embedded IdP.
|
// EmbeddedStorageConfig holds storage configuration for the embedded IdP.
|
||||||
@@ -110,6 +115,8 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
|
|||||||
Issuer: "NetBird",
|
Issuer: "NetBird",
|
||||||
Theme: "light",
|
Theme: "light",
|
||||||
},
|
},
|
||||||
|
// Always enable password DB initially - we disable the local connector after startup if needed.
|
||||||
|
// This ensures Dex has at least one connector during initialization.
|
||||||
EnablePasswordDB: true,
|
EnablePasswordDB: true,
|
||||||
StaticClients: []storage.Client{
|
StaticClients: []storage.Client{
|
||||||
{
|
{
|
||||||
@@ -197,11 +204,32 @@ func NewEmbeddedIdPManager(ctx context.Context, config *EmbeddedIdPConfig, appMe
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.WithContext(ctx).Debugf("initializing embedded Dex IDP with config: %+v", config)
|
||||||
|
|
||||||
provider, err := dex.NewProviderFromYAML(ctx, yamlConfig)
|
provider, err := dex.NewProviderFromYAML(ctx, yamlConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create embedded IdP provider: %w", err)
|
return nil, fmt.Errorf("failed to create embedded IdP provider: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If local auth is disabled, validate that other connectors exist
|
||||||
|
if config.LocalAuthDisabled {
|
||||||
|
hasOthers, err := provider.HasNonLocalConnectors(ctx)
|
||||||
|
if err != nil {
|
||||||
|
_ = provider.Stop(ctx)
|
||||||
|
return nil, fmt.Errorf("failed to check connectors: %w", err)
|
||||||
|
}
|
||||||
|
if !hasOthers {
|
||||||
|
_ = provider.Stop(ctx)
|
||||||
|
return nil, fmt.Errorf("cannot disable local authentication: no other identity providers configured")
|
||||||
|
}
|
||||||
|
// Ensure local connector is removed (it might exist from a previous run)
|
||||||
|
if err := provider.DisableLocalAuth(ctx); err != nil {
|
||||||
|
_ = provider.Stop(ctx)
|
||||||
|
return nil, fmt.Errorf("failed to disable local auth: %w", err)
|
||||||
|
}
|
||||||
|
log.WithContext(ctx).Info("local authentication disabled - only external identity providers can be used")
|
||||||
|
}
|
||||||
|
|
||||||
log.WithContext(ctx).Infof("embedded Dex IDP initialized with issuer: %s", yamlConfig.Issuer)
|
log.WithContext(ctx).Infof("embedded Dex IDP initialized with issuer: %s", yamlConfig.Issuer)
|
||||||
|
|
||||||
return &EmbeddedIdPManager{
|
return &EmbeddedIdPManager{
|
||||||
@@ -286,6 +314,8 @@ func (m *EmbeddedIdPManager) GetAllAccounts(ctx context.Context) (map[string][]*
|
|||||||
return nil, fmt.Errorf("failed to list users: %w", err)
|
return nil, fmt.Errorf("failed to list users: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.WithContext(ctx).Debugf("retrieved %d users from embedded IdP", len(users))
|
||||||
|
|
||||||
indexedUsers := make(map[string][]*UserData)
|
indexedUsers := make(map[string][]*UserData)
|
||||||
for _, user := range users {
|
for _, user := range users {
|
||||||
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], &UserData{
|
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], &UserData{
|
||||||
@@ -295,11 +325,17 @@ func (m *EmbeddedIdPManager) GetAllAccounts(ctx context.Context) (map[string][]*
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.WithContext(ctx).Debugf("retrieved %d users from embedded IdP", len(indexedUsers[UnsetAccountID]))
|
||||||
|
|
||||||
return indexedUsers, nil
|
return indexedUsers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateUser creates a new user in the embedded IdP.
|
// CreateUser creates a new user in the embedded IdP.
|
||||||
func (m *EmbeddedIdPManager) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) {
|
func (m *EmbeddedIdPManager) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) {
|
||||||
|
if m.config.LocalAuthDisabled {
|
||||||
|
return nil, fmt.Errorf("local user creation is disabled")
|
||||||
|
}
|
||||||
|
|
||||||
if m.appMetrics != nil {
|
if m.appMetrics != nil {
|
||||||
m.appMetrics.IDPMetrics().CountCreateUser()
|
m.appMetrics.IDPMetrics().CountCreateUser()
|
||||||
}
|
}
|
||||||
@@ -369,6 +405,10 @@ func (m *EmbeddedIdPManager) GetUserByEmail(ctx context.Context, email string) (
|
|||||||
// Unlike CreateUser which auto-generates a password, this method uses the provided password.
|
// Unlike CreateUser which auto-generates a password, this method uses the provided password.
|
||||||
// This is useful for instance setup where the user provides their own password.
|
// This is useful for instance setup where the user provides their own password.
|
||||||
func (m *EmbeddedIdPManager) CreateUserWithPassword(ctx context.Context, email, password, name string) (*UserData, error) {
|
func (m *EmbeddedIdPManager) CreateUserWithPassword(ctx context.Context, email, password, name string) (*UserData, error) {
|
||||||
|
if m.config.LocalAuthDisabled {
|
||||||
|
return nil, fmt.Errorf("local user creation is disabled")
|
||||||
|
}
|
||||||
|
|
||||||
if m.appMetrics != nil {
|
if m.appMetrics != nil {
|
||||||
m.appMetrics.IDPMetrics().CountCreateUser()
|
m.appMetrics.IDPMetrics().CountCreateUser()
|
||||||
}
|
}
|
||||||
@@ -558,3 +598,13 @@ func (m *EmbeddedIdPManager) GetClientIDs() []string {
|
|||||||
func (m *EmbeddedIdPManager) GetUserIDClaim() string {
|
func (m *EmbeddedIdPManager) GetUserIDClaim() string {
|
||||||
return defaultUserIDClaim
|
return defaultUserIDClaim
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsLocalAuthDisabled returns whether local authentication is disabled based on configuration.
|
||||||
|
func (m *EmbeddedIdPManager) IsLocalAuthDisabled() bool {
|
||||||
|
return m.config.LocalAuthDisabled
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasNonLocalConnectors checks if there are any identity provider connectors other than local.
|
||||||
|
func (m *EmbeddedIdPManager) HasNonLocalConnectors(ctx context.Context) (bool, error) {
|
||||||
|
return m.provider.HasNonLocalConnectors(ctx)
|
||||||
|
}
|
||||||
|
|||||||
@@ -370,3 +370,234 @@ func TestEmbeddedIdPManager_GetLocalKeysLocation(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEmbeddedIdPManager_LocalAuthDisabled(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
t.Run("cannot start with local auth disabled without other connectors", func(t *testing.T) {
|
||||||
|
tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer os.RemoveAll(tmpDir)
|
||||||
|
|
||||||
|
config := &EmbeddedIdPConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Issuer: "http://localhost:5556/dex",
|
||||||
|
LocalAuthDisabled: true,
|
||||||
|
Storage: EmbeddedStorageConfig{
|
||||||
|
Type: "sqlite3",
|
||||||
|
Config: EmbeddedStorageTypeConfig{
|
||||||
|
File: filepath.Join(tmpDir, "dex.db"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = NewEmbeddedIdPManager(ctx, config, nil)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "no other identity providers configured")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("local auth enabled by default", func(t *testing.T) {
|
||||||
|
tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer os.RemoveAll(tmpDir)
|
||||||
|
|
||||||
|
config := &EmbeddedIdPConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Issuer: "http://localhost:5556/dex",
|
||||||
|
Storage: EmbeddedStorageConfig{
|
||||||
|
Type: "sqlite3",
|
||||||
|
Config: EmbeddedStorageTypeConfig{
|
||||||
|
File: filepath.Join(tmpDir, "dex.db"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, err := NewEmbeddedIdPManager(ctx, config, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() { _ = manager.Stop(ctx) }()
|
||||||
|
|
||||||
|
// Verify local auth is enabled by default
|
||||||
|
assert.False(t, manager.IsLocalAuthDisabled())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("start with local auth disabled when connector exists", func(t *testing.T) {
|
||||||
|
tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer os.RemoveAll(tmpDir)
|
||||||
|
|
||||||
|
dbFile := filepath.Join(tmpDir, "dex.db")
|
||||||
|
|
||||||
|
// First, create a manager with local auth enabled and add a connector
|
||||||
|
config1 := &EmbeddedIdPConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Issuer: "http://localhost:5556/dex",
|
||||||
|
Storage: EmbeddedStorageConfig{
|
||||||
|
Type: "sqlite3",
|
||||||
|
Config: EmbeddedStorageTypeConfig{
|
||||||
|
File: dbFile,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
manager1, err := NewEmbeddedIdPManager(ctx, config1, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Create a user
|
||||||
|
userData, err := manager1.CreateUser(ctx, "preserved@example.com", "Preserved User", "account1", "admin@example.com")
|
||||||
|
require.NoError(t, err)
|
||||||
|
userID := userData.ID
|
||||||
|
|
||||||
|
// Add an external connector (Google doesn't require OIDC discovery)
|
||||||
|
_, err = manager1.CreateConnector(ctx, &dex.ConnectorConfig{
|
||||||
|
ID: "google-test",
|
||||||
|
Name: "Google Test",
|
||||||
|
Type: "google",
|
||||||
|
ClientID: "test-client-id",
|
||||||
|
ClientSecret: "test-client-secret",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Stop the first manager
|
||||||
|
err = manager1.Stop(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Now create a new manager with local auth disabled
|
||||||
|
config2 := &EmbeddedIdPConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Issuer: "http://localhost:5556/dex",
|
||||||
|
LocalAuthDisabled: true,
|
||||||
|
Storage: EmbeddedStorageConfig{
|
||||||
|
Type: "sqlite3",
|
||||||
|
Config: EmbeddedStorageTypeConfig{
|
||||||
|
File: dbFile,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
manager2, err := NewEmbeddedIdPManager(ctx, config2, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() { _ = manager2.Stop(ctx) }()
|
||||||
|
|
||||||
|
// Verify local auth is disabled via config
|
||||||
|
assert.True(t, manager2.IsLocalAuthDisabled())
|
||||||
|
|
||||||
|
// Verify the user still exists in storage (just can't login via local)
|
||||||
|
lookedUp, err := manager2.GetUserDataByID(ctx, userID, AppMetadata{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "preserved@example.com", lookedUp.Email)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("CreateUser fails when local auth is disabled", func(t *testing.T) {
|
||||||
|
tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer os.RemoveAll(tmpDir)
|
||||||
|
|
||||||
|
dbFile := filepath.Join(tmpDir, "dex.db")
|
||||||
|
|
||||||
|
// First, create a manager and add an external connector
|
||||||
|
config1 := &EmbeddedIdPConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Issuer: "http://localhost:5556/dex",
|
||||||
|
Storage: EmbeddedStorageConfig{
|
||||||
|
Type: "sqlite3",
|
||||||
|
Config: EmbeddedStorageTypeConfig{
|
||||||
|
File: dbFile,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
manager1, err := NewEmbeddedIdPManager(ctx, config1, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = manager1.CreateConnector(ctx, &dex.ConnectorConfig{
|
||||||
|
ID: "google-test",
|
||||||
|
Name: "Google Test",
|
||||||
|
Type: "google",
|
||||||
|
ClientID: "test-client-id",
|
||||||
|
ClientSecret: "test-client-secret",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = manager1.Stop(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Create manager with local auth disabled
|
||||||
|
config2 := &EmbeddedIdPConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Issuer: "http://localhost:5556/dex",
|
||||||
|
LocalAuthDisabled: true,
|
||||||
|
Storage: EmbeddedStorageConfig{
|
||||||
|
Type: "sqlite3",
|
||||||
|
Config: EmbeddedStorageTypeConfig{
|
||||||
|
File: dbFile,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
manager2, err := NewEmbeddedIdPManager(ctx, config2, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() { _ = manager2.Stop(ctx) }()
|
||||||
|
|
||||||
|
// Try to create a user - should fail
|
||||||
|
_, err = manager2.CreateUser(ctx, "newuser@example.com", "New User", "account1", "admin@example.com")
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "local user creation is disabled")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("CreateUserWithPassword fails when local auth is disabled", func(t *testing.T) {
|
||||||
|
tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer os.RemoveAll(tmpDir)
|
||||||
|
|
||||||
|
dbFile := filepath.Join(tmpDir, "dex.db")
|
||||||
|
|
||||||
|
// First, create a manager and add an external connector
|
||||||
|
config1 := &EmbeddedIdPConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Issuer: "http://localhost:5556/dex",
|
||||||
|
Storage: EmbeddedStorageConfig{
|
||||||
|
Type: "sqlite3",
|
||||||
|
Config: EmbeddedStorageTypeConfig{
|
||||||
|
File: dbFile,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
manager1, err := NewEmbeddedIdPManager(ctx, config1, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = manager1.CreateConnector(ctx, &dex.ConnectorConfig{
|
||||||
|
ID: "google-test",
|
||||||
|
Name: "Google Test",
|
||||||
|
Type: "google",
|
||||||
|
ClientID: "test-client-id",
|
||||||
|
ClientSecret: "test-client-secret",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = manager1.Stop(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Create manager with local auth disabled
|
||||||
|
config2 := &EmbeddedIdPConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Issuer: "http://localhost:5556/dex",
|
||||||
|
LocalAuthDisabled: true,
|
||||||
|
Storage: EmbeddedStorageConfig{
|
||||||
|
Type: "sqlite3",
|
||||||
|
Config: EmbeddedStorageTypeConfig{
|
||||||
|
File: dbFile,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
manager2, err := NewEmbeddedIdPManager(ctx, config2, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() { _ = manager2.Stop(ctx) }()
|
||||||
|
|
||||||
|
// Try to create a user with password - should fail
|
||||||
|
_, err = manager2.CreateUserWithPassword(ctx, "newuser@example.com", "SecurePass123!", "New User")
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "local user creation is disabled")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -104,13 +104,22 @@ func NewManager(ctx context.Context, store store.Store, idpManager idp.Manager)
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *DefaultManager) loadSetupRequired(ctx context.Context) error {
|
func (m *DefaultManager) loadSetupRequired(ctx context.Context) error {
|
||||||
|
// Check if there are any accounts in the NetBird store
|
||||||
|
numAccounts, err := m.store.GetAccountsCounter(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
hasAccounts := numAccounts > 0
|
||||||
|
|
||||||
|
// Check if there are any users in the embedded IdP (Dex)
|
||||||
users, err := m.embeddedIdpManager.GetAllAccounts(ctx)
|
users, err := m.embeddedIdpManager.GetAllAccounts(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
hasLocalUsers := len(users) > 0
|
||||||
|
|
||||||
m.setupMu.Lock()
|
m.setupMu.Lock()
|
||||||
m.setupRequired = len(users) == 0
|
m.setupRequired = !(hasAccounts || hasLocalUsers)
|
||||||
m.setupMu.Unlock()
|
m.setupMu.Unlock()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -610,6 +610,7 @@ func TestSync10PeersGetUpdates(t *testing.T) {
|
|||||||
|
|
||||||
initialPeers := 10
|
initialPeers := 10
|
||||||
additionalPeers := 10
|
additionalPeers := 10
|
||||||
|
expectedPeerCount := initialPeers + additionalPeers - 1 // -1 because peer doesn't see itself
|
||||||
|
|
||||||
var peers []wgtypes.Key
|
var peers []wgtypes.Key
|
||||||
for i := 0; i < initialPeers; i++ {
|
for i := 0; i < initialPeers; i++ {
|
||||||
@@ -618,8 +619,19 @@ func TestSync10PeersGetUpdates(t *testing.T) {
|
|||||||
peers = append(peers, key)
|
peers = append(peers, key)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Track the maximum peer count each peer has seen
|
||||||
|
type peerState struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
maxPeerCount int
|
||||||
|
done bool
|
||||||
|
}
|
||||||
|
peerStates := make(map[string]*peerState)
|
||||||
|
for _, pk := range peers {
|
||||||
|
peerStates[pk.PublicKey().String()] = &peerState{}
|
||||||
|
}
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
wg.Add(initialPeers + initialPeers*additionalPeers)
|
wg.Add(initialPeers) // One completion per initial peer
|
||||||
|
|
||||||
var syncClients []mgmtProto.ManagementService_SyncClient
|
var syncClients []mgmtProto.ManagementService_SyncClient
|
||||||
for _, pk := range peers {
|
for _, pk := range peers {
|
||||||
@@ -643,6 +655,9 @@ func TestSync10PeersGetUpdates(t *testing.T) {
|
|||||||
syncClients = append(syncClients, s)
|
syncClients = append(syncClients, s)
|
||||||
|
|
||||||
go func(pk wgtypes.Key, syncStream mgmtProto.ManagementService_SyncClient) {
|
go func(pk wgtypes.Key, syncStream mgmtProto.ManagementService_SyncClient) {
|
||||||
|
pubKey := pk.PublicKey().String()
|
||||||
|
state := peerStates[pubKey]
|
||||||
|
|
||||||
for {
|
for {
|
||||||
encMsg := &mgmtProto.EncryptedMessage{}
|
encMsg := &mgmtProto.EncryptedMessage{}
|
||||||
err := syncStream.RecvMsg(encMsg)
|
err := syncStream.RecvMsg(encMsg)
|
||||||
@@ -651,19 +666,28 @@ func TestSync10PeersGetUpdates(t *testing.T) {
|
|||||||
}
|
}
|
||||||
decryptedBytes, decErr := encryption.Decrypt(encMsg.Body, ts.serverPubKey, pk)
|
decryptedBytes, decErr := encryption.Decrypt(encMsg.Body, ts.serverPubKey, pk)
|
||||||
if decErr != nil {
|
if decErr != nil {
|
||||||
t.Errorf("failed to decrypt SyncResponse for peer %s: %v", pk.PublicKey().String(), decErr)
|
t.Errorf("failed to decrypt SyncResponse for peer %s: %v", pubKey, decErr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
resp := &mgmtProto.SyncResponse{}
|
resp := &mgmtProto.SyncResponse{}
|
||||||
umErr := pb.Unmarshal(decryptedBytes, resp)
|
umErr := pb.Unmarshal(decryptedBytes, resp)
|
||||||
if umErr != nil {
|
if umErr != nil {
|
||||||
t.Errorf("failed to unmarshal SyncResponse for peer %s: %v", pk.PublicKey().String(), umErr)
|
t.Errorf("failed to unmarshal SyncResponse for peer %s: %v", pubKey, umErr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// We only count if there's a new peer update
|
|
||||||
if len(resp.GetRemotePeers()) > 0 {
|
// Track the maximum peer count seen (due to debouncing, updates are coalesced)
|
||||||
|
peerCount := len(resp.GetRemotePeers())
|
||||||
|
state.mu.Lock()
|
||||||
|
if peerCount > state.maxPeerCount {
|
||||||
|
state.maxPeerCount = peerCount
|
||||||
|
}
|
||||||
|
// Signal completion when this peer has seen all expected peers
|
||||||
|
if !state.done && state.maxPeerCount >= expectedPeerCount {
|
||||||
|
state.done = true
|
||||||
wg.Done()
|
wg.Done()
|
||||||
}
|
}
|
||||||
|
state.mu.Unlock()
|
||||||
}
|
}
|
||||||
}(pk, s)
|
}(pk, s)
|
||||||
}
|
}
|
||||||
@@ -677,7 +701,30 @@ func TestSync10PeersGetUpdates(t *testing.T) {
|
|||||||
time.Sleep(time.Duration(n) * time.Millisecond)
|
time.Sleep(time.Duration(n) * time.Millisecond)
|
||||||
}
|
}
|
||||||
|
|
||||||
wg.Wait()
|
// Wait for debouncer to flush final updates (debounce interval is 1000ms)
|
||||||
|
time.Sleep(1500 * time.Millisecond)
|
||||||
|
|
||||||
|
// Wait with timeout
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
wg.Wait()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
// Success - all peers received expected peer count
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
// Timeout - report which peers didn't receive all updates
|
||||||
|
t.Error("Timeout waiting for all peers to receive updates")
|
||||||
|
for pubKey, state := range peerStates {
|
||||||
|
state.mu.Lock()
|
||||||
|
if state.maxPeerCount < expectedPeerCount {
|
||||||
|
t.Errorf("Peer %s only saw %d peers, expected %d", pubKey, state.maxPeerCount, expectedPeerCount)
|
||||||
|
}
|
||||||
|
state.mu.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for _, sc := range syncClients {
|
for _, sc := range syncClients {
|
||||||
err := sc.CloseSend()
|
err := sc.CloseSend()
|
||||||
|
|||||||
@@ -37,8 +37,8 @@ type MockAccountManager struct {
|
|||||||
GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
|
GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
|
||||||
ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error)
|
ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error)
|
||||||
GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
||||||
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error
|
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP, syncTime time.Time) error
|
||||||
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
||||||
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
|
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
|
||||||
GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error)
|
GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error)
|
||||||
GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*types.Network, error)
|
GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*types.Network, error)
|
||||||
@@ -214,16 +214,15 @@ func (am *MockAccountManager) DeleteSetupKey(ctx context.Context, accountID, use
|
|||||||
return status.Errorf(codes.Unimplemented, "method DeleteSetupKey is not implemented")
|
return status.Errorf(codes.Unimplemented, "method DeleteSetupKey is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||||
if am.SyncAndMarkPeerFunc != nil {
|
if am.SyncAndMarkPeerFunc != nil {
|
||||||
return am.SyncAndMarkPeerFunc(ctx, accountID, peerPubKey, meta, realIP)
|
return am.SyncAndMarkPeerFunc(ctx, accountID, peerPubKey, meta, realIP, syncTime)
|
||||||
}
|
}
|
||||||
return nil, nil, nil, 0, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
|
return nil, nil, nil, 0, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID string, peerPubKey string) error {
|
func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error {
|
||||||
// TODO implement me
|
return nil
|
||||||
panic("implement me")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) {
|
func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) {
|
||||||
@@ -323,9 +322,9 @@ func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userAuth
|
|||||||
}
|
}
|
||||||
|
|
||||||
// MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface
|
// MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface
|
||||||
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error {
|
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error {
|
||||||
if am.MarkPeerConnectedFunc != nil {
|
if am.MarkPeerConnectedFunc != nil {
|
||||||
return am.MarkPeerConnectedFunc(ctx, peerKey, connected, realIP)
|
return am.MarkPeerConnectedFunc(ctx, peerKey, connected, realIP, syncTime)
|
||||||
}
|
}
|
||||||
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
|
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -103,11 +103,13 @@ func (am *DefaultAccountManager) getUserAccessiblePeers(ctx context.Context, acc
|
|||||||
}
|
}
|
||||||
|
|
||||||
// MarkPeerConnected marks peer as connected (true) or disconnected (false)
|
// MarkPeerConnected marks peer as connected (true) or disconnected (false)
|
||||||
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, accountID string) error {
|
// syncTime is used as the LastSeen timestamp and for stale request detection
|
||||||
|
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error {
|
||||||
var peer *nbpeer.Peer
|
var peer *nbpeer.Peer
|
||||||
var settings *types.Settings
|
var settings *types.Settings
|
||||||
var expired bool
|
var expired bool
|
||||||
var err error
|
var err error
|
||||||
|
var skipped bool
|
||||||
|
|
||||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
peer, err = transaction.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, peerPubKey)
|
peer, err = transaction.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, peerPubKey)
|
||||||
@@ -115,9 +117,19 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
expired, err = updatePeerStatusAndLocation(ctx, am.geo, transaction, peer, connected, realIP, accountID)
|
if connected && !syncTime.After(peer.Status.LastSeen) {
|
||||||
|
log.WithContext(ctx).Tracef("peer %s has newer activity (lastSeen=%s >= syncTime=%s), skipping connect",
|
||||||
|
peer.ID, peer.Status.LastSeen.Format(time.RFC3339), syncTime.Format(time.RFC3339))
|
||||||
|
skipped = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
expired, err = updatePeerStatusAndLocation(ctx, am.geo, transaction, peer, connected, realIP, accountID, syncTime)
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
|
if skipped {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -147,10 +159,10 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocation, transaction store.Store, peer *nbpeer.Peer, connected bool, realIP net.IP, accountID string) (bool, error) {
|
func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocation, transaction store.Store, peer *nbpeer.Peer, connected bool, realIP net.IP, accountID string, syncTime time.Time) (bool, error) {
|
||||||
oldStatus := peer.Status.Copy()
|
oldStatus := peer.Status.Copy()
|
||||||
newStatus := oldStatus
|
newStatus := oldStatus
|
||||||
newStatus.LastSeen = time.Now().UTC()
|
newStatus.LastSeen = syncTime
|
||||||
newStatus.Connected = connected
|
newStatus.Connected = connected
|
||||||
// whenever peer got connected that means that it logged in successfully
|
// whenever peer got connected that means that it logged in successfully
|
||||||
if newStatus.Connected {
|
if newStatus.Connected {
|
||||||
|
|||||||
@@ -24,19 +24,28 @@ type Manager interface {
|
|||||||
UpdateExtraSettings(ctx context.Context, accountID, userID string, extraSettings *types.ExtraSettings) (bool, error)
|
UpdateExtraSettings(ctx context.Context, accountID, userID string, extraSettings *types.ExtraSettings) (bool, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IdpConfig holds IdP-related configuration that is set at runtime
|
||||||
|
// and not stored in the database.
|
||||||
|
type IdpConfig struct {
|
||||||
|
EmbeddedIdpEnabled bool
|
||||||
|
LocalAuthDisabled bool
|
||||||
|
}
|
||||||
|
|
||||||
type managerImpl struct {
|
type managerImpl struct {
|
||||||
store store.Store
|
store store.Store
|
||||||
extraSettingsManager extra_settings.Manager
|
extraSettingsManager extra_settings.Manager
|
||||||
userManager users.Manager
|
userManager users.Manager
|
||||||
permissionsManager permissions.Manager
|
permissionsManager permissions.Manager
|
||||||
|
idpConfig IdpConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewManager(store store.Store, userManager users.Manager, extraSettingsManager extra_settings.Manager, permissionsManager permissions.Manager) Manager {
|
func NewManager(store store.Store, userManager users.Manager, extraSettingsManager extra_settings.Manager, permissionsManager permissions.Manager, idpConfig IdpConfig) Manager {
|
||||||
return &managerImpl{
|
return &managerImpl{
|
||||||
store: store,
|
store: store,
|
||||||
extraSettingsManager: extraSettingsManager,
|
extraSettingsManager: extraSettingsManager,
|
||||||
userManager: userManager,
|
userManager: userManager,
|
||||||
permissionsManager: permissionsManager,
|
permissionsManager: permissionsManager,
|
||||||
|
idpConfig: idpConfig,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -74,6 +83,10 @@ func (m *managerImpl) GetSettings(ctx context.Context, accountID, userID string)
|
|||||||
settings.Extra.FlowDnsCollectionEnabled = extraSettings.FlowDnsCollectionEnabled
|
settings.Extra.FlowDnsCollectionEnabled = extraSettings.FlowDnsCollectionEnabled
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Fill in IdP-related runtime settings
|
||||||
|
settings.EmbeddedIdpEnabled = m.idpConfig.EmbeddedIdpEnabled
|
||||||
|
settings.LocalAuthDisabled = m.idpConfig.LocalAuthDisabled
|
||||||
|
|
||||||
return settings, nil
|
return settings, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -55,6 +55,14 @@ type Settings struct {
|
|||||||
|
|
||||||
// AutoUpdateVersion client auto-update version
|
// AutoUpdateVersion client auto-update version
|
||||||
AutoUpdateVersion string `gorm:"default:'disabled'"`
|
AutoUpdateVersion string `gorm:"default:'disabled'"`
|
||||||
|
|
||||||
|
// EmbeddedIdpEnabled indicates if the embedded identity provider is enabled.
|
||||||
|
// This is a runtime-only field, not stored in the database.
|
||||||
|
EmbeddedIdpEnabled bool `gorm:"-"`
|
||||||
|
|
||||||
|
// LocalAuthDisabled indicates if local (email/password) authentication is disabled.
|
||||||
|
// This is a runtime-only field, not stored in the database.
|
||||||
|
LocalAuthDisabled bool `gorm:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copy copies the Settings struct
|
// Copy copies the Settings struct
|
||||||
@@ -76,6 +84,8 @@ func (s *Settings) Copy() *Settings {
|
|||||||
DNSDomain: s.DNSDomain,
|
DNSDomain: s.DNSDomain,
|
||||||
NetworkRange: s.NetworkRange,
|
NetworkRange: s.NetworkRange,
|
||||||
AutoUpdateVersion: s.AutoUpdateVersion,
|
AutoUpdateVersion: s.AutoUpdateVersion,
|
||||||
|
EmbeddedIdpEnabled: s.EmbeddedIdpEnabled,
|
||||||
|
LocalAuthDisabled: s.LocalAuthDisabled,
|
||||||
}
|
}
|
||||||
if s.Extra != nil {
|
if s.Extra != nil {
|
||||||
settings.Extra = s.Extra.Copy()
|
settings.Extra = s.Extra.Copy()
|
||||||
|
|||||||
@@ -191,6 +191,10 @@ func (am *DefaultAccountManager) createNewIdpUser(ctx context.Context, accountID
|
|||||||
// Unlike createNewIdpUser, this method fetches user data directly from the database
|
// Unlike createNewIdpUser, this method fetches user data directly from the database
|
||||||
// since the embedded IdP usage ensures the username and email are stored locally in the User table.
|
// since the embedded IdP usage ensures the username and email are stored locally in the User table.
|
||||||
func (am *DefaultAccountManager) createEmbeddedIdpUser(ctx context.Context, accountID string, inviterID string, invite *types.UserInfo) (*idp.UserData, error) {
|
func (am *DefaultAccountManager) createEmbeddedIdpUser(ctx context.Context, accountID string, inviterID string, invite *types.UserInfo) (*idp.UserData, error) {
|
||||||
|
if IsLocalAuthDisabled(ctx, am.idpManager) {
|
||||||
|
return nil, status.Errorf(status.PreconditionFailed, "local user creation is disabled - use an external identity provider")
|
||||||
|
}
|
||||||
|
|
||||||
inviter, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, inviterID)
|
inviter, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, inviterID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get inviter user: %w", err)
|
return nil, fmt.Errorf("failed to get inviter user: %w", err)
|
||||||
@@ -1462,6 +1466,10 @@ func (am *DefaultAccountManager) CreateUserInvite(ctx context.Context, accountID
|
|||||||
return nil, status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
|
return nil, status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if IsLocalAuthDisabled(ctx, am.idpManager) {
|
||||||
|
return nil, status.Errorf(status.PreconditionFailed, "local user creation is disabled - use an external identity provider")
|
||||||
|
}
|
||||||
|
|
||||||
if err := validateUserInvite(invite); err != nil {
|
if err := validateUserInvite(invite); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -1621,6 +1629,10 @@ func (am *DefaultAccountManager) AcceptUserInvite(ctx context.Context, token, pa
|
|||||||
return status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
|
return status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if IsLocalAuthDisabled(ctx, am.idpManager) {
|
||||||
|
return status.Errorf(status.PreconditionFailed, "local user creation is disabled - use an external identity provider")
|
||||||
|
}
|
||||||
|
|
||||||
if password == "" {
|
if password == "" {
|
||||||
return status.Errorf(status.InvalidArgument, "password is required")
|
return status.Errorf(status.InvalidArgument, "password is required")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -296,6 +296,11 @@ components:
|
|||||||
type: boolean
|
type: boolean
|
||||||
readOnly: true
|
readOnly: true
|
||||||
example: false
|
example: false
|
||||||
|
local_auth_disabled:
|
||||||
|
description: Indicates whether local (email/password) authentication is disabled. When true, users can only authenticate via external identity providers. This is a read-only field.
|
||||||
|
type: boolean
|
||||||
|
readOnly: true
|
||||||
|
example: false
|
||||||
required:
|
required:
|
||||||
- peer_login_expiration_enabled
|
- peer_login_expiration_enabled
|
||||||
- peer_login_expiration
|
- peer_login_expiration
|
||||||
|
|||||||
@@ -446,6 +446,9 @@ type AccountSettings struct {
|
|||||||
// LazyConnectionEnabled Enables or disables experimental lazy connection
|
// LazyConnectionEnabled Enables or disables experimental lazy connection
|
||||||
LazyConnectionEnabled *bool `json:"lazy_connection_enabled,omitempty"`
|
LazyConnectionEnabled *bool `json:"lazy_connection_enabled,omitempty"`
|
||||||
|
|
||||||
|
// LocalAuthDisabled Indicates whether local (email/password) authentication is disabled. When true, users can only authenticate via external identity providers. This is a read-only field.
|
||||||
|
LocalAuthDisabled *bool `json:"local_auth_disabled,omitempty"`
|
||||||
|
|
||||||
// NetworkRange Allows to define a custom network range for the account in CIDR format
|
// NetworkRange Allows to define a custom network range for the account in CIDR format
|
||||||
NetworkRange *string `json:"network_range,omitempty"`
|
NetworkRange *string `json:"network_range,omitempty"`
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user