mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-27 12:46:39 +00:00
Compare commits
21 Commits
feature/mi
...
feature/di
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
279e96e6b1 | ||
|
|
245481f33b | ||
|
|
b352ab84c0 | ||
|
|
3ce5d6a4f8 | ||
|
|
4c2eb2af73 | ||
|
|
daf1449174 | ||
|
|
1ff7abe909 | ||
|
|
067c77e49e | ||
|
|
291e640b28 | ||
|
|
efb954b7d6 | ||
|
|
cac9326d3d | ||
|
|
520d9c66cf | ||
|
|
ff10498a8b | ||
|
|
00b747ad5d | ||
|
|
d9118eb239 | ||
|
|
94de656fae | ||
|
|
37abab8b69 | ||
|
|
b12c084a50 | ||
|
|
394ad19507 | ||
|
|
614e7d5b90 | ||
|
|
f7967f9ae3 |
@@ -38,6 +38,11 @@
|
||||
|
||||
</strong>
|
||||
<br>
|
||||
<strong>
|
||||
🚀 <a href="https://careers.netbird.io">We are hiring! Join us at careers.netbird.io</a>
|
||||
</strong>
|
||||
<br>
|
||||
<br>
|
||||
<a href="https://registry.terraform.io/providers/netbirdio/netbird/latest">
|
||||
New: NetBird terraform provider
|
||||
</a>
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
|
||||
# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
|
||||
|
||||
FROM alpine:3.22.2
|
||||
FROM alpine:3.23.2
|
||||
# iproute2: busybox doesn't display ip rules properly
|
||||
RUN apk add --no-cache \
|
||||
bash \
|
||||
|
||||
@@ -314,9 +314,8 @@ func getStatusOutput(cmd *cobra.Command, anon bool) string {
|
||||
profName = activeProf.Name
|
||||
}
|
||||
|
||||
statusOutputString = nbstatus.ParseToFullDetailSummary(
|
||||
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", profName),
|
||||
)
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", profName)
|
||||
statusOutputString = overview.FullDetailSummary()
|
||||
}
|
||||
return statusOutputString
|
||||
}
|
||||
|
||||
@@ -103,13 +103,13 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
var statusOutputString string
|
||||
switch {
|
||||
case detailFlag:
|
||||
statusOutputString = nbstatus.ParseToFullDetailSummary(outputInformationHolder)
|
||||
statusOutputString = outputInformationHolder.FullDetailSummary()
|
||||
case jsonFlag:
|
||||
statusOutputString, err = nbstatus.ParseToJSON(outputInformationHolder)
|
||||
statusOutputString, err = outputInformationHolder.JSON()
|
||||
case yamlFlag:
|
||||
statusOutputString, err = nbstatus.ParseToYAML(outputInformationHolder)
|
||||
statusOutputString, err = outputInformationHolder.YAML()
|
||||
default:
|
||||
statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false, false)
|
||||
statusOutputString = outputInformationHolder.GeneralSummary(false, false, false, false)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
sshcommon "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -38,6 +39,7 @@ type Client struct {
|
||||
setupKey string
|
||||
jwtToken string
|
||||
connect *internal.ConnectClient
|
||||
recorder *peer.Status
|
||||
}
|
||||
|
||||
// Options configures a new Client.
|
||||
@@ -161,11 +163,17 @@ func New(opts Options) (*Client, error) {
|
||||
func (c *Client) Start(startCtx context.Context) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.cancel != nil {
|
||||
if c.connect != nil {
|
||||
return ErrClientAlreadyStarted
|
||||
}
|
||||
|
||||
ctx := internal.CtxInitState(context.Background())
|
||||
ctx, cancel := context.WithCancel(internal.CtxInitState(context.Background()))
|
||||
defer func() {
|
||||
if c.connect == nil {
|
||||
cancel()
|
||||
}
|
||||
}()
|
||||
|
||||
// nolint:staticcheck
|
||||
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
|
||||
if err := internal.Login(ctx, c.config, c.setupKey, c.jwtToken); err != nil {
|
||||
@@ -173,7 +181,9 @@ func (c *Client) Start(startCtx context.Context) error {
|
||||
}
|
||||
|
||||
recorder := peer.NewRecorder(c.config.ManagementURL.String())
|
||||
c.recorder = recorder
|
||||
client := internal.NewConnectClient(ctx, c.config, recorder, false)
|
||||
client.SetSyncResponsePersistence(true)
|
||||
|
||||
// either startup error (permanent backoff err) or nil err (successful engine up)
|
||||
// TODO: make after-startup backoff err available
|
||||
@@ -197,6 +207,7 @@ func (c *Client) Start(startCtx context.Context) error {
|
||||
}
|
||||
|
||||
c.connect = client
|
||||
c.cancel = cancel
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -211,17 +222,23 @@ func (c *Client) Stop(ctx context.Context) error {
|
||||
return ErrClientNotStarted
|
||||
}
|
||||
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
c.cancel = nil
|
||||
}
|
||||
|
||||
done := make(chan error, 1)
|
||||
connect := c.connect
|
||||
go func() {
|
||||
done <- c.connect.Stop()
|
||||
done <- connect.Stop()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
c.cancel = nil
|
||||
c.connect = nil
|
||||
return ctx.Err()
|
||||
case err := <-done:
|
||||
c.cancel = nil
|
||||
c.connect = nil
|
||||
if err != nil {
|
||||
return fmt.Errorf("stop: %w", err)
|
||||
}
|
||||
@@ -315,6 +332,62 @@ func (c *Client) NewHTTPClient() *http.Client {
|
||||
}
|
||||
}
|
||||
|
||||
// Status returns the current status of the client.
|
||||
func (c *Client) Status() (peer.FullStatus, error) {
|
||||
c.mu.Lock()
|
||||
recorder := c.recorder
|
||||
connect := c.connect
|
||||
c.mu.Unlock()
|
||||
|
||||
if recorder == nil {
|
||||
return peer.FullStatus{}, errors.New("client not started")
|
||||
}
|
||||
|
||||
if connect != nil {
|
||||
engine := connect.Engine()
|
||||
if engine != nil {
|
||||
_ = engine.RunHealthProbes(false)
|
||||
}
|
||||
}
|
||||
|
||||
return recorder.GetFullStatus(), nil
|
||||
}
|
||||
|
||||
// GetLatestSyncResponse returns the latest sync response from the management server.
|
||||
func (c *Client) GetLatestSyncResponse() (*mgmProto.SyncResponse, error) {
|
||||
engine, err := c.getEngine()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
syncResp, err := engine.GetLatestSyncResponse()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get sync response: %w", err)
|
||||
}
|
||||
|
||||
return syncResp, nil
|
||||
}
|
||||
|
||||
// SetLogLevel sets the logging level for the client and its components.
|
||||
func (c *Client) SetLogLevel(levelStr string) error {
|
||||
level, err := logrus.ParseLevel(levelStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse log level: %w", err)
|
||||
}
|
||||
|
||||
logrus.SetLevel(level)
|
||||
|
||||
c.mu.Lock()
|
||||
connect := c.connect
|
||||
c.mu.Unlock()
|
||||
|
||||
if connect != nil {
|
||||
connect.SetLogLevel(level)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// VerifySSHHostKey verifies an SSH host key against stored peer keys.
|
||||
// Returns nil if the key matches, ErrPeerNotFound if peer is not in network,
|
||||
// ErrNoStoredKey if peer has no stored key, or an error for verification failures.
|
||||
|
||||
@@ -420,6 +420,19 @@ func (c *ConnectClient) GetLatestSyncResponse() (*mgmProto.SyncResponse, error)
|
||||
return syncResponse, nil
|
||||
}
|
||||
|
||||
// SetLogLevel sets the log level for the firewall manager if the engine is running.
|
||||
func (c *ConnectClient) SetLogLevel(level log.Level) {
|
||||
engine := c.Engine()
|
||||
if engine == nil {
|
||||
return
|
||||
}
|
||||
|
||||
fwManager := engine.GetFirewallManager()
|
||||
if fwManager != nil {
|
||||
fwManager.SetLogLevel(level)
|
||||
}
|
||||
}
|
||||
|
||||
// Status returns the current client status
|
||||
func (c *ConnectClient) Status() StatusType {
|
||||
if c == nil {
|
||||
|
||||
@@ -76,7 +76,7 @@ func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.Simple
|
||||
var records []nbdns.SimpleRecord
|
||||
|
||||
for _, zone := range config.CustomZones {
|
||||
if zone.SkipPTRProcess {
|
||||
if zone.NonAuthoritative {
|
||||
continue
|
||||
}
|
||||
for _, record := range zone.Records {
|
||||
|
||||
@@ -3,17 +3,21 @@ package dns
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
||||
)
|
||||
|
||||
const (
|
||||
PriorityMgmtCache = 150
|
||||
PriorityLocal = 100
|
||||
PriorityDNSRoute = 75
|
||||
PriorityDNSRoute = 100
|
||||
PriorityLocal = 75
|
||||
PriorityUpstream = 50
|
||||
PriorityDefault = 1
|
||||
PriorityFallback = -100
|
||||
@@ -43,7 +47,23 @@ type HandlerChain struct {
|
||||
type ResponseWriterChain struct {
|
||||
dns.ResponseWriter
|
||||
origPattern string
|
||||
requestID string
|
||||
shouldContinue bool
|
||||
response *dns.Msg
|
||||
meta map[string]string
|
||||
}
|
||||
|
||||
// RequestID returns the request ID for tracing
|
||||
func (w *ResponseWriterChain) RequestID() string {
|
||||
return w.requestID
|
||||
}
|
||||
|
||||
// SetMeta sets a metadata key-value pair for logging
|
||||
func (w *ResponseWriterChain) SetMeta(key, value string) {
|
||||
if w.meta == nil {
|
||||
w.meta = make(map[string]string)
|
||||
}
|
||||
w.meta[key] = value
|
||||
}
|
||||
|
||||
func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error {
|
||||
@@ -52,6 +72,7 @@ func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error {
|
||||
w.shouldContinue = true
|
||||
return nil
|
||||
}
|
||||
w.response = m
|
||||
return w.ResponseWriter.WriteMsg(m)
|
||||
}
|
||||
|
||||
@@ -101,6 +122,8 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
|
||||
|
||||
pos := c.findHandlerPosition(entry)
|
||||
c.handlers = append(c.handlers[:pos], append([]HandlerEntry{entry}, c.handlers[pos:]...)...)
|
||||
|
||||
c.logHandlers()
|
||||
}
|
||||
|
||||
// findHandlerPosition determines where to insert a new handler based on priority and specificity
|
||||
@@ -140,68 +163,109 @@ func (c *HandlerChain) removeEntry(pattern string, priority int) {
|
||||
for i := len(c.handlers) - 1; i >= 0; i-- {
|
||||
entry := c.handlers[i]
|
||||
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
|
||||
log.Debugf("removing handler pattern: domain=%s priority=%d", entry.OrigPattern, priority)
|
||||
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
||||
c.logHandlers()
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// logHandlers logs the current handler chain state. Caller must hold the lock.
|
||||
func (c *HandlerChain) logHandlers() {
|
||||
if !log.IsLevelEnabled(log.TraceLevel) {
|
||||
return
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
b.WriteString("handler chain (" + strconv.Itoa(len(c.handlers)) + "):\n")
|
||||
for _, h := range c.handlers {
|
||||
b.WriteString(" - pattern: domain=" + h.Pattern + " original: domain=" + h.OrigPattern +
|
||||
" wildcard=" + strconv.FormatBool(h.IsWildcard) +
|
||||
" match_subdomain=" + strconv.FormatBool(h.MatchSubdomains) +
|
||||
" priority=" + strconv.Itoa(h.Priority) + "\n")
|
||||
}
|
||||
log.Trace(strings.TrimSuffix(b.String(), "\n"))
|
||||
}
|
||||
|
||||
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
if len(r.Question) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
qname := strings.ToLower(r.Question[0].Name)
|
||||
startTime := time.Now()
|
||||
requestID := resutil.GenerateRequestID()
|
||||
logger := log.WithFields(log.Fields{
|
||||
"request_id": requestID,
|
||||
"dns_id": fmt.Sprintf("%04x", r.Id),
|
||||
})
|
||||
|
||||
question := r.Question[0]
|
||||
qname := strings.ToLower(question.Name)
|
||||
|
||||
c.mu.RLock()
|
||||
handlers := slices.Clone(c.handlers)
|
||||
c.mu.RUnlock()
|
||||
|
||||
if log.IsLevelEnabled(log.TraceLevel) {
|
||||
var b strings.Builder
|
||||
b.WriteString(fmt.Sprintf("DNS request domain=%s, handlers (%d):\n", qname, len(handlers)))
|
||||
for _, h := range handlers {
|
||||
b.WriteString(fmt.Sprintf(" - pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d\n",
|
||||
h.Pattern, h.OrigPattern, h.IsWildcard, h.MatchSubdomains, h.Priority))
|
||||
}
|
||||
log.Trace(strings.TrimSuffix(b.String(), "\n"))
|
||||
}
|
||||
|
||||
// Try handlers in priority order
|
||||
for _, entry := range handlers {
|
||||
matched := c.isHandlerMatch(qname, entry)
|
||||
|
||||
if matched {
|
||||
log.Tracef("handler matched: domain=%s -> pattern=%s wildcard=%v match_subdomain=%v priority=%d",
|
||||
qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority)
|
||||
|
||||
chainWriter := &ResponseWriterChain{
|
||||
ResponseWriter: w,
|
||||
origPattern: entry.OrigPattern,
|
||||
}
|
||||
entry.Handler.ServeDNS(chainWriter, r)
|
||||
|
||||
// If handler wants to continue, try next handler
|
||||
if chainWriter.shouldContinue {
|
||||
// Only log continue for non-management cache handlers to reduce noise
|
||||
if entry.Priority != PriorityMgmtCache {
|
||||
log.Tracef("handler requested continue to next handler for domain=%s", qname)
|
||||
}
|
||||
continue
|
||||
}
|
||||
return
|
||||
if !c.isHandlerMatch(qname, entry) {
|
||||
continue
|
||||
}
|
||||
|
||||
handlerName := entry.OrigPattern
|
||||
if s, ok := entry.Handler.(interface{ String() string }); ok {
|
||||
handlerName = s.String()
|
||||
}
|
||||
|
||||
logger.Tracef("question: domain=%s type=%s class=%s -> handler=%s pattern=%s wildcard=%v match_subdomain=%v priority=%d",
|
||||
qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass],
|
||||
handlerName, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority)
|
||||
|
||||
chainWriter := &ResponseWriterChain{
|
||||
ResponseWriter: w,
|
||||
origPattern: entry.OrigPattern,
|
||||
requestID: requestID,
|
||||
}
|
||||
entry.Handler.ServeDNS(chainWriter, r)
|
||||
|
||||
// If handler wants to continue, try next handler
|
||||
if chainWriter.shouldContinue {
|
||||
if entry.Priority != PriorityMgmtCache {
|
||||
logger.Tracef("handler requested continue for domain=%s", qname)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
c.logResponse(logger, chainWriter, qname, startTime)
|
||||
return
|
||||
}
|
||||
|
||||
// No handler matched or all handlers passed
|
||||
log.Tracef("no handler found for domain=%s", qname)
|
||||
logger.Tracef("no handler found for domain=%s type=%s class=%s",
|
||||
qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
|
||||
resp := &dns.Msg{}
|
||||
resp.SetRcode(r, dns.RcodeRefused)
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
log.Errorf("failed to write DNS response: %v", err)
|
||||
logger.Errorf("failed to write DNS response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, qname string, startTime time.Time) {
|
||||
if cw.response == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var meta string
|
||||
for k, v := range cw.meta {
|
||||
meta += " " + k + "=" + v
|
||||
}
|
||||
|
||||
logger.Tracef("response: domain=%s rcode=%s answers=%s%s took=%s",
|
||||
qname, dns.RcodeToString[cw.response.Rcode], resutil.FormatAnswers(cw.response.Answer),
|
||||
meta, time.Since(startTime))
|
||||
}
|
||||
|
||||
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
|
||||
switch {
|
||||
case entry.Pattern == ".":
|
||||
|
||||
@@ -1,30 +1,52 @@
|
||||
package local
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/types"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
const externalResolutionTimeout = 4 * time.Second
|
||||
|
||||
type resolver interface {
|
||||
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
|
||||
}
|
||||
|
||||
type Resolver struct {
|
||||
mu sync.RWMutex
|
||||
records map[dns.Question][]dns.RR
|
||||
domains map[domain.Domain]struct{}
|
||||
// zones maps zone domain -> NonAuthoritative (true = non-authoritative, user-created zone)
|
||||
zones map[domain.Domain]bool
|
||||
resolver resolver
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func NewResolver() *Resolver {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &Resolver{
|
||||
records: make(map[dns.Question][]dns.RR),
|
||||
domains: make(map[domain.Domain]struct{}),
|
||||
zones: make(map[domain.Domain]bool),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -37,7 +59,18 @@ func (d *Resolver) String() string {
|
||||
return fmt.Sprintf("LocalResolver [%d records]", len(d.records))
|
||||
}
|
||||
|
||||
func (d *Resolver) Stop() {}
|
||||
func (d *Resolver) Stop() {
|
||||
if d.cancel != nil {
|
||||
d.cancel()
|
||||
}
|
||||
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
maps.Clear(d.records)
|
||||
maps.Clear(d.domains)
|
||||
maps.Clear(d.zones)
|
||||
}
|
||||
|
||||
// ID returns the unique handler ID
|
||||
func (d *Resolver) ID() types.HandlerID {
|
||||
@@ -48,35 +81,85 @@ func (d *Resolver) ProbeAvailability() {}
|
||||
|
||||
// ServeDNS handles a DNS request
|
||||
func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
logger := log.WithField("request_id", resutil.GetRequestID(w))
|
||||
|
||||
if len(r.Question) == 0 {
|
||||
log.Debugf("received local resolver request with no question")
|
||||
logger.Debug("received local resolver request with no question")
|
||||
return
|
||||
}
|
||||
question := r.Question[0]
|
||||
question.Name = strings.ToLower(dns.Fqdn(question.Name))
|
||||
|
||||
log.Tracef("received local question: domain=%s type=%v class=%v", r.Question[0].Name, question.Qtype, question.Qclass)
|
||||
|
||||
replyMessage := &dns.Msg{}
|
||||
replyMessage.SetReply(r)
|
||||
replyMessage.RecursionAvailable = true
|
||||
|
||||
// lookup all records matching the question
|
||||
records := d.lookupRecords(question)
|
||||
if len(records) > 0 {
|
||||
replyMessage.Rcode = dns.RcodeSuccess
|
||||
replyMessage.Answer = append(replyMessage.Answer, records...)
|
||||
} else {
|
||||
// Check if we have any records for this domain name with different types
|
||||
if d.hasRecordsForDomain(domain.Domain(question.Name)) {
|
||||
replyMessage.Rcode = dns.RcodeSuccess // NOERROR with 0 records
|
||||
} else {
|
||||
replyMessage.Rcode = dns.RcodeNameError // NXDOMAIN
|
||||
}
|
||||
result := d.lookupRecords(logger, question)
|
||||
replyMessage.Authoritative = !result.hasExternalData
|
||||
replyMessage.Answer = result.records
|
||||
replyMessage.Rcode = d.determineRcode(question, result)
|
||||
|
||||
if replyMessage.Rcode == dns.RcodeNameError && d.shouldFallthrough(question.Name) {
|
||||
d.continueToNext(logger, w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if err := w.WriteMsg(replyMessage); err != nil {
|
||||
log.Warnf("failed to write the local resolver response: %v", err)
|
||||
logger.Warnf("failed to write the local resolver response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// determineRcode returns the appropriate DNS response code.
|
||||
// Per RFC 6604, CNAME chains should return the rcode of the final target resolution,
|
||||
// even if CNAME records are included in the answer.
|
||||
func (d *Resolver) determineRcode(question dns.Question, result lookupResult) int {
|
||||
// Use the rcode from lookup - this properly handles CNAME chains where
|
||||
// the target may be NXDOMAIN or SERVFAIL even though we have CNAME records
|
||||
if result.rcode != 0 {
|
||||
return result.rcode
|
||||
}
|
||||
|
||||
// No records found, but domain exists with different record types (NODATA)
|
||||
if d.hasRecordsForDomain(domain.Domain(question.Name)) {
|
||||
return dns.RcodeSuccess
|
||||
}
|
||||
|
||||
return dns.RcodeNameError
|
||||
}
|
||||
|
||||
// findZone finds the matching zone for a query name using reverse suffix lookup.
|
||||
// Returns (nonAuthoritative, found). This is O(k) where k = number of labels in qname.
|
||||
func (d *Resolver) findZone(qname string) (nonAuthoritative bool, found bool) {
|
||||
qname = strings.ToLower(dns.Fqdn(qname))
|
||||
for {
|
||||
if nonAuth, ok := d.zones[domain.Domain(qname)]; ok {
|
||||
return nonAuth, true
|
||||
}
|
||||
// Move to parent domain
|
||||
idx := strings.Index(qname, ".")
|
||||
if idx == -1 || idx == len(qname)-1 {
|
||||
return false, false
|
||||
}
|
||||
qname = qname[idx+1:]
|
||||
}
|
||||
}
|
||||
|
||||
// shouldFallthrough checks if the query should fallthrough to the next handler.
|
||||
// Returns true if the queried name belongs to a non-authoritative zone.
|
||||
func (d *Resolver) shouldFallthrough(qname string) bool {
|
||||
d.mu.RLock()
|
||||
defer d.mu.RUnlock()
|
||||
|
||||
nonAuth, found := d.findZone(qname)
|
||||
return found && nonAuth
|
||||
}
|
||||
|
||||
func (d *Resolver) continueToNext(logger *log.Entry, w dns.ResponseWriter, r *dns.Msg) {
|
||||
resp := &dns.Msg{}
|
||||
resp.SetRcode(r, dns.RcodeNameError)
|
||||
resp.MsgHdr.Zero = true
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
logger.Warnf("failed to write continue signal: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -89,8 +172,27 @@ func (d *Resolver) hasRecordsForDomain(domainName domain.Domain) bool {
|
||||
return exists
|
||||
}
|
||||
|
||||
// isInManagedZone checks if the given name falls within any of our managed zones.
|
||||
// This is used to avoid unnecessary external resolution for CNAME targets that
|
||||
// are within zones we manage - if we don't have a record for it, it doesn't exist.
|
||||
// Caller must NOT hold the lock.
|
||||
func (d *Resolver) isInManagedZone(name string) bool {
|
||||
d.mu.RLock()
|
||||
defer d.mu.RUnlock()
|
||||
|
||||
_, found := d.findZone(name)
|
||||
return found
|
||||
}
|
||||
|
||||
// lookupResult contains the result of a DNS lookup operation.
|
||||
type lookupResult struct {
|
||||
records []dns.RR
|
||||
rcode int
|
||||
hasExternalData bool
|
||||
}
|
||||
|
||||
// lookupRecords fetches *all* DNS records matching the first question in r.
|
||||
func (d *Resolver) lookupRecords(question dns.Question) []dns.RR {
|
||||
func (d *Resolver) lookupRecords(logger *log.Entry, question dns.Question) lookupResult {
|
||||
d.mu.RLock()
|
||||
records, found := d.records[question]
|
||||
|
||||
@@ -98,10 +200,14 @@ func (d *Resolver) lookupRecords(question dns.Question) []dns.RR {
|
||||
d.mu.RUnlock()
|
||||
// alternatively check if we have a cname
|
||||
if question.Qtype != dns.TypeCNAME {
|
||||
question.Qtype = dns.TypeCNAME
|
||||
return d.lookupRecords(question)
|
||||
cnameQuestion := dns.Question{
|
||||
Name: question.Name,
|
||||
Qtype: dns.TypeCNAME,
|
||||
Qclass: question.Qclass,
|
||||
}
|
||||
return d.lookupCNAMEChain(logger, cnameQuestion, question.Qtype)
|
||||
}
|
||||
return nil
|
||||
return lookupResult{rcode: dns.RcodeNameError}
|
||||
}
|
||||
|
||||
recordsCopy := slices.Clone(records)
|
||||
@@ -119,20 +225,178 @@ func (d *Resolver) lookupRecords(question dns.Question) []dns.RR {
|
||||
d.mu.Unlock()
|
||||
}
|
||||
|
||||
return recordsCopy
|
||||
return lookupResult{records: recordsCopy, rcode: dns.RcodeSuccess}
|
||||
}
|
||||
|
||||
func (d *Resolver) Update(update []nbdns.SimpleRecord) {
|
||||
// lookupCNAMEChain follows a CNAME chain and returns the CNAME records along with
|
||||
// the final resolved record of the requested type. This is required for musl libc
|
||||
// compatibility, which expects the full answer chain rather than just the CNAME.
|
||||
func (d *Resolver) lookupCNAMEChain(logger *log.Entry, cnameQuestion dns.Question, targetType uint16) lookupResult {
|
||||
const maxDepth = 8
|
||||
var chain []dns.RR
|
||||
|
||||
for range maxDepth {
|
||||
cnameRecords := d.getRecords(cnameQuestion)
|
||||
if len(cnameRecords) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
chain = append(chain, cnameRecords...)
|
||||
|
||||
cname, ok := cnameRecords[0].(*dns.CNAME)
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
|
||||
targetName := strings.ToLower(cname.Target)
|
||||
targetResult := d.resolveCNAMETarget(logger, targetName, targetType, cnameQuestion.Qclass)
|
||||
|
||||
// keep following chain
|
||||
if targetResult.rcode == -1 {
|
||||
cnameQuestion = dns.Question{Name: targetName, Qtype: dns.TypeCNAME, Qclass: cnameQuestion.Qclass}
|
||||
continue
|
||||
}
|
||||
|
||||
return d.buildChainResult(chain, targetResult)
|
||||
}
|
||||
|
||||
if len(chain) > 0 {
|
||||
return lookupResult{records: chain, rcode: dns.RcodeSuccess}
|
||||
}
|
||||
return lookupResult{rcode: dns.RcodeSuccess}
|
||||
}
|
||||
|
||||
// buildChainResult combines CNAME chain records with the target resolution result.
|
||||
// Per RFC 6604, the final rcode is propagated through the chain.
|
||||
func (d *Resolver) buildChainResult(chain []dns.RR, target lookupResult) lookupResult {
|
||||
records := chain
|
||||
if len(target.records) > 0 {
|
||||
records = append(records, target.records...)
|
||||
}
|
||||
|
||||
// preserve hasExternalData for SERVFAIL so caller knows the error came from upstream
|
||||
if target.hasExternalData && target.rcode == dns.RcodeServerFailure {
|
||||
return lookupResult{
|
||||
records: records,
|
||||
rcode: dns.RcodeServerFailure,
|
||||
hasExternalData: true,
|
||||
}
|
||||
}
|
||||
|
||||
return lookupResult{
|
||||
records: records,
|
||||
rcode: target.rcode,
|
||||
hasExternalData: target.hasExternalData,
|
||||
}
|
||||
}
|
||||
|
||||
// resolveCNAMETarget attempts to resolve a CNAME target name.
|
||||
// Returns rcode=-1 to signal "keep following the chain".
|
||||
func (d *Resolver) resolveCNAMETarget(logger *log.Entry, targetName string, targetType uint16, qclass uint16) lookupResult {
|
||||
if records := d.getRecords(dns.Question{Name: targetName, Qtype: targetType, Qclass: qclass}); len(records) > 0 {
|
||||
return lookupResult{records: records, rcode: dns.RcodeSuccess}
|
||||
}
|
||||
|
||||
// another CNAME, keep following
|
||||
if d.hasRecord(dns.Question{Name: targetName, Qtype: dns.TypeCNAME, Qclass: qclass}) {
|
||||
return lookupResult{rcode: -1}
|
||||
}
|
||||
|
||||
// domain exists locally but not this record type (NODATA)
|
||||
if d.hasRecordsForDomain(domain.Domain(targetName)) {
|
||||
return lookupResult{rcode: dns.RcodeSuccess}
|
||||
}
|
||||
|
||||
// in our zone but doesn't exist (NXDOMAIN)
|
||||
if d.isInManagedZone(targetName) {
|
||||
return lookupResult{rcode: dns.RcodeNameError}
|
||||
}
|
||||
|
||||
return d.resolveExternal(logger, targetName, targetType)
|
||||
}
|
||||
|
||||
func (d *Resolver) getRecords(q dns.Question) []dns.RR {
|
||||
d.mu.RLock()
|
||||
defer d.mu.RUnlock()
|
||||
return d.records[q]
|
||||
}
|
||||
|
||||
func (d *Resolver) hasRecord(q dns.Question) bool {
|
||||
d.mu.RLock()
|
||||
defer d.mu.RUnlock()
|
||||
_, ok := d.records[q]
|
||||
return ok
|
||||
}
|
||||
|
||||
// resolveExternal resolves a domain name using the system resolver.
|
||||
// This is used to resolve CNAME targets that point outside our local zone,
|
||||
// which is required for musl libc compatibility (musl expects complete answers).
|
||||
func (d *Resolver) resolveExternal(logger *log.Entry, name string, qtype uint16) lookupResult {
|
||||
network := resutil.NetworkForQtype(qtype)
|
||||
if network == "" {
|
||||
return lookupResult{rcode: dns.RcodeNotImplemented}
|
||||
}
|
||||
|
||||
resolver := d.resolver
|
||||
if resolver == nil {
|
||||
resolver = net.DefaultResolver
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(d.ctx, externalResolutionTimeout)
|
||||
defer cancel()
|
||||
|
||||
result := resutil.LookupIP(ctx, resolver, network, name, qtype)
|
||||
if result.Err != nil {
|
||||
d.logDNSError(logger, name, qtype, result.Err)
|
||||
return lookupResult{rcode: result.Rcode, hasExternalData: true}
|
||||
}
|
||||
|
||||
return lookupResult{
|
||||
records: resutil.IPsToRRs(name, result.IPs, 60),
|
||||
rcode: dns.RcodeSuccess,
|
||||
hasExternalData: true,
|
||||
}
|
||||
}
|
||||
|
||||
// logDNSError logs DNS resolution errors for debugging.
|
||||
func (d *Resolver) logDNSError(logger *log.Entry, hostname string, qtype uint16, err error) {
|
||||
qtypeName := dns.TypeToString[qtype]
|
||||
|
||||
var dnsErr *net.DNSError
|
||||
if !errors.As(err, &dnsErr) {
|
||||
logger.Debugf("DNS resolution failed for %s type %s: %v", hostname, qtypeName, err)
|
||||
return
|
||||
}
|
||||
|
||||
if dnsErr.IsNotFound {
|
||||
logger.Tracef("DNS target not found: %s type %s", hostname, qtypeName)
|
||||
return
|
||||
}
|
||||
|
||||
if dnsErr.Server != "" {
|
||||
logger.Debugf("DNS resolution failed for %s type %s server=%s: %v", hostname, qtypeName, dnsErr.Server, err)
|
||||
} else {
|
||||
logger.Debugf("DNS resolution failed for %s type %s: %v", hostname, qtypeName, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Update replaces all zones and their records
|
||||
func (d *Resolver) Update(customZones []nbdns.CustomZone) {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
maps.Clear(d.records)
|
||||
maps.Clear(d.domains)
|
||||
maps.Clear(d.zones)
|
||||
|
||||
for _, rec := range update {
|
||||
if err := d.registerRecord(rec); err != nil {
|
||||
log.Warnf("failed to register the record (%s): %v", rec, err)
|
||||
continue
|
||||
for _, zone := range customZones {
|
||||
zoneDomain := domain.Domain(strings.ToLower(dns.Fqdn(zone.Domain)))
|
||||
d.zones[zoneDomain] = zone.NonAuthoritative
|
||||
|
||||
for _, rec := range zone.Records {
|
||||
if err := d.registerRecord(rec); err != nil {
|
||||
log.Warnf("failed to register the record (%s): %v", rec, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,14 @@
|
||||
package local
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -12,6 +18,18 @@ import (
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
|
||||
// mockResolver implements resolver for testing
|
||||
type mockResolver struct {
|
||||
lookupFunc func(ctx context.Context, network, host string) ([]netip.Addr, error)
|
||||
}
|
||||
|
||||
func (m *mockResolver) LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) {
|
||||
if m.lookupFunc != nil {
|
||||
return m.lookupFunc(ctx, network, host)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func TestLocalResolver_ServeDNS(t *testing.T) {
|
||||
recordA := nbdns.SimpleRecord{
|
||||
Name: "peera.netbird.cloud.",
|
||||
@@ -106,11 +124,11 @@ func TestLocalResolver_Update_StaleRecord(t *testing.T) {
|
||||
|
||||
resolver := NewResolver()
|
||||
|
||||
update1 := []nbdns.SimpleRecord{record1}
|
||||
update2 := []nbdns.SimpleRecord{record2}
|
||||
zone1 := []nbdns.CustomZone{{Domain: "example.com.", Records: []nbdns.SimpleRecord{record1}}}
|
||||
zone2 := []nbdns.CustomZone{{Domain: "example.com.", Records: []nbdns.SimpleRecord{record2}}}
|
||||
|
||||
// Apply first update
|
||||
resolver.Update(update1)
|
||||
resolver.Update(zone1)
|
||||
|
||||
// Verify first update
|
||||
resolver.mu.RLock()
|
||||
@@ -122,7 +140,7 @@ func TestLocalResolver_Update_StaleRecord(t *testing.T) {
|
||||
assert.Contains(t, rrSlice1[0].String(), record1.RData, "Record after first update should be %s", record1.RData)
|
||||
|
||||
// Apply second update
|
||||
resolver.Update(update2)
|
||||
resolver.Update(zone2)
|
||||
|
||||
// Verify second update
|
||||
resolver.mu.RLock()
|
||||
@@ -151,10 +169,10 @@ func TestLocalResolver_MultipleRecords_SameQuestion(t *testing.T) {
|
||||
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2",
|
||||
}
|
||||
|
||||
update := []nbdns.SimpleRecord{record1, record2}
|
||||
zones := []nbdns.CustomZone{{Domain: "example.com.", Records: []nbdns.SimpleRecord{record1, record2}}}
|
||||
|
||||
// Apply update with both records
|
||||
resolver.Update(update)
|
||||
resolver.Update(zones)
|
||||
|
||||
// Create question that matches both records
|
||||
question := dns.Question{
|
||||
@@ -195,10 +213,10 @@ func TestLocalResolver_RecordRotation(t *testing.T) {
|
||||
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.3",
|
||||
}
|
||||
|
||||
update := []nbdns.SimpleRecord{record1, record2, record3}
|
||||
zones := []nbdns.CustomZone{{Domain: "example.com.", Records: []nbdns.SimpleRecord{record1, record2, record3}}}
|
||||
|
||||
// Apply update with all three records
|
||||
resolver.Update(update)
|
||||
resolver.Update(zones)
|
||||
|
||||
msg := new(dns.Msg).SetQuestion(recordName, recordType)
|
||||
|
||||
@@ -264,7 +282,7 @@ func TestLocalResolver_CaseInsensitiveMatching(t *testing.T) {
|
||||
}
|
||||
|
||||
// Update resolver with the records
|
||||
resolver.Update([]nbdns.SimpleRecord{lowerCaseRecord, mixedCaseRecord})
|
||||
resolver.Update([]nbdns.CustomZone{{Domain: "example.com.", Records: []nbdns.SimpleRecord{lowerCaseRecord, mixedCaseRecord}}})
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
@@ -379,7 +397,7 @@ func TestLocalResolver_CNAMEFallback(t *testing.T) {
|
||||
}
|
||||
|
||||
// Update resolver with both records
|
||||
resolver.Update([]nbdns.SimpleRecord{cnameRecord, targetRecord})
|
||||
resolver.Update([]nbdns.CustomZone{{Domain: "example.com.", Records: []nbdns.SimpleRecord{cnameRecord, targetRecord}}})
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
@@ -476,6 +494,20 @@ func TestLocalResolver_CNAMEFallback(t *testing.T) {
|
||||
// with 0 records instead of NXDOMAIN
|
||||
func TestLocalResolver_NoErrorWithDifferentRecordType(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
// Mock external resolver for CNAME target resolution
|
||||
resolver.resolver = &mockResolver{
|
||||
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
|
||||
if host == "target.example.com." {
|
||||
if network == "ip4" {
|
||||
return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil
|
||||
}
|
||||
if network == "ip6" {
|
||||
return []netip.Addr{netip.MustParseAddr("2606:2800:220:1:248:1893:25c8:1946")}, nil
|
||||
}
|
||||
}
|
||||
return nil, &net.DNSError{IsNotFound: true, Name: host}
|
||||
},
|
||||
}
|
||||
|
||||
recordA := nbdns.SimpleRecord{
|
||||
Name: "example.netbird.cloud.",
|
||||
@@ -493,7 +525,7 @@ func TestLocalResolver_NoErrorWithDifferentRecordType(t *testing.T) {
|
||||
RData: "target.example.com.",
|
||||
}
|
||||
|
||||
resolver.Update([]nbdns.SimpleRecord{recordA, recordCNAME})
|
||||
resolver.Update([]nbdns.CustomZone{{Domain: "netbird.cloud.", Records: []nbdns.SimpleRecord{recordA, recordCNAME}}})
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
@@ -582,3 +614,808 @@ func TestLocalResolver_NoErrorWithDifferentRecordType(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestLocalResolver_CNAMEChainResolution tests comprehensive CNAME chain following
|
||||
func TestLocalResolver_CNAMEChainResolution(t *testing.T) {
|
||||
t.Run("simple internal CNAME chain", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "example.com.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "alias.example.com.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "target.example.com."},
|
||||
{Name: "target.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.1"},
|
||||
},
|
||||
}})
|
||||
|
||||
msg := new(dns.Msg).SetQuestion("alias.example.com.", dns.TypeA)
|
||||
var resp *dns.Msg
|
||||
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||
require.Len(t, resp.Answer, 2)
|
||||
|
||||
cname, ok := resp.Answer[0].(*dns.CNAME)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "target.example.com.", cname.Target)
|
||||
|
||||
a, ok := resp.Answer[1].(*dns.A)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "192.168.1.1", a.A.String())
|
||||
})
|
||||
|
||||
t.Run("multi-hop CNAME chain", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "test.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "hop1.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "hop2.test."},
|
||||
{Name: "hop2.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "hop3.test."},
|
||||
{Name: "hop3.test.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
|
||||
},
|
||||
}})
|
||||
|
||||
msg := new(dns.Msg).SetQuestion("hop1.test.", dns.TypeA)
|
||||
var resp *dns.Msg
|
||||
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||
require.Len(t, resp.Answer, 3)
|
||||
})
|
||||
|
||||
t.Run("CNAME to non-existent internal target returns only CNAME", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "test.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "nonexistent.test."},
|
||||
},
|
||||
}})
|
||||
|
||||
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA)
|
||||
var resp *dns.Msg
|
||||
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||
|
||||
require.NotNil(t, resp)
|
||||
require.Len(t, resp.Answer, 1)
|
||||
_, ok := resp.Answer[0].(*dns.CNAME)
|
||||
assert.True(t, ok)
|
||||
})
|
||||
}
|
||||
|
||||
// TestLocalResolver_CNAMEMaxDepth tests the maximum depth limit for CNAME chains
|
||||
func TestLocalResolver_CNAMEMaxDepth(t *testing.T) {
|
||||
t.Run("chain at max depth resolves", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
var records []nbdns.SimpleRecord
|
||||
// Create chain of 7 CNAMEs (under max of 8)
|
||||
for i := 1; i <= 7; i++ {
|
||||
records = append(records, nbdns.SimpleRecord{
|
||||
Name: fmt.Sprintf("hop%d.test.", i),
|
||||
Type: int(dns.TypeCNAME),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: fmt.Sprintf("hop%d.test.", i+1),
|
||||
})
|
||||
}
|
||||
records = append(records, nbdns.SimpleRecord{
|
||||
Name: "hop8.test.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.10.10.10",
|
||||
})
|
||||
|
||||
resolver.Update([]nbdns.CustomZone{{Domain: "test.", Records: records}})
|
||||
|
||||
msg := new(dns.Msg).SetQuestion("hop1.test.", dns.TypeA)
|
||||
var resp *dns.Msg
|
||||
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||
require.Len(t, resp.Answer, 8)
|
||||
})
|
||||
|
||||
t.Run("chain exceeding max depth stops", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
var records []nbdns.SimpleRecord
|
||||
// Create chain of 10 CNAMEs (exceeds max of 8)
|
||||
for i := 1; i <= 10; i++ {
|
||||
records = append(records, nbdns.SimpleRecord{
|
||||
Name: fmt.Sprintf("deep%d.test.", i),
|
||||
Type: int(dns.TypeCNAME),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: fmt.Sprintf("deep%d.test.", i+1),
|
||||
})
|
||||
}
|
||||
records = append(records, nbdns.SimpleRecord{
|
||||
Name: "deep11.test.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.10.10.10",
|
||||
})
|
||||
|
||||
resolver.Update([]nbdns.CustomZone{{Domain: "test.", Records: records}})
|
||||
|
||||
msg := new(dns.Msg).SetQuestion("deep1.test.", dns.TypeA)
|
||||
var resp *dns.Msg
|
||||
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||
|
||||
require.NotNil(t, resp)
|
||||
// Should NOT have the final A record (chain too deep)
|
||||
assert.LessOrEqual(t, len(resp.Answer), 8)
|
||||
})
|
||||
|
||||
t.Run("circular CNAME is protected by max depth", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "test.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "loop1.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "loop2.test."},
|
||||
{Name: "loop2.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "loop1.test."},
|
||||
},
|
||||
}})
|
||||
|
||||
msg := new(dns.Msg).SetQuestion("loop1.test.", dns.TypeA)
|
||||
var resp *dns.Msg
|
||||
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||
|
||||
require.NotNil(t, resp)
|
||||
assert.LessOrEqual(t, len(resp.Answer), 8)
|
||||
})
|
||||
}
|
||||
|
||||
// TestLocalResolver_ExternalCNAMEResolution tests CNAME resolution to external domains
|
||||
func TestLocalResolver_ExternalCNAMEResolution(t *testing.T) {
|
||||
t.Run("CNAME to external domain resolves via external resolver", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.resolver = &mockResolver{
|
||||
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
|
||||
if host == "external.example.com." && network == "ip4" {
|
||||
return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil
|
||||
}
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "test.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
|
||||
},
|
||||
}})
|
||||
|
||||
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA)
|
||||
var resp *dns.Msg
|
||||
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||
|
||||
require.NotNil(t, resp)
|
||||
require.Len(t, resp.Answer, 2, "Should have CNAME + A record")
|
||||
|
||||
cname, ok := resp.Answer[0].(*dns.CNAME)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "external.example.com.", cname.Target)
|
||||
|
||||
a, ok := resp.Answer[1].(*dns.A)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "93.184.216.34", a.A.String())
|
||||
})
|
||||
|
||||
t.Run("CNAME to external domain resolves IPv6", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.resolver = &mockResolver{
|
||||
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
|
||||
if host == "external.example.com." && network == "ip6" {
|
||||
return []netip.Addr{netip.MustParseAddr("2606:2800:220:1:248:1893:25c8:1946")}, nil
|
||||
}
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "test.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
|
||||
},
|
||||
}})
|
||||
|
||||
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeAAAA)
|
||||
var resp *dns.Msg
|
||||
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||
|
||||
require.NotNil(t, resp)
|
||||
require.Len(t, resp.Answer, 2, "Should have CNAME + AAAA record")
|
||||
|
||||
cname, ok := resp.Answer[0].(*dns.CNAME)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "external.example.com.", cname.Target)
|
||||
|
||||
aaaa, ok := resp.Answer[1].(*dns.AAAA)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "2606:2800:220:1:248:1893:25c8:1946", aaaa.AAAA.String())
|
||||
})
|
||||
|
||||
t.Run("concurrent external resolution", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.resolver = &mockResolver{
|
||||
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
|
||||
if host == "external.example.com." && network == "ip4" {
|
||||
return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil
|
||||
}
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "test.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "concurrent.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
|
||||
},
|
||||
}})
|
||||
|
||||
var wg sync.WaitGroup
|
||||
results := make([]*dns.Msg, 10)
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
msg := new(dns.Msg).SetQuestion("concurrent.test.", dns.TypeA)
|
||||
var resp *dns.Msg
|
||||
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||
results[idx] = resp
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
for i, resp := range results {
|
||||
require.NotNil(t, resp, "Response %d should not be nil", i)
|
||||
require.Len(t, resp.Answer, 2, "Response %d should have CNAME + A", i)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestLocalResolver_ZoneManagement tests zone-aware CNAME resolution
|
||||
func TestLocalResolver_ZoneManagement(t *testing.T) {
|
||||
t.Run("Update sets zones correctly", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
|
||||
resolver.Update([]nbdns.CustomZone{
|
||||
{Domain: "example.com.", Records: []nbdns.SimpleRecord{
|
||||
{Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
|
||||
}},
|
||||
{Domain: "test.local."},
|
||||
})
|
||||
|
||||
assert.True(t, resolver.isInManagedZone("host.example.com."))
|
||||
assert.True(t, resolver.isInManagedZone("other.example.com."))
|
||||
assert.True(t, resolver.isInManagedZone("sub.test.local."))
|
||||
assert.False(t, resolver.isInManagedZone("external.com."))
|
||||
})
|
||||
|
||||
t.Run("isInManagedZone case insensitive", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.Update([]nbdns.CustomZone{{Domain: "Example.COM."}})
|
||||
|
||||
assert.True(t, resolver.isInManagedZone("host.example.com."))
|
||||
assert.True(t, resolver.isInManagedZone("HOST.EXAMPLE.COM."))
|
||||
})
|
||||
|
||||
t.Run("Update clears zones", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.Update([]nbdns.CustomZone{{Domain: "example.com."}})
|
||||
assert.True(t, resolver.isInManagedZone("host.example.com."))
|
||||
|
||||
resolver.Update(nil)
|
||||
assert.False(t, resolver.isInManagedZone("host.example.com."))
|
||||
})
|
||||
}
|
||||
|
||||
// TestLocalResolver_CNAMEZoneAwareResolution tests CNAME resolution with zone awareness
|
||||
func TestLocalResolver_CNAMEZoneAwareResolution(t *testing.T) {
|
||||
t.Run("CNAME target in managed zone returns NXDOMAIN per RFC 6604", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "myzone.test.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "alias.myzone.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "nonexistent.myzone.test."},
|
||||
},
|
||||
}})
|
||||
|
||||
msg := new(dns.Msg).SetQuestion("alias.myzone.test.", dns.TypeA)
|
||||
var resp *dns.Msg
|
||||
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||
|
||||
require.NotNil(t, resp)
|
||||
assert.Equal(t, dns.RcodeNameError, resp.Rcode, "Should return NXDOMAIN")
|
||||
require.Len(t, resp.Answer, 1, "Should include CNAME in answer")
|
||||
})
|
||||
|
||||
t.Run("CNAME to external domain skips zone check", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.resolver = &mockResolver{
|
||||
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
|
||||
if host == "external.other.com." && network == "ip4" {
|
||||
return []netip.Addr{netip.MustParseAddr("203.0.113.1")}, nil
|
||||
}
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "myzone.test.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "alias.myzone.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.other.com."},
|
||||
},
|
||||
}})
|
||||
|
||||
msg := new(dns.Msg).SetQuestion("alias.myzone.test.", dns.TypeA)
|
||||
var resp *dns.Msg
|
||||
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||
|
||||
require.NotNil(t, resp)
|
||||
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||
require.Len(t, resp.Answer, 2, "Should have CNAME + A from external resolution")
|
||||
})
|
||||
|
||||
t.Run("CNAME target exists with different type returns NODATA not NXDOMAIN", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
// CNAME points to target that has A but no AAAA - query for AAAA should be NODATA
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "myzone.test.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "alias.myzone.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "target.myzone.test."},
|
||||
{Name: "target.myzone.test.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "1.1.1.1"},
|
||||
},
|
||||
}})
|
||||
|
||||
msg := new(dns.Msg).SetQuestion("alias.myzone.test.", dns.TypeAAAA)
|
||||
var resp *dns.Msg
|
||||
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||
|
||||
require.NotNil(t, resp)
|
||||
assert.Equal(t, dns.RcodeSuccess, resp.Rcode, "Should return NODATA (success), not NXDOMAIN")
|
||||
require.Len(t, resp.Answer, 1, "Should have only CNAME, no AAAA")
|
||||
_, ok := resp.Answer[0].(*dns.CNAME)
|
||||
assert.True(t, ok, "Answer should be CNAME record")
|
||||
})
|
||||
|
||||
t.Run("external CNAME target exists but no AAAA records (NODATA)", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.resolver = &mockResolver{
|
||||
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
|
||||
if host == "external.example.com." {
|
||||
if network == "ip6" {
|
||||
// No AAAA records
|
||||
return nil, &net.DNSError{IsNotFound: true, Name: host}
|
||||
}
|
||||
if network == "ip4" {
|
||||
// But A records exist - domain exists
|
||||
return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil
|
||||
}
|
||||
}
|
||||
return nil, &net.DNSError{IsNotFound: true, Name: host}
|
||||
},
|
||||
}
|
||||
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "test.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
|
||||
},
|
||||
}})
|
||||
|
||||
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeAAAA)
|
||||
var resp *dns.Msg
|
||||
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||
|
||||
require.NotNil(t, resp)
|
||||
assert.Equal(t, dns.RcodeSuccess, resp.Rcode, "Should return NODATA (success), not NXDOMAIN")
|
||||
require.Len(t, resp.Answer, 1, "Should have only CNAME")
|
||||
_, ok := resp.Answer[0].(*dns.CNAME)
|
||||
assert.True(t, ok, "Answer should be CNAME record")
|
||||
})
|
||||
|
||||
// Table-driven test for all external resolution outcomes
|
||||
externalCases := []struct {
|
||||
name string
|
||||
lookupFunc func(context.Context, string, string) ([]netip.Addr, error)
|
||||
expectedRcode int
|
||||
expectedAnswer int
|
||||
}{
|
||||
{
|
||||
name: "external NXDOMAIN (both A and AAAA not found)",
|
||||
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
|
||||
return nil, &net.DNSError{IsNotFound: true, Name: host}
|
||||
},
|
||||
expectedRcode: dns.RcodeNameError,
|
||||
expectedAnswer: 1, // CNAME only
|
||||
},
|
||||
{
|
||||
name: "external SERVFAIL (temporary error)",
|
||||
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
|
||||
return nil, &net.DNSError{IsTemporary: true, Name: host}
|
||||
},
|
||||
expectedRcode: dns.RcodeServerFailure,
|
||||
expectedAnswer: 1, // CNAME only
|
||||
},
|
||||
{
|
||||
name: "external SERVFAIL (timeout)",
|
||||
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
|
||||
return nil, &net.DNSError{IsTimeout: true, Name: host}
|
||||
},
|
||||
expectedRcode: dns.RcodeServerFailure,
|
||||
expectedAnswer: 1, // CNAME only
|
||||
},
|
||||
{
|
||||
name: "external SERVFAIL (generic error)",
|
||||
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
|
||||
return nil, fmt.Errorf("connection refused")
|
||||
},
|
||||
expectedRcode: dns.RcodeServerFailure,
|
||||
expectedAnswer: 1, // CNAME only
|
||||
},
|
||||
{
|
||||
name: "external success with IPs",
|
||||
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
|
||||
if network == "ip4" {
|
||||
return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil
|
||||
}
|
||||
return nil, &net.DNSError{IsNotFound: true, Name: host}
|
||||
},
|
||||
expectedRcode: dns.RcodeSuccess,
|
||||
expectedAnswer: 2, // CNAME + A
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range externalCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.resolver = &mockResolver{lookupFunc: tc.lookupFunc}
|
||||
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "test.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
|
||||
},
|
||||
}})
|
||||
|
||||
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA)
|
||||
var resp *dns.Msg
|
||||
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||
|
||||
require.NotNil(t, resp)
|
||||
assert.Equal(t, tc.expectedRcode, resp.Rcode, "rcode mismatch")
|
||||
assert.Len(t, resp.Answer, tc.expectedAnswer, "answer count mismatch")
|
||||
if tc.expectedAnswer > 0 {
|
||||
_, ok := resp.Answer[0].(*dns.CNAME)
|
||||
assert.True(t, ok, "first answer should be CNAME")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestLocalResolver_Fallthrough verifies that non-authoritative zones
|
||||
// trigger fallthrough (Zero bit set) when no records match
|
||||
func TestLocalResolver_Fallthrough(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
|
||||
record := nbdns.SimpleRecord{
|
||||
Name: "existing.custom.zone.",
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: "10.0.0.1",
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
zones []nbdns.CustomZone
|
||||
queryName string
|
||||
expectFallthrough bool
|
||||
expectRecord bool
|
||||
}{
|
||||
{
|
||||
name: "Authoritative zone returns NXDOMAIN without fallthrough",
|
||||
zones: []nbdns.CustomZone{{
|
||||
Domain: "custom.zone.",
|
||||
Records: []nbdns.SimpleRecord{record},
|
||||
}},
|
||||
queryName: "nonexistent.custom.zone.",
|
||||
expectFallthrough: false,
|
||||
expectRecord: false,
|
||||
},
|
||||
{
|
||||
name: "Non-authoritative zone triggers fallthrough",
|
||||
zones: []nbdns.CustomZone{{
|
||||
Domain: "custom.zone.",
|
||||
Records: []nbdns.SimpleRecord{record},
|
||||
NonAuthoritative: true,
|
||||
}},
|
||||
queryName: "nonexistent.custom.zone.",
|
||||
expectFallthrough: true,
|
||||
expectRecord: false,
|
||||
},
|
||||
{
|
||||
name: "Record found in non-authoritative zone returns normally",
|
||||
zones: []nbdns.CustomZone{{
|
||||
Domain: "custom.zone.",
|
||||
Records: []nbdns.SimpleRecord{record},
|
||||
NonAuthoritative: true,
|
||||
}},
|
||||
queryName: "existing.custom.zone.",
|
||||
expectFallthrough: false,
|
||||
expectRecord: true,
|
||||
},
|
||||
{
|
||||
name: "Record found in authoritative zone returns normally",
|
||||
zones: []nbdns.CustomZone{{
|
||||
Domain: "custom.zone.",
|
||||
Records: []nbdns.SimpleRecord{record},
|
||||
}},
|
||||
queryName: "existing.custom.zone.",
|
||||
expectFallthrough: false,
|
||||
expectRecord: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
resolver.Update(tc.zones)
|
||||
|
||||
var responseMSG *dns.Msg
|
||||
responseWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
responseMSG = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
msg := new(dns.Msg).SetQuestion(tc.queryName, dns.TypeA)
|
||||
resolver.ServeDNS(responseWriter, msg)
|
||||
|
||||
require.NotNil(t, responseMSG, "Should have received a response")
|
||||
|
||||
if tc.expectFallthrough {
|
||||
assert.True(t, responseMSG.MsgHdr.Zero, "Zero bit should be set for fallthrough")
|
||||
assert.Equal(t, dns.RcodeNameError, responseMSG.Rcode, "Should return NXDOMAIN")
|
||||
} else {
|
||||
assert.False(t, responseMSG.MsgHdr.Zero, "Zero bit should not be set")
|
||||
}
|
||||
|
||||
if tc.expectRecord {
|
||||
assert.Greater(t, len(responseMSG.Answer), 0, "Should have answer records")
|
||||
assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestLocalResolver_AuthoritativeFlag tests the AA flag behavior
|
||||
func TestLocalResolver_AuthoritativeFlag(t *testing.T) {
|
||||
t.Run("direct record lookup is authoritative", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "example.com.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
|
||||
},
|
||||
}})
|
||||
|
||||
msg := new(dns.Msg).SetQuestion("host.example.com.", dns.TypeA)
|
||||
var resp *dns.Msg
|
||||
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||
|
||||
require.NotNil(t, resp)
|
||||
assert.True(t, resp.Authoritative)
|
||||
})
|
||||
|
||||
t.Run("external resolution is not authoritative", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.resolver = &mockResolver{
|
||||
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
|
||||
if host == "external.example.com." && network == "ip4" {
|
||||
return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil
|
||||
}
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "test.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
|
||||
},
|
||||
}})
|
||||
|
||||
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA)
|
||||
var resp *dns.Msg
|
||||
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||
|
||||
require.NotNil(t, resp)
|
||||
require.Len(t, resp.Answer, 2)
|
||||
assert.False(t, resp.Authoritative)
|
||||
})
|
||||
}
|
||||
|
||||
// TestLocalResolver_Stop tests cleanup on Stop
|
||||
func TestLocalResolver_Stop(t *testing.T) {
|
||||
t.Run("Stop clears all state", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "example.com.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
|
||||
},
|
||||
}})
|
||||
|
||||
resolver.Stop()
|
||||
|
||||
msg := new(dns.Msg).SetQuestion("host.example.com.", dns.TypeA)
|
||||
var resp *dns.Msg
|
||||
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||
|
||||
require.NotNil(t, resp)
|
||||
assert.Len(t, resp.Answer, 0)
|
||||
assert.False(t, resolver.isInManagedZone("host.example.com."))
|
||||
})
|
||||
|
||||
t.Run("Stop is safe to call multiple times", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "example.com.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
|
||||
},
|
||||
}})
|
||||
|
||||
resolver.Stop()
|
||||
resolver.Stop()
|
||||
resolver.Stop()
|
||||
})
|
||||
|
||||
t.Run("Stop cancels in-flight external resolution", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
|
||||
lookupStarted := make(chan struct{})
|
||||
lookupCtxCanceled := make(chan struct{})
|
||||
|
||||
resolver.resolver = &mockResolver{
|
||||
lookupFunc: func(ctx context.Context, network, host string) ([]netip.Addr, error) {
|
||||
close(lookupStarted)
|
||||
<-ctx.Done()
|
||||
close(lookupCtxCanceled)
|
||||
return nil, ctx.Err()
|
||||
},
|
||||
}
|
||||
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "test.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
|
||||
},
|
||||
}})
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA)
|
||||
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { return nil }}, msg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
<-lookupStarted
|
||||
resolver.Stop()
|
||||
|
||||
select {
|
||||
case <-lookupCtxCanceled:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("external lookup context was not canceled")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("ServeDNS did not return after Stop")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestLocalResolver_FallthroughCaseInsensitive verifies case-insensitive domain matching for fallthrough
|
||||
func TestLocalResolver_FallthroughCaseInsensitive(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "EXAMPLE.COM.",
|
||||
Records: []nbdns.SimpleRecord{{Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "1.2.3.4"}},
|
||||
NonAuthoritative: true,
|
||||
}})
|
||||
|
||||
var responseMSG *dns.Msg
|
||||
responseWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
responseMSG = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
msg := new(dns.Msg).SetQuestion("nonexistent.example.com.", dns.TypeA)
|
||||
resolver.ServeDNS(responseWriter, msg)
|
||||
|
||||
require.NotNil(t, responseMSG)
|
||||
assert.True(t, responseMSG.MsgHdr.Zero, "Should fallthrough for non-authoritative zone with case-insensitive match")
|
||||
}
|
||||
|
||||
// BenchmarkFindZone_BestCase benchmarks zone lookup with immediate match (first label)
|
||||
func BenchmarkFindZone_BestCase(b *testing.B) {
|
||||
resolver := NewResolver()
|
||||
|
||||
// Single zone that matches immediately
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "example.com.",
|
||||
NonAuthoritative: true,
|
||||
}})
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
resolver.shouldFallthrough("example.com.")
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkFindZone_WorstCase benchmarks zone lookup with many zones, no match, many labels
|
||||
func BenchmarkFindZone_WorstCase(b *testing.B) {
|
||||
resolver := NewResolver()
|
||||
|
||||
// 100 zones that won't match
|
||||
var zones []nbdns.CustomZone
|
||||
for i := 0; i < 100; i++ {
|
||||
zones = append(zones, nbdns.CustomZone{
|
||||
Domain: fmt.Sprintf("zone%d.internal.", i),
|
||||
NonAuthoritative: true,
|
||||
})
|
||||
}
|
||||
resolver.Update(zones)
|
||||
|
||||
// Query with many labels that won't match any zone
|
||||
qname := "a.b.c.d.e.f.g.h.external.com."
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
resolver.shouldFallthrough(qname)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkFindZone_TypicalCase benchmarks typical usage: few zones, subdomain match
|
||||
func BenchmarkFindZone_TypicalCase(b *testing.B) {
|
||||
resolver := NewResolver()
|
||||
|
||||
// Typical setup: peer zone (authoritative) + one user zone (non-authoritative)
|
||||
resolver.Update([]nbdns.CustomZone{
|
||||
{Domain: "netbird.cloud.", NonAuthoritative: false},
|
||||
{Domain: "custom.local.", NonAuthoritative: true},
|
||||
})
|
||||
|
||||
// Query for subdomain of user zone
|
||||
qname := "myhost.custom.local."
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
resolver.shouldFallthrough(qname)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkIsInManagedZone_ManyZones benchmarks isInManagedZone with 100 zones
|
||||
func BenchmarkIsInManagedZone_ManyZones(b *testing.B) {
|
||||
resolver := NewResolver()
|
||||
|
||||
var zones []nbdns.CustomZone
|
||||
for i := 0; i < 100; i++ {
|
||||
zones = append(zones, nbdns.CustomZone{
|
||||
Domain: fmt.Sprintf("zone%d.internal.", i),
|
||||
})
|
||||
}
|
||||
resolver.Update(zones)
|
||||
|
||||
// Query that matches zone50
|
||||
qname := "host.zone50.internal."
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
resolver.isInManagedZone(qname)
|
||||
}
|
||||
}
|
||||
|
||||
197
client/internal/dns/resutil/resolve.go
Normal file
197
client/internal/dns/resutil/resolve.go
Normal file
@@ -0,0 +1,197 @@
|
||||
// Package resutil provides shared DNS resolution utilities
|
||||
package resutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// GenerateRequestID creates a random 8-character hex string for request tracing.
|
||||
func GenerateRequestID() string {
|
||||
bytes := make([]byte, 4)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
log.Errorf("generate request ID: %v", err)
|
||||
return ""
|
||||
}
|
||||
return hex.EncodeToString(bytes)
|
||||
}
|
||||
|
||||
// IPsToRRs converts a slice of IP addresses to DNS resource records.
|
||||
// IPv4 addresses become A records, IPv6 addresses become AAAA records.
|
||||
func IPsToRRs(name string, ips []netip.Addr, ttl uint32) []dns.RR {
|
||||
var result []dns.RR
|
||||
|
||||
for _, ip := range ips {
|
||||
if ip.Is6() {
|
||||
result = append(result, &dns.AAAA{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: name,
|
||||
Rrtype: dns.TypeAAAA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: ttl,
|
||||
},
|
||||
AAAA: ip.AsSlice(),
|
||||
})
|
||||
} else {
|
||||
result = append(result, &dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: name,
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: ttl,
|
||||
},
|
||||
A: ip.AsSlice(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// NetworkForQtype returns the network string ("ip4" or "ip6") for a DNS query type.
|
||||
// Returns empty string for unsupported types.
|
||||
func NetworkForQtype(qtype uint16) string {
|
||||
switch qtype {
|
||||
case dns.TypeA:
|
||||
return "ip4"
|
||||
case dns.TypeAAAA:
|
||||
return "ip6"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
type resolver interface {
|
||||
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
|
||||
}
|
||||
|
||||
// chainedWriter is implemented by ResponseWriters that carry request metadata
|
||||
type chainedWriter interface {
|
||||
RequestID() string
|
||||
SetMeta(key, value string)
|
||||
}
|
||||
|
||||
// GetRequestID extracts a request ID from the ResponseWriter if available,
|
||||
// otherwise generates a new one.
|
||||
func GetRequestID(w dns.ResponseWriter) string {
|
||||
if cw, ok := w.(chainedWriter); ok {
|
||||
if id := cw.RequestID(); id != "" {
|
||||
return id
|
||||
}
|
||||
}
|
||||
return GenerateRequestID()
|
||||
}
|
||||
|
||||
// SetMeta sets metadata on the ResponseWriter if it supports it.
|
||||
func SetMeta(w dns.ResponseWriter, key, value string) {
|
||||
if cw, ok := w.(chainedWriter); ok {
|
||||
cw.SetMeta(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
// LookupResult contains the result of an external DNS lookup
|
||||
type LookupResult struct {
|
||||
IPs []netip.Addr
|
||||
Rcode int
|
||||
Err error // Original error for caller's logging needs
|
||||
}
|
||||
|
||||
// LookupIP performs a DNS lookup and determines the appropriate rcode.
|
||||
func LookupIP(ctx context.Context, r resolver, network, host string, qtype uint16) LookupResult {
|
||||
ips, err := r.LookupNetIP(ctx, network, host)
|
||||
if err != nil {
|
||||
return LookupResult{
|
||||
Rcode: getRcodeForError(ctx, r, host, qtype, err),
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// Unmap IPv4-mapped IPv6 addresses that some resolvers may return
|
||||
for i, ip := range ips {
|
||||
ips[i] = ip.Unmap()
|
||||
}
|
||||
|
||||
return LookupResult{
|
||||
IPs: ips,
|
||||
Rcode: dns.RcodeSuccess,
|
||||
}
|
||||
}
|
||||
|
||||
func getRcodeForError(ctx context.Context, r resolver, host string, qtype uint16, err error) int {
|
||||
var dnsErr *net.DNSError
|
||||
if !errors.As(err, &dnsErr) {
|
||||
return dns.RcodeServerFailure
|
||||
}
|
||||
|
||||
if dnsErr.IsNotFound {
|
||||
return getRcodeForNotFound(ctx, r, host, qtype)
|
||||
}
|
||||
|
||||
return dns.RcodeServerFailure
|
||||
}
|
||||
|
||||
// getRcodeForNotFound distinguishes between NXDOMAIN (domain doesn't exist) and NODATA
|
||||
// (domain exists but no records of requested type) by checking the opposite record type.
|
||||
//
|
||||
// musl libc (the reason we need this distinction) only queries A/AAAA pairs in getaddrinfo,
|
||||
// so checking the opposite A/AAAA type is sufficient. Other record types (MX, TXT, etc.)
|
||||
// are not queried by musl and don't need this handling.
|
||||
func getRcodeForNotFound(ctx context.Context, r resolver, domain string, originalQtype uint16) int {
|
||||
// Try querying for a different record type to see if the domain exists
|
||||
// If the original query was for AAAA, try A. If it was for A, try AAAA.
|
||||
// This helps distinguish between NXDOMAIN and NODATA.
|
||||
var alternativeNetwork string
|
||||
switch originalQtype {
|
||||
case dns.TypeAAAA:
|
||||
alternativeNetwork = "ip4"
|
||||
case dns.TypeA:
|
||||
alternativeNetwork = "ip6"
|
||||
default:
|
||||
return dns.RcodeNameError
|
||||
}
|
||||
|
||||
if _, err := r.LookupNetIP(ctx, alternativeNetwork, domain); err != nil {
|
||||
var dnsErr *net.DNSError
|
||||
if errors.As(err, &dnsErr) && dnsErr.IsNotFound {
|
||||
// Alternative query also returned not found - domain truly doesn't exist
|
||||
return dns.RcodeNameError
|
||||
}
|
||||
// Some other error (timeout, server failure, etc.) - can't determine, assume domain exists
|
||||
return dns.RcodeSuccess
|
||||
}
|
||||
|
||||
// Alternative query succeeded - domain exists but has no records of this type
|
||||
return dns.RcodeSuccess
|
||||
}
|
||||
|
||||
// FormatAnswers formats DNS resource records for logging.
|
||||
func FormatAnswers(answers []dns.RR) string {
|
||||
if len(answers) == 0 {
|
||||
return "[]"
|
||||
}
|
||||
|
||||
parts := make([]string, 0, len(answers))
|
||||
for _, rr := range answers {
|
||||
switch r := rr.(type) {
|
||||
case *dns.A:
|
||||
parts = append(parts, r.A.String())
|
||||
case *dns.AAAA:
|
||||
parts = append(parts, r.AAAA.String())
|
||||
case *dns.CNAME:
|
||||
parts = append(parts, "CNAME:"+r.Target)
|
||||
case *dns.PTR:
|
||||
parts = append(parts, "PTR:"+r.Ptr)
|
||||
default:
|
||||
parts = append(parts, dns.TypeToString[rr.Header().Rrtype])
|
||||
}
|
||||
}
|
||||
return "[" + strings.Join(parts, ", ") + "]"
|
||||
}
|
||||
@@ -485,7 +485,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
}
|
||||
}
|
||||
|
||||
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
|
||||
localMuxUpdates, localZones, err := s.buildLocalHandlerUpdate(update.CustomZones)
|
||||
if err != nil {
|
||||
return fmt.Errorf("local handler updater: %w", err)
|
||||
}
|
||||
@@ -498,8 +498,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
|
||||
s.updateMux(muxUpdates)
|
||||
|
||||
// register local records
|
||||
s.localResolver.Update(localRecords)
|
||||
s.localResolver.Update(localZones)
|
||||
|
||||
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
|
||||
|
||||
@@ -632,9 +631,7 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
|
||||
|
||||
handler, err := newUpstreamResolver(
|
||||
s.ctx,
|
||||
s.wgInterface.Name(),
|
||||
s.wgInterface.Address().IP,
|
||||
s.wgInterface.Address().Network,
|
||||
s.wgInterface,
|
||||
s.statusRecorder,
|
||||
s.hostsDNSHolder,
|
||||
nbdns.RootZone,
|
||||
@@ -659,9 +656,9 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
|
||||
s.registerHandler([]string{nbdns.RootZone}, handler, PriorityFallback)
|
||||
}
|
||||
|
||||
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, []nbdns.SimpleRecord, error) {
|
||||
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, []nbdns.CustomZone, error) {
|
||||
var muxUpdates []handlerWrapper
|
||||
var localRecords []nbdns.SimpleRecord
|
||||
var zones []nbdns.CustomZone
|
||||
|
||||
for _, customZone := range customZones {
|
||||
if len(customZone.Records) == 0 {
|
||||
@@ -675,17 +672,20 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
|
||||
priority: PriorityLocal,
|
||||
})
|
||||
|
||||
// zone records contain the fqdn, so we can just flatten them
|
||||
var localRecords []nbdns.SimpleRecord
|
||||
for _, record := range customZone.Records {
|
||||
if record.Class != nbdns.DefaultClass {
|
||||
log.Warnf("received an invalid class type: %s", record.Class)
|
||||
continue
|
||||
}
|
||||
// zone records contain the fqdn, so we can just flatten them
|
||||
localRecords = append(localRecords, record)
|
||||
}
|
||||
customZone.Records = localRecords
|
||||
zones = append(zones, customZone)
|
||||
}
|
||||
|
||||
return muxUpdates, localRecords, nil
|
||||
return muxUpdates, zones, nil
|
||||
}
|
||||
|
||||
func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]handlerWrapper, error) {
|
||||
@@ -741,9 +741,7 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
|
||||
log.Debugf("creating handler for domain=%s with priority=%d", domainGroup.domain, priority)
|
||||
handler, err := newUpstreamResolver(
|
||||
s.ctx,
|
||||
s.wgInterface.Name(),
|
||||
s.wgInterface.Address().IP,
|
||||
s.wgInterface.Address().Network,
|
||||
s.wgInterface,
|
||||
s.statusRecorder,
|
||||
s.hostsDNSHolder,
|
||||
domainGroup.domain,
|
||||
@@ -924,9 +922,7 @@ func (s *DefaultServer) addHostRootZone() {
|
||||
|
||||
handler, err := newUpstreamResolver(
|
||||
s.ctx,
|
||||
s.wgInterface.Name(),
|
||||
s.wgInterface.Address().IP,
|
||||
s.wgInterface.Address().Network,
|
||||
s.wgInterface,
|
||||
s.statusRecorder,
|
||||
s.hostsDNSHolder,
|
||||
nbdns.RootZone,
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||
@@ -81,6 +82,10 @@ func (w *mocWGIface) GetStats(_ string) (configurer.WGStats, error) {
|
||||
return configurer.WGStats{}, nil
|
||||
}
|
||||
|
||||
func (w *mocWGIface) GetNet() *netstack.Net {
|
||||
return nil
|
||||
}
|
||||
|
||||
var zoneRecords = []nbdns.SimpleRecord{
|
||||
{
|
||||
Name: "peera.netbird.cloud",
|
||||
@@ -128,7 +133,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
initUpstreamMap registeredHandlerMap
|
||||
initLocalRecords []nbdns.SimpleRecord
|
||||
initLocalZones []nbdns.CustomZone
|
||||
initSerial uint64
|
||||
inputSerial uint64
|
||||
inputUpdate nbdns.Config
|
||||
@@ -180,8 +185,8 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
expectedLocalQs: []dns.Question{{Name: "peera.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}},
|
||||
},
|
||||
{
|
||||
name: "New Config Should Succeed",
|
||||
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
|
||||
name: "New Config Should Succeed",
|
||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||
initUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||
domain: "netbird.cloud",
|
||||
@@ -221,19 +226,19 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}},
|
||||
},
|
||||
{
|
||||
name: "Smaller Config Serial Should Be Skipped",
|
||||
initLocalRecords: []nbdns.SimpleRecord{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 2,
|
||||
inputSerial: 1,
|
||||
shouldFail: true,
|
||||
name: "Smaller Config Serial Should Be Skipped",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 2,
|
||||
inputSerial: 1,
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
|
||||
initLocalRecords: []nbdns.SimpleRecord{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
@@ -251,11 +256,11 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid NS Group Nameservers list Should Fail",
|
||||
initLocalRecords: []nbdns.SimpleRecord{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
name: "Invalid NS Group Nameservers list Should Fail",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
@@ -273,11 +278,11 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid Custom Zone Records list Should Skip",
|
||||
initLocalRecords: []nbdns.SimpleRecord{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
name: "Invalid Custom Zone Records list Should Skip",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
@@ -299,8 +304,8 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "Empty Config Should Succeed and Clean Maps",
|
||||
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
|
||||
name: "Empty Config Should Succeed and Clean Maps",
|
||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||
initUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||
domain: zoneRecords[0].Name,
|
||||
@@ -315,8 +320,8 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
expectedLocalQs: []dns.Question{},
|
||||
},
|
||||
{
|
||||
name: "Disabled Service Should clean map",
|
||||
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
|
||||
name: "Disabled Service Should clean map",
|
||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||
initUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||
domain: zoneRecords[0].Name,
|
||||
@@ -385,7 +390,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
}()
|
||||
|
||||
dnsServer.dnsMuxMap = testCase.initUpstreamMap
|
||||
dnsServer.localResolver.Update(testCase.initLocalRecords)
|
||||
dnsServer.localResolver.Update(testCase.initLocalZones)
|
||||
dnsServer.updateSerial = testCase.initSerial
|
||||
|
||||
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
|
||||
@@ -510,8 +515,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
}
|
||||
//dnsServer.localResolver.RegisteredMap = local.RegistrationMap{local.BuildRecordKey("netbird.cloud", dns.ClassINET, dns.TypeA): struct{}{}}
|
||||
dnsServer.localResolver.Update([]nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}})
|
||||
dnsServer.localResolver.Update([]nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}})
|
||||
dnsServer.updateSerial = 0
|
||||
|
||||
nameServers := []nbdns.NameServer{
|
||||
@@ -2048,7 +2052,7 @@ func TestLocalResolverPriorityInServer(t *testing.T) {
|
||||
|
||||
func TestLocalResolverPriorityConstants(t *testing.T) {
|
||||
// Test that priority constants are ordered correctly
|
||||
assert.Greater(t, PriorityLocal, PriorityDNSRoute, "Local priority should be higher than DNS route")
|
||||
assert.Greater(t, PriorityDNSRoute, PriorityLocal, "DNS Route should be higher than Local priority")
|
||||
assert.Greater(t, PriorityLocal, PriorityUpstream, "Local priority should be higher than upstream")
|
||||
assert.Greater(t, PriorityUpstream, PriorityDefault, "Upstream priority should be higher than default")
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
@@ -19,8 +18,10 @@ import (
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/types"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
@@ -113,10 +114,7 @@ func (u *upstreamResolverBase) Stop() {
|
||||
|
||||
// ServeDNS handles a DNS request
|
||||
func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
requestID := GenerateRequestID()
|
||||
logger := log.WithField("request_id", requestID)
|
||||
|
||||
logger.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||
logger := log.WithField("request_id", resutil.GetRequestID(w))
|
||||
|
||||
u.prepareRequest(r)
|
||||
|
||||
@@ -202,11 +200,18 @@ func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.Add
|
||||
|
||||
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool {
|
||||
u.successCount.Add(1)
|
||||
logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, domain)
|
||||
|
||||
resutil.SetMeta(w, "upstream", upstream.String())
|
||||
|
||||
// Clear Zero bit from external responses to prevent upstream servers from
|
||||
// manipulating our internal fallthrough signaling mechanism
|
||||
rm.MsgHdr.Zero = false
|
||||
|
||||
if err := w.WriteMsg(rm); err != nil {
|
||||
logger.Errorf("failed to write DNS response for question domain=%s: %s", domain, err)
|
||||
return true
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -414,16 +419,56 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
|
||||
return rm, t, nil
|
||||
}
|
||||
|
||||
func GenerateRequestID() string {
|
||||
bytes := make([]byte, 4)
|
||||
_, err := rand.Read(bytes)
|
||||
// ExchangeWithNetstack performs a DNS exchange using netstack for dialing.
|
||||
// This is needed when netstack is enabled to reach peer IPs through the tunnel.
|
||||
func ExchangeWithNetstack(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream string) (*dns.Msg, error) {
|
||||
reply, err := netstackExchange(ctx, nsNet, r, upstream, "udp")
|
||||
if err != nil {
|
||||
log.Errorf("failed to generate request ID: %v", err)
|
||||
return ""
|
||||
return nil, err
|
||||
}
|
||||
return hex.EncodeToString(bytes)
|
||||
|
||||
// If response is truncated, retry with TCP
|
||||
if reply != nil && reply.MsgHdr.Truncated {
|
||||
log.Tracef("udp response for domain=%s type=%v class=%v is truncated, trying TCP",
|
||||
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||
return netstackExchange(ctx, nsNet, r, upstream, "tcp")
|
||||
}
|
||||
|
||||
return reply, nil
|
||||
}
|
||||
|
||||
func netstackExchange(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream, network string) (*dns.Msg, error) {
|
||||
conn, err := nsNet.DialContext(ctx, network, upstream)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("with %s: %w", network, err)
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Debugf("failed to close DNS connection: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
if err := conn.SetDeadline(deadline); err != nil {
|
||||
return nil, fmt.Errorf("set deadline: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
dnsConn := &dns.Conn{Conn: conn}
|
||||
|
||||
if err := dnsConn.WriteMsg(r); err != nil {
|
||||
return nil, fmt.Errorf("write %s message: %w", network, err)
|
||||
}
|
||||
|
||||
reply, err := dnsConn.ReadMsg()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read %s message: %w", network, err)
|
||||
}
|
||||
|
||||
return reply, nil
|
||||
}
|
||||
|
||||
|
||||
// FormatPeerStatus formats peer connection status information for debugging DNS timeouts
|
||||
func FormatPeerStatus(peerState *peer.State) string {
|
||||
isConnected := peerState.ConnStatus == peer.StatusConnected
|
||||
|
||||
@@ -23,9 +23,7 @@ type upstreamResolver struct {
|
||||
// first time, and we need to wait for a while to start to use again the proper DNS resolver.
|
||||
func newUpstreamResolver(
|
||||
ctx context.Context,
|
||||
_ string,
|
||||
_ netip.Addr,
|
||||
_ netip.Prefix,
|
||||
_ WGIface,
|
||||
statusRecorder *peer.Status,
|
||||
hostsDNSHolder *hostsDNSHolder,
|
||||
domain string,
|
||||
|
||||
@@ -5,22 +5,23 @@ package dns
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
)
|
||||
|
||||
type upstreamResolver struct {
|
||||
*upstreamResolverBase
|
||||
nsNet *netstack.Net
|
||||
}
|
||||
|
||||
func newUpstreamResolver(
|
||||
ctx context.Context,
|
||||
_ string,
|
||||
_ netip.Addr,
|
||||
_ netip.Prefix,
|
||||
wgIface WGIface,
|
||||
statusRecorder *peer.Status,
|
||||
_ *hostsDNSHolder,
|
||||
domain string,
|
||||
@@ -28,12 +29,23 @@ func newUpstreamResolver(
|
||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
|
||||
nonIOS := &upstreamResolver{
|
||||
upstreamResolverBase: upstreamResolverBase,
|
||||
nsNet: wgIface.GetNet(),
|
||||
}
|
||||
upstreamResolverBase.upstreamClient = nonIOS
|
||||
return nonIOS, nil
|
||||
}
|
||||
|
||||
func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
|
||||
// TODO: Check if upstream DNS server is routed through a peer before using netstack.
|
||||
// Similar to iOS logic, we should determine if the DNS server is reachable directly
|
||||
// or needs to go through the tunnel, and only use netstack when necessary.
|
||||
// For now, only use netstack on JS platform where direct access is not possible.
|
||||
if u.nsNet != nil && runtime.GOOS == "js" {
|
||||
start := time.Now()
|
||||
reply, err := ExchangeWithNetstack(ctx, u.nsNet, r, upstream)
|
||||
return reply, time.Since(start), err
|
||||
}
|
||||
|
||||
client := &dns.Client{
|
||||
Timeout: ClientTimeout,
|
||||
}
|
||||
|
||||
@@ -26,9 +26,7 @@ type upstreamResolverIOS struct {
|
||||
|
||||
func newUpstreamResolver(
|
||||
ctx context.Context,
|
||||
interfaceName string,
|
||||
ip netip.Addr,
|
||||
net netip.Prefix,
|
||||
wgIface WGIface,
|
||||
statusRecorder *peer.Status,
|
||||
_ *hostsDNSHolder,
|
||||
domain string,
|
||||
@@ -37,9 +35,9 @@ func newUpstreamResolver(
|
||||
|
||||
ios := &upstreamResolverIOS{
|
||||
upstreamResolverBase: upstreamResolverBase,
|
||||
lIP: ip,
|
||||
lNet: net,
|
||||
interfaceName: interfaceName,
|
||||
lIP: wgIface.Address().IP,
|
||||
lNet: wgIface.Address().Network,
|
||||
interfaceName: wgIface.Name(),
|
||||
}
|
||||
ios.upstreamClient = ios
|
||||
|
||||
|
||||
@@ -2,13 +2,17 @@ package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||
)
|
||||
|
||||
@@ -58,7 +62,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.TODO())
|
||||
resolver, _ := newUpstreamResolver(ctx, "", netip.Addr{}, netip.Prefix{}, nil, nil, ".")
|
||||
resolver, _ := newUpstreamResolver(ctx, &mockNetstackProvider{}, nil, nil, ".")
|
||||
// Convert test servers to netip.AddrPort
|
||||
var servers []netip.AddrPort
|
||||
for _, server := range testCase.InputServers {
|
||||
@@ -112,6 +116,19 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
type mockNetstackProvider struct{}
|
||||
|
||||
func (m *mockNetstackProvider) Name() string { return "mock" }
|
||||
func (m *mockNetstackProvider) Address() wgaddr.Address { return wgaddr.Address{} }
|
||||
func (m *mockNetstackProvider) ToInterface() *net.Interface { return nil }
|
||||
func (m *mockNetstackProvider) IsUserspaceBind() bool { return false }
|
||||
func (m *mockNetstackProvider) GetFilter() device.PacketFilter { return nil }
|
||||
func (m *mockNetstackProvider) GetDevice() *device.FilteredDevice { return nil }
|
||||
func (m *mockNetstackProvider) GetNet() *netstack.Net { return nil }
|
||||
func (m *mockNetstackProvider) GetInterfaceGUIDString() (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
type mockUpstreamResolver struct {
|
||||
r *dns.Msg
|
||||
rtt time.Duration
|
||||
|
||||
@@ -5,6 +5,8 @@ package dns
|
||||
import (
|
||||
"net"
|
||||
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
@@ -17,4 +19,5 @@ type WGIface interface {
|
||||
IsUserspaceBind() bool
|
||||
GetFilter() device.PacketFilter
|
||||
GetDevice() *device.FilteredDevice
|
||||
GetNet() *netstack.Net
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
@@ -12,5 +14,6 @@ type WGIface interface {
|
||||
IsUserspaceBind() bool
|
||||
GetFilter() device.PacketFilter
|
||||
GetDevice() *device.FilteredDevice
|
||||
GetNet() *netstack.Net
|
||||
GetInterfaceGUIDString() (string, error)
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
@@ -189,29 +190,22 @@ func (f *DNSForwarder) Close(ctx context.Context) error {
|
||||
return nberrors.FormatErrorOrNil(result)
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns.Msg {
|
||||
func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, query *dns.Msg) *dns.Msg {
|
||||
if len(query.Question) == 0 {
|
||||
return nil
|
||||
}
|
||||
question := query.Question[0]
|
||||
log.Tracef("received DNS request for DNS forwarder: domain=%v type=%v class=%v",
|
||||
question.Name, question.Qtype, question.Qclass)
|
||||
logger.Tracef("received DNS request for DNS forwarder: domain=%s type=%s class=%s",
|
||||
question.Name, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
|
||||
|
||||
domain := strings.ToLower(question.Name)
|
||||
|
||||
resp := query.SetReply(query)
|
||||
var network string
|
||||
switch question.Qtype {
|
||||
case dns.TypeA:
|
||||
network = "ip4"
|
||||
case dns.TypeAAAA:
|
||||
network = "ip6"
|
||||
default:
|
||||
// TODO: Handle other types
|
||||
|
||||
network := resutil.NetworkForQtype(question.Qtype)
|
||||
if network == "" {
|
||||
resp.Rcode = dns.RcodeNotImplemented
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
log.Errorf("failed to write DNS response: %v", err)
|
||||
logger.Errorf("failed to write DNS response: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -221,33 +215,35 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
|
||||
if mostSpecificResId == "" {
|
||||
resp.Rcode = dns.RcodeRefused
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
log.Errorf("failed to write DNS response: %v", err)
|
||||
logger.Errorf("failed to write DNS response: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
|
||||
defer cancel()
|
||||
ips, err := f.resolver.LookupNetIP(ctx, network, domain)
|
||||
if err != nil {
|
||||
f.handleDNSError(ctx, w, question, resp, domain, err)
|
||||
|
||||
result := resutil.LookupIP(ctx, f.resolver, network, domain, question.Qtype)
|
||||
if result.Err != nil {
|
||||
f.handleDNSError(ctx, logger, w, question, resp, domain, result)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Unmap IPv4-mapped IPv6 addresses that some resolvers may return
|
||||
for i, ip := range ips {
|
||||
ips[i] = ip.Unmap()
|
||||
}
|
||||
|
||||
f.updateInternalState(ips, mostSpecificResId, matchingEntries)
|
||||
f.addIPsToResponse(resp, domain, ips)
|
||||
f.cache.set(domain, question.Qtype, ips)
|
||||
f.updateInternalState(result.IPs, mostSpecificResId, matchingEntries)
|
||||
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, result.IPs, f.ttl)...)
|
||||
f.cache.set(domain, question.Qtype, result.IPs)
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
|
||||
resp := f.handleDNSQuery(w, query)
|
||||
startTime := time.Now()
|
||||
logger := log.WithFields(log.Fields{
|
||||
"request_id": resutil.GenerateRequestID(),
|
||||
"dns_id": fmt.Sprintf("%04x", query.Id),
|
||||
})
|
||||
|
||||
resp := f.handleDNSQuery(logger, w, query)
|
||||
if resp == nil {
|
||||
return
|
||||
}
|
||||
@@ -265,19 +261,33 @@ func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
|
||||
}
|
||||
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
log.Errorf("failed to write DNS response: %v", err)
|
||||
logger.Errorf("failed to write DNS response: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
|
||||
query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
|
||||
resp := f.handleDNSQuery(w, query)
|
||||
startTime := time.Now()
|
||||
logger := log.WithFields(log.Fields{
|
||||
"request_id": resutil.GenerateRequestID(),
|
||||
"dns_id": fmt.Sprintf("%04x", query.Id),
|
||||
})
|
||||
|
||||
resp := f.handleDNSQuery(logger, w, query)
|
||||
if resp == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
log.Errorf("failed to write DNS response: %v", err)
|
||||
logger.Errorf("failed to write DNS response: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
|
||||
query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) {
|
||||
@@ -315,140 +325,64 @@ func (f *DNSForwarder) updateFirewall(matchingEntries []*ForwarderEntry, prefixe
|
||||
}
|
||||
}
|
||||
|
||||
// setResponseCodeForNotFound determines and sets the appropriate response code when IsNotFound is true
|
||||
// It distinguishes between NXDOMAIN (domain doesn't exist) and NODATA (domain exists but no records of requested type)
|
||||
//
|
||||
// LIMITATION: This function only checks A and AAAA record types to determine domain existence.
|
||||
// If a domain has only other record types (MX, TXT, CNAME, etc.) but no A/AAAA records,
|
||||
// it may incorrectly return NXDOMAIN instead of NODATA. This is acceptable since the forwarder
|
||||
// only handles A/AAAA queries and returns NOTIMP for other types.
|
||||
func (f *DNSForwarder) setResponseCodeForNotFound(ctx context.Context, resp *dns.Msg, domain string, originalQtype uint16) {
|
||||
// Try querying for a different record type to see if the domain exists
|
||||
// If the original query was for AAAA, try A. If it was for A, try AAAA.
|
||||
// This helps distinguish between NXDOMAIN and NODATA.
|
||||
var alternativeNetwork string
|
||||
switch originalQtype {
|
||||
case dns.TypeAAAA:
|
||||
alternativeNetwork = "ip4"
|
||||
case dns.TypeA:
|
||||
alternativeNetwork = "ip6"
|
||||
default:
|
||||
resp.Rcode = dns.RcodeNameError
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := f.resolver.LookupNetIP(ctx, alternativeNetwork, domain); err != nil {
|
||||
var dnsErr *net.DNSError
|
||||
if errors.As(err, &dnsErr) && dnsErr.IsNotFound {
|
||||
// Alternative query also returned not found - domain truly doesn't exist
|
||||
resp.Rcode = dns.RcodeNameError
|
||||
return
|
||||
}
|
||||
// Some other error (timeout, server failure, etc.) - can't determine, assume domain exists
|
||||
resp.Rcode = dns.RcodeSuccess
|
||||
return
|
||||
}
|
||||
|
||||
// Alternative query succeeded - domain exists but has no records of this type
|
||||
resp.Rcode = dns.RcodeSuccess
|
||||
}
|
||||
|
||||
// handleDNSError processes DNS lookup errors and sends an appropriate error response.
|
||||
func (f *DNSForwarder) handleDNSError(
|
||||
ctx context.Context,
|
||||
logger *log.Entry,
|
||||
w dns.ResponseWriter,
|
||||
question dns.Question,
|
||||
resp *dns.Msg,
|
||||
domain string,
|
||||
err error,
|
||||
result resutil.LookupResult,
|
||||
) {
|
||||
// Default to SERVFAIL; override below when appropriate.
|
||||
resp.Rcode = dns.RcodeServerFailure
|
||||
|
||||
qType := question.Qtype
|
||||
qTypeName := dns.TypeToString[qType]
|
||||
|
||||
// Prefer typed DNS errors; fall back to generic logging otherwise.
|
||||
var dnsErr *net.DNSError
|
||||
if !errors.As(err, &dnsErr) {
|
||||
log.Warnf(errResolveFailed, domain, err)
|
||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||
log.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
resp.Rcode = result.Rcode
|
||||
|
||||
// NotFound: set NXDOMAIN / appropriate code via helper.
|
||||
if dnsErr.IsNotFound {
|
||||
f.setResponseCodeForNotFound(ctx, resp, domain, qType)
|
||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||
log.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||
}
|
||||
// NotFound: cache negative result and respond
|
||||
if result.Rcode == dns.RcodeNameError || result.Rcode == dns.RcodeSuccess {
|
||||
f.cache.set(domain, question.Qtype, nil)
|
||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||
logger.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Upstream failed but we might have a cached answer—serve it if present.
|
||||
if ips, ok := f.cache.get(domain, qType); ok {
|
||||
if len(ips) > 0 {
|
||||
log.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName)
|
||||
f.addIPsToResponse(resp, domain, ips)
|
||||
logger.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName)
|
||||
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, ips, f.ttl)...)
|
||||
resp.Rcode = dns.RcodeSuccess
|
||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||
log.Errorf("failed to write cached DNS response: %v", writeErr)
|
||||
}
|
||||
} else { // send NXDOMAIN / appropriate code if cache is empty
|
||||
f.setResponseCodeForNotFound(ctx, resp, domain, qType)
|
||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||
log.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||
logger.Errorf("failed to write cached DNS response: %v", writeErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Cached negative result - re-verify NXDOMAIN vs NODATA
|
||||
verifyResult := resutil.LookupIP(ctx, f.resolver, resutil.NetworkForQtype(qType), domain, qType)
|
||||
if verifyResult.Rcode == dns.RcodeNameError || verifyResult.Rcode == dns.RcodeSuccess {
|
||||
resp.Rcode = verifyResult.Rcode
|
||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||
logger.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// No cache. Log with or without the server field for more context.
|
||||
if dnsErr.Server != "" {
|
||||
log.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, err)
|
||||
// No cache or verification failed. Log with or without the server field for more context.
|
||||
var dnsErr *net.DNSError
|
||||
if errors.As(result.Err, &dnsErr) && dnsErr.Server != "" {
|
||||
logger.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, result.Err)
|
||||
} else {
|
||||
log.Warnf(errResolveFailed, domain, err)
|
||||
logger.Warnf(errResolveFailed, domain, result.Err)
|
||||
}
|
||||
|
||||
// Write final failure response.
|
||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||
log.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||
}
|
||||
}
|
||||
|
||||
// addIPsToResponse adds IP addresses to the DNS response as appropriate A or AAAA records
|
||||
func (f *DNSForwarder) addIPsToResponse(resp *dns.Msg, domain string, ips []netip.Addr) {
|
||||
for _, ip := range ips {
|
||||
var respRecord dns.RR
|
||||
if ip.Is6() {
|
||||
log.Tracef("resolved domain=%s to IPv6=%s", domain, ip)
|
||||
rr := dns.AAAA{
|
||||
AAAA: ip.AsSlice(),
|
||||
Hdr: dns.RR_Header{
|
||||
Name: domain,
|
||||
Rrtype: dns.TypeAAAA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: f.ttl,
|
||||
},
|
||||
}
|
||||
respRecord = &rr
|
||||
} else {
|
||||
log.Tracef("resolved domain=%s to IPv4=%s", domain, ip)
|
||||
rr := dns.A{
|
||||
A: ip.AsSlice(),
|
||||
Hdr: dns.RR_Header{
|
||||
Name: domain,
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: f.ttl,
|
||||
},
|
||||
}
|
||||
respRecord = &rr
|
||||
}
|
||||
resp.Answer = append(resp.Answer, respRecord)
|
||||
logger.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -317,7 +318,7 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
|
||||
query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA)
|
||||
|
||||
mockWriter := &test.MockResponseWriter{}
|
||||
resp := forwarder.handleDNSQuery(mockWriter, query)
|
||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
||||
|
||||
if tt.shouldResolve {
|
||||
require.NotNil(t, resp, "Expected response for authorized domain")
|
||||
@@ -465,7 +466,7 @@ func TestDNSForwarder_FirewallSetUpdates(t *testing.T) {
|
||||
dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA)
|
||||
|
||||
mockWriter := &test.MockResponseWriter{}
|
||||
resp := forwarder.handleDNSQuery(mockWriter, dnsQuery)
|
||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, dnsQuery)
|
||||
|
||||
// Verify response
|
||||
if tt.shouldResolve {
|
||||
@@ -527,7 +528,7 @@ func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
|
||||
query.SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
mockWriter := &test.MockResponseWriter{}
|
||||
resp := forwarder.handleDNSQuery(mockWriter, query)
|
||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
||||
|
||||
// Verify response contains all IPs
|
||||
require.NotNil(t, resp)
|
||||
@@ -604,7 +605,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
_ = forwarder.handleDNSQuery(mockWriter, query)
|
||||
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
||||
|
||||
// Check the response written to the writer
|
||||
require.NotNil(t, writtenResp, "Expected response to be written")
|
||||
@@ -674,7 +675,7 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
|
||||
q1 := &dns.Msg{}
|
||||
q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
||||
w1 := &test.MockResponseWriter{}
|
||||
resp1 := forwarder.handleDNSQuery(w1, q1)
|
||||
resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1)
|
||||
require.NotNil(t, resp1)
|
||||
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
||||
require.Len(t, resp1.Answer, 1)
|
||||
@@ -684,7 +685,7 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
|
||||
q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
||||
var writtenResp *dns.Msg
|
||||
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
|
||||
_ = forwarder.handleDNSQuery(w2, q2)
|
||||
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2)
|
||||
|
||||
require.NotNil(t, writtenResp, "expected response to be written")
|
||||
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
|
||||
@@ -714,7 +715,7 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
|
||||
q1 := &dns.Msg{}
|
||||
q1.SetQuestion(mixedQuery+".", dns.TypeA)
|
||||
w1 := &test.MockResponseWriter{}
|
||||
resp1 := forwarder.handleDNSQuery(w1, q1)
|
||||
resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1)
|
||||
require.NotNil(t, resp1)
|
||||
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
||||
require.Len(t, resp1.Answer, 1)
|
||||
@@ -728,7 +729,7 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
|
||||
q2.SetQuestion("EXAMPLE.COM", dns.TypeA)
|
||||
var writtenResp *dns.Msg
|
||||
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
|
||||
_ = forwarder.handleDNSQuery(w2, q2)
|
||||
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2)
|
||||
|
||||
require.NotNil(t, writtenResp)
|
||||
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
|
||||
@@ -783,7 +784,7 @@ func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
|
||||
query.SetQuestion("smtp.mail.example.com.", dns.TypeA)
|
||||
|
||||
mockWriter := &test.MockResponseWriter{}
|
||||
resp := forwarder.handleDNSQuery(mockWriter, query)
|
||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
||||
|
||||
require.NotNil(t, resp)
|
||||
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||
@@ -904,7 +905,7 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
resp := forwarder.handleDNSQuery(mockWriter, query)
|
||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
||||
|
||||
// If a response was returned, it means it should be written (happens in wrapper functions)
|
||||
if resp != nil && writtenResp == nil {
|
||||
@@ -937,7 +938,7 @@ func TestDNSForwarder_EmptyQuery(t *testing.T) {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
resp := forwarder.handleDNSQuery(mockWriter, query)
|
||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
||||
|
||||
assert.Nil(t, resp, "Should return nil for empty query")
|
||||
assert.False(t, writeCalled, "Should not write response for empty query")
|
||||
|
||||
@@ -1251,11 +1251,16 @@ func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns
|
||||
ForwarderPort: forwarderPort,
|
||||
}
|
||||
|
||||
for _, zone := range protoDNSConfig.GetCustomZones() {
|
||||
protoZones := protoDNSConfig.GetCustomZones()
|
||||
// Treat single zone as authoritative for backward compatibility with old servers
|
||||
// that only send the peer FQDN zone without setting field 4.
|
||||
singleZoneCompat := len(protoZones) == 1
|
||||
|
||||
for _, zone := range protoZones {
|
||||
dnsZone := nbdns.CustomZone{
|
||||
Domain: zone.GetDomain(),
|
||||
SearchDomainDisabled: zone.GetSearchDomainDisabled(),
|
||||
SkipPTRProcess: zone.GetSkipPTRProcess(),
|
||||
NonAuthoritative: zone.GetNonAuthoritative() && !singleZoneCompat,
|
||||
}
|
||||
for _, record := range zone.Records {
|
||||
dnsRecord := nbdns.SimpleRecord{
|
||||
@@ -1743,22 +1748,26 @@ func (e *Engine) RunHealthProbes(waitForResult bool) bool {
|
||||
}
|
||||
|
||||
e.syncMsgMux.Unlock()
|
||||
var results []relay.ProbeResult
|
||||
if waitForResult {
|
||||
results = e.probeStunTurn.ProbeAllWaitResult(e.ctx, stuns, turns)
|
||||
} else {
|
||||
results = e.probeStunTurn.ProbeAll(e.ctx, stuns, turns)
|
||||
}
|
||||
e.statusRecorder.UpdateRelayStates(results)
|
||||
|
||||
// Skip STUN/TURN probing for JS/WASM as it's not available
|
||||
relayHealthy := true
|
||||
for _, res := range results {
|
||||
if res.Err != nil {
|
||||
relayHealthy = false
|
||||
break
|
||||
if runtime.GOOS != "js" {
|
||||
var results []relay.ProbeResult
|
||||
if waitForResult {
|
||||
results = e.probeStunTurn.ProbeAllWaitResult(e.ctx, stuns, turns)
|
||||
} else {
|
||||
results = e.probeStunTurn.ProbeAll(e.ctx, stuns, turns)
|
||||
}
|
||||
e.statusRecorder.UpdateRelayStates(results)
|
||||
|
||||
for _, res := range results {
|
||||
if res.Err != nil {
|
||||
relayHealthy = false
|
||||
break
|
||||
}
|
||||
}
|
||||
log.Debugf("relay health check: healthy=%t", relayHealthy)
|
||||
}
|
||||
log.Debugf("relay health check: healthy=%t", relayHealthy)
|
||||
|
||||
allHealthy := signalHealthy && managementHealthy && relayHealthy
|
||||
log.Debugf("all health checks completed: healthy=%t", allHealthy)
|
||||
|
||||
@@ -72,9 +72,16 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
||||
}
|
||||
|
||||
if protoJWT := sshConf.GetJwtConfig(); protoJWT != nil {
|
||||
audiences := protoJWT.GetAudiences()
|
||||
if len(audiences) == 0 && protoJWT.GetAudience() != "" {
|
||||
audiences = []string{protoJWT.GetAudience()}
|
||||
}
|
||||
|
||||
log.Debugf("starting SSH server with JWT authentication: audiences=%v", audiences)
|
||||
|
||||
jwtConfig := &sshserver.JWTConfig{
|
||||
Issuer: protoJWT.GetIssuer(),
|
||||
Audience: protoJWT.GetAudience(),
|
||||
Audiences: audiences,
|
||||
KeysLocation: protoJWT.GetKeysLocation(),
|
||||
MaxTokenAge: protoJWT.GetMaxTokenAge(),
|
||||
}
|
||||
|
||||
@@ -669,10 +669,17 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) {
|
||||
}
|
||||
}()
|
||||
|
||||
if runtime.GOOS != "js" && conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() {
|
||||
// For JS platform: only relay connection is supported
|
||||
if runtime.GOOS == "js" {
|
||||
return conn.statusRelay.Get() == worker.StatusConnected
|
||||
}
|
||||
|
||||
// For non-JS platforms: check ICE connection status
|
||||
if conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() {
|
||||
return false
|
||||
}
|
||||
|
||||
// If relay is supported with peer, it must also be connected
|
||||
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
|
||||
if conn.statusRelay.Get() == worker.StatusDisconnected {
|
||||
return false
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"golang.org/x/exp/maps"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
@@ -158,6 +159,7 @@ type FullStatus struct {
|
||||
NSGroupStates []NSGroupState
|
||||
NumOfForwardingRules int
|
||||
LazyConnectionEnabled bool
|
||||
Events []*proto.SystemEvent
|
||||
}
|
||||
|
||||
type StatusChangeSubscription struct {
|
||||
@@ -981,6 +983,7 @@ func (d *Status) GetFullStatus() FullStatus {
|
||||
}
|
||||
|
||||
fullStatus.Peers = append(fullStatus.Peers, d.offlinePeers...)
|
||||
fullStatus.Events = d.GetEventHistory()
|
||||
return fullStatus
|
||||
}
|
||||
|
||||
@@ -1181,3 +1184,97 @@ type EventSubscription struct {
|
||||
func (s *EventSubscription) Events() <-chan *proto.SystemEvent {
|
||||
return s.events
|
||||
}
|
||||
|
||||
// ToProto converts FullStatus to proto.FullStatus.
|
||||
func (fs FullStatus) ToProto() *proto.FullStatus {
|
||||
pbFullStatus := proto.FullStatus{
|
||||
ManagementState: &proto.ManagementState{},
|
||||
SignalState: &proto.SignalState{},
|
||||
LocalPeerState: &proto.LocalPeerState{},
|
||||
Peers: []*proto.PeerState{},
|
||||
}
|
||||
|
||||
pbFullStatus.ManagementState.URL = fs.ManagementState.URL
|
||||
pbFullStatus.ManagementState.Connected = fs.ManagementState.Connected
|
||||
if err := fs.ManagementState.Error; err != nil {
|
||||
pbFullStatus.ManagementState.Error = err.Error()
|
||||
}
|
||||
|
||||
pbFullStatus.SignalState.URL = fs.SignalState.URL
|
||||
pbFullStatus.SignalState.Connected = fs.SignalState.Connected
|
||||
if err := fs.SignalState.Error; err != nil {
|
||||
pbFullStatus.SignalState.Error = err.Error()
|
||||
}
|
||||
|
||||
pbFullStatus.LocalPeerState.IP = fs.LocalPeerState.IP
|
||||
pbFullStatus.LocalPeerState.PubKey = fs.LocalPeerState.PubKey
|
||||
pbFullStatus.LocalPeerState.KernelInterface = fs.LocalPeerState.KernelInterface
|
||||
pbFullStatus.LocalPeerState.Fqdn = fs.LocalPeerState.FQDN
|
||||
pbFullStatus.LocalPeerState.RosenpassPermissive = fs.RosenpassState.Permissive
|
||||
pbFullStatus.LocalPeerState.RosenpassEnabled = fs.RosenpassState.Enabled
|
||||
pbFullStatus.NumberOfForwardingRules = int32(fs.NumOfForwardingRules)
|
||||
pbFullStatus.LazyConnectionEnabled = fs.LazyConnectionEnabled
|
||||
|
||||
pbFullStatus.LocalPeerState.Networks = maps.Keys(fs.LocalPeerState.Routes)
|
||||
|
||||
for _, peerState := range fs.Peers {
|
||||
networks := maps.Keys(peerState.GetRoutes())
|
||||
|
||||
pbPeerState := &proto.PeerState{
|
||||
IP: peerState.IP,
|
||||
PubKey: peerState.PubKey,
|
||||
ConnStatus: peerState.ConnStatus.String(),
|
||||
ConnStatusUpdate: timestamppb.New(peerState.ConnStatusUpdate),
|
||||
Relayed: peerState.Relayed,
|
||||
LocalIceCandidateType: peerState.LocalIceCandidateType,
|
||||
RemoteIceCandidateType: peerState.RemoteIceCandidateType,
|
||||
LocalIceCandidateEndpoint: peerState.LocalIceCandidateEndpoint,
|
||||
RemoteIceCandidateEndpoint: peerState.RemoteIceCandidateEndpoint,
|
||||
RelayAddress: peerState.RelayServerAddress,
|
||||
Fqdn: peerState.FQDN,
|
||||
LastWireguardHandshake: timestamppb.New(peerState.LastWireguardHandshake),
|
||||
BytesRx: peerState.BytesRx,
|
||||
BytesTx: peerState.BytesTx,
|
||||
RosenpassEnabled: peerState.RosenpassEnabled,
|
||||
Networks: networks,
|
||||
Latency: durationpb.New(peerState.Latency),
|
||||
SshHostKey: peerState.SSHHostKey,
|
||||
}
|
||||
pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState)
|
||||
}
|
||||
|
||||
for _, relayState := range fs.Relays {
|
||||
pbRelayState := &proto.RelayState{
|
||||
URI: relayState.URI,
|
||||
Available: relayState.Err == nil,
|
||||
}
|
||||
if err := relayState.Err; err != nil {
|
||||
pbRelayState.Error = err.Error()
|
||||
}
|
||||
pbFullStatus.Relays = append(pbFullStatus.Relays, pbRelayState)
|
||||
}
|
||||
|
||||
for _, dnsState := range fs.NSGroupStates {
|
||||
var err string
|
||||
if dnsState.Error != nil {
|
||||
err = dnsState.Error.Error()
|
||||
}
|
||||
|
||||
var servers []string
|
||||
for _, server := range dnsState.Servers {
|
||||
servers = append(servers, server.String())
|
||||
}
|
||||
|
||||
pbDnsState := &proto.NSGroupState{
|
||||
Servers: servers,
|
||||
Domains: dnsState.Domains,
|
||||
Enabled: dnsState.Enabled,
|
||||
Error: err,
|
||||
}
|
||||
pbFullStatus.DnsServers = append(pbFullStatus.DnsServers, pbDnsState)
|
||||
}
|
||||
|
||||
pbFullStatus.Events = fs.Events
|
||||
|
||||
return &pbFullStatus
|
||||
}
|
||||
|
||||
@@ -17,12 +17,13 @@ import (
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/common"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/fakeip"
|
||||
iface "github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
@@ -37,11 +38,6 @@ type internalDNATer interface {
|
||||
AddInternalDNATMapping(netip.Addr, netip.Addr) error
|
||||
}
|
||||
|
||||
type wgInterface interface {
|
||||
Name() string
|
||||
Address() wgaddr.Address
|
||||
}
|
||||
|
||||
type DnsInterceptor struct {
|
||||
mu sync.RWMutex
|
||||
route *route.Route
|
||||
@@ -51,7 +47,7 @@ type DnsInterceptor struct {
|
||||
dnsServer nbdns.Server
|
||||
currentPeerKey string
|
||||
interceptedDomains domainMap
|
||||
wgInterface wgInterface
|
||||
wgInterface iface.WGIface
|
||||
peerStore *peerstore.Store
|
||||
firewall firewall.Manager
|
||||
fakeIPManager *fakeip.Manager
|
||||
@@ -219,14 +215,14 @@ func (d *DnsInterceptor) RemoveAllowedIPs() error {
|
||||
|
||||
// ServeDNS implements the dns.Handler interface
|
||||
func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
requestID := nbdns.GenerateRequestID()
|
||||
logger := log.WithField("request_id", requestID)
|
||||
logger := log.WithFields(log.Fields{
|
||||
"request_id": resutil.GetRequestID(w),
|
||||
"dns_id": fmt.Sprintf("%04x", r.Id),
|
||||
})
|
||||
|
||||
if len(r.Question) == 0 {
|
||||
return
|
||||
}
|
||||
logger.Tracef("received DNS request for domain=%s type=%v class=%v",
|
||||
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||
|
||||
// pass if non A/AAAA query
|
||||
if r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA {
|
||||
@@ -249,12 +245,6 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
return
|
||||
}
|
||||
|
||||
client, err := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), dnsTimeout)
|
||||
if err != nil {
|
||||
d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if r.Extra == nil {
|
||||
r.MsgHdr.AuthenticatedData = true
|
||||
}
|
||||
@@ -263,32 +253,15 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
|
||||
defer cancel()
|
||||
|
||||
startTime := time.Now()
|
||||
reply, _, err := nbdns.ExchangeWithFallback(ctx, client, r, upstream)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
elapsed := time.Since(startTime)
|
||||
peerInfo := d.debugPeerTimeout(upstreamIP, peerKey)
|
||||
logger.Errorf("peer DNS timeout after %v (timeout=%v) for domain=%s to peer %s (%s)%s - error: %v",
|
||||
elapsed.Truncate(time.Millisecond), dnsTimeout, r.Question[0].Name, upstreamIP.String(), peerKey, peerInfo, err)
|
||||
} else {
|
||||
logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
|
||||
}
|
||||
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
|
||||
logger.Errorf("failed writing DNS response: %v", err)
|
||||
}
|
||||
reply := d.queryUpstreamDNS(ctx, w, r, upstream, upstreamIP, peerKey, logger)
|
||||
if reply == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var answer []dns.RR
|
||||
if reply != nil {
|
||||
answer = reply.Answer
|
||||
}
|
||||
|
||||
logger.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP.String(), peerKey, r.Question[0].Name, answer)
|
||||
resutil.SetMeta(w, "peer", peerKey)
|
||||
|
||||
reply.Id = r.Id
|
||||
if err := d.writeMsg(w, reply); err != nil {
|
||||
if err := d.writeMsg(w, reply, logger); err != nil {
|
||||
logger.Errorf("failed writing DNS response: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -324,11 +297,15 @@ func (d *DnsInterceptor) getUpstreamIP(peerKey string) (netip.Addr, error) {
|
||||
return peerAllowedIP, nil
|
||||
}
|
||||
|
||||
func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
|
||||
func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) error {
|
||||
if r == nil {
|
||||
return fmt.Errorf("received nil DNS message")
|
||||
}
|
||||
|
||||
// Clear Zero bit from peer responses to prevent external sources from
|
||||
// manipulating our internal fallthrough signaling mechanism
|
||||
r.MsgHdr.Zero = false
|
||||
|
||||
if len(r.Answer) > 0 && len(r.Question) > 0 {
|
||||
origPattern := ""
|
||||
if writer, ok := w.(*nbdns.ResponseWriterChain); ok {
|
||||
@@ -350,14 +327,14 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
|
||||
case *dns.A:
|
||||
addr, ok := netip.AddrFromSlice(rr.A)
|
||||
if !ok {
|
||||
log.Tracef("failed to convert A record for domain=%s ip=%v", resolvedDomain, rr.A)
|
||||
logger.Tracef("failed to convert A record for domain=%s ip=%v", resolvedDomain, rr.A)
|
||||
continue
|
||||
}
|
||||
ip = addr
|
||||
case *dns.AAAA:
|
||||
addr, ok := netip.AddrFromSlice(rr.AAAA)
|
||||
if !ok {
|
||||
log.Tracef("failed to convert AAAA record for domain=%s ip=%v", resolvedDomain, rr.AAAA)
|
||||
logger.Tracef("failed to convert AAAA record for domain=%s ip=%v", resolvedDomain, rr.AAAA)
|
||||
continue
|
||||
}
|
||||
ip = addr
|
||||
@@ -370,11 +347,11 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
|
||||
}
|
||||
|
||||
if len(newPrefixes) > 0 {
|
||||
if err := d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes); err != nil {
|
||||
log.Errorf("failed to update domain prefixes: %v", err)
|
||||
if err := d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes, logger); err != nil {
|
||||
logger.Errorf("failed to update domain prefixes: %v", err)
|
||||
}
|
||||
|
||||
d.replaceIPsInDNSResponse(r, newPrefixes)
|
||||
d.replaceIPsInDNSResponse(r, newPrefixes, logger)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -386,22 +363,22 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
|
||||
}
|
||||
|
||||
// logPrefixChanges handles the logging for prefix changes
|
||||
func (d *DnsInterceptor) logPrefixChanges(resolvedDomain, originalDomain domain.Domain, toAdd, toRemove []netip.Prefix) {
|
||||
func (d *DnsInterceptor) logPrefixChanges(resolvedDomain, originalDomain domain.Domain, toAdd, toRemove []netip.Prefix, logger *log.Entry) {
|
||||
if len(toAdd) > 0 {
|
||||
log.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s",
|
||||
logger.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s",
|
||||
resolvedDomain.SafeString(),
|
||||
originalDomain.SafeString(),
|
||||
toAdd)
|
||||
}
|
||||
if len(toRemove) > 0 && !d.route.KeepRoute {
|
||||
log.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s",
|
||||
logger.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s",
|
||||
resolvedDomain.SafeString(),
|
||||
originalDomain.SafeString(),
|
||||
toRemove)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain domain.Domain, newPrefixes []netip.Prefix) error {
|
||||
func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain domain.Domain, newPrefixes []netip.Prefix, logger *log.Entry) error {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
@@ -418,9 +395,9 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
|
||||
realIP := prefix.Addr()
|
||||
if fakeIP, err := d.fakeIPManager.AllocateFakeIP(realIP); err == nil {
|
||||
dnatMappings[fakeIP] = realIP
|
||||
log.Tracef("allocated fake IP %s for real IP %s", fakeIP, realIP)
|
||||
logger.Tracef("allocated fake IP %s for real IP %s", fakeIP, realIP)
|
||||
} else {
|
||||
log.Errorf("Failed to allocate fake IP for %s: %v", realIP, err)
|
||||
logger.Errorf("failed to allocate fake IP for %s: %v", realIP, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -432,7 +409,7 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
|
||||
}
|
||||
}
|
||||
|
||||
d.addDNATMappings(dnatMappings)
|
||||
d.addDNATMappings(dnatMappings, logger)
|
||||
|
||||
if !d.route.KeepRoute {
|
||||
// Remove old prefixes
|
||||
@@ -448,7 +425,7 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
|
||||
}
|
||||
}
|
||||
|
||||
d.removeDNATMappings(toRemove)
|
||||
d.removeDNATMappings(toRemove, logger)
|
||||
}
|
||||
|
||||
// Update domain prefixes using resolved domain as key - store real IPs
|
||||
@@ -463,14 +440,14 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
|
||||
// Store real IPs for status (user-facing), not fake IPs
|
||||
d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes, d.route.GetResourceID())
|
||||
|
||||
d.logPrefixChanges(resolvedDomain, originalDomain, toAdd, toRemove)
|
||||
d.logPrefixChanges(resolvedDomain, originalDomain, toAdd, toRemove, logger)
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// removeDNATMappings removes DNAT mappings from the firewall for real IP prefixes
|
||||
func (d *DnsInterceptor) removeDNATMappings(realPrefixes []netip.Prefix) {
|
||||
func (d *DnsInterceptor) removeDNATMappings(realPrefixes []netip.Prefix, logger *log.Entry) {
|
||||
if len(realPrefixes) == 0 {
|
||||
return
|
||||
}
|
||||
@@ -484,9 +461,9 @@ func (d *DnsInterceptor) removeDNATMappings(realPrefixes []netip.Prefix) {
|
||||
realIP := prefix.Addr()
|
||||
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
|
||||
if err := dnatFirewall.RemoveInternalDNATMapping(fakeIP); err != nil {
|
||||
log.Errorf("Failed to remove DNAT mapping for %s: %v", fakeIP, err)
|
||||
logger.Errorf("failed to remove DNAT mapping for %s: %v", fakeIP, err)
|
||||
} else {
|
||||
log.Debugf("Removed DNAT mapping for: %s -> %s", fakeIP, realIP)
|
||||
logger.Debugf("removed DNAT mapping: %s -> %s", fakeIP, realIP)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -502,7 +479,7 @@ func (d *DnsInterceptor) internalDnatFw() (internalDNATer, bool) {
|
||||
}
|
||||
|
||||
// addDNATMappings adds DNAT mappings to the firewall
|
||||
func (d *DnsInterceptor) addDNATMappings(mappings map[netip.Addr]netip.Addr) {
|
||||
func (d *DnsInterceptor) addDNATMappings(mappings map[netip.Addr]netip.Addr, logger *log.Entry) {
|
||||
if len(mappings) == 0 {
|
||||
return
|
||||
}
|
||||
@@ -514,9 +491,9 @@ func (d *DnsInterceptor) addDNATMappings(mappings map[netip.Addr]netip.Addr) {
|
||||
|
||||
for fakeIP, realIP := range mappings {
|
||||
if err := dnatFirewall.AddInternalDNATMapping(fakeIP, realIP); err != nil {
|
||||
log.Errorf("Failed to add DNAT mapping %s -> %s: %v", fakeIP, realIP, err)
|
||||
logger.Errorf("failed to add DNAT mapping %s -> %s: %v", fakeIP, realIP, err)
|
||||
} else {
|
||||
log.Debugf("Added DNAT mapping: %s -> %s", fakeIP, realIP)
|
||||
logger.Debugf("added DNAT mapping: %s -> %s", fakeIP, realIP)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -528,12 +505,12 @@ func (d *DnsInterceptor) cleanupDNATMappings() {
|
||||
}
|
||||
|
||||
for _, prefixes := range d.interceptedDomains {
|
||||
d.removeDNATMappings(prefixes)
|
||||
d.removeDNATMappings(prefixes, log.NewEntry(log.StandardLogger()))
|
||||
}
|
||||
}
|
||||
|
||||
// replaceIPsInDNSResponse replaces real IPs with fake IPs in the DNS response
|
||||
func (d *DnsInterceptor) replaceIPsInDNSResponse(reply *dns.Msg, realPrefixes []netip.Prefix) {
|
||||
func (d *DnsInterceptor) replaceIPsInDNSResponse(reply *dns.Msg, realPrefixes []netip.Prefix, logger *log.Entry) {
|
||||
if _, ok := d.internalDnatFw(); !ok {
|
||||
return
|
||||
}
|
||||
@@ -549,7 +526,7 @@ func (d *DnsInterceptor) replaceIPsInDNSResponse(reply *dns.Msg, realPrefixes []
|
||||
|
||||
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
|
||||
rr.A = fakeIP.AsSlice()
|
||||
log.Tracef("Replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP)
|
||||
logger.Tracef("replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP)
|
||||
}
|
||||
|
||||
case *dns.AAAA:
|
||||
@@ -560,7 +537,7 @@ func (d *DnsInterceptor) replaceIPsInDNSResponse(reply *dns.Msg, realPrefixes []
|
||||
|
||||
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
|
||||
rr.AAAA = fakeIP.AsSlice()
|
||||
log.Tracef("Replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP)
|
||||
logger.Tracef("replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -586,6 +563,44 @@ func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toR
|
||||
return
|
||||
}
|
||||
|
||||
// queryUpstreamDNS queries the upstream DNS server using netstack if available, otherwise uses regular client.
|
||||
// Returns the DNS reply on success, or nil on error (error responses are written internally).
|
||||
func (d *DnsInterceptor) queryUpstreamDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, upstream string, upstreamIP netip.Addr, peerKey string, logger *log.Entry) *dns.Msg {
|
||||
startTime := time.Now()
|
||||
|
||||
nsNet := d.wgInterface.GetNet()
|
||||
var reply *dns.Msg
|
||||
var err error
|
||||
|
||||
if nsNet != nil {
|
||||
reply, err = nbdns.ExchangeWithNetstack(ctx, nsNet, r, upstream)
|
||||
} else {
|
||||
client, clientErr := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), dnsTimeout)
|
||||
if clientErr != nil {
|
||||
d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", clientErr))
|
||||
return nil
|
||||
}
|
||||
reply, _, err = nbdns.ExchangeWithFallback(ctx, client, r, upstream)
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
return reply
|
||||
}
|
||||
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
elapsed := time.Since(startTime)
|
||||
peerInfo := d.debugPeerTimeout(upstreamIP, peerKey)
|
||||
logger.Errorf("peer DNS timeout after %v (timeout=%v) for domain=%s to peer %s (%s)%s - error: %v",
|
||||
elapsed.Truncate(time.Millisecond), dnsTimeout, r.Question[0].Name, upstreamIP.String(), peerKey, peerInfo, err)
|
||||
} else {
|
||||
logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
|
||||
}
|
||||
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
|
||||
logger.Errorf("failed writing DNS response: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DnsInterceptor) debugPeerTimeout(peerIP netip.Addr, peerKey string) string {
|
||||
if d.statusRecorder == nil {
|
||||
return ""
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
@@ -18,4 +20,5 @@ type wgIfaceBase interface {
|
||||
IsUserspaceBind() bool
|
||||
GetFilter() device.PacketFilter
|
||||
GetDevice() *device.FilteredDevice
|
||||
GetNet() *netstack.Net
|
||||
}
|
||||
|
||||
@@ -173,20 +173,9 @@ func (s *Server) SetLogLevel(_ context.Context, req *proto.SetLogLevelRequest) (
|
||||
|
||||
log.SetLevel(level)
|
||||
|
||||
if s.connectClient == nil {
|
||||
return nil, fmt.Errorf("connect client not initialized")
|
||||
if s.connectClient != nil {
|
||||
s.connectClient.SetLogLevel(level)
|
||||
}
|
||||
engine := s.connectClient.Engine()
|
||||
if engine == nil {
|
||||
return nil, fmt.Errorf("engine not initialized")
|
||||
}
|
||||
|
||||
fwManager := engine.GetFirewallManager()
|
||||
if fwManager == nil {
|
||||
return nil, fmt.Errorf("firewall manager not initialized")
|
||||
}
|
||||
|
||||
fwManager.SetLogLevel(level)
|
||||
|
||||
log.Infof("Log level set to %s", level.String())
|
||||
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
@@ -29,8 +27,3 @@ func (s *Server) SubscribeEvents(req *proto.SubscribeRequest, stream proto.Daemo
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) GetEvents(context.Context, *proto.GetEventsRequest) (*proto.GetEventsResponse, error) {
|
||||
events := s.statusRecorder.GetEventHistory()
|
||||
return &proto.GetEventsResponse{Events: events}, nil
|
||||
}
|
||||
|
||||
@@ -13,15 +13,12 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/auth"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
@@ -1067,11 +1064,9 @@ func (s *Server) Status(
|
||||
if msg.GetFullPeerStatus {
|
||||
s.runProbes(msg.ShouldRunProbes)
|
||||
fullStatus := s.statusRecorder.GetFullStatus()
|
||||
pbFullStatus := toProtoFullStatus(fullStatus)
|
||||
pbFullStatus := fullStatus.ToProto()
|
||||
pbFullStatus.Events = s.statusRecorder.GetEventHistory()
|
||||
|
||||
pbFullStatus.SshServerState = s.getSSHServerState()
|
||||
|
||||
statusResponse.FullStatus = pbFullStatus
|
||||
}
|
||||
|
||||
@@ -1600,94 +1595,6 @@ func parseEnvDuration(envVar string, defaultDuration time.Duration) time.Duratio
|
||||
return defaultDuration
|
||||
}
|
||||
|
||||
func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
|
||||
pbFullStatus := proto.FullStatus{
|
||||
ManagementState: &proto.ManagementState{},
|
||||
SignalState: &proto.SignalState{},
|
||||
LocalPeerState: &proto.LocalPeerState{},
|
||||
Peers: []*proto.PeerState{},
|
||||
}
|
||||
|
||||
pbFullStatus.ManagementState.URL = fullStatus.ManagementState.URL
|
||||
pbFullStatus.ManagementState.Connected = fullStatus.ManagementState.Connected
|
||||
if err := fullStatus.ManagementState.Error; err != nil {
|
||||
pbFullStatus.ManagementState.Error = err.Error()
|
||||
}
|
||||
|
||||
pbFullStatus.SignalState.URL = fullStatus.SignalState.URL
|
||||
pbFullStatus.SignalState.Connected = fullStatus.SignalState.Connected
|
||||
if err := fullStatus.SignalState.Error; err != nil {
|
||||
pbFullStatus.SignalState.Error = err.Error()
|
||||
}
|
||||
|
||||
pbFullStatus.LocalPeerState.IP = fullStatus.LocalPeerState.IP
|
||||
pbFullStatus.LocalPeerState.PubKey = fullStatus.LocalPeerState.PubKey
|
||||
pbFullStatus.LocalPeerState.KernelInterface = fullStatus.LocalPeerState.KernelInterface
|
||||
pbFullStatus.LocalPeerState.Fqdn = fullStatus.LocalPeerState.FQDN
|
||||
pbFullStatus.LocalPeerState.RosenpassPermissive = fullStatus.RosenpassState.Permissive
|
||||
pbFullStatus.LocalPeerState.RosenpassEnabled = fullStatus.RosenpassState.Enabled
|
||||
pbFullStatus.LocalPeerState.Networks = maps.Keys(fullStatus.LocalPeerState.Routes)
|
||||
pbFullStatus.NumberOfForwardingRules = int32(fullStatus.NumOfForwardingRules)
|
||||
pbFullStatus.LazyConnectionEnabled = fullStatus.LazyConnectionEnabled
|
||||
|
||||
for _, peerState := range fullStatus.Peers {
|
||||
pbPeerState := &proto.PeerState{
|
||||
IP: peerState.IP,
|
||||
PubKey: peerState.PubKey,
|
||||
ConnStatus: peerState.ConnStatus.String(),
|
||||
ConnStatusUpdate: timestamppb.New(peerState.ConnStatusUpdate),
|
||||
Relayed: peerState.Relayed,
|
||||
LocalIceCandidateType: peerState.LocalIceCandidateType,
|
||||
RemoteIceCandidateType: peerState.RemoteIceCandidateType,
|
||||
LocalIceCandidateEndpoint: peerState.LocalIceCandidateEndpoint,
|
||||
RemoteIceCandidateEndpoint: peerState.RemoteIceCandidateEndpoint,
|
||||
RelayAddress: peerState.RelayServerAddress,
|
||||
Fqdn: peerState.FQDN,
|
||||
LastWireguardHandshake: timestamppb.New(peerState.LastWireguardHandshake),
|
||||
BytesRx: peerState.BytesRx,
|
||||
BytesTx: peerState.BytesTx,
|
||||
RosenpassEnabled: peerState.RosenpassEnabled,
|
||||
Networks: maps.Keys(peerState.GetRoutes()),
|
||||
Latency: durationpb.New(peerState.Latency),
|
||||
SshHostKey: peerState.SSHHostKey,
|
||||
}
|
||||
pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState)
|
||||
}
|
||||
|
||||
for _, relayState := range fullStatus.Relays {
|
||||
pbRelayState := &proto.RelayState{
|
||||
URI: relayState.URI,
|
||||
Available: relayState.Err == nil,
|
||||
}
|
||||
if err := relayState.Err; err != nil {
|
||||
pbRelayState.Error = err.Error()
|
||||
}
|
||||
pbFullStatus.Relays = append(pbFullStatus.Relays, pbRelayState)
|
||||
}
|
||||
|
||||
for _, dnsState := range fullStatus.NSGroupStates {
|
||||
var err string
|
||||
if dnsState.Error != nil {
|
||||
err = dnsState.Error.Error()
|
||||
}
|
||||
|
||||
var servers []string
|
||||
for _, server := range dnsState.Servers {
|
||||
servers = append(servers, server.String())
|
||||
}
|
||||
|
||||
pbDnsState := &proto.NSGroupState{
|
||||
Servers: servers,
|
||||
Domains: dnsState.Domains,
|
||||
Enabled: dnsState.Enabled,
|
||||
Error: err,
|
||||
}
|
||||
pbFullStatus.DnsServers = append(pbFullStatus.DnsServers, pbDnsState)
|
||||
}
|
||||
|
||||
return &pbFullStatus
|
||||
}
|
||||
|
||||
// sendTerminalNotification sends a terminal notification message
|
||||
// to inform the user that the NetBird connection session has expired.
|
||||
func sendTerminalNotification() error {
|
||||
|
||||
@@ -132,7 +132,7 @@ func TestSSHProxy_Connect(t *testing.T) {
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: &server.JWTConfig{
|
||||
Issuer: issuer,
|
||||
Audience: audience,
|
||||
Audiences: []string{audience},
|
||||
KeysLocation: jwksURL,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -43,7 +43,7 @@ func TestJWTEnforcement(t *testing.T) {
|
||||
t.Run("blocks_without_jwt", func(t *testing.T) {
|
||||
jwtConfig := &JWTConfig{
|
||||
Issuer: "test-issuer",
|
||||
Audience: "test-audience",
|
||||
Audiences: []string{"test-audience"},
|
||||
KeysLocation: "test-keys",
|
||||
}
|
||||
serverConfig := &Config{
|
||||
@@ -202,7 +202,7 @@ func TestJWTDetection(t *testing.T) {
|
||||
|
||||
jwtConfig := &JWTConfig{
|
||||
Issuer: issuer,
|
||||
Audience: audience,
|
||||
Audiences: []string{audience},
|
||||
KeysLocation: jwksURL,
|
||||
}
|
||||
serverConfig := &Config{
|
||||
@@ -329,7 +329,7 @@ func TestJWTFailClose(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
jwtConfig := &JWTConfig{
|
||||
Issuer: issuer,
|
||||
Audience: audience,
|
||||
Audiences: []string{audience},
|
||||
KeysLocation: jwksURL,
|
||||
MaxTokenAge: 3600,
|
||||
}
|
||||
@@ -567,7 +567,7 @@ func TestJWTAuthentication(t *testing.T) {
|
||||
|
||||
jwtConfig := &JWTConfig{
|
||||
Issuer: issuer,
|
||||
Audience: audience,
|
||||
Audiences: []string{audience},
|
||||
KeysLocation: jwksURL,
|
||||
}
|
||||
serverConfig := &Config{
|
||||
@@ -646,3 +646,108 @@ func TestJWTAuthentication(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestJWTMultipleAudiences tests JWT validation with multiple audiences (dashboard and CLI).
|
||||
func TestJWTMultipleAudiences(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping JWT multiple audiences tests in short mode")
|
||||
}
|
||||
|
||||
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
|
||||
defer jwksServer.Close()
|
||||
|
||||
const (
|
||||
issuer = "https://test-issuer.example.com"
|
||||
dashboardAudience = "dashboard-audience"
|
||||
cliAudience = "cli-audience"
|
||||
)
|
||||
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
audience string
|
||||
wantAuthOK bool
|
||||
}{
|
||||
{
|
||||
name: "accepts_dashboard_audience",
|
||||
audience: dashboardAudience,
|
||||
wantAuthOK: true,
|
||||
},
|
||||
{
|
||||
name: "accepts_cli_audience",
|
||||
audience: cliAudience,
|
||||
wantAuthOK: true,
|
||||
},
|
||||
{
|
||||
name: "rejects_unknown_audience",
|
||||
audience: "unknown-audience",
|
||||
wantAuthOK: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
jwtConfig := &JWTConfig{
|
||||
Issuer: issuer,
|
||||
Audiences: []string{dashboardAudience, cliAudience},
|
||||
KeysLocation: jwksURL,
|
||||
}
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: jwtConfig,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
testUserHash, err := sshuserhash.HashUserID("test-user")
|
||||
require.NoError(t, err)
|
||||
|
||||
currentUser := testutil.GetTestUsername(t)
|
||||
authConfig := &sshauth.Config{
|
||||
UserIDClaim: sshauth.DefaultUserIDClaim,
|
||||
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
|
||||
MachineUsers: map[string][]uint32{
|
||||
currentUser: {0},
|
||||
},
|
||||
}
|
||||
server.UpdateSSHAuth(authConfig)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer require.NoError(t, server.Stop())
|
||||
|
||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||
require.NoError(t, err)
|
||||
|
||||
token := generateValidJWT(t, privateKey, issuer, tc.audience)
|
||||
config := &cryptossh.ClientConfig{
|
||||
User: testutil.GetTestUsername(t),
|
||||
Auth: []cryptossh.AuthMethod{
|
||||
cryptossh.Password(token),
|
||||
},
|
||||
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
|
||||
Timeout: 2 * time.Second,
|
||||
}
|
||||
|
||||
conn, err := cryptossh.Dial("tcp", net.JoinHostPort(host, portStr), config)
|
||||
if tc.wantAuthOK {
|
||||
require.NoError(t, err, "JWT authentication should succeed for audience %s", tc.audience)
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
t.Logf("close connection: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
session, err := conn.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer session.Close()
|
||||
|
||||
err = session.Shell()
|
||||
require.NoError(t, err, "Shell should work with valid audience")
|
||||
} else {
|
||||
assert.Error(t, err, "JWT authentication should fail for unknown audience")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -176,9 +176,9 @@ type Server struct {
|
||||
|
||||
type JWTConfig struct {
|
||||
Issuer string
|
||||
Audience string
|
||||
KeysLocation string
|
||||
MaxTokenAge int64
|
||||
Audiences []string
|
||||
}
|
||||
|
||||
// Config contains all SSH server configuration options
|
||||
@@ -427,18 +427,21 @@ func (s *Server) ensureJWTValidator() error {
|
||||
return fmt.Errorf("JWT config not set")
|
||||
}
|
||||
|
||||
log.Debugf("Initializing JWT validator (issuer: %s, audience: %s)", config.Issuer, config.Audience)
|
||||
if len(config.Audiences) == 0 {
|
||||
return fmt.Errorf("JWT config has no audiences configured")
|
||||
}
|
||||
|
||||
log.Debugf("Initializing JWT validator (issuer: %s, audiences: %v)", config.Issuer, config.Audiences)
|
||||
validator := jwt.NewValidator(
|
||||
config.Issuer,
|
||||
[]string{config.Audience},
|
||||
config.Audiences,
|
||||
config.KeysLocation,
|
||||
true,
|
||||
)
|
||||
|
||||
// Use custom userIDClaim from authorizer if available
|
||||
extractorOptions := []jwt.ClaimsExtractorOption{
|
||||
jwt.WithAudience(config.Audience),
|
||||
jwt.WithAudience(config.Audiences[0]),
|
||||
}
|
||||
if authorizer.GetUserIDClaim() != "" {
|
||||
extractorOptions = append(extractorOptions, jwt.WithUserIDClaim(authorizer.GetUserIDClaim()))
|
||||
@@ -475,8 +478,8 @@ func (s *Server) validateJWTToken(tokenString string) (*gojwt.Token, error) {
|
||||
if err != nil {
|
||||
if jwtConfig != nil {
|
||||
if claims, parseErr := s.parseTokenWithoutValidation(tokenString); parseErr == nil {
|
||||
return nil, fmt.Errorf("validate token (expected issuer=%s, audience=%s, actual issuer=%v, audience=%v): %w",
|
||||
jwtConfig.Issuer, jwtConfig.Audience, claims["iss"], claims["aud"], err)
|
||||
return nil, fmt.Errorf("validate token (expected issuer=%s, audiences=%v, actual issuer=%v, audience=%v): %w",
|
||||
jwtConfig.Issuer, jwtConfig.Audiences, claims["iss"], claims["aud"], err)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("validate token: %w", err)
|
||||
|
||||
@@ -325,61 +325,64 @@ func sortPeersByIP(peersStateDetail []PeerStateDetailOutput) {
|
||||
}
|
||||
}
|
||||
|
||||
func ParseToJSON(overview OutputOverview) (string, error) {
|
||||
jsonBytes, err := json.Marshal(overview)
|
||||
// JSON returns the status overview as a JSON string.
|
||||
func (o *OutputOverview) JSON() (string, error) {
|
||||
jsonBytes, err := json.Marshal(o)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("json marshal failed")
|
||||
}
|
||||
return string(jsonBytes), err
|
||||
}
|
||||
|
||||
func ParseToYAML(overview OutputOverview) (string, error) {
|
||||
yamlBytes, err := yaml.Marshal(overview)
|
||||
// YAML returns the status overview as a YAML string.
|
||||
func (o *OutputOverview) YAML() (string, error) {
|
||||
yamlBytes, err := yaml.Marshal(o)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("yaml marshal failed")
|
||||
}
|
||||
return string(yamlBytes), nil
|
||||
}
|
||||
|
||||
func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, showNameServers bool, showSSHSessions bool) string {
|
||||
// GeneralSummary returns a general summary of the status overview.
|
||||
func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameServers bool, showSSHSessions bool) string {
|
||||
var managementConnString string
|
||||
if overview.ManagementState.Connected {
|
||||
if o.ManagementState.Connected {
|
||||
managementConnString = "Connected"
|
||||
if showURL {
|
||||
managementConnString = fmt.Sprintf("%s to %s", managementConnString, overview.ManagementState.URL)
|
||||
managementConnString = fmt.Sprintf("%s to %s", managementConnString, o.ManagementState.URL)
|
||||
}
|
||||
} else {
|
||||
managementConnString = "Disconnected"
|
||||
if overview.ManagementState.Error != "" {
|
||||
managementConnString = fmt.Sprintf("%s, reason: %s", managementConnString, overview.ManagementState.Error)
|
||||
if o.ManagementState.Error != "" {
|
||||
managementConnString = fmt.Sprintf("%s, reason: %s", managementConnString, o.ManagementState.Error)
|
||||
}
|
||||
}
|
||||
|
||||
var signalConnString string
|
||||
if overview.SignalState.Connected {
|
||||
if o.SignalState.Connected {
|
||||
signalConnString = "Connected"
|
||||
if showURL {
|
||||
signalConnString = fmt.Sprintf("%s to %s", signalConnString, overview.SignalState.URL)
|
||||
signalConnString = fmt.Sprintf("%s to %s", signalConnString, o.SignalState.URL)
|
||||
}
|
||||
} else {
|
||||
signalConnString = "Disconnected"
|
||||
if overview.SignalState.Error != "" {
|
||||
signalConnString = fmt.Sprintf("%s, reason: %s", signalConnString, overview.SignalState.Error)
|
||||
if o.SignalState.Error != "" {
|
||||
signalConnString = fmt.Sprintf("%s, reason: %s", signalConnString, o.SignalState.Error)
|
||||
}
|
||||
}
|
||||
|
||||
interfaceTypeString := "Userspace"
|
||||
interfaceIP := overview.IP
|
||||
if overview.KernelInterface {
|
||||
interfaceIP := o.IP
|
||||
if o.KernelInterface {
|
||||
interfaceTypeString = "Kernel"
|
||||
} else if overview.IP == "" {
|
||||
} else if o.IP == "" {
|
||||
interfaceTypeString = "N/A"
|
||||
interfaceIP = "N/A"
|
||||
}
|
||||
|
||||
var relaysString string
|
||||
if showRelays {
|
||||
for _, relay := range overview.Relays.Details {
|
||||
for _, relay := range o.Relays.Details {
|
||||
available := "Available"
|
||||
reason := ""
|
||||
|
||||
@@ -395,18 +398,18 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
|
||||
relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason)
|
||||
}
|
||||
} else {
|
||||
relaysString = fmt.Sprintf("%d/%d Available", overview.Relays.Available, overview.Relays.Total)
|
||||
relaysString = fmt.Sprintf("%d/%d Available", o.Relays.Available, o.Relays.Total)
|
||||
}
|
||||
|
||||
networks := "-"
|
||||
if len(overview.Networks) > 0 {
|
||||
sort.Strings(overview.Networks)
|
||||
networks = strings.Join(overview.Networks, ", ")
|
||||
if len(o.Networks) > 0 {
|
||||
sort.Strings(o.Networks)
|
||||
networks = strings.Join(o.Networks, ", ")
|
||||
}
|
||||
|
||||
var dnsServersString string
|
||||
if showNameServers {
|
||||
for _, nsServerGroup := range overview.NSServerGroups {
|
||||
for _, nsServerGroup := range o.NSServerGroups {
|
||||
enabled := "Available"
|
||||
if !nsServerGroup.Enabled {
|
||||
enabled = "Unavailable"
|
||||
@@ -430,25 +433,25 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
dnsServersString = fmt.Sprintf("%d/%d Available", countEnabled(overview.NSServerGroups), len(overview.NSServerGroups))
|
||||
dnsServersString = fmt.Sprintf("%d/%d Available", countEnabled(o.NSServerGroups), len(o.NSServerGroups))
|
||||
}
|
||||
|
||||
rosenpassEnabledStatus := "false"
|
||||
if overview.RosenpassEnabled {
|
||||
if o.RosenpassEnabled {
|
||||
rosenpassEnabledStatus = "true"
|
||||
if overview.RosenpassPermissive {
|
||||
if o.RosenpassPermissive {
|
||||
rosenpassEnabledStatus = "true (permissive)" //nolint:gosec
|
||||
}
|
||||
}
|
||||
|
||||
lazyConnectionEnabledStatus := "false"
|
||||
if overview.LazyConnectionEnabled {
|
||||
if o.LazyConnectionEnabled {
|
||||
lazyConnectionEnabledStatus = "true"
|
||||
}
|
||||
|
||||
sshServerStatus := "Disabled"
|
||||
if overview.SSHServerState.Enabled {
|
||||
sessionCount := len(overview.SSHServerState.Sessions)
|
||||
if o.SSHServerState.Enabled {
|
||||
sessionCount := len(o.SSHServerState.Sessions)
|
||||
if sessionCount > 0 {
|
||||
sessionWord := "session"
|
||||
if sessionCount > 1 {
|
||||
@@ -460,7 +463,7 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
|
||||
}
|
||||
|
||||
if showSSHSessions && sessionCount > 0 {
|
||||
for _, session := range overview.SSHServerState.Sessions {
|
||||
for _, session := range o.SSHServerState.Sessions {
|
||||
var sessionDisplay string
|
||||
if session.JWTUsername != "" {
|
||||
sessionDisplay = fmt.Sprintf("[%s@%s -> %s] %s",
|
||||
@@ -484,7 +487,7 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
|
||||
}
|
||||
}
|
||||
|
||||
peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total)
|
||||
peersCountString := fmt.Sprintf("%d/%d Connected", o.Peers.Connected, o.Peers.Total)
|
||||
|
||||
goos := runtime.GOOS
|
||||
goarch := runtime.GOARCH
|
||||
@@ -512,30 +515,31 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
|
||||
"Forwarding rules: %d\n"+
|
||||
"Peers count: %s\n",
|
||||
fmt.Sprintf("%s/%s%s", goos, goarch, goarm),
|
||||
overview.DaemonVersion,
|
||||
o.DaemonVersion,
|
||||
version.NetbirdVersion(),
|
||||
overview.ProfileName,
|
||||
o.ProfileName,
|
||||
managementConnString,
|
||||
signalConnString,
|
||||
relaysString,
|
||||
dnsServersString,
|
||||
domain.Domain(overview.FQDN).SafeString(),
|
||||
domain.Domain(o.FQDN).SafeString(),
|
||||
interfaceIP,
|
||||
interfaceTypeString,
|
||||
rosenpassEnabledStatus,
|
||||
lazyConnectionEnabledStatus,
|
||||
sshServerStatus,
|
||||
networks,
|
||||
overview.NumberOfForwardingRules,
|
||||
o.NumberOfForwardingRules,
|
||||
peersCountString,
|
||||
)
|
||||
return summary
|
||||
}
|
||||
|
||||
func ParseToFullDetailSummary(overview OutputOverview) string {
|
||||
parsedPeersString := parsePeers(overview.Peers, overview.RosenpassEnabled, overview.RosenpassPermissive)
|
||||
parsedEventsString := parseEvents(overview.Events)
|
||||
summary := ParseGeneralSummary(overview, true, true, true, true)
|
||||
// FullDetailSummary returns a full detailed summary with peer details and events.
|
||||
func (o *OutputOverview) FullDetailSummary() string {
|
||||
parsedPeersString := parsePeers(o.Peers, o.RosenpassEnabled, o.RosenpassPermissive)
|
||||
parsedEventsString := parseEvents(o.Events)
|
||||
summary := o.GeneralSummary(true, true, true, true)
|
||||
|
||||
return fmt.Sprintf(
|
||||
"Peers detail:"+
|
||||
|
||||
@@ -268,7 +268,7 @@ func TestSortingOfPeers(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestParsingToJSON(t *testing.T) {
|
||||
jsonString, _ := ParseToJSON(overview)
|
||||
jsonString, _ := overview.JSON()
|
||||
|
||||
//@formatter:off
|
||||
expectedJSONString := `
|
||||
@@ -404,7 +404,7 @@ func TestParsingToJSON(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestParsingToYAML(t *testing.T) {
|
||||
yaml, _ := ParseToYAML(overview)
|
||||
yaml, _ := overview.YAML()
|
||||
|
||||
expectedYAML :=
|
||||
`peers:
|
||||
@@ -511,7 +511,7 @@ func TestParsingToDetail(t *testing.T) {
|
||||
lastConnectionUpdate2 := timeAgo(overview.Peers.Details[1].LastStatusUpdate)
|
||||
lastHandshake2 := timeAgo(overview.Peers.Details[1].LastWireguardHandshake)
|
||||
|
||||
detail := ParseToFullDetailSummary(overview)
|
||||
detail := overview.FullDetailSummary()
|
||||
|
||||
expectedDetail := fmt.Sprintf(
|
||||
`Peers detail:
|
||||
@@ -575,7 +575,7 @@ Peers count: 2/2 Connected
|
||||
}
|
||||
|
||||
func TestParsingToShortVersion(t *testing.T) {
|
||||
shortVersion := ParseGeneralSummary(overview, false, false, false, false)
|
||||
shortVersion := overview.GeneralSummary(false, false, false, false)
|
||||
|
||||
expectedString := fmt.Sprintf("OS: %s/%s", runtime.GOOS, runtime.GOARCH) + `
|
||||
Daemon version: 0.14.1
|
||||
|
||||
22
client/system/disk_encryption.go
Normal file
22
client/system/disk_encryption.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package system
|
||||
|
||||
// DiskEncryptionVolume represents encryption status of a single volume.
|
||||
type DiskEncryptionVolume struct {
|
||||
Path string
|
||||
Encrypted bool
|
||||
}
|
||||
|
||||
// DiskEncryptionInfo holds disk encryption detection results.
|
||||
type DiskEncryptionInfo struct {
|
||||
Volumes []DiskEncryptionVolume
|
||||
}
|
||||
|
||||
// IsEncrypted returns true if the volume at the given path is encrypted.
|
||||
func (d DiskEncryptionInfo) IsEncrypted(path string) bool {
|
||||
for _, v := range d.Volumes {
|
||||
if v.Path == path {
|
||||
return v.Encrypted
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
35
client/system/disk_encryption_darwin.go
Normal file
35
client/system/disk_encryption_darwin.go
Normal file
@@ -0,0 +1,35 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package system
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// detectDiskEncryption detects FileVault encryption status on macOS.
|
||||
func detectDiskEncryption(ctx context.Context) DiskEncryptionInfo {
|
||||
info := DiskEncryptionInfo{}
|
||||
|
||||
cmdCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(cmdCtx, "fdesetup", "status")
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
log.Debugf("execute fdesetup: %v", err)
|
||||
return info
|
||||
}
|
||||
|
||||
encrypted := strings.Contains(string(output), "FileVault is On")
|
||||
info.Volumes = append(info.Volumes, DiskEncryptionVolume{
|
||||
Path: "/",
|
||||
Encrypted: encrypted,
|
||||
})
|
||||
|
||||
return info
|
||||
}
|
||||
98
client/system/disk_encryption_linux.go
Normal file
98
client/system/disk_encryption_linux.go
Normal file
@@ -0,0 +1,98 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package system
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// detectDiskEncryption detects LUKS encryption status on Linux by reading sysfs.
|
||||
func detectDiskEncryption(ctx context.Context) DiskEncryptionInfo {
|
||||
info := DiskEncryptionInfo{}
|
||||
|
||||
encryptedDevices := findEncryptedDevices()
|
||||
mountPoints := parseMounts(encryptedDevices)
|
||||
|
||||
info.Volumes = mountPoints
|
||||
return info
|
||||
}
|
||||
|
||||
// findEncryptedDevices scans /sys/block for dm-crypt (LUKS) encrypted devices.
|
||||
func findEncryptedDevices() map[string]bool {
|
||||
encryptedDevices := make(map[string]bool)
|
||||
|
||||
sysBlock := "/sys/block"
|
||||
entries, err := os.ReadDir(sysBlock)
|
||||
if err != nil {
|
||||
log.Debugf("read /sys/block: %v", err)
|
||||
return encryptedDevices
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
dmUuidPath := filepath.Join(sysBlock, entry.Name(), "dm", "uuid")
|
||||
data, err := os.ReadFile(dmUuidPath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
uuid := strings.TrimSpace(string(data))
|
||||
if strings.HasPrefix(uuid, "CRYPT-") {
|
||||
dmNamePath := filepath.Join(sysBlock, entry.Name(), "dm", "name")
|
||||
if nameData, err := os.ReadFile(dmNamePath); err == nil {
|
||||
dmName := strings.TrimSpace(string(nameData))
|
||||
encryptedDevices["/dev/mapper/"+dmName] = true
|
||||
}
|
||||
encryptedDevices["/dev/"+entry.Name()] = true
|
||||
}
|
||||
}
|
||||
|
||||
return encryptedDevices
|
||||
}
|
||||
|
||||
// parseMounts reads /proc/mounts and maps devices to mount points with encryption status.
|
||||
func parseMounts(encryptedDevices map[string]bool) []DiskEncryptionVolume {
|
||||
var volumes []DiskEncryptionVolume
|
||||
|
||||
mountsFile, err := os.Open("/proc/mounts")
|
||||
if err != nil {
|
||||
log.Debugf("open /proc/mounts: %v", err)
|
||||
return volumes
|
||||
}
|
||||
defer func() {
|
||||
if err := mountsFile.Close(); err != nil {
|
||||
log.Debugf("close /proc/mounts: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
scanner := bufio.NewScanner(mountsFile)
|
||||
for scanner.Scan() {
|
||||
fields := strings.Fields(scanner.Text())
|
||||
if len(fields) < 2 {
|
||||
continue
|
||||
}
|
||||
device, mountPoint := fields[0], fields[1]
|
||||
|
||||
encrypted := encryptedDevices[device]
|
||||
|
||||
if !encrypted && strings.HasPrefix(device, "/dev/mapper/") {
|
||||
for encDev := range encryptedDevices {
|
||||
if device == encDev {
|
||||
encrypted = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
volumes = append(volumes, DiskEncryptionVolume{
|
||||
Path: mountPoint,
|
||||
Encrypted: encrypted,
|
||||
})
|
||||
}
|
||||
|
||||
return volumes
|
||||
}
|
||||
10
client/system/disk_encryption_stub.go
Normal file
10
client/system/disk_encryption_stub.go
Normal file
@@ -0,0 +1,10 @@
|
||||
//go:build android || ios || freebsd || js
|
||||
|
||||
package system
|
||||
|
||||
import "context"
|
||||
|
||||
// detectDiskEncryption is a stub for unsupported platforms.
|
||||
func detectDiskEncryption(_ context.Context) DiskEncryptionInfo {
|
||||
return DiskEncryptionInfo{}
|
||||
}
|
||||
41
client/system/disk_encryption_windows.go
Normal file
41
client/system/disk_encryption_windows.go
Normal file
@@ -0,0 +1,41 @@
|
||||
//go:build windows
|
||||
|
||||
package system
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/yusufpapurcu/wmi"
|
||||
)
|
||||
|
||||
// Win32EncryptableVolume represents the WMI class for BitLocker status.
|
||||
type Win32EncryptableVolume struct {
|
||||
DriveLetter string
|
||||
ProtectionStatus uint32
|
||||
}
|
||||
|
||||
// detectDiskEncryption detects BitLocker encryption status on Windows via WMI.
|
||||
func detectDiskEncryption(_ context.Context) DiskEncryptionInfo {
|
||||
info := DiskEncryptionInfo{}
|
||||
|
||||
var volumes []Win32EncryptableVolume
|
||||
query := "SELECT DriveLetter, ProtectionStatus FROM Win32_EncryptableVolume"
|
||||
|
||||
err := wmi.QueryNamespace(query, &volumes, `root\CIMV2\Security\MicrosoftVolumeEncryption`)
|
||||
if err != nil {
|
||||
log.Debugf("query BitLocker status: %v", err)
|
||||
return info
|
||||
}
|
||||
|
||||
for _, vol := range volumes {
|
||||
driveLetter := strings.TrimSuffix(vol.DriveLetter, "\\")
|
||||
info.Volumes = append(info.Volumes, DiskEncryptionVolume{
|
||||
Path: driveLetter,
|
||||
Encrypted: vol.ProtectionStatus == 1,
|
||||
})
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
@@ -59,6 +59,7 @@ type Info struct {
|
||||
SystemManufacturer string
|
||||
Environment Environment
|
||||
Files []File // for posture checks
|
||||
DiskEncryption DiskEncryptionInfo
|
||||
|
||||
RosenpassEnabled bool
|
||||
RosenpassPermissive bool
|
||||
|
||||
@@ -44,6 +44,7 @@ func GetInfo(ctx context.Context) *Info {
|
||||
SystemSerialNumber: serial(),
|
||||
SystemProductName: productModel(),
|
||||
SystemManufacturer: productManufacturer(),
|
||||
DiskEncryption: detectDiskEncryption(ctx),
|
||||
}
|
||||
|
||||
return gio
|
||||
|
||||
@@ -62,6 +62,7 @@ func GetInfo(ctx context.Context) *Info {
|
||||
SystemProductName: si.SystemProductName,
|
||||
SystemManufacturer: si.SystemManufacturer,
|
||||
Environment: si.Environment,
|
||||
DiskEncryption: detectDiskEncryption(ctx),
|
||||
}
|
||||
|
||||
systemHostname, _ := os.Hostname()
|
||||
|
||||
@@ -55,6 +55,7 @@ func GetInfo(ctx context.Context) *Info {
|
||||
UIVersion: extractUserAgent(ctx),
|
||||
KernelVersion: osInfo[1],
|
||||
Environment: env,
|
||||
DiskEncryption: detectDiskEncryption(ctx),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ func GetInfo(ctx context.Context) *Info {
|
||||
sysName := extractOsName(ctx, "sysName")
|
||||
swVersion := extractOsVersion(ctx, "swVersion")
|
||||
|
||||
gio := &Info{Kernel: sysName, OSVersion: swVersion, Platform: "unknown", OS: sysName, GoOS: runtime.GOOS, CPUs: runtime.NumCPU(), KernelVersion: swVersion}
|
||||
gio := &Info{Kernel: sysName, OSVersion: swVersion, Platform: "unknown", OS: sysName, GoOS: runtime.GOOS, CPUs: runtime.NumCPU(), KernelVersion: swVersion, DiskEncryption: detectDiskEncryption(ctx)}
|
||||
gio.Hostname = extractDeviceName(ctx, "hostname")
|
||||
gio.NetbirdVersion = version.NetbirdVersion()
|
||||
gio.UIVersion = extractUserAgent(ctx)
|
||||
|
||||
@@ -15,7 +15,7 @@ func UpdateStaticInfoAsync() {
|
||||
}
|
||||
|
||||
// GetInfo retrieves system information for WASM environment
|
||||
func GetInfo(_ context.Context) *Info {
|
||||
func GetInfo(ctx context.Context) *Info {
|
||||
info := &Info{
|
||||
GoOS: runtime.GOOS,
|
||||
Kernel: runtime.GOARCH,
|
||||
@@ -25,6 +25,7 @@ func GetInfo(_ context.Context) *Info {
|
||||
Hostname: "wasm-client",
|
||||
CPUs: runtime.NumCPU(),
|
||||
NetbirdVersion: version.NetbirdVersion(),
|
||||
DiskEncryption: detectDiskEncryption(ctx),
|
||||
}
|
||||
|
||||
collectBrowserInfo(info)
|
||||
|
||||
@@ -73,6 +73,7 @@ func GetInfo(ctx context.Context) *Info {
|
||||
SystemProductName: si.SystemProductName,
|
||||
SystemManufacturer: si.SystemManufacturer,
|
||||
Environment: si.Environment,
|
||||
DiskEncryption: detectDiskEncryption(ctx),
|
||||
}
|
||||
|
||||
return gio
|
||||
|
||||
@@ -35,6 +35,7 @@ func GetInfo(ctx context.Context) *Info {
|
||||
SystemProductName: si.SystemProductName,
|
||||
SystemManufacturer: si.SystemManufacturer,
|
||||
Environment: si.Environment,
|
||||
DiskEncryption: detectDiskEncryption(ctx),
|
||||
}
|
||||
|
||||
addrs, err := networkAddresses()
|
||||
|
||||
@@ -441,7 +441,7 @@ func (s *serviceClient) collectDebugData(
|
||||
var postUpStatusOutput string
|
||||
if postUpStatus != nil {
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", profName)
|
||||
postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
|
||||
postUpStatusOutput = overview.FullDetailSummary()
|
||||
}
|
||||
headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
|
||||
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, postUpStatusOutput)
|
||||
@@ -458,7 +458,7 @@ func (s *serviceClient) collectDebugData(
|
||||
var preDownStatusOutput string
|
||||
if preDownStatus != nil {
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", profName)
|
||||
preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
|
||||
preDownStatusOutput = overview.FullDetailSummary()
|
||||
}
|
||||
headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s",
|
||||
time.Now().Format(time.RFC3339), params.duration)
|
||||
@@ -595,7 +595,7 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa
|
||||
var statusOutput string
|
||||
if statusResp != nil {
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", profName)
|
||||
statusOutput = nbstatus.ParseToFullDetailSummary(overview)
|
||||
statusOutput = overview.FullDetailSummary()
|
||||
}
|
||||
|
||||
request := &proto.DebugBundleRequest{
|
||||
|
||||
@@ -9,20 +9,29 @@ import (
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
netbird "github.com/netbirdio/netbird/client/embed"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
sshdetection "github.com/netbirdio/netbird/client/ssh/detection"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
"github.com/netbirdio/netbird/client/wasm/internal/http"
|
||||
"github.com/netbirdio/netbird/client/wasm/internal/rdp"
|
||||
"github.com/netbirdio/netbird/client/wasm/internal/ssh"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
const (
|
||||
clientStartTimeout = 30 * time.Second
|
||||
clientStopTimeout = 10 * time.Second
|
||||
pingTimeout = 10 * time.Second
|
||||
defaultLogLevel = "warn"
|
||||
defaultSSHDetectionTimeout = 20 * time.Second
|
||||
|
||||
icmpEchoRequest = 8
|
||||
icmpCodeEcho = 0
|
||||
pingBufferSize = 1500
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -113,18 +122,45 @@ func createStopMethod(client *netbird.Client) js.Func {
|
||||
})
|
||||
}
|
||||
|
||||
// validateSSHArgs validates SSH connection arguments
|
||||
func validateSSHArgs(args []js.Value) (host string, port int, username string, err js.Value) {
|
||||
if len(args) < 2 {
|
||||
return "", 0, "", js.ValueOf("error: requires host and port")
|
||||
}
|
||||
|
||||
if args[0].Type() != js.TypeString {
|
||||
return "", 0, "", js.ValueOf("host parameter must be a string")
|
||||
}
|
||||
if args[1].Type() != js.TypeNumber {
|
||||
return "", 0, "", js.ValueOf("port parameter must be a number")
|
||||
}
|
||||
|
||||
host = args[0].String()
|
||||
port = args[1].Int()
|
||||
username = "root"
|
||||
|
||||
if len(args) > 2 {
|
||||
if args[2].Type() == js.TypeString && args[2].String() != "" {
|
||||
username = args[2].String()
|
||||
} else if args[2].Type() != js.TypeString {
|
||||
return "", 0, "", js.ValueOf("username parameter must be a string")
|
||||
}
|
||||
}
|
||||
|
||||
return host, port, username, js.Undefined()
|
||||
}
|
||||
|
||||
// createSSHMethod creates the SSH connection method
|
||||
func createSSHMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
if len(args) < 2 {
|
||||
return js.ValueOf("error: requires host and port")
|
||||
}
|
||||
|
||||
host := args[0].String()
|
||||
port := args[1].Int()
|
||||
username := "root"
|
||||
if len(args) > 2 && args[2].String() != "" {
|
||||
username = args[2].String()
|
||||
host, port, username, validationErr := validateSSHArgs(args)
|
||||
if !validationErr.IsUndefined() {
|
||||
if validationErr.Type() == js.TypeString && validationErr.String() == "error: requires host and port" {
|
||||
return validationErr
|
||||
}
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
reject.Invoke(validationErr)
|
||||
})
|
||||
}
|
||||
|
||||
var jwtToken string
|
||||
@@ -154,6 +190,110 @@ func createSSHMethod(client *netbird.Client) js.Func {
|
||||
})
|
||||
}
|
||||
|
||||
func performPing(client *netbird.Client, hostname string) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
|
||||
defer cancel()
|
||||
|
||||
start := time.Now()
|
||||
conn, err := client.Dial(ctx, "ping", hostname)
|
||||
if err != nil {
|
||||
js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s failed: %v", hostname, err))
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Debugf("failed to close ping connection: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
icmpData := make([]byte, 8)
|
||||
icmpData[0] = icmpEchoRequest
|
||||
icmpData[1] = icmpCodeEcho
|
||||
|
||||
if _, err := conn.Write(icmpData); err != nil {
|
||||
js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s write failed: %v", hostname, err))
|
||||
return
|
||||
}
|
||||
|
||||
buf := make([]byte, pingBufferSize)
|
||||
if _, err := conn.Read(buf); err != nil {
|
||||
js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s read failed: %v", hostname, err))
|
||||
return
|
||||
}
|
||||
|
||||
latency := time.Since(start)
|
||||
js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s: %dms", hostname, latency.Milliseconds()))
|
||||
}
|
||||
|
||||
func performPingTCP(client *netbird.Client, hostname string, port int) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
|
||||
defer cancel()
|
||||
|
||||
address := fmt.Sprintf("%s:%d", hostname, port)
|
||||
start := time.Now()
|
||||
conn, err := client.Dial(ctx, "tcp", address)
|
||||
if err != nil {
|
||||
js.Global().Get("console").Call("log", fmt.Sprintf("TCP ping to %s failed: %v", address, err))
|
||||
return
|
||||
}
|
||||
latency := time.Since(start)
|
||||
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Debugf("failed to close TCP connection: %v", err)
|
||||
}
|
||||
|
||||
js.Global().Get("console").Call("log", fmt.Sprintf("TCP ping to %s succeeded: %dms", address, latency.Milliseconds()))
|
||||
}
|
||||
|
||||
// createPingMethod creates the ping method
|
||||
func createPingMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
if len(args) < 1 {
|
||||
return js.ValueOf("error: hostname required")
|
||||
}
|
||||
|
||||
if args[0].Type() != js.TypeString {
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
reject.Invoke(js.ValueOf("hostname parameter must be a string"))
|
||||
})
|
||||
}
|
||||
|
||||
hostname := args[0].String()
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
performPing(client, hostname)
|
||||
resolve.Invoke(js.Undefined())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// createPingTCPMethod creates the pingtcp method
|
||||
func createPingTCPMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
if len(args) < 2 {
|
||||
return js.ValueOf("error: hostname and port required")
|
||||
}
|
||||
|
||||
if args[0].Type() != js.TypeString {
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
reject.Invoke(js.ValueOf("hostname parameter must be a string"))
|
||||
})
|
||||
}
|
||||
|
||||
if args[1].Type() != js.TypeNumber {
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
reject.Invoke(js.ValueOf("port parameter must be a number"))
|
||||
})
|
||||
}
|
||||
|
||||
hostname := args[0].String()
|
||||
port := args[1].Int()
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
performPingTCP(client, hostname, port)
|
||||
resolve.Invoke(js.Undefined())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// createProxyRequestMethod creates the proxyRequest method
|
||||
func createProxyRequestMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
@@ -162,6 +302,11 @@ func createProxyRequestMethod(client *netbird.Client) js.Func {
|
||||
}
|
||||
|
||||
request := args[0]
|
||||
if request.Type() != js.TypeObject {
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
reject.Invoke(js.ValueOf("request parameter must be an object"))
|
||||
})
|
||||
}
|
||||
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
response, err := http.ProxyRequest(client, request)
|
||||
@@ -181,11 +326,145 @@ func createRDPProxyMethod(client *netbird.Client) js.Func {
|
||||
return js.ValueOf("error: hostname and port required")
|
||||
}
|
||||
|
||||
if args[0].Type() != js.TypeString {
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
reject.Invoke(js.ValueOf("hostname parameter must be a string"))
|
||||
})
|
||||
}
|
||||
if args[1].Type() != js.TypeString {
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
reject.Invoke(js.ValueOf("port parameter must be a string"))
|
||||
})
|
||||
}
|
||||
|
||||
proxy := rdp.NewRDCleanPathProxy(client)
|
||||
return proxy.CreateProxy(args[0].String(), args[1].String())
|
||||
})
|
||||
}
|
||||
|
||||
// getStatusOverview is a helper to get the status overview
|
||||
func getStatusOverview(client *netbird.Client) (nbstatus.OutputOverview, error) {
|
||||
fullStatus, err := client.Status()
|
||||
if err != nil {
|
||||
return nbstatus.OutputOverview{}, err
|
||||
}
|
||||
|
||||
pbFullStatus := fullStatus.ToProto()
|
||||
statusResp := &proto.StatusResponse{
|
||||
DaemonVersion: version.NetbirdVersion(),
|
||||
FullStatus: pbFullStatus,
|
||||
}
|
||||
|
||||
return nbstatus.ConvertToStatusOutputOverview(statusResp, false, "", nil, nil, nil, "", ""), nil
|
||||
}
|
||||
|
||||
// createStatusMethod creates the status method that returns JSON
|
||||
func createStatusMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
overview, err := getStatusOverview(client)
|
||||
if err != nil {
|
||||
reject.Invoke(js.ValueOf(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
jsonStr, err := overview.JSON()
|
||||
if err != nil {
|
||||
reject.Invoke(js.ValueOf(err.Error()))
|
||||
return
|
||||
}
|
||||
jsonObj := js.Global().Get("JSON").Call("parse", jsonStr)
|
||||
resolve.Invoke(jsonObj)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// createStatusSummaryMethod creates the statusSummary method
|
||||
func createStatusSummaryMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
overview, err := getStatusOverview(client)
|
||||
if err != nil {
|
||||
reject.Invoke(js.ValueOf(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
summary := overview.GeneralSummary(false, false, false, false)
|
||||
js.Global().Get("console").Call("log", summary)
|
||||
resolve.Invoke(js.Undefined())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// createStatusDetailMethod creates the statusDetail method
|
||||
func createStatusDetailMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
overview, err := getStatusOverview(client)
|
||||
if err != nil {
|
||||
reject.Invoke(js.ValueOf(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
detail := overview.FullDetailSummary()
|
||||
js.Global().Get("console").Call("log", detail)
|
||||
resolve.Invoke(js.Undefined())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// createGetSyncResponseMethod creates the getSyncResponse method that returns the latest sync response as JSON
|
||||
func createGetSyncResponseMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
syncResp, err := client.GetLatestSyncResponse()
|
||||
if err != nil {
|
||||
reject.Invoke(js.ValueOf(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
options := protojson.MarshalOptions{
|
||||
EmitUnpopulated: true,
|
||||
UseProtoNames: true,
|
||||
AllowPartial: true,
|
||||
}
|
||||
jsonBytes, err := options.Marshal(syncResp)
|
||||
if err != nil {
|
||||
reject.Invoke(js.ValueOf(fmt.Sprintf("marshal sync response: %v", err)))
|
||||
return
|
||||
}
|
||||
|
||||
jsonObj := js.Global().Get("JSON").Call("parse", string(jsonBytes))
|
||||
resolve.Invoke(jsonObj)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// createSetLogLevelMethod creates the setLogLevel method to dynamically change logging level
|
||||
func createSetLogLevelMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
if len(args) < 1 {
|
||||
return js.ValueOf("error: log level required")
|
||||
}
|
||||
|
||||
if args[0].Type() != js.TypeString {
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
reject.Invoke(js.ValueOf("log level parameter must be a string"))
|
||||
})
|
||||
}
|
||||
|
||||
logLevel := args[0].String()
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
if err := client.SetLogLevel(logLevel); err != nil {
|
||||
reject.Invoke(js.ValueOf(fmt.Sprintf("set log level: %v", err)))
|
||||
return
|
||||
}
|
||||
log.Infof("Log level set to: %s", logLevel)
|
||||
resolve.Invoke(js.ValueOf(true))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// createPromise is a helper to create JavaScript promises
|
||||
func createPromise(handler func(resolve, reject js.Value)) js.Value {
|
||||
return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any {
|
||||
@@ -237,17 +516,24 @@ func createClientObject(client *netbird.Client) js.Value {
|
||||
|
||||
obj["start"] = createStartMethod(client)
|
||||
obj["stop"] = createStopMethod(client)
|
||||
obj["ping"] = createPingMethod(client)
|
||||
obj["pingtcp"] = createPingTCPMethod(client)
|
||||
obj["detectSSHServerType"] = createDetectSSHServerMethod(client)
|
||||
obj["createSSHConnection"] = createSSHMethod(client)
|
||||
obj["proxyRequest"] = createProxyRequestMethod(client)
|
||||
obj["createRDPProxy"] = createRDPProxyMethod(client)
|
||||
obj["status"] = createStatusMethod(client)
|
||||
obj["statusSummary"] = createStatusSummaryMethod(client)
|
||||
obj["statusDetail"] = createStatusDetailMethod(client)
|
||||
obj["getSyncResponse"] = createGetSyncResponseMethod(client)
|
||||
obj["setLogLevel"] = createSetLogLevelMethod(client)
|
||||
|
||||
return js.ValueOf(obj)
|
||||
}
|
||||
|
||||
// netBirdClientConstructor acts as a JavaScript constructor function
|
||||
func netBirdClientConstructor(this js.Value, args []js.Value) any {
|
||||
return js.Global().Get("Promise").New(js.FuncOf(func(this js.Value, promiseArgs []js.Value) any {
|
||||
func netBirdClientConstructor(_ js.Value, args []js.Value) any {
|
||||
return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any {
|
||||
resolve := promiseArgs[0]
|
||||
reject := promiseArgs[1]
|
||||
|
||||
|
||||
@@ -47,8 +47,8 @@ type CustomZone struct {
|
||||
Records []SimpleRecord
|
||||
// SearchDomainDisabled indicates whether to add match domains to a search domains list or not
|
||||
SearchDomainDisabled bool
|
||||
// SkipPTRProcess indicates whether a client should process PTR records from custom zones
|
||||
SkipPTRProcess bool
|
||||
// NonAuthoritative marks user-created zones
|
||||
NonAuthoritative bool
|
||||
}
|
||||
|
||||
// SimpleRecord provides a simple DNS record specification for CNAME, A and AAAA records
|
||||
|
||||
8
go.mod
8
go.mod
@@ -78,8 +78,8 @@ require (
|
||||
github.com/pion/logging v0.2.4
|
||||
github.com/pion/randutil v0.1.0
|
||||
github.com/pion/stun/v2 v2.0.0
|
||||
github.com/pion/stun/v3 v3.0.0
|
||||
github.com/pion/transport/v3 v3.0.7
|
||||
github.com/pion/stun/v3 v3.1.0
|
||||
github.com/pion/transport/v3 v3.1.1
|
||||
github.com/pion/turn/v3 v3.0.1
|
||||
github.com/pkg/sftp v1.13.9
|
||||
github.com/prometheus/client_golang v1.23.2
|
||||
@@ -241,7 +241,7 @@ require (
|
||||
github.com/opencontainers/go-digest v1.0.0 // indirect
|
||||
github.com/opencontainers/image-spec v1.1.0 // indirect
|
||||
github.com/pion/dtls/v2 v2.2.10 // indirect
|
||||
github.com/pion/dtls/v3 v3.0.7 // indirect
|
||||
github.com/pion/dtls/v3 v3.0.9 // indirect
|
||||
github.com/pion/mdns/v2 v2.0.7 // indirect
|
||||
github.com/pion/transport/v2 v2.2.4 // indirect
|
||||
github.com/pion/turn/v4 v4.1.1 // indirect
|
||||
@@ -263,7 +263,7 @@ require (
|
||||
github.com/tklauser/numcpus v0.8.0 // indirect
|
||||
github.com/vishvananda/netns v0.0.5 // indirect
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
|
||||
github.com/wlynxg/anet v0.0.3 // indirect
|
||||
github.com/wlynxg/anet v0.0.5 // indirect
|
||||
github.com/yuin/goldmark v1.7.8 // indirect
|
||||
github.com/zeebo/blake3 v0.2.3 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
|
||||
|
||||
16
go.sum
16
go.sum
@@ -444,8 +444,8 @@ github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203/go.mod h1:pxMtw7c
|
||||
github.com/pion/dtls/v2 v2.2.7/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ/F9s=
|
||||
github.com/pion/dtls/v2 v2.2.10 h1:u2Axk+FyIR1VFTPurktB+1zoEPGIW3bmyj3LEFrXjAA=
|
||||
github.com/pion/dtls/v2 v2.2.10/go.mod h1:d9SYc9fch0CqK90mRk1dC7AkzzpwJj6u2GU3u+9pqFE=
|
||||
github.com/pion/dtls/v3 v3.0.7 h1:bItXtTYYhZwkPFk4t1n3Kkf5TDrfj6+4wG+CZR8uI9Q=
|
||||
github.com/pion/dtls/v3 v3.0.7/go.mod h1:uDlH5VPrgOQIw59irKYkMudSFprY9IEFCqz/eTz16f8=
|
||||
github.com/pion/dtls/v3 v3.0.9 h1:4AijfFRm8mAjd1gfdlB1wzJF3fjjR/VPIpJgkEtvYmM=
|
||||
github.com/pion/dtls/v3 v3.0.9/go.mod h1:abApPjgadS/ra1wvUzHLc3o2HvoxppAh+NZkyApL4Os=
|
||||
github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms=
|
||||
github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8=
|
||||
github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so=
|
||||
@@ -455,14 +455,14 @@ github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA=
|
||||
github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8=
|
||||
github.com/pion/stun/v2 v2.0.0 h1:A5+wXKLAypxQri59+tmQKVs7+l6mMM+3d+eER9ifRU0=
|
||||
github.com/pion/stun/v2 v2.0.0/go.mod h1:22qRSh08fSEttYUmJZGlriq9+03jtVmXNODgLccj8GQ=
|
||||
github.com/pion/stun/v3 v3.0.0 h1:4h1gwhWLWuZWOJIJR9s2ferRO+W3zA/b6ijOI6mKzUw=
|
||||
github.com/pion/stun/v3 v3.0.0/go.mod h1:HvCN8txt8mwi4FBvS3EmDghW6aQJ24T+y+1TKjB5jyU=
|
||||
github.com/pion/stun/v3 v3.1.0 h1:bS1jjT3tGWZ4UPmIUeyalOylamTMTFg1OvXtY/r6seM=
|
||||
github.com/pion/stun/v3 v3.1.0/go.mod h1:egmx1CUcfSSGJxQCOjtVlomfPqmQ58BibPyuOWNGQEU=
|
||||
github.com/pion/transport/v2 v2.2.1/go.mod h1:cXXWavvCnFF6McHTft3DWS9iic2Mftcz1Aq29pGcU5g=
|
||||
github.com/pion/transport/v2 v2.2.4 h1:41JJK6DZQYSeVLxILA2+F4ZkKb4Xd/tFJZRFZQ9QAlo=
|
||||
github.com/pion/transport/v2 v2.2.4/go.mod h1:q2U/tf9FEfnSBGSW6w5Qp5PFWRLRj3NjLhCCgpRK4p0=
|
||||
github.com/pion/transport/v3 v3.0.1/go.mod h1:UY7kiITrlMv7/IKgd5eTUcaahZx5oUN3l9SzK5f5xE0=
|
||||
github.com/pion/transport/v3 v3.0.7 h1:iRbMH05BzSNwhILHoBoAPxoB9xQgOaJk+591KC9P1o0=
|
||||
github.com/pion/transport/v3 v3.0.7/go.mod h1:YleKiTZ4vqNxVwh77Z0zytYi7rXHl7j6uPLGhhz9rwo=
|
||||
github.com/pion/transport/v3 v3.1.1 h1:Tr684+fnnKlhPceU+ICdrw6KKkTms+5qHMgw6bIkYOM=
|
||||
github.com/pion/transport/v3 v3.1.1/go.mod h1:+c2eewC5WJQHiAA46fkMMzoYZSuGzA/7E2FPrOYHctQ=
|
||||
github.com/pion/turn/v3 v3.0.1 h1:wLi7BTQr6/Q20R0vt/lHbjv6y4GChFtC33nkYbasoT8=
|
||||
github.com/pion/turn/v3 v3.0.1/go.mod h1:MrJDKgqryDyWy1/4NT9TWfXWGMC7UHT6pJIv1+gMeNE=
|
||||
github.com/pion/turn/v4 v4.1.1 h1:9UnY2HB99tpDyz3cVVZguSxcqkJ1DsTSZ+8TGruh4fc=
|
||||
@@ -574,8 +574,8 @@ github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IU
|
||||
github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok=
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
|
||||
github.com/wlynxg/anet v0.0.3 h1:PvR53psxFXstc12jelG6f1Lv4MWqE0tI76/hHGjh9rg=
|
||||
github.com/wlynxg/anet v0.0.3/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA=
|
||||
github.com/wlynxg/anet v0.0.5 h1:J3VJGi1gvo0JwZ/P1/Yc/8p63SoW98B5dHkYDmpgvvU=
|
||||
github.com/wlynxg/anet v0.0.5/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA=
|
||||
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
|
||||
|
||||
113
idp/dex/logrus_handler.go
Normal file
113
idp/dex/logrus_handler.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package dex
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
)
|
||||
|
||||
// LogrusHandler is an slog.Handler that delegates to logrus.
|
||||
// This allows Dex to use the same log format as the rest of NetBird.
|
||||
type LogrusHandler struct {
|
||||
logger *logrus.Logger
|
||||
attrs []slog.Attr
|
||||
groups []string
|
||||
}
|
||||
|
||||
// NewLogrusHandler creates a new slog handler that wraps logrus with NetBird's text formatter.
|
||||
func NewLogrusHandler(level slog.Level) *LogrusHandler {
|
||||
logger := logrus.New()
|
||||
formatter.SetTextFormatter(logger)
|
||||
|
||||
// Map slog level to logrus level
|
||||
switch level {
|
||||
case slog.LevelDebug:
|
||||
logger.SetLevel(logrus.DebugLevel)
|
||||
case slog.LevelInfo:
|
||||
logger.SetLevel(logrus.InfoLevel)
|
||||
case slog.LevelWarn:
|
||||
logger.SetLevel(logrus.WarnLevel)
|
||||
case slog.LevelError:
|
||||
logger.SetLevel(logrus.ErrorLevel)
|
||||
default:
|
||||
logger.SetLevel(logrus.WarnLevel)
|
||||
}
|
||||
|
||||
return &LogrusHandler{logger: logger}
|
||||
}
|
||||
|
||||
// Enabled reports whether the handler handles records at the given level.
|
||||
func (h *LogrusHandler) Enabled(_ context.Context, level slog.Level) bool {
|
||||
switch level {
|
||||
case slog.LevelDebug:
|
||||
return h.logger.IsLevelEnabled(logrus.DebugLevel)
|
||||
case slog.LevelInfo:
|
||||
return h.logger.IsLevelEnabled(logrus.InfoLevel)
|
||||
case slog.LevelWarn:
|
||||
return h.logger.IsLevelEnabled(logrus.WarnLevel)
|
||||
case slog.LevelError:
|
||||
return h.logger.IsLevelEnabled(logrus.ErrorLevel)
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Handle handles the Record.
|
||||
func (h *LogrusHandler) Handle(_ context.Context, r slog.Record) error {
|
||||
fields := make(logrus.Fields)
|
||||
|
||||
// Add pre-set attributes
|
||||
for _, attr := range h.attrs {
|
||||
fields[attr.Key] = attr.Value.Any()
|
||||
}
|
||||
|
||||
// Add record attributes
|
||||
r.Attrs(func(attr slog.Attr) bool {
|
||||
fields[attr.Key] = attr.Value.Any()
|
||||
return true
|
||||
})
|
||||
|
||||
entry := h.logger.WithFields(fields)
|
||||
|
||||
switch r.Level {
|
||||
case slog.LevelDebug:
|
||||
entry.Debug(r.Message)
|
||||
case slog.LevelInfo:
|
||||
entry.Info(r.Message)
|
||||
case slog.LevelWarn:
|
||||
entry.Warn(r.Message)
|
||||
case slog.LevelError:
|
||||
entry.Error(r.Message)
|
||||
default:
|
||||
entry.Info(r.Message)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// WithAttrs returns a new Handler with the given attributes added.
|
||||
func (h *LogrusHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
|
||||
newAttrs := make([]slog.Attr, len(h.attrs)+len(attrs))
|
||||
copy(newAttrs, h.attrs)
|
||||
copy(newAttrs[len(h.attrs):], attrs)
|
||||
return &LogrusHandler{
|
||||
logger: h.logger,
|
||||
attrs: newAttrs,
|
||||
groups: h.groups,
|
||||
}
|
||||
}
|
||||
|
||||
// WithGroup returns a new Handler with the given group appended to the receiver's groups.
|
||||
func (h *LogrusHandler) WithGroup(name string) slog.Handler {
|
||||
newGroups := make([]string, len(h.groups)+1)
|
||||
copy(newGroups, h.groups)
|
||||
newGroups[len(h.groups)] = name
|
||||
return &LogrusHandler{
|
||||
logger: h.logger,
|
||||
attrs: h.attrs,
|
||||
groups: newGroups,
|
||||
}
|
||||
}
|
||||
@@ -130,7 +130,21 @@ func NewProvider(ctx context.Context, config *Config) (*Provider, error) {
|
||||
|
||||
// NewProviderFromYAML creates and initializes the Dex server from a YAMLConfig
|
||||
func NewProviderFromYAML(ctx context.Context, yamlConfig *YAMLConfig) (*Provider, error) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
||||
// Configure log level from config, default to WARN to avoid logging sensitive data (emails)
|
||||
logLevel := slog.LevelWarn
|
||||
if yamlConfig.Logger.Level != "" {
|
||||
switch strings.ToLower(yamlConfig.Logger.Level) {
|
||||
case "debug":
|
||||
logLevel = slog.LevelDebug
|
||||
case "info":
|
||||
logLevel = slog.LevelInfo
|
||||
case "warn", "warning":
|
||||
logLevel = slog.LevelWarn
|
||||
case "error":
|
||||
logLevel = slog.LevelError
|
||||
}
|
||||
}
|
||||
logger := slog.New(NewLogrusHandler(logLevel))
|
||||
|
||||
stor, err := yamlConfig.Storage.OpenStorage(logger)
|
||||
if err != nil {
|
||||
@@ -778,20 +792,24 @@ func (p *Provider) resolveRedirectURI(redirectURI string) string {
|
||||
// buildOIDCConnectorConfig creates config for OIDC-based connectors
|
||||
func buildOIDCConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte, error) {
|
||||
oidcConfig := map[string]interface{}{
|
||||
"issuer": cfg.Issuer,
|
||||
"clientID": cfg.ClientID,
|
||||
"clientSecret": cfg.ClientSecret,
|
||||
"redirectURI": redirectURI,
|
||||
"scopes": []string{"openid", "profile", "email"},
|
||||
"issuer": cfg.Issuer,
|
||||
"clientID": cfg.ClientID,
|
||||
"clientSecret": cfg.ClientSecret,
|
||||
"redirectURI": redirectURI,
|
||||
"scopes": []string{"openid", "profile", "email"},
|
||||
"insecureEnableGroups": true,
|
||||
//some providers don't return email verified, so we need to skip it if not present (e.g., Entra, Okta, Duo)
|
||||
"insecureSkipEmailVerified": true,
|
||||
}
|
||||
switch cfg.Type {
|
||||
case "zitadel":
|
||||
oidcConfig["getUserInfo"] = true
|
||||
case "entra":
|
||||
oidcConfig["insecureSkipEmailVerified"] = true
|
||||
oidcConfig["claimMapping"] = map[string]string{"email": "preferred_username"}
|
||||
case "okta":
|
||||
oidcConfig["insecureSkipEmailVerified"] = true
|
||||
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
|
||||
case "pocketid":
|
||||
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
|
||||
}
|
||||
return encodeConnectorConfig(oidcConfig)
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -143,7 +143,7 @@ func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*nbconfig.Confi
|
||||
applyCommandLineOverrides(loadedConfig)
|
||||
|
||||
// Apply EmbeddedIdP config to HttpConfig if embedded IdP is enabled
|
||||
err := applyEmbeddedIdPConfig(loadedConfig)
|
||||
err := applyEmbeddedIdPConfig(ctx, loadedConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -177,7 +177,7 @@ func applyCommandLineOverrides(cfg *nbconfig.Config) {
|
||||
|
||||
// applyEmbeddedIdPConfig populates HttpConfig and EmbeddedIdP storage from config when embedded IdP is enabled.
|
||||
// This allows users to only specify EmbeddedIdP config without duplicating values in HttpConfig.
|
||||
func applyEmbeddedIdPConfig(cfg *nbconfig.Config) error {
|
||||
func applyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config) error {
|
||||
if cfg.EmbeddedIdP == nil || !cfg.EmbeddedIdP.Enabled {
|
||||
return nil
|
||||
}
|
||||
@@ -190,10 +190,8 @@ func applyEmbeddedIdPConfig(cfg *nbconfig.Config) error {
|
||||
// Enable user deletion from IDP by default if EmbeddedIdP is enabled
|
||||
userDeleteFromIDPEnabled = true
|
||||
|
||||
// Ensure HttpConfig exists
|
||||
if cfg.HttpConfig == nil {
|
||||
cfg.HttpConfig = &nbconfig.HttpServerConfig{}
|
||||
}
|
||||
// Set LocalAddress for embedded IdP if enabled, used for internal JWT validation
|
||||
cfg.EmbeddedIdP.LocalAddress = fmt.Sprintf("localhost:%d", mgmtPort)
|
||||
|
||||
// Set storage defaults based on Datadir
|
||||
if cfg.EmbeddedIdP.Storage.Type == "" {
|
||||
@@ -205,40 +203,22 @@ func applyEmbeddedIdPConfig(cfg *nbconfig.Config) error {
|
||||
|
||||
issuer := cfg.EmbeddedIdP.Issuer
|
||||
|
||||
// Set AuthIssuer from EmbeddedIdP issuer
|
||||
if cfg.HttpConfig.AuthIssuer == "" {
|
||||
cfg.HttpConfig.AuthIssuer = issuer
|
||||
if cfg.HttpConfig != nil {
|
||||
log.WithContext(ctx).Warnf("overriding HttpConfig with EmbeddedIdP config. " +
|
||||
"HttpConfig is ignored when EmbeddedIdP is enabled. Please remove HttpConfig section from the config file")
|
||||
} else {
|
||||
// Ensure HttpConfig exists. We need it for backwards compatibility with the old config format.
|
||||
cfg.HttpConfig = &nbconfig.HttpServerConfig{}
|
||||
}
|
||||
|
||||
// Set AuthAudience to the dashboard client ID
|
||||
if cfg.HttpConfig.AuthAudience == "" {
|
||||
cfg.HttpConfig.AuthAudience = "netbird-dashboard"
|
||||
}
|
||||
|
||||
// Set CLIAuthAudience to the client app client ID
|
||||
if cfg.HttpConfig.CLIAuthAudience == "" {
|
||||
cfg.HttpConfig.CLIAuthAudience = "netbird-cli"
|
||||
}
|
||||
|
||||
// Set AuthUserIDClaim to "sub" (standard OIDC claim)
|
||||
if cfg.HttpConfig.AuthUserIDClaim == "" {
|
||||
cfg.HttpConfig.AuthUserIDClaim = "sub"
|
||||
}
|
||||
|
||||
// Set AuthKeysLocation to the JWKS endpoint
|
||||
if cfg.HttpConfig.AuthKeysLocation == "" {
|
||||
cfg.HttpConfig.AuthKeysLocation = issuer + "/keys"
|
||||
}
|
||||
|
||||
// Set OIDCConfigEndpoint to the discovery endpoint
|
||||
if cfg.HttpConfig.OIDCConfigEndpoint == "" {
|
||||
cfg.HttpConfig.OIDCConfigEndpoint = issuer + "/.well-known/openid-configuration"
|
||||
}
|
||||
|
||||
// Copy SignKeyRefreshEnabled from EmbeddedIdP config
|
||||
if cfg.EmbeddedIdP.SignKeyRefreshEnabled {
|
||||
cfg.HttpConfig.IdpSignKeyRefreshEnabled = true
|
||||
}
|
||||
// Set HttpConfig values from EmbeddedIdP
|
||||
cfg.HttpConfig.AuthIssuer = issuer
|
||||
cfg.HttpConfig.AuthAudience = "netbird-dashboard"
|
||||
cfg.HttpConfig.CLIAuthAudience = "netbird-cli"
|
||||
cfg.HttpConfig.AuthUserIDClaim = "sub"
|
||||
cfg.HttpConfig.AuthKeysLocation = issuer + "/keys"
|
||||
cfg.HttpConfig.OIDCConfigEndpoint = issuer + "/.well-known/openid-configuration"
|
||||
cfg.HttpConfig.IdpSignKeyRefreshEnabled = true
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -246,7 +226,12 @@ func applyEmbeddedIdPConfig(cfg *nbconfig.Config) error {
|
||||
// applyOIDCConfig fetches and applies OIDC configuration if endpoint is specified
|
||||
func applyOIDCConfig(ctx context.Context, cfg *nbconfig.Config) error {
|
||||
oidcEndpoint := cfg.HttpConfig.OIDCConfigEndpoint
|
||||
if oidcEndpoint == "" || cfg.EmbeddedIdP != nil {
|
||||
if oidcEndpoint == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if cfg.EmbeddedIdP != nil && cfg.EmbeddedIdP.Enabled {
|
||||
// skip OIDC config fetching if EmbeddedIdP is enabled as it is unnecessary given it is embedded
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
@@ -175,7 +176,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
||||
|
||||
dnsCache := &cache.DNSConfigCache{}
|
||||
dnsDomain := c.GetDNSDomain(account.Settings)
|
||||
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
@@ -197,6 +198,12 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
||||
|
||||
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
|
||||
|
||||
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get account zones: %v", err)
|
||||
return fmt.Errorf("failed to get account zones: %v", err)
|
||||
}
|
||||
|
||||
for _, peer := range account.Peers {
|
||||
if !c.peersUpdateManager.HasChannel(peer.ID) {
|
||||
log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID)
|
||||
@@ -223,9 +230,9 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
||||
var remotePeerNetworkMap *types.NetworkMap
|
||||
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
|
||||
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
|
||||
} else {
|
||||
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
}
|
||||
|
||||
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
|
||||
@@ -318,7 +325,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
|
||||
|
||||
dnsCache := &cache.DNSConfigCache{}
|
||||
dnsDomain := c.GetDNSDomain(account.Settings)
|
||||
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
@@ -335,12 +342,18 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
|
||||
return err
|
||||
}
|
||||
|
||||
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get account zones: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
var remotePeerNetworkMap *types.NetworkMap
|
||||
|
||||
if c.experimentalNetworkMap(accountId) {
|
||||
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
|
||||
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
|
||||
} else {
|
||||
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
}
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
@@ -434,7 +447,14 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
|
||||
}
|
||||
log.WithContext(ctx).Debugf("getPeerPostureChecks took %s", time.Since(startPosture))
|
||||
|
||||
customZone := account.GetPeersCustomZone(ctx, c.GetDNSDomain(account.Settings))
|
||||
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get account zones: %v", err)
|
||||
return nil, nil, nil, 0, err
|
||||
}
|
||||
|
||||
dnsDomain := c.GetDNSDomain(account.Settings)
|
||||
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||
|
||||
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers)
|
||||
if err != nil {
|
||||
@@ -445,11 +465,11 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
|
||||
var networkMap *types.NetworkMap
|
||||
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
|
||||
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
|
||||
} else {
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, account.GetActiveGroupUsers())
|
||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, account.GetActiveGroupUsers())
|
||||
}
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
@@ -472,7 +492,8 @@ func (c *Controller) getPeerNetworkMapExp(
|
||||
accountId string,
|
||||
peerId string,
|
||||
validatedPeers map[string]struct{},
|
||||
customZone nbdns.CustomZone,
|
||||
peersCustomZone nbdns.CustomZone,
|
||||
accountZones []*zones.Zone,
|
||||
metrics *telemetry.AccountManagerMetrics,
|
||||
) *types.NetworkMap {
|
||||
account := c.getAccountFromHolderOrInit(ctx, accountId)
|
||||
@@ -483,7 +504,7 @@ func (c *Controller) getPeerNetworkMapExp(
|
||||
}
|
||||
}
|
||||
|
||||
return account.GetPeerNetworkMapExp(ctx, peerId, customZone, validatedPeers, metrics)
|
||||
return account.GetPeerNetworkMapExp(ctx, peerId, peersCustomZone, accountZones, validatedPeers, metrics)
|
||||
}
|
||||
|
||||
func (c *Controller) onPeersAddedUpdNetworkMapCache(account *types.Account, peerIds ...string) {
|
||||
@@ -798,7 +819,15 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
customZone := account.GetPeersCustomZone(ctx, c.GetDNSDomain(account.Settings))
|
||||
|
||||
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get account zones: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dnsDomain := c.GetDNSDomain(account.Settings)
|
||||
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||
|
||||
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peerID, account.Peers)
|
||||
if err != nil {
|
||||
@@ -809,11 +838,11 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
|
||||
var networkMap *types.NetworkMap
|
||||
|
||||
if c.experimentalNetworkMap(peer.AccountID) {
|
||||
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, customZone, nil)
|
||||
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, peersCustomZone, accountZones, nil)
|
||||
} else {
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
}
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
|
||||
@@ -3,6 +3,7 @@ package controller
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
@@ -14,6 +15,7 @@ type Repository interface {
|
||||
GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error)
|
||||
GetPeersByIDs(ctx context.Context, accountID string, peerIDs []string) (map[string]*peer.Peer, error)
|
||||
GetPeerByID(ctx context.Context, accountID string, peerID string) (*peer.Peer, error)
|
||||
GetAccountZones(ctx context.Context, accountID string) ([]*zones.Zone, error)
|
||||
}
|
||||
|
||||
type repository struct {
|
||||
@@ -47,3 +49,7 @@ func (r *repository) GetPeersByIDs(ctx context.Context, accountID string, peerID
|
||||
func (r *repository) GetPeerByID(ctx context.Context, accountID string, peerID string) (*peer.Peer, error) {
|
||||
return r.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||
}
|
||||
|
||||
func (r *repository) GetAccountZones(ctx context.Context, accountID string) ([]*zones.Zone, error) {
|
||||
return r.store.GetAccountZones(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
13
management/internals/modules/zones/interface.go
Normal file
13
management/internals/modules/zones/interface.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package zones
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
type Manager interface {
|
||||
GetAllZones(ctx context.Context, accountID, userID string) ([]*Zone, error)
|
||||
GetZone(ctx context.Context, accountID, userID, zone string) (*Zone, error)
|
||||
CreateZone(ctx context.Context, accountID, userID string, zone *Zone) (*Zone, error)
|
||||
UpdateZone(ctx context.Context, accountID, userID string, zone *Zone) (*Zone, error)
|
||||
DeleteZone(ctx context.Context, accountID, userID, zoneID string) error
|
||||
}
|
||||
161
management/internals/modules/zones/manager/api.go
Normal file
161
management/internals/modules/zones/manager/api.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
type handler struct {
|
||||
manager zones.Manager
|
||||
}
|
||||
|
||||
func RegisterEndpoints(router *mux.Router, manager zones.Manager) {
|
||||
h := &handler{
|
||||
manager: manager,
|
||||
}
|
||||
|
||||
router.HandleFunc("/dns/zones", h.getAllZones).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones", h.createZone).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}", h.getZone).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}", h.updateZone).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}", h.deleteZone).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
func (h *handler) getAllZones(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
allZones, err := h.manager.GetAllZones(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
apiZones := make([]*api.Zone, 0, len(allZones))
|
||||
for _, zone := range allZones {
|
||||
apiZones = append(apiZones, zone.ToAPIResponse())
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, apiZones)
|
||||
}
|
||||
|
||||
func (h *handler) createZone(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.PostApiDnsZonesJSONRequestBody
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
zone := new(zones.Zone)
|
||||
zone.FromAPIRequest(&req)
|
||||
|
||||
if err = zone.Validate(); err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||
return
|
||||
}
|
||||
|
||||
createdZone, err := h.manager.CreateZone(r.Context(), userAuth.AccountId, userAuth.UserId, zone)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, createdZone.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) getZone(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
zoneID := mux.Vars(r)["zoneId"]
|
||||
if zoneID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
zone, err := h.manager.GetZone(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, zone.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) updateZone(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
zoneID := mux.Vars(r)["zoneId"]
|
||||
if zoneID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.PutApiDnsZonesZoneIdJSONRequestBody
|
||||
if err = json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
zone := new(zones.Zone)
|
||||
zone.FromAPIRequest(&req)
|
||||
zone.ID = zoneID
|
||||
|
||||
if err = zone.Validate(); err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||
return
|
||||
}
|
||||
|
||||
updatedZone, err := h.manager.UpdateZone(r.Context(), userAuth.AccountId, userAuth.UserId, zone)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, updatedZone.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) deleteZone(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
zoneID := mux.Vars(r)["zoneId"]
|
||||
if zoneID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
if err = h.manager.DeleteZone(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
229
management/internals/modules/zones/manager/manager.go
Normal file
229
management/internals/modules/zones/manager/manager.go
Normal file
@@ -0,0 +1,229 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
type managerImpl struct {
|
||||
store store.Store
|
||||
accountManager account.Manager
|
||||
permissionsManager permissions.Manager
|
||||
dnsDomain string
|
||||
}
|
||||
|
||||
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, dnsDomain string) zones.Manager {
|
||||
return &managerImpl{
|
||||
store: store,
|
||||
accountManager: accountManager,
|
||||
permissionsManager: permissionsManager,
|
||||
dnsDomain: dnsDomain,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetAllZones(ctx context.Context, accountID, userID string) ([]*zones.Zone, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return m.store.GetAccountZones(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetZone(ctx context.Context, accountID, userID, zoneID string) (*zones.Zone, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return m.store.GetZoneByID(ctx, store.LockingStrengthNone, accountID, zoneID)
|
||||
}
|
||||
|
||||
func (m *managerImpl) CreateZone(ctx context.Context, accountID, userID string, zone *zones.Zone) (*zones.Zone, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Create)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
if err = m.validateZoneDomainConflict(ctx, accountID, zone.Domain); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
zone = zones.NewZone(accountID, zone.Name, zone.Domain, zone.Enabled, zone.EnableSearchDomain, zone.DistributionGroups)
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
existingZone, err := transaction.GetZoneByDomain(ctx, accountID, zone.Domain)
|
||||
if err != nil {
|
||||
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
|
||||
return fmt.Errorf("failed to check existing zone: %w", err)
|
||||
}
|
||||
}
|
||||
if existingZone != nil {
|
||||
return status.Errorf(status.AlreadyExists, "zone with domain %s already exists", zone.Domain)
|
||||
}
|
||||
|
||||
for _, groupID := range zone.DistributionGroups {
|
||||
_, err = transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.InvalidArgument, "%s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if err = transaction.CreateZone(ctx, zone); err != nil {
|
||||
return fmt.Errorf("failed to create zone: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, zone.ID, accountID, activity.DNSZoneCreated, zone.EventMeta())
|
||||
|
||||
return zone, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) UpdateZone(ctx context.Context, accountID, userID string, updatedZone *zones.Zone) (*zones.Zone, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
zone, err := m.store.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, updatedZone.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get zone: %w", err)
|
||||
}
|
||||
|
||||
if zone.Domain != updatedZone.Domain {
|
||||
return nil, status.Errorf(status.InvalidArgument, "zone domain cannot be updated")
|
||||
}
|
||||
|
||||
zone.Name = updatedZone.Name
|
||||
zone.Enabled = updatedZone.Enabled
|
||||
zone.EnableSearchDomain = updatedZone.EnableSearchDomain
|
||||
zone.DistributionGroups = updatedZone.DistributionGroups
|
||||
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
for _, groupID := range zone.DistributionGroups {
|
||||
_, err = transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.InvalidArgument, "%s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if err = transaction.UpdateZone(ctx, zone); err != nil {
|
||||
return fmt.Errorf("failed to update zone: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, zone.ID, accountID, activity.DNSZoneUpdated, zone.EventMeta())
|
||||
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return zone, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) DeleteZone(ctx context.Context, accountID, userID, zoneID string) error {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
zone, err := m.store.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get zone: %w", err)
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
records, err := transaction.GetZoneDNSRecords(ctx, store.LockingStrengthNone, accountID, zoneID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get records: %w", err)
|
||||
}
|
||||
|
||||
err = transaction.DeleteZoneDNSRecords(ctx, accountID, zoneID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete zone dns records: %w", err)
|
||||
}
|
||||
|
||||
err = transaction.DeleteZone(ctx, accountID, zoneID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete zone: %w", err)
|
||||
}
|
||||
|
||||
err = transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
|
||||
for _, record := range records {
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
meta := record.EventMeta(zone.ID, zone.Name)
|
||||
m.accountManager.StoreEvent(ctx, userID, record.ID, accountID, activity.DNSRecordDeleted, meta)
|
||||
})
|
||||
}
|
||||
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
m.accountManager.StoreEvent(ctx, userID, zoneID, accountID, activity.DNSZoneDeleted, zone.EventMeta())
|
||||
})
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, event := range eventsToStore {
|
||||
event()
|
||||
}
|
||||
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) validateZoneDomainConflict(ctx context.Context, accountID, domain string) error {
|
||||
if m.dnsDomain != "" && m.dnsDomain == domain {
|
||||
return status.Errorf(status.InvalidArgument, "zone domain %s conflicts with peer DNS domain", domain)
|
||||
}
|
||||
|
||||
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if settings.DNSDomain != "" && settings.DNSDomain == domain {
|
||||
return status.Errorf(status.InvalidArgument, "zone domain %s conflicts with peer DNS domain", domain)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
553
management/internals/modules/zones/manager/manager_test.go
Normal file
553
management/internals/modules/zones/manager/manager_test.go
Normal file
@@ -0,0 +1,553 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
const (
|
||||
testAccountID = "test-account-id"
|
||||
testUserID = "test-user-id"
|
||||
testZoneID = "test-zone-id"
|
||||
testGroupID = "test-group-id"
|
||||
testDNSDomain = "netbird.selfhosted"
|
||||
)
|
||||
|
||||
func setupTest(t *testing.T) (*managerImpl, store.Store, *mock_server.MockAccountManager, *permissions.MockManager, *gomock.Controller, func()) {
|
||||
t.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
testStore, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||
require.NoError(t, err)
|
||||
|
||||
err = testStore.SaveAccount(ctx, &types.Account{
|
||||
Id: testAccountID,
|
||||
Groups: map[string]*types.Group{
|
||||
testGroupID: {
|
||||
ID: testGroupID,
|
||||
Name: "Test Group",
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mockAccountManager := &mock_server.MockAccountManager{}
|
||||
mockPermissionsManager := permissions.NewMockManager(ctrl)
|
||||
|
||||
manager := &managerImpl{
|
||||
store: testStore,
|
||||
accountManager: mockAccountManager,
|
||||
permissionsManager: mockPermissionsManager,
|
||||
dnsDomain: testDNSDomain,
|
||||
}
|
||||
|
||||
return manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup
|
||||
}
|
||||
|
||||
func TestManagerImpl_GetAllZones(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
zone1 := zones.NewZone(testAccountID, "Zone 1", "zone1.example.com", true, true, []string{testGroupID})
|
||||
err := testStore.CreateZone(ctx, zone1)
|
||||
require.NoError(t, err)
|
||||
|
||||
zone2 := zones.NewZone(testAccountID, "Zone 2", "zone2.example.com", false, false, []string{testGroupID})
|
||||
err = testStore.CreateZone(ctx, zone2)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.GetAllZones(ctx, testAccountID, testUserID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, result, 2)
|
||||
assert.Equal(t, zone1.ID, result[0].ID)
|
||||
assert.Equal(t, zone2.ID, result[1].ID)
|
||||
})
|
||||
|
||||
t.Run("permission denied", func(t *testing.T) {
|
||||
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(false, nil)
|
||||
|
||||
result, err := manager.GetAllZones(ctx, testAccountID, testUserID)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||
})
|
||||
|
||||
t.Run("permission validation error", func(t *testing.T) {
|
||||
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(false, status.Errorf(status.Internal, "permission check failed"))
|
||||
|
||||
result, err := manager.GetAllZones(ctx, testAccountID, testUserID)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
}
|
||||
|
||||
func TestManagerImpl_GetZone(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
zone := zones.NewZone(testAccountID, "Test Zone", "test.example.com", true, true, []string{testGroupID})
|
||||
err := testStore.CreateZone(ctx, zone)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.GetZone(ctx, testAccountID, testUserID, zone.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, zone.ID, result.ID)
|
||||
assert.Equal(t, zone.Name, result.Name)
|
||||
assert.Equal(t, zone.Domain, result.Domain)
|
||||
})
|
||||
|
||||
t.Run("permission denied", func(t *testing.T) {
|
||||
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(false, nil)
|
||||
|
||||
result, err := manager.GetZone(ctx, testAccountID, testUserID, testZoneID)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||
})
|
||||
}
|
||||
|
||||
func TestManagerImpl_CreateZone(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
manager, _, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
inputZone := &zones.Zone{
|
||||
Name: "New Zone",
|
||||
Domain: "new.example.com",
|
||||
Enabled: true,
|
||||
EnableSearchDomain: true,
|
||||
DistributionGroups: []string{testGroupID},
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
assert.Equal(t, testUserID, initiatorID)
|
||||
assert.Equal(t, testAccountID, accountID)
|
||||
assert.Equal(t, activity.DNSZoneCreated, activityID)
|
||||
}
|
||||
|
||||
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.NotEmpty(t, result.ID)
|
||||
assert.Equal(t, testAccountID, result.AccountID)
|
||||
assert.Equal(t, inputZone.Name, result.Name)
|
||||
assert.Equal(t, inputZone.Domain, result.Domain)
|
||||
assert.Equal(t, inputZone.Enabled, result.Enabled)
|
||||
assert.Equal(t, inputZone.EnableSearchDomain, result.EnableSearchDomain)
|
||||
assert.Equal(t, inputZone.DistributionGroups, result.DistributionGroups)
|
||||
})
|
||||
|
||||
t.Run("permission denied", func(t *testing.T) {
|
||||
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
inputZone := &zones.Zone{
|
||||
Name: "New Zone",
|
||||
Domain: "new.example.com",
|
||||
DistributionGroups: []string{testGroupID},
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(false, nil)
|
||||
|
||||
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||
})
|
||||
|
||||
t.Run("invalid group", func(t *testing.T) {
|
||||
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
inputZone := &zones.Zone{
|
||||
Name: "New Zone",
|
||||
Domain: "new.example.com",
|
||||
DistributionGroups: []string{"invalid-group"},
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
|
||||
t.Run("duplicate domain", func(t *testing.T) {
|
||||
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
existingZone := zones.NewZone(testAccountID, "Existing Zone", "duplicate.example.com", true, false, []string{testGroupID})
|
||||
err := testStore.CreateZone(ctx, existingZone)
|
||||
require.NoError(t, err)
|
||||
|
||||
inputZone := &zones.Zone{
|
||||
Name: "New Zone",
|
||||
Domain: "duplicate.example.com",
|
||||
Enabled: true,
|
||||
EnableSearchDomain: false,
|
||||
DistributionGroups: []string{testGroupID},
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), "zone with domain duplicate.example.com already exists")
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.AlreadyExists, s.Type())
|
||||
})
|
||||
|
||||
t.Run("peer DNS domain conflict", func(t *testing.T) {
|
||||
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
account, err := testStore.GetAccount(ctx, testAccountID)
|
||||
require.NoError(t, err)
|
||||
account.Settings.DNSDomain = "peers.example.com"
|
||||
err = testStore.SaveAccount(ctx, account)
|
||||
require.NoError(t, err)
|
||||
|
||||
inputZone := &zones.Zone{
|
||||
Name: "Test Zone",
|
||||
Domain: "peers.example.com",
|
||||
Enabled: true,
|
||||
EnableSearchDomain: false,
|
||||
DistributionGroups: []string{testGroupID},
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), "zone domain peers.example.com conflicts with peer DNS domain")
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.InvalidArgument, s.Type())
|
||||
})
|
||||
|
||||
t.Run("default DNS domain conflict", func(t *testing.T) {
|
||||
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
inputZone := &zones.Zone{
|
||||
Name: "Test Zone",
|
||||
Domain: testDNSDomain,
|
||||
Enabled: true,
|
||||
EnableSearchDomain: false,
|
||||
DistributionGroups: []string{testGroupID},
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), fmt.Sprintf("zone domain %s conflicts with peer DNS domain", testDNSDomain))
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.InvalidArgument, s.Type())
|
||||
})
|
||||
}
|
||||
|
||||
func TestManagerImpl_UpdateZone(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
existingZone := zones.NewZone(testAccountID, "Old Name", "example.com", false, false, []string{testGroupID})
|
||||
err := testStore.CreateZone(ctx, existingZone)
|
||||
require.NoError(t, err)
|
||||
|
||||
updatedZone := &zones.Zone{
|
||||
ID: existingZone.ID,
|
||||
Name: "Updated Name",
|
||||
Domain: "example.com",
|
||||
Enabled: true,
|
||||
EnableSearchDomain: true,
|
||||
DistributionGroups: []string{testGroupID},
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||
Return(true, nil)
|
||||
|
||||
storeEventCalled := false
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
storeEventCalled = true
|
||||
assert.Equal(t, testUserID, initiatorID)
|
||||
assert.Equal(t, existingZone.ID, targetID)
|
||||
assert.Equal(t, testAccountID, accountID)
|
||||
assert.Equal(t, activity.DNSZoneUpdated, activityID)
|
||||
}
|
||||
|
||||
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.Equal(t, updatedZone.Name, result.Name)
|
||||
assert.Equal(t, updatedZone.Enabled, result.Enabled)
|
||||
assert.Equal(t, updatedZone.EnableSearchDomain, result.EnableSearchDomain)
|
||||
assert.True(t, storeEventCalled, "StoreEvent should have been called")
|
||||
})
|
||||
|
||||
t.Run("domain change not allowed", func(t *testing.T) {
|
||||
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
existingZone := zones.NewZone(testAccountID, "Test Zone", "example.com", true, true, []string{testGroupID})
|
||||
err := testStore.CreateZone(ctx, existingZone)
|
||||
require.NoError(t, err)
|
||||
|
||||
updatedZone := &zones.Zone{
|
||||
ID: existingZone.ID,
|
||||
Name: "Test Zone",
|
||||
Domain: "different.com",
|
||||
Enabled: true,
|
||||
EnableSearchDomain: true,
|
||||
DistributionGroups: []string{testGroupID},
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), "zone domain cannot be updated")
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.InvalidArgument, s.Type())
|
||||
})
|
||||
|
||||
t.Run("permission denied", func(t *testing.T) {
|
||||
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
updatedZone := &zones.Zone{
|
||||
ID: testZoneID,
|
||||
Name: "Updated Name",
|
||||
Domain: "example.com",
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||
Return(false, nil)
|
||||
|
||||
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||
})
|
||||
|
||||
t.Run("zone not found", func(t *testing.T) {
|
||||
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
updatedZone := &zones.Zone{
|
||||
ID: "non-existent-zone",
|
||||
Name: "Updated Name",
|
||||
Domain: "example.com",
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
}
|
||||
|
||||
func TestManagerImpl_DeleteZone(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("success with records", func(t *testing.T) {
|
||||
manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
zone := zones.NewZone(testAccountID, "Test Zone", "example.com", true, true, []string{testGroupID})
|
||||
err := testStore.CreateZone(ctx, zone)
|
||||
require.NoError(t, err)
|
||||
|
||||
record1 := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||
err = testStore.CreateDNSRecord(ctx, record1)
|
||||
require.NoError(t, err)
|
||||
|
||||
record2 := records.NewRecord(testAccountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.2", 300)
|
||||
err = testStore.CreateDNSRecord(ctx, record2)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
|
||||
Return(true, nil)
|
||||
|
||||
storeEventCallCount := 0
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
storeEventCallCount++
|
||||
assert.Equal(t, testUserID, initiatorID)
|
||||
assert.Equal(t, testAccountID, accountID)
|
||||
}
|
||||
|
||||
err = manager.DeleteZone(ctx, testAccountID, testUserID, zone.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 3, storeEventCallCount)
|
||||
|
||||
_, err = testStore.GetZoneByID(ctx, store.LockingStrengthNone, testAccountID, zone.ID)
|
||||
require.Error(t, err)
|
||||
|
||||
zoneRecords, err := testStore.GetZoneDNSRecords(ctx, store.LockingStrengthNone, testAccountID, zone.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, zoneRecords)
|
||||
})
|
||||
|
||||
t.Run("success without records", func(t *testing.T) {
|
||||
manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
zone := zones.NewZone(testAccountID, "Test Zone", "example.com", true, true, []string{testGroupID})
|
||||
err := testStore.CreateZone(ctx, zone)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
|
||||
Return(true, nil)
|
||||
|
||||
storeEventCalled := false
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
storeEventCalled = true
|
||||
assert.Equal(t, testUserID, initiatorID)
|
||||
assert.Equal(t, zone.ID, targetID)
|
||||
assert.Equal(t, testAccountID, accountID)
|
||||
assert.Equal(t, activity.DNSZoneDeleted, activityID)
|
||||
}
|
||||
|
||||
err = manager.DeleteZone(ctx, testAccountID, testUserID, zone.ID)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, storeEventCalled, "StoreEvent should have been called")
|
||||
|
||||
_, err = testStore.GetZoneByID(ctx, store.LockingStrengthNone, testAccountID, zone.ID)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("permission denied", func(t *testing.T) {
|
||||
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
|
||||
Return(false, nil)
|
||||
|
||||
err := manager.DeleteZone(ctx, testAccountID, testUserID, testZoneID)
|
||||
require.Error(t, err)
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||
})
|
||||
|
||||
t.Run("zone not found", func(t *testing.T) {
|
||||
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
|
||||
Return(true, nil)
|
||||
|
||||
err := manager.DeleteZone(ctx, testAccountID, testUserID, "non-existent-zone")
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
13
management/internals/modules/zones/records/interface.go
Normal file
13
management/internals/modules/zones/records/interface.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package records
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
type Manager interface {
|
||||
GetAllRecords(ctx context.Context, accountID, userID, zoneID string) ([]*Record, error)
|
||||
GetRecord(ctx context.Context, accountID, userID, zoneID, recordID string) (*Record, error)
|
||||
CreateRecord(ctx context.Context, accountID, userID, zoneID string, record *Record) (*Record, error)
|
||||
UpdateRecord(ctx context.Context, accountID, userID, zoneID string, record *Record) (*Record, error)
|
||||
DeleteRecord(ctx context.Context, accountID, userID, zoneID, recordID string) error
|
||||
}
|
||||
191
management/internals/modules/zones/records/manager/api.go
Normal file
191
management/internals/modules/zones/records/manager/api.go
Normal file
@@ -0,0 +1,191 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
type handler struct {
|
||||
manager records.Manager
|
||||
}
|
||||
|
||||
func RegisterEndpoints(router *mux.Router, manager records.Manager) {
|
||||
h := &handler{
|
||||
manager: manager,
|
||||
}
|
||||
|
||||
router.HandleFunc("/dns/zones/{zoneId}/records", h.getAllRecords).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}/records", h.createRecord).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.getRecord).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.updateRecord).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.deleteRecord).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
func (h *handler) getAllRecords(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
zoneID := mux.Vars(r)["zoneId"]
|
||||
if zoneID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
allRecords, err := h.manager.GetAllRecords(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
apiRecords := make([]*api.DNSRecord, 0, len(allRecords))
|
||||
for _, record := range allRecords {
|
||||
apiRecords = append(apiRecords, record.ToAPIResponse())
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, apiRecords)
|
||||
}
|
||||
|
||||
func (h *handler) createRecord(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
zoneID := mux.Vars(r)["zoneId"]
|
||||
if zoneID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.PostApiDnsZonesZoneIdRecordsJSONRequestBody
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
record := new(records.Record)
|
||||
record.FromAPIRequest(&req)
|
||||
|
||||
if err = record.Validate(); err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||
return
|
||||
}
|
||||
|
||||
createdRecord, err := h.manager.CreateRecord(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID, record)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, createdRecord.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) getRecord(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
zoneID := mux.Vars(r)["zoneId"]
|
||||
if zoneID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
recordID := mux.Vars(r)["recordId"]
|
||||
if recordID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "record ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
record, err := h.manager.GetRecord(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID, recordID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, record.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
zoneID := mux.Vars(r)["zoneId"]
|
||||
if zoneID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
recordID := mux.Vars(r)["recordId"]
|
||||
if recordID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "record ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody
|
||||
if err = json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
record := new(records.Record)
|
||||
record.FromAPIRequest(&req)
|
||||
record.ID = recordID
|
||||
|
||||
if err = record.Validate(); err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||
return
|
||||
}
|
||||
|
||||
updatedRecord, err := h.manager.UpdateRecord(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID, record)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, updatedRecord.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) deleteRecord(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
zoneID := mux.Vars(r)["zoneId"]
|
||||
if zoneID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
recordID := mux.Vars(r)["recordId"]
|
||||
if recordID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "record ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
if err = h.manager.DeleteRecord(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID, recordID); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
236
management/internals/modules/zones/records/manager/manager.go
Normal file
236
management/internals/modules/zones/records/manager/manager.go
Normal file
@@ -0,0 +1,236 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
type managerImpl struct {
|
||||
store store.Store
|
||||
accountManager account.Manager
|
||||
permissionsManager permissions.Manager
|
||||
}
|
||||
|
||||
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager) records.Manager {
|
||||
return &managerImpl{
|
||||
store: store,
|
||||
accountManager: accountManager,
|
||||
permissionsManager: permissionsManager,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetAllRecords(ctx context.Context, accountID, userID, zoneID string) ([]*records.Record, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return m.store.GetZoneDNSRecords(ctx, store.LockingStrengthNone, accountID, zoneID)
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetRecord(ctx context.Context, accountID, userID, zoneID, recordID string) (*records.Record, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return m.store.GetDNSRecordByID(ctx, store.LockingStrengthNone, accountID, zoneID, recordID)
|
||||
}
|
||||
|
||||
func (m *managerImpl) CreateRecord(ctx context.Context, accountID, userID, zoneID string, record *records.Record) (*records.Record, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Create)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var zone *zones.Zone
|
||||
|
||||
record = records.NewRecord(accountID, zoneID, record.Name, record.Type, record.Content, record.TTL)
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
zone, err = transaction.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get zone: %w", err)
|
||||
}
|
||||
|
||||
err = validateRecordConflicts(ctx, transaction, zone, record)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.CreateDNSRecord(ctx, record); err != nil {
|
||||
return fmt.Errorf("failed to create dns record: %w", err)
|
||||
}
|
||||
|
||||
err = transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
meta := record.EventMeta(zone.ID, zone.Name)
|
||||
m.accountManager.StoreEvent(ctx, userID, record.ID, accountID, activity.DNSRecordCreated, meta)
|
||||
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) UpdateRecord(ctx context.Context, accountID, userID, zoneID string, updatedRecord *records.Record) (*records.Record, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var zone *zones.Zone
|
||||
var record *records.Record
|
||||
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
zone, err = transaction.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get zone: %w", err)
|
||||
}
|
||||
|
||||
record, err = transaction.GetDNSRecordByID(ctx, store.LockingStrengthUpdate, accountID, zoneID, updatedRecord.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get record: %w", err)
|
||||
}
|
||||
|
||||
hasChanges := record.Name != updatedRecord.Name || record.Type != updatedRecord.Type || record.Content != updatedRecord.Content
|
||||
|
||||
record.Name = updatedRecord.Name
|
||||
record.Type = updatedRecord.Type
|
||||
record.Content = updatedRecord.Content
|
||||
record.TTL = updatedRecord.TTL
|
||||
|
||||
if hasChanges {
|
||||
if err = validateRecordConflicts(ctx, transaction, zone, record); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err = transaction.UpdateDNSRecord(ctx, record); err != nil {
|
||||
return fmt.Errorf("failed to update dns record: %w", err)
|
||||
}
|
||||
|
||||
err = transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
meta := record.EventMeta(zone.ID, zone.Name)
|
||||
m.accountManager.StoreEvent(ctx, userID, record.ID, accountID, activity.DNSRecordUpdated, meta)
|
||||
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) DeleteRecord(ctx context.Context, accountID, userID, zoneID, recordID string) error {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var record *records.Record
|
||||
var zone *zones.Zone
|
||||
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
zone, err = transaction.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get zone: %w", err)
|
||||
}
|
||||
|
||||
record, err = transaction.GetDNSRecordByID(ctx, store.LockingStrengthUpdate, accountID, zoneID, recordID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get record: %w", err)
|
||||
}
|
||||
|
||||
err = transaction.DeleteDNSRecord(ctx, accountID, zoneID, recordID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete dns record: %w", err)
|
||||
}
|
||||
|
||||
err = transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
meta := record.EventMeta(zone.ID, zone.Name)
|
||||
m.accountManager.StoreEvent(ctx, userID, recordID, accountID, activity.DNSRecordDeleted, meta)
|
||||
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateRecordConflicts checks for duplicate records and CNAME conflicts
|
||||
func validateRecordConflicts(ctx context.Context, transaction store.Store, zone *zones.Zone, record *records.Record) error {
|
||||
if record.Name != zone.Domain && !strings.HasSuffix(record.Name, "."+zone.Domain) {
|
||||
return status.Errorf(status.InvalidArgument, "record name does not belong to zone")
|
||||
}
|
||||
|
||||
existingRecords, err := transaction.GetZoneDNSRecordsByName(ctx, store.LockingStrengthNone, zone.AccountID, zone.ID, record.Name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check existing records: %w", err)
|
||||
}
|
||||
|
||||
for _, existing := range existingRecords {
|
||||
if existing.ID == record.ID {
|
||||
continue
|
||||
}
|
||||
|
||||
if existing.Type == record.Type && existing.Content == record.Content {
|
||||
return status.Errorf(status.AlreadyExists, "identical record already exists")
|
||||
}
|
||||
|
||||
if record.Type == records.RecordTypeCNAME || existing.Type == records.RecordTypeCNAME {
|
||||
return status.Errorf(status.InvalidArgument,
|
||||
"An A, AAAA, or CNAME record with name %s already exists", record.Name)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,573 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
const (
|
||||
testAccountID = "test-account-id"
|
||||
testUserID = "test-user-id"
|
||||
testRecordID = "test-record-id"
|
||||
testGroupID = "test-group-id"
|
||||
)
|
||||
|
||||
func setupTest(t *testing.T) (*managerImpl, store.Store, *zones.Zone, *mock_server.MockAccountManager, *permissions.MockManager, *gomock.Controller, func()) {
|
||||
t.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
testStore, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||
require.NoError(t, err)
|
||||
|
||||
err = testStore.SaveAccount(ctx, &types.Account{
|
||||
Id: testAccountID,
|
||||
Groups: map[string]*types.Group{
|
||||
testGroupID: {
|
||||
ID: testGroupID,
|
||||
Name: "Test Group",
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
zone := zones.NewZone(testAccountID, "Test Zone", "example.com", true, true, []string{testGroupID})
|
||||
err = testStore.CreateZone(ctx, zone)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mockAccountManager := &mock_server.MockAccountManager{}
|
||||
mockPermissionsManager := permissions.NewMockManager(ctrl)
|
||||
|
||||
manager := &managerImpl{
|
||||
store: testStore,
|
||||
accountManager: mockAccountManager,
|
||||
permissionsManager: mockPermissionsManager,
|
||||
}
|
||||
|
||||
return manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup
|
||||
}
|
||||
|
||||
func TestManagerImpl_GetAllRecords(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
record1 := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||
err := testStore.CreateDNSRecord(ctx, record1)
|
||||
require.NoError(t, err)
|
||||
|
||||
record2 := records.NewRecord(testAccountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.2", 300)
|
||||
err = testStore.CreateDNSRecord(ctx, record2)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, result, 2)
|
||||
assert.Equal(t, record1.ID, result[0].ID)
|
||||
assert.Equal(t, record2.ID, result[1].ID)
|
||||
})
|
||||
|
||||
t.Run("permission denied", func(t *testing.T) {
|
||||
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(false, nil)
|
||||
|
||||
result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||
})
|
||||
|
||||
t.Run("permission validation error", func(t *testing.T) {
|
||||
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(false, status.Errorf(status.Internal, "permission check failed"))
|
||||
|
||||
result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
}
|
||||
|
||||
func TestManagerImpl_GetRecord(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
record := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||
err := testStore.CreateDNSRecord(ctx, record)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.GetRecord(ctx, testAccountID, testUserID, zone.ID, record.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, record.ID, result.ID)
|
||||
assert.Equal(t, record.Name, result.Name)
|
||||
assert.Equal(t, record.Type, result.Type)
|
||||
assert.Equal(t, record.Content, result.Content)
|
||||
assert.Equal(t, record.TTL, result.TTL)
|
||||
})
|
||||
|
||||
t.Run("permission denied", func(t *testing.T) {
|
||||
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(false, nil)
|
||||
|
||||
result, err := manager.GetRecord(ctx, testAccountID, testUserID, zone.ID, testRecordID)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||
})
|
||||
}
|
||||
|
||||
func TestManagerImpl_CreateRecord(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("success - A record", func(t *testing.T) {
|
||||
manager, _, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
inputRecord := &records.Record{
|
||||
Name: "api.example.com",
|
||||
Type: records.RecordTypeA,
|
||||
Content: "192.168.1.1",
|
||||
TTL: 300,
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
assert.Equal(t, testUserID, initiatorID)
|
||||
assert.Equal(t, testAccountID, accountID)
|
||||
assert.Equal(t, activity.DNSRecordCreated, activityID)
|
||||
}
|
||||
|
||||
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.NotEmpty(t, result.ID)
|
||||
assert.Equal(t, testAccountID, result.AccountID)
|
||||
assert.Equal(t, zone.ID, result.ZoneID)
|
||||
assert.Equal(t, inputRecord.Name, result.Name)
|
||||
assert.Equal(t, inputRecord.Type, result.Type)
|
||||
assert.Equal(t, inputRecord.Content, result.Content)
|
||||
assert.Equal(t, inputRecord.TTL, result.TTL)
|
||||
})
|
||||
|
||||
t.Run("success - AAAA record", func(t *testing.T) {
|
||||
manager, _, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
inputRecord := &records.Record{
|
||||
Name: "ipv6.example.com",
|
||||
Type: records.RecordTypeAAAA,
|
||||
Content: "2001:db8::1",
|
||||
TTL: 600,
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
assert.Equal(t, testUserID, initiatorID)
|
||||
assert.Equal(t, testAccountID, accountID)
|
||||
assert.Equal(t, activity.DNSRecordCreated, activityID)
|
||||
}
|
||||
|
||||
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.Equal(t, inputRecord.Type, result.Type)
|
||||
assert.Equal(t, inputRecord.Content, result.Content)
|
||||
})
|
||||
|
||||
t.Run("success - CNAME record", func(t *testing.T) {
|
||||
manager, _, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
inputRecord := &records.Record{
|
||||
Name: "www.example.com",
|
||||
Type: records.RecordTypeCNAME,
|
||||
Content: "example.com",
|
||||
TTL: 300,
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
assert.Equal(t, testUserID, initiatorID)
|
||||
assert.Equal(t, testAccountID, accountID)
|
||||
assert.Equal(t, activity.DNSRecordCreated, activityID)
|
||||
}
|
||||
|
||||
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.Equal(t, inputRecord.Type, result.Type)
|
||||
assert.Equal(t, inputRecord.Content, result.Content)
|
||||
})
|
||||
|
||||
t.Run("permission denied", func(t *testing.T) {
|
||||
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
inputRecord := &records.Record{
|
||||
Name: "api.example.com",
|
||||
Type: records.RecordTypeA,
|
||||
Content: "192.168.1.1",
|
||||
TTL: 300,
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(false, nil)
|
||||
|
||||
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||
})
|
||||
|
||||
t.Run("record name not in zone", func(t *testing.T) {
|
||||
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
inputRecord := &records.Record{
|
||||
Name: "api.different.com",
|
||||
Type: records.RecordTypeA,
|
||||
Content: "192.168.1.1",
|
||||
TTL: 300,
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), "does not belong to zone")
|
||||
})
|
||||
|
||||
t.Run("duplicate record", func(t *testing.T) {
|
||||
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
existingRecord := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||
err := testStore.CreateDNSRecord(ctx, existingRecord)
|
||||
require.NoError(t, err)
|
||||
|
||||
inputRecord := &records.Record{
|
||||
Name: "api.example.com",
|
||||
Type: records.RecordTypeA,
|
||||
Content: "192.168.1.1",
|
||||
TTL: 300,
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), "identical record already exists")
|
||||
})
|
||||
|
||||
t.Run("CNAME conflict with existing A record", func(t *testing.T) {
|
||||
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
existingRecord := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||
err := testStore.CreateDNSRecord(ctx, existingRecord)
|
||||
require.NoError(t, err)
|
||||
|
||||
inputRecord := &records.Record{
|
||||
Name: "api.example.com",
|
||||
Type: records.RecordTypeCNAME,
|
||||
Content: "example.com",
|
||||
TTL: 300,
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), "already exists")
|
||||
})
|
||||
}
|
||||
|
||||
func TestManagerImpl_UpdateRecord(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
existingRecord := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||
err := testStore.CreateDNSRecord(ctx, existingRecord)
|
||||
require.NoError(t, err)
|
||||
|
||||
updatedRecord := &records.Record{
|
||||
ID: existingRecord.ID,
|
||||
Name: "api.example.com",
|
||||
Type: records.RecordTypeA,
|
||||
Content: "192.168.1.100", // Changed IP
|
||||
TTL: 600, // Changed TTL
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||
Return(true, nil)
|
||||
|
||||
storeEventCalled := false
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
storeEventCalled = true
|
||||
assert.Equal(t, testUserID, initiatorID)
|
||||
assert.Equal(t, existingRecord.ID, targetID)
|
||||
assert.Equal(t, testAccountID, accountID)
|
||||
assert.Equal(t, activity.DNSRecordUpdated, activityID)
|
||||
}
|
||||
|
||||
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.Equal(t, updatedRecord.Content, result.Content)
|
||||
assert.Equal(t, updatedRecord.TTL, result.TTL)
|
||||
assert.True(t, storeEventCalled, "StoreEvent should have been called")
|
||||
})
|
||||
|
||||
t.Run("update only TTL - no validation", func(t *testing.T) {
|
||||
manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
existingRecord := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||
err := testStore.CreateDNSRecord(ctx, existingRecord)
|
||||
require.NoError(t, err)
|
||||
|
||||
updatedRecord := &records.Record{
|
||||
ID: existingRecord.ID,
|
||||
Name: existingRecord.Name,
|
||||
Type: existingRecord.Type,
|
||||
Content: existingRecord.Content,
|
||||
TTL: 600, // Only TTL changed
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||
Return(true, nil)
|
||||
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
// Event should be stored
|
||||
}
|
||||
|
||||
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.Equal(t, 600, result.TTL)
|
||||
})
|
||||
|
||||
t.Run("permission denied", func(t *testing.T) {
|
||||
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
updatedRecord := &records.Record{
|
||||
ID: testRecordID,
|
||||
Name: "api.example.com",
|
||||
Type: records.RecordTypeA,
|
||||
Content: "192.168.1.100",
|
||||
TTL: 600,
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||
Return(false, nil)
|
||||
|
||||
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||
})
|
||||
|
||||
t.Run("record not found", func(t *testing.T) {
|
||||
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
updatedRecord := &records.Record{
|
||||
ID: "non-existent-record",
|
||||
Name: "api.example.com",
|
||||
Type: records.RecordTypeA,
|
||||
Content: "192.168.1.100",
|
||||
TTL: 600,
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
|
||||
t.Run("update creates duplicate", func(t *testing.T) {
|
||||
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
record1 := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||
err := testStore.CreateDNSRecord(ctx, record1)
|
||||
require.NoError(t, err)
|
||||
|
||||
record2 := records.NewRecord(testAccountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.2", 300)
|
||||
err = testStore.CreateDNSRecord(ctx, record2)
|
||||
require.NoError(t, err)
|
||||
|
||||
updatedRecord := &records.Record{
|
||||
ID: record2.ID,
|
||||
Name: "api.example.com",
|
||||
Type: records.RecordTypeA,
|
||||
Content: "192.168.1.1",
|
||||
TTL: 300,
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), "identical record already exists")
|
||||
})
|
||||
}
|
||||
|
||||
func TestManagerImpl_DeleteRecord(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
record := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||
err := testStore.CreateDNSRecord(ctx, record)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
|
||||
Return(true, nil)
|
||||
|
||||
storeEventCalled := false
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
storeEventCalled = true
|
||||
assert.Equal(t, testUserID, initiatorID)
|
||||
assert.Equal(t, record.ID, targetID)
|
||||
assert.Equal(t, testAccountID, accountID)
|
||||
assert.Equal(t, activity.DNSRecordDeleted, activityID)
|
||||
}
|
||||
|
||||
err = manager.DeleteRecord(ctx, testAccountID, testUserID, zone.ID, record.ID)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, storeEventCalled, "StoreEvent should have been called")
|
||||
|
||||
_, err = testStore.GetDNSRecordByID(ctx, store.LockingStrengthNone, testAccountID, zone.ID, record.ID)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("permission denied", func(t *testing.T) {
|
||||
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
|
||||
Return(false, nil)
|
||||
|
||||
err := manager.DeleteRecord(ctx, testAccountID, testUserID, zone.ID, testRecordID)
|
||||
require.Error(t, err)
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||
})
|
||||
|
||||
t.Run("record not found", func(t *testing.T) {
|
||||
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
|
||||
Return(true, nil)
|
||||
|
||||
err := manager.DeleteRecord(ctx, testAccountID, testUserID, zone.ID, "non-existent-record")
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
129
management/internals/modules/zones/records/record.go
Normal file
129
management/internals/modules/zones/records/record.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package records
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
type RecordType string
|
||||
|
||||
const (
|
||||
RecordTypeA RecordType = "A"
|
||||
RecordTypeAAAA RecordType = "AAAA"
|
||||
RecordTypeCNAME RecordType = "CNAME"
|
||||
)
|
||||
|
||||
type Record struct {
|
||||
AccountID string `gorm:"index"`
|
||||
ZoneID string `gorm:"index"`
|
||||
ID string `gorm:"primaryKey"`
|
||||
Name string
|
||||
Type RecordType
|
||||
Content string
|
||||
TTL int
|
||||
}
|
||||
|
||||
func NewRecord(accountID, zoneID, name string, recordType RecordType, content string, ttl int) *Record {
|
||||
return &Record{
|
||||
ID: xid.New().String(),
|
||||
AccountID: accountID,
|
||||
ZoneID: zoneID,
|
||||
Name: name,
|
||||
Type: recordType,
|
||||
Content: content,
|
||||
TTL: ttl,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Record) ToAPIResponse() *api.DNSRecord {
|
||||
recordType := api.DNSRecordType(r.Type)
|
||||
return &api.DNSRecord{
|
||||
Id: r.ID,
|
||||
Name: r.Name,
|
||||
Type: recordType,
|
||||
Content: r.Content,
|
||||
Ttl: r.TTL,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Record) FromAPIRequest(req *api.DNSRecordRequest) {
|
||||
r.Name = req.Name
|
||||
r.Type = RecordType(req.Type)
|
||||
r.Content = req.Content
|
||||
r.TTL = req.Ttl
|
||||
}
|
||||
|
||||
func (r *Record) Validate() error {
|
||||
if r.Name == "" {
|
||||
return errors.New("record name is required")
|
||||
}
|
||||
|
||||
if !util.IsValidDomain(r.Name) {
|
||||
return errors.New("invalid record name format")
|
||||
}
|
||||
|
||||
if r.Type == "" {
|
||||
return errors.New("record type is required")
|
||||
}
|
||||
|
||||
switch r.Type {
|
||||
case RecordTypeA:
|
||||
if err := validateIPv4(r.Content); err != nil {
|
||||
return err
|
||||
}
|
||||
case RecordTypeAAAA:
|
||||
if err := validateIPv6(r.Content); err != nil {
|
||||
return err
|
||||
}
|
||||
case RecordTypeCNAME:
|
||||
if !util.IsValidDomain(r.Content) {
|
||||
return errors.New("invalid CNAME record format")
|
||||
}
|
||||
default:
|
||||
return errors.New("invalid record type, must be A, AAAA, or CNAME")
|
||||
}
|
||||
|
||||
if r.TTL < 0 {
|
||||
return errors.New("TTL cannot be negative")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Record) EventMeta(zoneID, zoneName string) map[string]any {
|
||||
return map[string]any{
|
||||
"name": r.Name,
|
||||
"type": string(r.Type),
|
||||
"content": r.Content,
|
||||
"ttl": r.TTL,
|
||||
"zone_id": zoneID,
|
||||
"zone_name": zoneName,
|
||||
}
|
||||
}
|
||||
|
||||
func validateIPv4(content string) error {
|
||||
if content == "" {
|
||||
return errors.New("A record is required") //nolint:staticcheck
|
||||
}
|
||||
ip := net.ParseIP(content)
|
||||
if ip == nil || ip.To4() == nil {
|
||||
return errors.New("A record must be a valid IPv4 address") //nolint:staticcheck
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateIPv6(content string) error {
|
||||
if content == "" {
|
||||
return errors.New("AAAA record is required")
|
||||
}
|
||||
ip := net.ParseIP(content)
|
||||
if ip == nil || ip.To4() != nil {
|
||||
return errors.New("AAAA record must be a valid IPv6 address")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
89
management/internals/modules/zones/zone.go
Normal file
89
management/internals/modules/zones/zone.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package zones
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
type Zone struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
AccountID string `gorm:"index"`
|
||||
Name string
|
||||
Domain string
|
||||
Enabled bool
|
||||
EnableSearchDomain bool
|
||||
DistributionGroups []string `gorm:"serializer:json"`
|
||||
Records []*records.Record `gorm:"foreignKey:ZoneID;references:ID"`
|
||||
}
|
||||
|
||||
func NewZone(accountID, name, domain string, enabled, enableSearchDomain bool, distributionGroups []string) *Zone {
|
||||
return &Zone{
|
||||
ID: xid.New().String(),
|
||||
AccountID: accountID,
|
||||
Name: name,
|
||||
Domain: domain,
|
||||
Enabled: enabled,
|
||||
EnableSearchDomain: enableSearchDomain,
|
||||
DistributionGroups: distributionGroups,
|
||||
}
|
||||
}
|
||||
|
||||
func (z *Zone) ToAPIResponse() *api.Zone {
|
||||
apiRecords := make([]api.DNSRecord, 0, len(z.Records))
|
||||
for _, record := range z.Records {
|
||||
if apiRecord := record.ToAPIResponse(); apiRecord != nil {
|
||||
apiRecords = append(apiRecords, *apiRecord)
|
||||
}
|
||||
}
|
||||
|
||||
return &api.Zone{
|
||||
DistributionGroups: z.DistributionGroups,
|
||||
Domain: z.Domain,
|
||||
EnableSearchDomain: z.EnableSearchDomain,
|
||||
Enabled: z.Enabled,
|
||||
Id: z.ID,
|
||||
Name: z.Name,
|
||||
Records: apiRecords,
|
||||
}
|
||||
}
|
||||
|
||||
func (z *Zone) FromAPIRequest(req *api.ZoneRequest) {
|
||||
z.Name = req.Name
|
||||
z.Domain = req.Domain
|
||||
z.EnableSearchDomain = req.EnableSearchDomain
|
||||
z.DistributionGroups = req.DistributionGroups
|
||||
|
||||
enabled := true
|
||||
if req.Enabled != nil {
|
||||
enabled = *req.Enabled
|
||||
}
|
||||
z.Enabled = enabled
|
||||
}
|
||||
|
||||
func (z *Zone) Validate() error {
|
||||
if z.Name == "" {
|
||||
return errors.New("zone name is required")
|
||||
}
|
||||
if len(z.Name) > 255 {
|
||||
return errors.New("zone name exceeds maximum length of 255 characters")
|
||||
}
|
||||
|
||||
if !util.IsValidDomain(z.Domain) {
|
||||
return errors.New("invalid zone domain format")
|
||||
}
|
||||
|
||||
if len(z.DistributionGroups) == 0 {
|
||||
return errors.New("at least one distribution group is required")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (z *Zone) EventMeta() map[string]any {
|
||||
return map[string]any{"name": z.Name, "domain": z.Domain}
|
||||
}
|
||||
@@ -92,7 +92,7 @@ func (s *BaseServer) EventStore() activity.Store {
|
||||
|
||||
func (s *BaseServer) APIHandler() http.Handler {
|
||||
return Create(s, func() http.Handler {
|
||||
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.NetworkMapController(), s.IdpManager())
|
||||
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager())
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create API handler: %v", err)
|
||||
}
|
||||
|
||||
@@ -68,7 +68,8 @@ func (s *BaseServer) AuthManager() auth.Manager {
|
||||
if len(audiences) > 0 {
|
||||
audience = audiences[0] // Use the first client ID as the primary audience
|
||||
}
|
||||
keysLocation = oauthProvider.GetKeysLocation()
|
||||
// Use localhost keys location for internal validation (management has embedded Dex)
|
||||
keysLocation = oauthProvider.GetLocalKeysLocation()
|
||||
signingKeyRefreshEnabled = true
|
||||
issuer = oauthProvider.GetIssuer()
|
||||
userIDClaim = oauthProvider.GetUserIDClaim()
|
||||
|
||||
@@ -8,6 +8,10 @@ import (
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
recordsManager "github.com/netbirdio/netbird/management/internals/modules/zones/records/manager"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
@@ -158,3 +162,15 @@ func (s *BaseServer) NetworksManager() networks.Manager {
|
||||
return networks.NewManager(s.Store(), s.PermissionsManager(), s.ResourcesManager(), s.RoutesManager(), s.AccountManager())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) ZonesManager() zones.Manager {
|
||||
return Create(s, func() zones.Manager {
|
||||
return zonesManager.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.DNSDomain())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) RecordsManager() records.Manager {
|
||||
return Create(s, func() records.Manager {
|
||||
return recordsManager.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager())
|
||||
})
|
||||
}
|
||||
|
||||
@@ -374,8 +374,10 @@ func shouldUsePortRange(rule *proto.FirewallRule) bool {
|
||||
// Helper function to convert nbdns.CustomZone to proto.CustomZone
|
||||
func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone {
|
||||
protoZone := &proto.CustomZone{
|
||||
Domain: zone.Domain,
|
||||
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
|
||||
Domain: zone.Domain,
|
||||
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
|
||||
SearchDomainDisabled: zone.SearchDomainDisabled,
|
||||
NonAuthoritative: zone.NonAuthoritative,
|
||||
}
|
||||
for _, record := range zone.Records {
|
||||
protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{
|
||||
@@ -432,9 +434,16 @@ func buildJWTConfig(config *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfi
|
||||
if config.CLIAuthAudience != "" {
|
||||
audience = config.CLIAuthAudience
|
||||
}
|
||||
|
||||
audiences := []string{config.AuthAudience}
|
||||
if config.CLIAuthAudience != "" && config.CLIAuthAudience != config.AuthAudience {
|
||||
audiences = append(audiences, config.CLIAuthAudience)
|
||||
}
|
||||
|
||||
return &proto.JWTConfig{
|
||||
Issuer: issuer,
|
||||
Audience: audience,
|
||||
Audiences: audiences,
|
||||
KeysLocation: keysLocation,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,9 +6,12 @@ import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
)
|
||||
|
||||
func TestToProtocolDNSConfigWithCache(t *testing.T) {
|
||||
@@ -148,3 +151,51 @@ func generateTestData(size int) nbdns.Config {
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
func TestBuildJWTConfig_Audiences(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
authAudience string
|
||||
cliAuthAudience string
|
||||
expectedAudiences []string
|
||||
expectedAudience string
|
||||
}{
|
||||
{
|
||||
name: "only_auth_audience",
|
||||
authAudience: "dashboard-aud",
|
||||
cliAuthAudience: "",
|
||||
expectedAudiences: []string{"dashboard-aud"},
|
||||
expectedAudience: "dashboard-aud",
|
||||
},
|
||||
{
|
||||
name: "both_audiences_different",
|
||||
authAudience: "dashboard-aud",
|
||||
cliAuthAudience: "cli-aud",
|
||||
expectedAudiences: []string{"dashboard-aud", "cli-aud"},
|
||||
expectedAudience: "cli-aud",
|
||||
},
|
||||
{
|
||||
name: "both_audiences_same",
|
||||
authAudience: "same-aud",
|
||||
cliAuthAudience: "same-aud",
|
||||
expectedAudiences: []string{"same-aud"},
|
||||
expectedAudience: "same-aud",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
config := &nbconfig.HttpServerConfig{
|
||||
AuthIssuer: "https://issuer.example.com",
|
||||
AuthAudience: tc.authAudience,
|
||||
CLIAuthAudience: tc.cliAuthAudience,
|
||||
}
|
||||
|
||||
result := buildJWTConfig(config, nil)
|
||||
|
||||
assert.NotNil(t, result)
|
||||
assert.Equal(t, tc.expectedAudiences, result.Audiences, "audiences should match expected")
|
||||
assert.Equal(t, tc.expectedAudience, result.Audience, "audience should match expected")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -470,6 +470,16 @@ func extractPeerMeta(ctx context.Context, meta *proto.PeerSystemMeta) nbpeer.Pee
|
||||
})
|
||||
}
|
||||
|
||||
diskEncryptionVolumes := make([]nbpeer.DiskEncryptionVolume, 0)
|
||||
if meta.GetDiskEncryption() != nil {
|
||||
for _, vol := range meta.GetDiskEncryption().GetVolumes() {
|
||||
diskEncryptionVolumes = append(diskEncryptionVolumes, nbpeer.DiskEncryptionVolume{
|
||||
Path: vol.GetPath(),
|
||||
Encrypted: vol.GetEncrypted(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return nbpeer.PeerSystemMeta{
|
||||
Hostname: meta.GetHostname(),
|
||||
GoOS: meta.GetGoOS(),
|
||||
@@ -501,6 +511,9 @@ func extractPeerMeta(ctx context.Context, meta *proto.PeerSystemMeta) nbpeer.Pee
|
||||
LazyConnectionEnabled: meta.GetFlags().GetLazyConnectionEnabled(),
|
||||
},
|
||||
Files: files,
|
||||
DiskEncryption: nbpeer.DiskEncryptionInfo{
|
||||
Volumes: diskEncryptionVolumes,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -295,7 +295,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
return err
|
||||
}
|
||||
|
||||
if err = am.validateSettingsUpdate(ctx, newSettings, oldSettings, userID, accountID); err != nil {
|
||||
if err = am.validateSettingsUpdate(ctx, transaction, newSettings, oldSettings, userID, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -388,7 +388,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
return newSettings, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, newSettings, oldSettings *types.Settings, userID, accountID string) error {
|
||||
func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, transaction store.Store, newSettings, oldSettings *types.Settings, userID, accountID string) error {
|
||||
halfYearLimit := 180 * 24 * time.Hour
|
||||
if newSettings.PeerLoginExpiration > halfYearLimit {
|
||||
return status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days")
|
||||
@@ -402,6 +402,18 @@ func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, new
|
||||
return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain)
|
||||
}
|
||||
|
||||
if newSettings.DNSDomain != oldSettings.DNSDomain && newSettings.DNSDomain != "" {
|
||||
existingZone, err := transaction.GetZoneByDomain(ctx, accountID, newSettings.DNSDomain)
|
||||
if err != nil {
|
||||
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
|
||||
return fmt.Errorf("failed to check existing zone: %w", err)
|
||||
}
|
||||
}
|
||||
if existingZone != nil {
|
||||
return status.Errorf(status.InvalidArgument, "peer DNS domain %s conflicts with existing custom DNS zone", newSettings.DNSDomain)
|
||||
}
|
||||
}
|
||||
|
||||
return am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, oldSettings.Extra, userID, accountID)
|
||||
}
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
nbAccount "github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
@@ -397,7 +398,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
|
||||
}
|
||||
|
||||
customZone := account.GetPeersCustomZone(context.Background(), "netbird.io")
|
||||
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
|
||||
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, nil, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
|
||||
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
|
||||
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
|
||||
}
|
||||
@@ -1676,7 +1677,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
routes := account.GetRoutesToSync(context.Background(), "peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}})
|
||||
routes := account.GetRoutesToSync(context.Background(), "peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}}, account.GetPeerGroups("peer-2"))
|
||||
|
||||
assert.Len(t, routes, 2)
|
||||
routeIDs := make(map[route.ID]struct{}, 2)
|
||||
@@ -1686,7 +1687,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) {
|
||||
assert.Contains(t, routeIDs, route.ID("route-2"))
|
||||
assert.Contains(t, routeIDs, route.ID("route-3"))
|
||||
|
||||
emptyRoutes := account.GetRoutesToSync(context.Background(), "peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}})
|
||||
emptyRoutes := account.GetRoutesToSync(context.Background(), "peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}}, account.GetPeerGroups("peer-3"))
|
||||
|
||||
assert.Len(t, emptyRoutes, 0)
|
||||
}
|
||||
@@ -2095,6 +2096,35 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerApproval(t *testing.T)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_UpdateAccountSettings_DNSDomainConflict(t *testing.T) {
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
ctx := context.Background()
|
||||
err = manager.Store.CreateZone(ctx, &zones.Zone{
|
||||
ID: "test-zone-id",
|
||||
AccountID: accountID,
|
||||
Name: "Test Zone",
|
||||
Domain: "custom.example.com",
|
||||
Enabled: true,
|
||||
EnableSearchDomain: false,
|
||||
DistributionGroups: []string{},
|
||||
})
|
||||
require.NoError(t, err, "unable to create custom DNS zone")
|
||||
|
||||
_, err = manager.UpdateAccountSettings(ctx, accountID, userID, &types.Settings{
|
||||
DNSDomain: "custom.example.com",
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: false,
|
||||
Extra: &types.ExtraSettings{},
|
||||
})
|
||||
require.Error(t, err, "expecting to fail when DNS domain conflicts with custom zone")
|
||||
assert.Contains(t, err.Error(), "conflicts with existing custom DNS zone")
|
||||
}
|
||||
|
||||
func TestAccount_GetExpiredPeers(t *testing.T) {
|
||||
type test struct {
|
||||
name string
|
||||
|
||||
@@ -187,6 +187,14 @@ const (
|
||||
IdentityProviderUpdated Activity = 94
|
||||
IdentityProviderDeleted Activity = 95
|
||||
|
||||
DNSZoneCreated Activity = 96
|
||||
DNSZoneUpdated Activity = 97
|
||||
DNSZoneDeleted Activity = 98
|
||||
|
||||
DNSRecordCreated Activity = 99
|
||||
DNSRecordUpdated Activity = 100
|
||||
DNSRecordDeleted Activity = 101
|
||||
|
||||
AccountDeleted Activity = 99999
|
||||
)
|
||||
|
||||
@@ -303,6 +311,14 @@ var activityMap = map[Activity]Code{
|
||||
IdentityProviderCreated: {"Identity provider created", "identityprovider.create"},
|
||||
IdentityProviderUpdated: {"Identity provider updated", "identityprovider.update"},
|
||||
IdentityProviderDeleted: {"Identity provider deleted", "identityprovider.delete"},
|
||||
|
||||
DNSZoneCreated: {"DNS zone created", "dns.zone.create"},
|
||||
DNSZoneUpdated: {"DNS zone updated", "dns.zone.update"},
|
||||
DNSZoneDeleted: {"DNS zone deleted", "dns.zone.delete"},
|
||||
|
||||
DNSRecordCreated: {"DNS zone record created", "dns.zone.record.create"},
|
||||
DNSRecordUpdated: {"DNS zone record updated", "dns.zone.record.update"},
|
||||
DNSRecordDeleted: {"DNS zone record deleted", "dns.zone.record.delete"},
|
||||
}
|
||||
|
||||
// StringCode returns a string code of the activity
|
||||
|
||||
25
management/server/cache/idp.go
vendored
25
management/server/cache/idp.go
vendored
@@ -26,6 +26,8 @@ type UserDataCache interface {
|
||||
Get(ctx context.Context, key string) (*idp.UserData, error)
|
||||
Set(ctx context.Context, key string, value *idp.UserData, expiration time.Duration) error
|
||||
Delete(ctx context.Context, key string) error
|
||||
GetUsers(ctx context.Context, key string) ([]*idp.UserData, error)
|
||||
SetUsers(ctx context.Context, key string, users []*idp.UserData, expiration time.Duration) error
|
||||
}
|
||||
|
||||
// UserDataCacheImpl is a struct that implements the UserDataCache interface.
|
||||
@@ -51,6 +53,29 @@ func (u *UserDataCacheImpl) Delete(ctx context.Context, key string) error {
|
||||
return u.cache.Delete(ctx, key)
|
||||
}
|
||||
|
||||
func (u *UserDataCacheImpl) GetUsers(ctx context.Context, key string) ([]*idp.UserData, error) {
|
||||
var users []*idp.UserData
|
||||
v, err := u.cache.Get(ctx, key, &users)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch v := v.(type) {
|
||||
case []*idp.UserData:
|
||||
return v, nil
|
||||
case *[]*idp.UserData:
|
||||
return *v, nil
|
||||
case []byte:
|
||||
return unmarshalUserData(v)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unexpected type: %T", v)
|
||||
}
|
||||
|
||||
func (u *UserDataCacheImpl) SetUsers(ctx context.Context, key string, users []*idp.UserData, expiration time.Duration) error {
|
||||
return u.cache.Set(ctx, key, users, store.WithExpiration(expiration))
|
||||
}
|
||||
|
||||
// NewUserDataCache creates a new UserDataCacheImpl object.
|
||||
func NewUserDataCache(store store.StoreInterface) *UserDataCacheImpl {
|
||||
simpleCache := cache.New[any](store)
|
||||
|
||||
@@ -15,7 +15,10 @@ import (
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
recordsManager "github.com/netbirdio/netbird/management/internals/modules/zones/records/manager"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
|
||||
@@ -56,7 +59,7 @@ const (
|
||||
)
|
||||
|
||||
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
||||
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager) (http.Handler, error) {
|
||||
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager) (http.Handler, error) {
|
||||
|
||||
// Register bypass paths for unauthenticated endpoints
|
||||
if err := bypass.AddBypassPath("/api/instance"); err != nil {
|
||||
@@ -138,6 +141,8 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
|
||||
dns.AddEndpoints(accountManager, router)
|
||||
events.AddEndpoints(accountManager, router)
|
||||
networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, accountManager, router)
|
||||
zonesManager.RegisterEndpoints(router, zManager)
|
||||
recordsManager.RegisterEndpoints(router, rManager)
|
||||
idp.AddEndpoints(accountManager, router)
|
||||
instance.AddEndpoints(instanceManager, router)
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
@@ -298,8 +299,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
dnsDomain := h.networkMapController.GetDNSDomain(account.Settings)
|
||||
|
||||
customZone := account.GetPeersCustomZone(r.Context(), dnsDomain)
|
||||
netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
|
||||
netMap := account.GetPeerNetworkMap(r.Context(), peerID, dns.CustomZone{}, nil, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
|
||||
}
|
||||
|
||||
@@ -178,7 +178,7 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
|
||||
m.patUsageTracker.IncrementUsage(token)
|
||||
}
|
||||
|
||||
if m.rateLimiter != nil {
|
||||
if m.rateLimiter != nil && !isTerraformRequest(r) {
|
||||
if !m.rateLimiter.Allow(token) {
|
||||
return r, status.Errorf(status.TooManyRequests, "too many requests")
|
||||
}
|
||||
@@ -214,6 +214,11 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
|
||||
return nbcontext.SetUserAuthInRequest(r, userAuth), nil
|
||||
}
|
||||
|
||||
func isTerraformRequest(r *http.Request) bool {
|
||||
ua := strings.ToLower(r.Header.Get("User-Agent"))
|
||||
return strings.Contains(ua, "terraform")
|
||||
}
|
||||
|
||||
// getTokenFromJWTRequest is a "TokenExtractor" that takes auth header parts and extracts
|
||||
// the JWT token from the Authorization header.
|
||||
func getTokenFromJWTRequest(authHeaderParts []string) (string, error) {
|
||||
|
||||
@@ -508,6 +508,103 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request after cleanup should be rate limited again")
|
||||
})
|
||||
|
||||
t.Run("Terraform User Agent Not Rate Limited", func(t *testing.T) {
|
||||
rateLimitConfig := &RateLimiterConfig{
|
||||
RequestsPerMinute: 1,
|
||||
Burst: 1,
|
||||
CleanupInterval: 5 * time.Minute,
|
||||
LimiterTTL: 10 * time.Minute,
|
||||
}
|
||||
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockAuth,
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
|
||||
return userAuth.AccountId, userAuth.UserId, nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) error {
|
||||
return nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
nil,
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// Test various Terraform user agent formats
|
||||
terraformUserAgents := []string{
|
||||
"Terraform/1.5.0",
|
||||
"terraform/1.0.0",
|
||||
"Terraform-Provider/2.0.0",
|
||||
"Mozilla/5.0 (compatible; Terraform/1.3.0)",
|
||||
}
|
||||
|
||||
for _, userAgent := range terraformUserAgents {
|
||||
t.Run("UserAgent: "+userAgent, func(t *testing.T) {
|
||||
successCount := 0
|
||||
for i := 0; i < 10; i++ {
|
||||
req := httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Token "+PAT)
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rec, req)
|
||||
if rec.Code == http.StatusOK {
|
||||
successCount++
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equal(t, 10, successCount, "All Terraform user agent requests should succeed (not rate limited)")
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Non-Terraform User Agent With PAT Is Rate Limited", func(t *testing.T) {
|
||||
rateLimitConfig := &RateLimiterConfig{
|
||||
RequestsPerMinute: 1,
|
||||
Burst: 1,
|
||||
CleanupInterval: 5 * time.Minute,
|
||||
LimiterTTL: 10 * time.Minute,
|
||||
}
|
||||
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockAuth,
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
|
||||
return userAuth.AccountId, userAuth.UserId, nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) error {
|
||||
return nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
nil,
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Token "+PAT)
|
||||
req.Header.Set("User-Agent", "curl/7.68.0")
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code, "First request should succeed")
|
||||
|
||||
req = httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Token "+PAT)
|
||||
req.Header.Set("User-Agent", "curl/7.68.0")
|
||||
rec = httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request should be rate limited")
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
||||
|
||||
@@ -10,6 +10,8 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
|
||||
recordsManager "github.com/netbirdio/netbird/management/internals/modules/zones/records/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
@@ -93,8 +95,10 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
||||
routersManagerMock := routers.NewManagerMock()
|
||||
groupsManagerMock := groups.NewManagerMock()
|
||||
peersManager := peers.NewManager(store, permissionsManager)
|
||||
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
|
||||
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
|
||||
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, networkMapController, nil)
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create API handler: %v", err)
|
||||
}
|
||||
|
||||
@@ -2,7 +2,13 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
"github.com/rs/xid"
|
||||
@@ -17,6 +23,69 @@ import (
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
// oidcProviderJSON represents the OpenID Connect discovery document
|
||||
type oidcProviderJSON struct {
|
||||
Issuer string `json:"issuer"`
|
||||
}
|
||||
|
||||
// validateOIDCIssuer validates the OIDC issuer by fetching the OpenID configuration
|
||||
// and verifying that the returned issuer matches the configured one.
|
||||
func validateOIDCIssuer(ctx context.Context, issuer string) error {
|
||||
wellKnown := strings.TrimSuffix(issuer, "/") + "/.well-known/openid-configuration"
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, wellKnown, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %v", types.ErrIdentityProviderIssuerUnreachable, err)
|
||||
}
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %v", types.ErrIdentityProviderIssuerUnreachable, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: unable to read response body: %v", types.ErrIdentityProviderIssuerUnreachable, err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("%w: %s: %s", types.ErrIdentityProviderIssuerUnreachable, resp.Status, body)
|
||||
}
|
||||
|
||||
var p oidcProviderJSON
|
||||
if err := json.Unmarshal(body, &p); err != nil {
|
||||
return fmt.Errorf("%w: failed to decode provider discovery object: %v", types.ErrIdentityProviderIssuerUnreachable, err)
|
||||
}
|
||||
|
||||
if p.Issuer != issuer {
|
||||
return fmt.Errorf("%w: expected %q got %q", types.ErrIdentityProviderIssuerMismatch, issuer, p.Issuer)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateIdentityProviderConfig validates the identity provider configuration including
|
||||
// basic validation and OIDC issuer verification.
|
||||
func validateIdentityProviderConfig(ctx context.Context, idpConfig *types.IdentityProvider) error {
|
||||
if err := idpConfig.Validate(); err != nil {
|
||||
return status.Errorf(status.InvalidArgument, "%s", err.Error())
|
||||
}
|
||||
|
||||
// Validate the issuer by calling the OIDC discovery endpoint
|
||||
if idpConfig.Issuer != "" {
|
||||
if err := validateOIDCIssuer(ctx, idpConfig.Issuer); err != nil {
|
||||
return status.Errorf(status.InvalidArgument, "%s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetIdentityProviders returns all identity providers for an account
|
||||
func (am *DefaultAccountManager) GetIdentityProviders(ctx context.Context, accountID, userID string) ([]*types.IdentityProvider, error) {
|
||||
ok, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.IdentityProviders, operations.Read)
|
||||
@@ -82,8 +151,8 @@ func (am *DefaultAccountManager) CreateIdentityProvider(ctx context.Context, acc
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
if err := idpConfig.Validate(); err != nil {
|
||||
return nil, status.Errorf(status.InvalidArgument, "%s", err.Error())
|
||||
if err := validateIdentityProviderConfig(ctx, idpConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
embeddedManager, ok := am.idpManager.(*idp.EmbeddedIdPManager)
|
||||
@@ -119,8 +188,8 @@ func (am *DefaultAccountManager) UpdateIdentityProvider(ctx context.Context, acc
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
if err := idpConfig.Validate(); err != nil {
|
||||
return nil, status.Errorf(status.InvalidArgument, "%s", err.Error())
|
||||
if err := validateIdentityProviderConfig(ctx, idpConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
embeddedManager, ok := am.idpManager.(*idp.EmbeddedIdPManager)
|
||||
|
||||
@@ -2,6 +2,10 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
@@ -200,3 +204,109 @@ func TestDefaultAccountManager_UpdateIdentityProvider_Validation(t *testing.T) {
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "name is required")
|
||||
}
|
||||
|
||||
func TestValidateOIDCIssuer(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupServer func() *httptest.Server
|
||||
expectedErr error
|
||||
expectedErrMsg string
|
||||
}{
|
||||
{
|
||||
name: "issuer mismatch",
|
||||
setupServer: func() *httptest.Server {
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
resp := oidcProviderJSON{Issuer: "https://different-issuer.com"}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
},
|
||||
expectedErr: types.ErrIdentityProviderIssuerMismatch,
|
||||
expectedErrMsg: "does not match",
|
||||
},
|
||||
{
|
||||
name: "server returns non-200 status",
|
||||
setupServer: func() *httptest.Server {
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
_, _ = w.Write([]byte("not found"))
|
||||
}))
|
||||
},
|
||||
expectedErr: types.ErrIdentityProviderIssuerUnreachable,
|
||||
expectedErrMsg: "404",
|
||||
},
|
||||
{
|
||||
name: "server returns invalid JSON",
|
||||
setupServer: func() *httptest.Server {
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte("invalid json"))
|
||||
}))
|
||||
},
|
||||
expectedErr: types.ErrIdentityProviderIssuerUnreachable,
|
||||
expectedErrMsg: "failed to decode",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server := tt.setupServer()
|
||||
defer server.Close()
|
||||
|
||||
err := validateOIDCIssuer(context.Background(), server.URL)
|
||||
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, tt.expectedErr), "expected error %v, got %v", tt.expectedErr, err)
|
||||
if tt.expectedErrMsg != "" {
|
||||
assert.Contains(t, err.Error(), tt.expectedErrMsg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateOIDCIssuer_Success(t *testing.T) {
|
||||
// Create a server that returns its own URL as the issuer
|
||||
var server *httptest.Server
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/.well-known/openid-configuration" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
resp := oidcProviderJSON{Issuer: server.URL}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
err := validateOIDCIssuer(context.Background(), server.URL)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestValidateOIDCIssuer_UnreachableServer(t *testing.T) {
|
||||
// Use a URL that will definitely fail to connect
|
||||
err := validateOIDCIssuer(context.Background(), "http://localhost:59999")
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, types.ErrIdentityProviderIssuerUnreachable))
|
||||
}
|
||||
|
||||
func TestValidateOIDCIssuer_TrailingSlash(t *testing.T) {
|
||||
// Test that trailing slashes are handled correctly
|
||||
var server *httptest.Server
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/.well-known/openid-configuration" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
// Return issuer without trailing slash
|
||||
resp := oidcProviderJSON{Issuer: server.URL}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Pass issuer with trailing slash
|
||||
err := validateOIDCIssuer(context.Background(), server.URL+"/")
|
||||
// This should fail because the issuer returned doesn't have trailing slash
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, types.ErrIdentityProviderIssuerMismatch))
|
||||
}
|
||||
|
||||
@@ -135,10 +135,11 @@ func NewAuth0Manager(config Auth0ClientConfig, appMetrics telemetry.AppMetrics)
|
||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
httpClient := &http.Client{
|
||||
Timeout: idpTimeout(),
|
||||
Transport: httpTransport,
|
||||
}
|
||||
|
||||
helper := JsonParser{}
|
||||
|
||||
if config.AuthIssuer == "" {
|
||||
|
||||
@@ -48,16 +48,15 @@ type AuthentikCredentials struct {
|
||||
}
|
||||
|
||||
// NewAuthentikManager creates a new instance of the AuthentikManager.
|
||||
func NewAuthentikManager(config AuthentikClientConfig,
|
||||
appMetrics telemetry.AppMetrics) (*AuthentikManager, error) {
|
||||
func NewAuthentikManager(config AuthentikClientConfig, appMetrics telemetry.AppMetrics) (*AuthentikManager, error) {
|
||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Timeout: idpTimeout(),
|
||||
Transport: httpTransport,
|
||||
}
|
||||
|
||||
|
||||
helper := JsonParser{}
|
||||
|
||||
if config.ClientID == "" {
|
||||
|
||||
@@ -57,10 +57,11 @@ func NewAzureManager(config AzureClientConfig, appMetrics telemetry.AppMetrics)
|
||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
httpClient := &http.Client{
|
||||
Timeout: idpTimeout(),
|
||||
Transport: httpTransport,
|
||||
}
|
||||
|
||||
helper := JsonParser{}
|
||||
|
||||
if config.ClientID == "" {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
"github.com/google/uuid"
|
||||
@@ -19,7 +20,7 @@ const (
|
||||
staticClientCLI = "netbird-cli"
|
||||
defaultCLIRedirectURL1 = "http://localhost:53000/"
|
||||
defaultCLIRedirectURL2 = "http://localhost:54000/"
|
||||
defaultScopes = "openid profile email offline_access"
|
||||
defaultScopes = "openid profile email"
|
||||
defaultUserIDClaim = "sub"
|
||||
)
|
||||
|
||||
@@ -27,8 +28,11 @@ const (
|
||||
type EmbeddedIdPConfig struct {
|
||||
// Enabled indicates whether the embedded IDP is enabled
|
||||
Enabled bool
|
||||
// Issuer is the OIDC issuer URL (e.g., "http://localhost:3002/oauth2")
|
||||
// Issuer is the OIDC issuer URL (e.g., "https://management.netbird.io/oauth2")
|
||||
Issuer string
|
||||
// LocalAddress is the management server's local listen address (e.g., ":8080" or "localhost:8080")
|
||||
// Used for internal JWT validation to avoid external network calls
|
||||
LocalAddress string
|
||||
// Storage configuration for the IdP database
|
||||
Storage EmbeddedStorageConfig
|
||||
// DashboardRedirectURIs are the OAuth2 redirect URIs for the dashboard client
|
||||
@@ -146,7 +150,12 @@ var _ OAuthConfigProvider = (*EmbeddedIdPManager)(nil)
|
||||
// OAuthConfigProvider defines the interface for OAuth configuration needed by auth flows.
|
||||
type OAuthConfigProvider interface {
|
||||
GetIssuer() string
|
||||
// GetKeysLocation returns the public JWKS endpoint URL (uses external issuer URL)
|
||||
GetKeysLocation() string
|
||||
// GetLocalKeysLocation returns the localhost JWKS endpoint URL for internal use.
|
||||
// Management server has embedded Dex and can validate tokens via localhost,
|
||||
// avoiding external network calls and DNS resolution issues during startup.
|
||||
GetLocalKeysLocation() string
|
||||
GetClientIDs() []string
|
||||
GetUserIDClaim() string
|
||||
GetTokenEndpoint() string
|
||||
@@ -500,6 +509,22 @@ func (m *EmbeddedIdPManager) GetKeysLocation() string {
|
||||
return m.provider.GetKeysLocation()
|
||||
}
|
||||
|
||||
// GetLocalKeysLocation returns the localhost JWKS endpoint URL for internal token validation.
|
||||
// Uses the LocalAddress from config (management server's listen address) since embedded Dex
|
||||
// is served by the management HTTP server, not a standalone Dex server.
|
||||
func (m *EmbeddedIdPManager) GetLocalKeysLocation() string {
|
||||
addr := m.config.LocalAddress
|
||||
if addr == "" {
|
||||
return ""
|
||||
}
|
||||
// Construct localhost URL from listen address
|
||||
// addr is in format ":port" or "host:port" or "localhost:port"
|
||||
if strings.HasPrefix(addr, ":") {
|
||||
return fmt.Sprintf("http://localhost%s/oauth2/keys", addr)
|
||||
}
|
||||
return fmt.Sprintf("http://%s/oauth2/keys", addr)
|
||||
}
|
||||
|
||||
// GetClientIDs returns the OAuth2 client IDs configured for this provider.
|
||||
func (m *EmbeddedIdPManager) GetClientIDs() []string {
|
||||
return []string{staticClientDashboard, staticClientCLI}
|
||||
|
||||
@@ -247,3 +247,61 @@ func TestEmbeddedIdPManager_UserIDFormat_MatchesJWT(t *testing.T) {
|
||||
t.Logf(" Raw UUID: %s", rawUserID)
|
||||
t.Logf(" Connector: %s", connectorID)
|
||||
}
|
||||
|
||||
func TestEmbeddedIdPManager_GetLocalKeysLocation(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
localAddress string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "localhost with port",
|
||||
localAddress: "localhost:8080",
|
||||
expected: "http://localhost:8080/oauth2/keys",
|
||||
},
|
||||
{
|
||||
name: "localhost with https port",
|
||||
localAddress: "localhost:443",
|
||||
expected: "http://localhost:443/oauth2/keys",
|
||||
},
|
||||
{
|
||||
name: "port only format",
|
||||
localAddress: ":8080",
|
||||
expected: "http://localhost:8080/oauth2/keys",
|
||||
},
|
||||
{
|
||||
name: "empty address",
|
||||
localAddress: "",
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := &EmbeddedIdPConfig{
|
||||
Enabled: true,
|
||||
Issuer: "http://localhost:5556/dex",
|
||||
LocalAddress: tt.localAddress,
|
||||
Storage: EmbeddedStorageConfig{
|
||||
Type: "sqlite3",
|
||||
Config: EmbeddedStorageTypeConfig{
|
||||
File: filepath.Join(tmpDir, "dex-"+tt.name+".db"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := NewEmbeddedIdPManager(ctx, config, nil)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = manager.Stop(ctx) }()
|
||||
|
||||
result := manager.GetLocalKeysLocation()
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/oauth2/google"
|
||||
@@ -49,9 +48,10 @@ func NewGoogleWorkspaceManager(ctx context.Context, config GoogleWorkspaceClient
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Timeout: idpTimeout(),
|
||||
Transport: httpTransport,
|
||||
}
|
||||
|
||||
helper := JsonParser{}
|
||||
|
||||
if config.CustomerID == "" {
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
v1 "github.com/TheJumpCloud/jcapi-go/v1"
|
||||
|
||||
@@ -46,9 +45,10 @@ func NewJumpCloudManager(config JumpCloudClientConfig, appMetrics telemetry.AppM
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Timeout: idpTimeout(),
|
||||
Transport: httpTransport,
|
||||
}
|
||||
|
||||
helper := JsonParser{}
|
||||
|
||||
if config.APIToken == "" {
|
||||
|
||||
@@ -63,9 +63,10 @@ func NewKeycloakManager(config KeycloakClientConfig, appMetrics telemetry.AppMet
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Timeout: idpTimeout(),
|
||||
Transport: httpTransport,
|
||||
}
|
||||
|
||||
helper := JsonParser{}
|
||||
|
||||
if config.ClientID == "" {
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/okta/okta-sdk-golang/v2/okta"
|
||||
"github.com/okta/okta-sdk-golang/v2/okta/query"
|
||||
@@ -45,7 +44,7 @@ func NewOktaManager(config OktaClientConfig, appMetrics telemetry.AppMetrics) (*
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Timeout: idpTimeout(),
|
||||
Transport: httpTransport,
|
||||
}
|
||||
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"net/url"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
@@ -88,9 +87,10 @@ func NewPocketIdManager(config PocketIdClientConfig, appMetrics telemetry.AppMet
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Timeout: idpTimeout(),
|
||||
Transport: httpTransport,
|
||||
}
|
||||
|
||||
helper := JsonParser{}
|
||||
|
||||
if config.ManagementEndpoint == "" {
|
||||
|
||||
@@ -4,7 +4,9 @@ import (
|
||||
"encoding/json"
|
||||
"math/rand"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -69,3 +71,24 @@ func baseURL(rawURL string) string {
|
||||
|
||||
return parsedURL.Scheme + "://" + parsedURL.Host
|
||||
}
|
||||
|
||||
const (
|
||||
// Provides the env variable name for use with idpTimeout function
|
||||
idpTimeoutEnv = "NB_IDP_TIMEOUT"
|
||||
// Sets the defaultTimeout to 10s.
|
||||
defaultTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
// idpTimeout returns a timeout value for the IDP
|
||||
func idpTimeout() time.Duration {
|
||||
timeoutStr, ok := os.LookupEnv(idpTimeoutEnv)
|
||||
if !ok || timeoutStr == "" {
|
||||
return defaultTimeout
|
||||
}
|
||||
|
||||
timeout, err := time.ParseDuration(timeoutStr)
|
||||
if err != nil {
|
||||
return defaultTimeout
|
||||
}
|
||||
return timeout
|
||||
}
|
||||
|
||||
@@ -164,9 +164,10 @@ func NewZitadelManager(config ZitadelClientConfig, appMetrics telemetry.AppMetri
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Timeout: idpTimeout(),
|
||||
Transport: httpTransport,
|
||||
}
|
||||
|
||||
helper := JsonParser{}
|
||||
|
||||
hasPAT := config.PAT != ""
|
||||
|
||||
@@ -95,6 +95,27 @@ type File struct {
|
||||
ProcessIsRunning bool
|
||||
}
|
||||
|
||||
// DiskEncryptionVolume represents encryption status of a volume.
|
||||
type DiskEncryptionVolume struct {
|
||||
Path string
|
||||
Encrypted bool
|
||||
}
|
||||
|
||||
// DiskEncryptionInfo holds encryption info for all volumes.
|
||||
type DiskEncryptionInfo struct {
|
||||
Volumes []DiskEncryptionVolume `gorm:"serializer:json"`
|
||||
}
|
||||
|
||||
// IsEncrypted returns true if the volume at path is encrypted.
|
||||
func (d DiskEncryptionInfo) IsEncrypted(path string) bool {
|
||||
for _, v := range d.Volumes {
|
||||
if v.Path == path {
|
||||
return v.Encrypted
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Flags defines a set of options to control feature behavior
|
||||
type Flags struct {
|
||||
RosenpassEnabled bool
|
||||
@@ -127,9 +148,10 @@ type PeerSystemMeta struct { //nolint:revive
|
||||
SystemSerialNumber string
|
||||
SystemProductName string
|
||||
SystemManufacturer string
|
||||
Environment Environment `gorm:"serializer:json"`
|
||||
Flags Flags `gorm:"serializer:json"`
|
||||
Files []File `gorm:"serializer:json"`
|
||||
Environment Environment `gorm:"serializer:json"`
|
||||
Flags Flags `gorm:"serializer:json"`
|
||||
Files []File `gorm:"serializer:json"`
|
||||
DiskEncryption DiskEncryptionInfo `gorm:"serializer:json"`
|
||||
}
|
||||
|
||||
func (p PeerSystemMeta) isEqual(other PeerSystemMeta) bool {
|
||||
@@ -159,6 +181,19 @@ func (p PeerSystemMeta) isEqual(other PeerSystemMeta) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
sort.Slice(p.DiskEncryption.Volumes, func(i, j int) bool {
|
||||
return p.DiskEncryption.Volumes[i].Path < p.DiskEncryption.Volumes[j].Path
|
||||
})
|
||||
sort.Slice(other.DiskEncryption.Volumes, func(i, j int) bool {
|
||||
return other.DiskEncryption.Volumes[i].Path < other.DiskEncryption.Volumes[j].Path
|
||||
})
|
||||
equalDiskEncryption := slices.EqualFunc(p.DiskEncryption.Volumes, other.DiskEncryption.Volumes, func(vol DiskEncryptionVolume, oVol DiskEncryptionVolume) bool {
|
||||
return vol.Path == oVol.Path && vol.Encrypted == oVol.Encrypted
|
||||
})
|
||||
if !equalDiskEncryption {
|
||||
return false
|
||||
}
|
||||
|
||||
return p.Hostname == other.Hostname &&
|
||||
p.GoOS == other.GoOS &&
|
||||
p.Kernel == other.Kernel &&
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user