mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-18 14:49:57 +00:00
Merge branch 'main' into fix/pkg-loss
This commit is contained in:
@@ -8,6 +8,7 @@ import (
|
||||
"os"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -20,6 +21,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
mgm "github.com/netbirdio/netbird/management/client"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
@@ -68,6 +70,10 @@ type ConfigInput struct {
|
||||
DisableFirewall *bool
|
||||
|
||||
BlockLANAccess *bool
|
||||
|
||||
DisableNotifications *bool
|
||||
|
||||
DNSLabels domain.List
|
||||
}
|
||||
|
||||
// Config Configuration type
|
||||
@@ -93,6 +99,10 @@ type Config struct {
|
||||
|
||||
BlockLANAccess bool
|
||||
|
||||
DisableNotifications bool
|
||||
|
||||
DNSLabels domain.List
|
||||
|
||||
// SSHKey is a private SSH key in a PEM format
|
||||
SSHKey string
|
||||
|
||||
@@ -469,6 +479,16 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.DisableNotifications != nil && *input.DisableNotifications != config.DisableNotifications {
|
||||
if *input.DisableNotifications {
|
||||
log.Infof("disabling notifications")
|
||||
} else {
|
||||
log.Infof("enabling notifications")
|
||||
}
|
||||
config.DisableNotifications = *input.DisableNotifications
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.ClientCertKeyPath != "" {
|
||||
config.ClientCertKeyPath = input.ClientCertKeyPath
|
||||
updated = true
|
||||
@@ -489,6 +509,14 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
}
|
||||
}
|
||||
|
||||
if input.DNSLabels != nil && !slices.Equal(config.DNSLabels, input.DNSLabels) {
|
||||
log.Infof("updating DNS labels [ %s ] (old value: [ %s ])",
|
||||
input.DNSLabels.SafeString(),
|
||||
config.DNSLabels.SafeString())
|
||||
config.DNSLabels = input.DNSLabels
|
||||
updated = true
|
||||
}
|
||||
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -478,7 +478,7 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
|
||||
config.DisableDNS,
|
||||
config.DisableFirewall,
|
||||
)
|
||||
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey)
|
||||
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey, config.DNSLabels)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
111
client/internal/dns.go
Normal file
111
client/internal/dns.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
|
||||
func createPTRRecord(aRecord nbdns.SimpleRecord, ipNet *net.IPNet) (nbdns.SimpleRecord, bool) {
|
||||
ip := net.ParseIP(aRecord.RData)
|
||||
if ip == nil || ip.To4() == nil {
|
||||
return nbdns.SimpleRecord{}, false
|
||||
}
|
||||
|
||||
if !ipNet.Contains(ip) {
|
||||
return nbdns.SimpleRecord{}, false
|
||||
}
|
||||
|
||||
ipOctets := strings.Split(ip.String(), ".")
|
||||
slices.Reverse(ipOctets)
|
||||
rdnsName := dns.Fqdn(strings.Join(ipOctets, ".") + ".in-addr.arpa")
|
||||
|
||||
return nbdns.SimpleRecord{
|
||||
Name: rdnsName,
|
||||
Type: int(dns.TypePTR),
|
||||
Class: aRecord.Class,
|
||||
TTL: aRecord.TTL,
|
||||
RData: dns.Fqdn(aRecord.Name),
|
||||
}, true
|
||||
}
|
||||
|
||||
// generateReverseZoneName creates the reverse DNS zone name for a given network
|
||||
func generateReverseZoneName(ipNet *net.IPNet) (string, error) {
|
||||
networkIP := ipNet.IP.Mask(ipNet.Mask)
|
||||
maskOnes, _ := ipNet.Mask.Size()
|
||||
|
||||
// round up to nearest byte
|
||||
octetsToUse := (maskOnes + 7) / 8
|
||||
|
||||
octets := strings.Split(networkIP.String(), ".")
|
||||
if octetsToUse > len(octets) {
|
||||
return "", fmt.Errorf("invalid network mask size for reverse DNS: %d", maskOnes)
|
||||
}
|
||||
|
||||
reverseOctets := make([]string, octetsToUse)
|
||||
for i := 0; i < octetsToUse; i++ {
|
||||
reverseOctets[octetsToUse-1-i] = octets[i]
|
||||
}
|
||||
|
||||
return dns.Fqdn(strings.Join(reverseOctets, ".") + ".in-addr.arpa"), nil
|
||||
}
|
||||
|
||||
// zoneExists checks if a zone with the given name already exists in the configuration
|
||||
func zoneExists(config *nbdns.Config, zoneName string) bool {
|
||||
for _, zone := range config.CustomZones {
|
||||
if zone.Domain == zoneName {
|
||||
log.Debugf("reverse DNS zone %s already exists", zoneName)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// collectPTRRecords gathers all PTR records for the given network from A records
|
||||
func collectPTRRecords(config *nbdns.Config, ipNet *net.IPNet) []nbdns.SimpleRecord {
|
||||
var records []nbdns.SimpleRecord
|
||||
|
||||
for _, zone := range config.CustomZones {
|
||||
for _, record := range zone.Records {
|
||||
if record.Type != int(dns.TypeA) {
|
||||
continue
|
||||
}
|
||||
|
||||
if ptrRecord, ok := createPTRRecord(record, ipNet); ok {
|
||||
records = append(records, ptrRecord)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return records
|
||||
}
|
||||
|
||||
// addReverseZone adds a reverse DNS zone to the configuration for the given network
|
||||
func addReverseZone(config *nbdns.Config, ipNet *net.IPNet) {
|
||||
zoneName, err := generateReverseZoneName(ipNet)
|
||||
if err != nil {
|
||||
log.Warn(err)
|
||||
return
|
||||
}
|
||||
|
||||
if zoneExists(config, zoneName) {
|
||||
log.Debugf("reverse DNS zone %s already exists", zoneName)
|
||||
return
|
||||
}
|
||||
|
||||
records := collectPTRRecords(config, ipNet)
|
||||
|
||||
reverseZone := nbdns.CustomZone{
|
||||
Domain: zoneName,
|
||||
Records: records,
|
||||
}
|
||||
|
||||
config.CustomZones = append(config.CustomZones, reverseZone)
|
||||
log.Debugf("added reverse DNS zone: %s with %d records", zoneName, len(records))
|
||||
}
|
||||
@@ -9,6 +9,11 @@ import (
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
|
||||
const (
|
||||
ipv4ReverseZone = ".in-addr.arpa"
|
||||
ipv6ReverseZone = ".ip6.arpa"
|
||||
)
|
||||
|
||||
type hostManager interface {
|
||||
applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error
|
||||
restoreHostDNS() error
|
||||
@@ -94,9 +99,10 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) HostD
|
||||
}
|
||||
|
||||
for _, customZone := range dnsConfig.CustomZones {
|
||||
matchOnly := strings.HasSuffix(customZone.Domain, ipv4ReverseZone) || strings.HasSuffix(customZone.Domain, ipv6ReverseZone)
|
||||
config.Domains = append(config.Domains, DomainConfig{
|
||||
Domain: strings.TrimSuffix(customZone.Domain, "."),
|
||||
MatchOnly: false,
|
||||
MatchOnly: matchOnly,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -131,11 +131,30 @@ func (r *registryConfigurator) addDNSSetupForAll(ip string) error {
|
||||
func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip string) error {
|
||||
// if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored
|
||||
// see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745
|
||||
policyPath := dnsPolicyConfigMatchPath
|
||||
if r.gpo {
|
||||
policyPath = gpoDnsPolicyConfigMatchPath
|
||||
if err := r.configureDNSPolicy(gpoDnsPolicyConfigMatchPath, domains, ip); err != nil {
|
||||
return fmt.Errorf("configure GPO DNS policy: %w", err)
|
||||
}
|
||||
|
||||
if err := r.configureDNSPolicy(dnsPolicyConfigMatchPath, domains, ip); err != nil {
|
||||
return fmt.Errorf("configure local DNS policy: %w", err)
|
||||
}
|
||||
|
||||
if err := refreshGroupPolicy(); err != nil {
|
||||
log.Warnf("failed to refresh group policy: %v", err)
|
||||
}
|
||||
} else {
|
||||
if err := r.configureDNSPolicy(dnsPolicyConfigMatchPath, domains, ip); err != nil {
|
||||
return fmt.Errorf("configure local DNS policy: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Infof("added %d match domains. Domain list: %s", len(domains), domains)
|
||||
return nil
|
||||
}
|
||||
|
||||
// configureDNSPolicy handles the actual configuration of a DNS policy at the specified path
|
||||
func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []string, ip string) error {
|
||||
if err := removeRegistryKeyFromDNSPolicyConfig(policyPath); err != nil {
|
||||
return fmt.Errorf("remove existing dns policy: %w", err)
|
||||
}
|
||||
@@ -162,13 +181,6 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip string) er
|
||||
return fmt.Errorf("set %s: %w", dnsPolicyConfigConfigOptionsKey, err)
|
||||
}
|
||||
|
||||
if r.gpo {
|
||||
if err := refreshGroupPolicy(); err != nil {
|
||||
log.Warnf("failed to refresh group policy: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Infof("added %d match domains. Domain list: %s", len(domains), domains)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ type registrationMap map[string]struct{}
|
||||
|
||||
type localResolver struct {
|
||||
registeredMap registrationMap
|
||||
records sync.Map
|
||||
records sync.Map // key: string (domain_class_type), value: []dns.RR
|
||||
}
|
||||
|
||||
func (d *localResolver) MatchSubdomains() bool {
|
||||
@@ -44,11 +44,12 @@ func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
replyMessage := &dns.Msg{}
|
||||
replyMessage.SetReply(r)
|
||||
replyMessage.RecursionAvailable = true
|
||||
replyMessage.Rcode = dns.RcodeSuccess
|
||||
|
||||
response := d.lookupRecord(r)
|
||||
if response != nil {
|
||||
replyMessage.Answer = append(replyMessage.Answer, response)
|
||||
// lookup all records matching the question
|
||||
records := d.lookupRecords(r)
|
||||
if len(records) > 0 {
|
||||
replyMessage.Rcode = dns.RcodeSuccess
|
||||
replyMessage.Answer = append(replyMessage.Answer, records...)
|
||||
} else {
|
||||
replyMessage.Rcode = dns.RcodeNameError
|
||||
}
|
||||
@@ -59,38 +60,65 @@ func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
}
|
||||
}
|
||||
|
||||
func (d *localResolver) lookupRecord(r *dns.Msg) dns.RR {
|
||||
// lookupRecords fetches *all* DNS records matching the first question in r.
|
||||
func (d *localResolver) lookupRecords(r *dns.Msg) []dns.RR {
|
||||
if len(r.Question) == 0 {
|
||||
return nil
|
||||
}
|
||||
question := r.Question[0]
|
||||
question.Name = strings.ToLower(question.Name)
|
||||
record, found := d.records.Load(buildRecordKey(question.Name, question.Qclass, question.Qtype))
|
||||
key := buildRecordKey(question.Name, question.Qclass, question.Qtype)
|
||||
|
||||
value, found := d.records.Load(key)
|
||||
if !found {
|
||||
return nil
|
||||
}
|
||||
|
||||
return record.(dns.RR)
|
||||
}
|
||||
|
||||
func (d *localResolver) registerRecord(record nbdns.SimpleRecord) error {
|
||||
fullRecord, err := dns.NewRR(record.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("register record: %w", err)
|
||||
records, ok := value.([]dns.RR)
|
||||
if !ok {
|
||||
log.Errorf("failed to cast records to []dns.RR, records: %v", value)
|
||||
return nil
|
||||
}
|
||||
|
||||
fullRecord.Header().Rdlength = record.Len()
|
||||
// if there's more than one record, rotate them (round-robin)
|
||||
if len(records) > 1 {
|
||||
first := records[0]
|
||||
records = append(records[1:], first)
|
||||
d.records.Store(key, records)
|
||||
}
|
||||
|
||||
header := fullRecord.Header()
|
||||
d.records.Store(buildRecordKey(header.Name, header.Class, header.Rrtype), fullRecord)
|
||||
|
||||
return nil
|
||||
return records
|
||||
}
|
||||
|
||||
// registerRecord stores a new record by appending it to any existing list
|
||||
func (d *localResolver) registerRecord(record nbdns.SimpleRecord) (string, error) {
|
||||
rr, err := dns.NewRR(record.String())
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("register record: %w", err)
|
||||
}
|
||||
|
||||
rr.Header().Rdlength = record.Len()
|
||||
header := rr.Header()
|
||||
key := buildRecordKey(header.Name, header.Class, header.Rrtype)
|
||||
|
||||
// load any existing slice of records, then append
|
||||
existing, _ := d.records.LoadOrStore(key, []dns.RR{})
|
||||
records := existing.([]dns.RR)
|
||||
records = append(records, rr)
|
||||
|
||||
// store updated slice
|
||||
d.records.Store(key, records)
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// deleteRecord removes *all* records under the recordKey.
|
||||
func (d *localResolver) deleteRecord(recordKey string) {
|
||||
d.records.Delete(dns.Fqdn(recordKey))
|
||||
}
|
||||
|
||||
// buildRecordKey consistently generates a key: name_class_type
|
||||
func buildRecordKey(name string, class, qType uint16) string {
|
||||
key := fmt.Sprintf("%s_%d_%d", name, class, qType)
|
||||
return key
|
||||
return fmt.Sprintf("%s_%d_%d", dns.Fqdn(name), class, qType)
|
||||
}
|
||||
|
||||
func (d *localResolver) probeAvailability() {}
|
||||
|
||||
@@ -55,7 +55,7 @@ func TestLocalResolver_ServeDNS(t *testing.T) {
|
||||
resolver := &localResolver{
|
||||
registeredMap: make(registrationMap),
|
||||
}
|
||||
_ = resolver.registerRecord(testCase.inputRecord)
|
||||
_, _ = resolver.registerRecord(testCase.inputRecord)
|
||||
var responseMSG *dns.Msg
|
||||
responseWriter := &mockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
|
||||
@@ -393,18 +393,22 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
s.service.Stop()
|
||||
}
|
||||
|
||||
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
|
||||
localMuxUpdates, localRecordsByDomain, err := s.buildLocalHandlerUpdate(update.CustomZones)
|
||||
if err != nil {
|
||||
return fmt.Errorf("not applying dns update, error: %v", err)
|
||||
return fmt.Errorf("local handler updater: %w", err)
|
||||
}
|
||||
|
||||
upstreamMuxUpdates, err := s.buildUpstreamHandlerUpdate(update.NameServerGroups)
|
||||
if err != nil {
|
||||
return fmt.Errorf("not applying dns update, error: %v", err)
|
||||
return fmt.Errorf("upstream handler updater: %w", err)
|
||||
}
|
||||
muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...) //nolint:gocritic
|
||||
|
||||
s.updateMux(muxUpdates)
|
||||
s.updateLocalResolver(localRecords)
|
||||
|
||||
// register local records
|
||||
s.updateLocalResolver(localRecordsByDomain)
|
||||
|
||||
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
|
||||
|
||||
hostUpdate := s.currentConfig
|
||||
@@ -434,13 +438,17 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, map[string]nbdns.SimpleRecord, error) {
|
||||
func (s *DefaultServer) buildLocalHandlerUpdate(
|
||||
customZones []nbdns.CustomZone,
|
||||
) ([]handlerWrapper, map[string][]nbdns.SimpleRecord, error) {
|
||||
|
||||
var muxUpdates []handlerWrapper
|
||||
localRecords := make(map[string]nbdns.SimpleRecord, 0)
|
||||
localRecords := make(map[string][]nbdns.SimpleRecord)
|
||||
|
||||
for _, customZone := range customZones {
|
||||
if len(customZone.Records) == 0 {
|
||||
return nil, nil, fmt.Errorf("received an empty list of records")
|
||||
log.Warnf("received a custom zone with empty records, skipping domain: %s", customZone.Domain)
|
||||
continue
|
||||
}
|
||||
|
||||
muxUpdates = append(muxUpdates, handlerWrapper{
|
||||
@@ -449,16 +457,20 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
|
||||
priority: PriorityMatchDomain,
|
||||
})
|
||||
|
||||
// group all records under this domain
|
||||
for _, record := range customZone.Records {
|
||||
var class uint16 = dns.ClassINET
|
||||
if record.Class != nbdns.DefaultClass {
|
||||
return nil, nil, fmt.Errorf("received an invalid class type: %s", record.Class)
|
||||
log.Warnf("received an invalid class type: %s", record.Class)
|
||||
continue
|
||||
}
|
||||
|
||||
key := buildRecordKey(record.Name, class, uint16(record.Type))
|
||||
localRecords[key] = record
|
||||
|
||||
localRecords[key] = append(localRecords[key], record)
|
||||
}
|
||||
}
|
||||
|
||||
return muxUpdates, localRecords, nil
|
||||
}
|
||||
|
||||
@@ -594,7 +606,8 @@ func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
|
||||
s.dnsMuxMap = muxUpdateMap
|
||||
}
|
||||
|
||||
func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord) {
|
||||
func (s *DefaultServer) updateLocalResolver(update map[string][]nbdns.SimpleRecord) {
|
||||
// remove old records that are no longer present
|
||||
for key := range s.localResolver.registeredMap {
|
||||
_, found := update[key]
|
||||
if !found {
|
||||
@@ -603,12 +616,18 @@ func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord
|
||||
}
|
||||
|
||||
updatedMap := make(registrationMap)
|
||||
for key, record := range update {
|
||||
err := s.localResolver.registerRecord(record)
|
||||
if err != nil {
|
||||
log.Warnf("got an error while registering the record (%s), error: %v", record.String(), err)
|
||||
for _, recs := range update {
|
||||
for _, rec := range recs {
|
||||
// convert the record to a dns.RR and register
|
||||
key, err := s.localResolver.registerRecord(rec)
|
||||
if err != nil {
|
||||
log.Warnf("got an error while registering the record (%s), error: %v",
|
||||
rec.String(), err)
|
||||
continue
|
||||
}
|
||||
|
||||
updatedMap[key] = struct{}{}
|
||||
}
|
||||
updatedMap[key] = struct{}{}
|
||||
}
|
||||
|
||||
s.localResolver.registeredMap = updatedMap
|
||||
|
||||
@@ -266,7 +266,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid Custom Zone Records list Should Fail",
|
||||
name: "Invalid Custom Zone Records list Should Skip",
|
||||
initLocalMap: make(registrationMap),
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
@@ -285,7 +285,11 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
shouldFail: true,
|
||||
expectedUpstreamMap: registeredHandlerMap{generateDummyHandler(".", nameServers).id(): handlerWrapper{
|
||||
domain: ".",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityDefault,
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "Empty Config Should Succeed and Clean Maps",
|
||||
@@ -573,7 +577,7 @@ func TestDNSServerStartStop(t *testing.T) {
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
defer dnsServer.Stop()
|
||||
err = dnsServer.localResolver.registerRecord(zoneRecords[0])
|
||||
_, err = dnsServer.localResolver.registerRecord(zoneRecords[0])
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
@@ -10,6 +9,8 @@ import (
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
type ServiceViaMemory struct {
|
||||
@@ -27,7 +28,7 @@ func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory {
|
||||
wgInterface: wgIface,
|
||||
dnsMux: dns.NewServeMux(),
|
||||
|
||||
runtimeIP: getLastIPFromNetwork(wgIface.Address().Network, 1),
|
||||
runtimeIP: nbnet.GetLastIPFromNetwork(wgIface.Address().Network, 1).String(),
|
||||
runtimePort: defaultPort,
|
||||
}
|
||||
return s
|
||||
@@ -118,22 +119,3 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
|
||||
|
||||
return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook), nil
|
||||
}
|
||||
|
||||
func getLastIPFromNetwork(network *net.IPNet, fromEnd int) string {
|
||||
// Calculate the last IP in the CIDR range
|
||||
var endIP net.IP
|
||||
for i := 0; i < len(network.IP); i++ {
|
||||
endIP = append(endIP, network.IP[i]|^network.Mask[i])
|
||||
}
|
||||
|
||||
// convert to big.Int
|
||||
endInt := big.NewInt(0)
|
||||
endInt.SetBytes(endIP)
|
||||
|
||||
// subtract fromEnd from the last ip
|
||||
fromEndBig := big.NewInt(int64(fromEnd))
|
||||
resultInt := big.NewInt(0)
|
||||
resultInt.Sub(endInt, fromEndBig)
|
||||
|
||||
return net.IP(resultInt.Bytes()).String()
|
||||
}
|
||||
|
||||
@@ -3,6 +3,8 @@ package dns
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
func TestGetLastIPFromNetwork(t *testing.T) {
|
||||
@@ -23,7 +25,7 @@ func TestGetLastIPFromNetwork(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
lastIP := getLastIPFromNetwork(ipnet, 1)
|
||||
lastIP := nbnet.GetLastIPFromNetwork(ipnet, 1).String()
|
||||
if lastIP != tt.ip {
|
||||
t.Errorf("wrong IP address, expected %s: got %s", tt.ip, lastIP)
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -230,6 +231,14 @@ func (u *upstreamResolverBase) probeAvailability() {
|
||||
// didn't find a working upstream server, let's disable and try later
|
||||
if !success {
|
||||
u.disable(errors.ErrorOrNil())
|
||||
|
||||
u.statusRecorder.PublishEvent(
|
||||
proto.SystemEvent_WARNING,
|
||||
proto.SystemEvent_DNS,
|
||||
"All upstream servers failed",
|
||||
"Unable to reach one or more DNS servers. This might affect your ability to connect to some services.",
|
||||
map[string]string{"upstreams": strings.Join(u.upstreamServers, ", ")},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"github.com/pion/ice/v3"
|
||||
"github.com/pion/stun/v2"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
@@ -28,7 +29,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/internal/acl"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
||||
@@ -153,7 +154,7 @@ type Engine struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
wgInterface iface.IWGIface
|
||||
wgInterface WGIface
|
||||
|
||||
udpMux *bind.UniversalUDPMuxDefault
|
||||
|
||||
@@ -724,7 +725,7 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
||||
// start SSH server if it wasn't running
|
||||
if isNil(e.sshServer) {
|
||||
listenAddr := fmt.Sprintf("%s:%d", e.wgInterface.Address().IP.String(), nbssh.DefaultSSHPort)
|
||||
if netstack.IsEnabled() {
|
||||
if nbnetstack.IsEnabled() {
|
||||
listenAddr = fmt.Sprintf("127.0.0.1:%d", nbssh.DefaultSSHPort)
|
||||
}
|
||||
// nil sshServer means it has not yet been started
|
||||
@@ -952,7 +953,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
protoDNSConfig = &mgmProto.DNSConfig{}
|
||||
}
|
||||
|
||||
if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig)); err != nil {
|
||||
if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network)); err != nil {
|
||||
log.Errorf("failed to update dns server, err: %v", err)
|
||||
}
|
||||
|
||||
@@ -1021,7 +1022,7 @@ func toRouteDomains(myPubKey string, protoRoutes []*mgmProto.Route) []string {
|
||||
return dnsRoutes
|
||||
}
|
||||
|
||||
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig) nbdns.Config {
|
||||
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network *net.IPNet) nbdns.Config {
|
||||
dnsUpdate := nbdns.Config{
|
||||
ServiceEnable: protoDNSConfig.GetServiceEnable(),
|
||||
CustomZones: make([]nbdns.CustomZone, 0),
|
||||
@@ -1061,6 +1062,11 @@ func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig) nbdns.Config {
|
||||
}
|
||||
dnsUpdate.NameServerGroups = append(dnsUpdate.NameServerGroups, dnsNSGroup)
|
||||
}
|
||||
|
||||
if len(dnsUpdate.CustomZones) > 0 {
|
||||
addReverseZone(&dnsUpdate, network)
|
||||
}
|
||||
|
||||
return dnsUpdate
|
||||
}
|
||||
|
||||
@@ -1367,7 +1373,7 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) {
|
||||
return nil, nil, err
|
||||
}
|
||||
routes := toRoutes(netMap.GetRoutes())
|
||||
dnsCfg := toDNSConfig(netMap.GetDNSConfig())
|
||||
dnsCfg := toDNSConfig(netMap.GetDNSConfig(), e.wgInterface.Address().Network)
|
||||
return routes, &dnsCfg, nil
|
||||
}
|
||||
|
||||
@@ -1716,6 +1722,37 @@ func (e *Engine) updateDNSForwarder(enabled bool, domains []string) {
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Engine) GetNet() (*netstack.Net, error) {
|
||||
e.syncMsgMux.Lock()
|
||||
intf := e.wgInterface
|
||||
e.syncMsgMux.Unlock()
|
||||
if intf == nil {
|
||||
return nil, errors.New("wireguard interface not initialized")
|
||||
}
|
||||
|
||||
nsnet := intf.GetNet()
|
||||
if nsnet == nil {
|
||||
return nil, errors.New("failed to get netstack")
|
||||
}
|
||||
return nsnet, nil
|
||||
}
|
||||
|
||||
func (e *Engine) Address() (netip.Addr, error) {
|
||||
e.syncMsgMux.Lock()
|
||||
intf := e.wgInterface
|
||||
e.syncMsgMux.Unlock()
|
||||
if intf == nil {
|
||||
return netip.Addr{}, errors.New("wireguard interface not initialized")
|
||||
}
|
||||
|
||||
addr := e.wgInterface.Address()
|
||||
ip, ok := netip.AddrFromSlice(addr.IP)
|
||||
if !ok {
|
||||
return netip.Addr{}, errors.New("failed to convert address to netip.Addr")
|
||||
}
|
||||
return ip.Unmap(), nil
|
||||
}
|
||||
|
||||
// isChecksEqual checks if two slices of checks are equal.
|
||||
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
|
||||
for _, check := range checks {
|
||||
|
||||
@@ -23,10 +23,11 @@ import (
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||
@@ -48,6 +49,8 @@ import (
|
||||
"github.com/netbirdio/netbird/signal/proto"
|
||||
signalServer "github.com/netbirdio/netbird/signal/server"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -64,6 +67,114 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
type MockWGIface struct {
|
||||
CreateFunc func() error
|
||||
CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error
|
||||
IsUserspaceBindFunc func() bool
|
||||
NameFunc func() string
|
||||
AddressFunc func() device.WGAddress
|
||||
ToInterfaceFunc func() *net.Interface
|
||||
UpFunc func() (*bind.UniversalUDPMuxDefault, error)
|
||||
UpdateAddrFunc func(newAddr string) error
|
||||
UpdatePeerFunc func(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||
RemovePeerFunc func(peerKey string) error
|
||||
AddAllowedIPFunc func(peerKey string, allowedIP string) error
|
||||
RemoveAllowedIPFunc func(peerKey string, allowedIP string) error
|
||||
CloseFunc func() error
|
||||
SetFilterFunc func(filter device.PacketFilter) error
|
||||
GetFilterFunc func() device.PacketFilter
|
||||
GetDeviceFunc func() *device.FilteredDevice
|
||||
GetWGDeviceFunc func() *wgdevice.Device
|
||||
GetStatsFunc func(peerKey string) (configurer.WGStats, error)
|
||||
GetInterfaceGUIDStringFunc func() (string, error)
|
||||
GetProxyFunc func() wgproxy.Proxy
|
||||
GetNetFunc func() *netstack.Net
|
||||
}
|
||||
|
||||
func (m *MockWGIface) GetInterfaceGUIDString() (string, error) {
|
||||
return m.GetInterfaceGUIDStringFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) Create() error {
|
||||
return m.CreateFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) CreateOnAndroid(routeRange []string, ip string, domains []string) error {
|
||||
return m.CreateOnAndroidFunc(routeRange, ip, domains)
|
||||
}
|
||||
|
||||
func (m *MockWGIface) IsUserspaceBind() bool {
|
||||
return m.IsUserspaceBindFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) Name() string {
|
||||
return m.NameFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) Address() device.WGAddress {
|
||||
return m.AddressFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) ToInterface() *net.Interface {
|
||||
return m.ToInterfaceFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) Up() (*bind.UniversalUDPMuxDefault, error) {
|
||||
return m.UpFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) UpdateAddr(newAddr string) error {
|
||||
return m.UpdateAddrFunc(newAddr)
|
||||
}
|
||||
|
||||
func (m *MockWGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
||||
return m.UpdatePeerFunc(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
|
||||
}
|
||||
|
||||
func (m *MockWGIface) RemovePeer(peerKey string) error {
|
||||
return m.RemovePeerFunc(peerKey)
|
||||
}
|
||||
|
||||
func (m *MockWGIface) AddAllowedIP(peerKey string, allowedIP string) error {
|
||||
return m.AddAllowedIPFunc(peerKey, allowedIP)
|
||||
}
|
||||
|
||||
func (m *MockWGIface) RemoveAllowedIP(peerKey string, allowedIP string) error {
|
||||
return m.RemoveAllowedIPFunc(peerKey, allowedIP)
|
||||
}
|
||||
|
||||
func (m *MockWGIface) Close() error {
|
||||
return m.CloseFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) SetFilter(filter device.PacketFilter) error {
|
||||
return m.SetFilterFunc(filter)
|
||||
}
|
||||
|
||||
func (m *MockWGIface) GetFilter() device.PacketFilter {
|
||||
return m.GetFilterFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) GetDevice() *device.FilteredDevice {
|
||||
return m.GetDeviceFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) GetWGDevice() *wgdevice.Device {
|
||||
return m.GetWGDeviceFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) GetStats(peerKey string) (configurer.WGStats, error) {
|
||||
return m.GetStatsFunc(peerKey)
|
||||
}
|
||||
|
||||
func (m *MockWGIface) GetProxy() wgproxy.Proxy {
|
||||
return m.GetProxyFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) GetNet() *netstack.Net {
|
||||
return m.GetNetFunc()
|
||||
}
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
_ = util.InitLog("debug", "console")
|
||||
code := m.Run()
|
||||
@@ -245,11 +356,20 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
peer.NewRecorder("https://mgm"),
|
||||
nil)
|
||||
|
||||
wgIface := &iface.MockWGIface{
|
||||
wgIface := &MockWGIface{
|
||||
NameFunc: func() string { return "utun102" },
|
||||
RemovePeerFunc: func(peerKey string) error {
|
||||
return nil
|
||||
},
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
IP: net.ParseIP("10.20.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("10.20.0.0"),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
engine.wgInterface = wgIface
|
||||
engine.routeManager = routemanager.NewManager(routemanager.ManagerConfig{
|
||||
@@ -692,6 +812,9 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Domain: "0.66.100.in-addr.arpa.",
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*mgmtProto.NameServerGroup{
|
||||
{
|
||||
@@ -721,6 +844,9 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Domain: "0.66.100.in-addr.arpa.",
|
||||
},
|
||||
},
|
||||
expectedNSGroupsLen: 1,
|
||||
expectedNSGroups: []*nbdns.NameServerGroup{
|
||||
@@ -1226,7 +1352,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
||||
}
|
||||
|
||||
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
|
||||
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil)
|
||||
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
8
client/internal/iface.go
Normal file
8
client/internal/iface.go
Normal file
@@ -0,0 +1,8 @@
|
||||
//go:build !windows
|
||||
// +build !windows
|
||||
|
||||
package internal
|
||||
|
||||
type WGIface interface {
|
||||
wgIfaceBase
|
||||
}
|
||||
38
client/internal/iface_common.go
Normal file
38
client/internal/iface_common.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
type wgIfaceBase interface {
|
||||
Create() error
|
||||
CreateOnAndroid(routeRange []string, ip string, domains []string) error
|
||||
IsUserspaceBind() bool
|
||||
Name() string
|
||||
Address() device.WGAddress
|
||||
ToInterface() *net.Interface
|
||||
Up() (*bind.UniversalUDPMuxDefault, error)
|
||||
UpdateAddr(newAddr string) error
|
||||
GetProxy() wgproxy.Proxy
|
||||
UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||
RemovePeer(peerKey string) error
|
||||
AddAllowedIP(peerKey string, allowedIP string) error
|
||||
RemoveAllowedIP(peerKey string, allowedIP string) error
|
||||
Close() error
|
||||
SetFilter(filter device.PacketFilter) error
|
||||
GetFilter() device.PacketFilter
|
||||
GetDevice() *device.FilteredDevice
|
||||
GetWGDevice() *wgdevice.Device
|
||||
GetStats(peerKey string) (configurer.WGStats, error)
|
||||
GetNet() *netstack.Net
|
||||
}
|
||||
6
client/internal/iface_windows.go
Normal file
6
client/internal/iface_windows.go
Normal file
@@ -0,0 +1,6 @@
|
||||
package internal
|
||||
|
||||
type WGIface interface {
|
||||
wgIfaceBase
|
||||
GetInterfaceGUIDString() (string, error)
|
||||
}
|
||||
@@ -117,7 +117,7 @@ func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte
|
||||
config.DisableDNS,
|
||||
config.DisableFirewall,
|
||||
)
|
||||
_, err = mgmClient.Login(*serverKey, sysInfo, pubSSHKey)
|
||||
_, err = mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels)
|
||||
return serverKey, err
|
||||
}
|
||||
|
||||
|
||||
@@ -52,17 +52,10 @@ const (
|
||||
connPriorityICEP2P ConnPriority = 3
|
||||
)
|
||||
|
||||
type WgInterface interface {
|
||||
UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||
RemovePeer(publicKey string) error
|
||||
GetProxy() wgproxy.Proxy
|
||||
GetStats(peerKey string) (configurer.WGStats, error)
|
||||
}
|
||||
|
||||
type WgConfig struct {
|
||||
WgListenPort int
|
||||
RemoteKey string
|
||||
WgInterface WgInterface
|
||||
WgInterface WGIface
|
||||
AllowedIps string
|
||||
PreSharedKey *wgtypes.Key
|
||||
}
|
||||
|
||||
17
client/internal/peer/iface.go
Normal file
17
client/internal/peer/iface.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
type WGIface interface {
|
||||
UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||
RemovePeer(peerKey string) error
|
||||
GetStats(peerKey string) (configurer.WGStats, error)
|
||||
GetProxy() wgproxy.Proxy
|
||||
}
|
||||
@@ -7,21 +7,31 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/internal/relay"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
relayClient "github.com/netbirdio/netbird/relay/client"
|
||||
)
|
||||
|
||||
const eventQueueSize = 10
|
||||
|
||||
type ResolvedDomainInfo struct {
|
||||
Prefixes []netip.Prefix
|
||||
ParentDomain domain.Domain
|
||||
}
|
||||
|
||||
type EventListener interface {
|
||||
OnEvent(event *proto.SystemEvent)
|
||||
}
|
||||
|
||||
// State contains the latest state of a peer
|
||||
type State struct {
|
||||
Mux *sync.RWMutex
|
||||
@@ -157,6 +167,10 @@ type Status struct {
|
||||
peerListChangedForNotification bool
|
||||
|
||||
relayMgr *relayClient.Manager
|
||||
|
||||
eventMux sync.RWMutex
|
||||
eventStreams map[string]chan *proto.SystemEvent
|
||||
eventQueue *EventQueue
|
||||
}
|
||||
|
||||
// NewRecorder returns a new Status instance
|
||||
@@ -164,6 +178,8 @@ func NewRecorder(mgmAddress string) *Status {
|
||||
return &Status{
|
||||
peers: make(map[string]State),
|
||||
changeNotify: make(map[string]chan struct{}),
|
||||
eventStreams: make(map[string]chan *proto.SystemEvent),
|
||||
eventQueue: NewEventQueue(eventQueueSize),
|
||||
offlinePeers: make([]State, 0),
|
||||
notifier: newNotifier(),
|
||||
mgmAddress: mgmAddress,
|
||||
@@ -806,3 +822,112 @@ func (d *Status) notifyAddressChanged() {
|
||||
func (d *Status) numOfPeers() int {
|
||||
return len(d.peers) + len(d.offlinePeers)
|
||||
}
|
||||
|
||||
// PublishEvent adds an event to the queue and distributes it to all subscribers
|
||||
func (d *Status) PublishEvent(
|
||||
severity proto.SystemEvent_Severity,
|
||||
category proto.SystemEvent_Category,
|
||||
msg string,
|
||||
userMsg string,
|
||||
metadata map[string]string,
|
||||
) {
|
||||
event := &proto.SystemEvent{
|
||||
Id: uuid.New().String(),
|
||||
Severity: severity,
|
||||
Category: category,
|
||||
Message: msg,
|
||||
UserMessage: userMsg,
|
||||
Metadata: metadata,
|
||||
Timestamp: timestamppb.Now(),
|
||||
}
|
||||
|
||||
d.eventMux.Lock()
|
||||
defer d.eventMux.Unlock()
|
||||
|
||||
d.eventQueue.Add(event)
|
||||
|
||||
for _, stream := range d.eventStreams {
|
||||
select {
|
||||
case stream <- event:
|
||||
default:
|
||||
log.Debugf("event stream buffer full, skipping event: %v", event)
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("event published: %v", event)
|
||||
}
|
||||
|
||||
// SubscribeToEvents returns a new event subscription
|
||||
func (d *Status) SubscribeToEvents() *EventSubscription {
|
||||
d.eventMux.Lock()
|
||||
defer d.eventMux.Unlock()
|
||||
|
||||
id := uuid.New().String()
|
||||
stream := make(chan *proto.SystemEvent, 10)
|
||||
d.eventStreams[id] = stream
|
||||
|
||||
return &EventSubscription{
|
||||
id: id,
|
||||
events: stream,
|
||||
}
|
||||
}
|
||||
|
||||
// UnsubscribeFromEvents removes an event subscription
|
||||
func (d *Status) UnsubscribeFromEvents(sub *EventSubscription) {
|
||||
if sub == nil {
|
||||
return
|
||||
}
|
||||
|
||||
d.eventMux.Lock()
|
||||
defer d.eventMux.Unlock()
|
||||
|
||||
if stream, exists := d.eventStreams[sub.id]; exists {
|
||||
close(stream)
|
||||
delete(d.eventStreams, sub.id)
|
||||
}
|
||||
}
|
||||
|
||||
// GetEventHistory returns all events in the queue
|
||||
func (d *Status) GetEventHistory() []*proto.SystemEvent {
|
||||
return d.eventQueue.GetAll()
|
||||
}
|
||||
|
||||
type EventQueue struct {
|
||||
maxSize int
|
||||
events []*proto.SystemEvent
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
func NewEventQueue(size int) *EventQueue {
|
||||
return &EventQueue{
|
||||
maxSize: size,
|
||||
events: make([]*proto.SystemEvent, 0, size),
|
||||
}
|
||||
}
|
||||
|
||||
func (q *EventQueue) Add(event *proto.SystemEvent) {
|
||||
q.mutex.Lock()
|
||||
defer q.mutex.Unlock()
|
||||
|
||||
q.events = append(q.events, event)
|
||||
|
||||
if len(q.events) > q.maxSize {
|
||||
q.events = q.events[len(q.events)-q.maxSize:]
|
||||
}
|
||||
}
|
||||
|
||||
func (q *EventQueue) GetAll() []*proto.SystemEvent {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
return slices.Clone(q.events)
|
||||
}
|
||||
|
||||
type EventSubscription struct {
|
||||
id string
|
||||
events chan *proto.SystemEvent
|
||||
}
|
||||
|
||||
func (s *EventSubscription) Events() <-chan *proto.SystemEvent {
|
||||
return s.events
|
||||
}
|
||||
|
||||
@@ -4,21 +4,22 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
runtime "runtime"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/dnsinterceptor"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/static"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
@@ -28,6 +29,15 @@ const (
|
||||
handlerTypeStatic
|
||||
)
|
||||
|
||||
type reason int
|
||||
|
||||
const (
|
||||
reasonUnknown reason = iota
|
||||
reasonRouteUpdate
|
||||
reasonPeerUpdate
|
||||
reasonShutdown
|
||||
)
|
||||
|
||||
type routerPeerStatus struct {
|
||||
connected bool
|
||||
relayed bool
|
||||
@@ -52,7 +62,7 @@ type clientNetwork struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
statusRecorder *peer.Status
|
||||
wgInterface iface.IWGIface
|
||||
wgInterface iface.WGIface
|
||||
routes map[route.ID]*route.Route
|
||||
routeUpdate chan routesUpdate
|
||||
peerStateUpdate chan struct{}
|
||||
@@ -65,7 +75,7 @@ type clientNetwork struct {
|
||||
func newClientNetworkWatcher(
|
||||
ctx context.Context,
|
||||
dnsRouteInterval time.Duration,
|
||||
wgInterface iface.IWGIface,
|
||||
wgInterface iface.WGIface,
|
||||
statusRecorder *peer.Status,
|
||||
rt *route.Route,
|
||||
routeRefCounter *refcounter.RouteRefCounter,
|
||||
@@ -255,7 +265,7 @@ func (c *clientNetwork) removeRouteFromWireGuardPeer() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *clientNetwork) removeRouteFromPeerAndSystem() error {
|
||||
func (c *clientNetwork) removeRouteFromPeerAndSystem(rsn reason) error {
|
||||
if c.currentChosen == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -269,17 +279,19 @@ func (c *clientNetwork) removeRouteFromPeerAndSystem() error {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove route: %w", err))
|
||||
}
|
||||
|
||||
c.disconnectEvent(rsn)
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
|
||||
func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem(rsn reason) error {
|
||||
routerPeerStatuses := c.getRouterPeerStatuses()
|
||||
|
||||
newChosenID := c.getBestRouteFromStatuses(routerPeerStatuses)
|
||||
|
||||
// If no route is chosen, remove the route from the peer and system
|
||||
if newChosenID == "" {
|
||||
if err := c.removeRouteFromPeerAndSystem(); err != nil {
|
||||
if err := c.removeRouteFromPeerAndSystem(rsn); err != nil {
|
||||
return fmt.Errorf("remove route for peer %s: %w", c.currentChosen.Peer, err)
|
||||
}
|
||||
|
||||
@@ -319,6 +331,58 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *clientNetwork) disconnectEvent(rsn reason) {
|
||||
var defaultRoute bool
|
||||
for _, r := range c.routes {
|
||||
if r.Network.Bits() == 0 {
|
||||
defaultRoute = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !defaultRoute {
|
||||
return
|
||||
}
|
||||
|
||||
var severity proto.SystemEvent_Severity
|
||||
var message string
|
||||
var userMessage string
|
||||
meta := make(map[string]string)
|
||||
|
||||
switch rsn {
|
||||
case reasonShutdown:
|
||||
severity = proto.SystemEvent_INFO
|
||||
message = "Default route removed"
|
||||
userMessage = "Exit node disconnected."
|
||||
meta["network"] = c.handler.String()
|
||||
case reasonRouteUpdate:
|
||||
severity = proto.SystemEvent_INFO
|
||||
message = "Default route updated due to configuration change"
|
||||
meta["network"] = c.handler.String()
|
||||
case reasonPeerUpdate:
|
||||
severity = proto.SystemEvent_WARNING
|
||||
message = "Default route disconnected due to peer unreachability"
|
||||
userMessage = "Exit node connection lost. Your internet access might be affected."
|
||||
if c.currentChosen != nil {
|
||||
meta["peer"] = c.currentChosen.Peer
|
||||
meta["network"] = c.handler.String()
|
||||
}
|
||||
default:
|
||||
severity = proto.SystemEvent_ERROR
|
||||
message = "Default route disconnected for unknown reason"
|
||||
userMessage = "Exit node disconnected for unknown reasons."
|
||||
meta["network"] = c.handler.String()
|
||||
}
|
||||
|
||||
c.statusRecorder.PublishEvent(
|
||||
severity,
|
||||
proto.SystemEvent_NETWORK,
|
||||
message,
|
||||
userMessage,
|
||||
meta,
|
||||
)
|
||||
}
|
||||
|
||||
func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) {
|
||||
go func() {
|
||||
c.routeUpdate <- update
|
||||
@@ -361,12 +425,12 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
log.Debugf("Stopping watcher for network [%v]", c.handler)
|
||||
if err := c.removeRouteFromPeerAndSystem(); err != nil {
|
||||
if err := c.removeRouteFromPeerAndSystem(reasonShutdown); err != nil {
|
||||
log.Errorf("Failed to remove routes for [%v]: %v", c.handler, err)
|
||||
}
|
||||
return
|
||||
case <-c.peerStateUpdate:
|
||||
err := c.recalculateRouteAndUpdatePeerAndSystem()
|
||||
err := c.recalculateRouteAndUpdatePeerAndSystem(reasonPeerUpdate)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to recalculate routes for network [%v]: %v", c.handler, err)
|
||||
}
|
||||
@@ -385,7 +449,7 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
|
||||
|
||||
if isTrueRouteUpdate {
|
||||
log.Debug("Client network update contains different routes, recalculating routes")
|
||||
err := c.recalculateRouteAndUpdatePeerAndSystem()
|
||||
err := c.recalculateRouteAndUpdatePeerAndSystem(reasonRouteUpdate)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to recalculate routes for network [%v]: %v", c.handler, err)
|
||||
}
|
||||
@@ -404,7 +468,7 @@ func handlerFromRoute(
|
||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
||||
dnsRouterInteval time.Duration,
|
||||
statusRecorder *peer.Status,
|
||||
wgInterface iface.IWGIface,
|
||||
wgInterface iface.WGIface,
|
||||
dnsServer nbdns.Server,
|
||||
peerStore *peerstore.Store,
|
||||
useNewDNSRoute bool,
|
||||
|
||||
@@ -13,8 +13,8 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
@@ -48,7 +48,7 @@ type Route struct {
|
||||
currentPeerKey string
|
||||
cancel context.CancelFunc
|
||||
statusRecorder *peer.Status
|
||||
wgInterface iface.IWGIface
|
||||
wgInterface iface.WGIface
|
||||
resolverAddr string
|
||||
}
|
||||
|
||||
@@ -58,7 +58,7 @@ func NewRoute(
|
||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
||||
interval time.Duration,
|
||||
statusRecorder *peer.Status,
|
||||
wgInterface iface.IWGIface,
|
||||
wgInterface iface.WGIface,
|
||||
resolverAddr string,
|
||||
) *Route {
|
||||
return &Route{
|
||||
|
||||
9
client/internal/routemanager/iface/iface.go
Normal file
9
client/internal/routemanager/iface/iface.go
Normal file
@@ -0,0 +1,9 @@
|
||||
//go:build !windows
|
||||
// +build !windows
|
||||
|
||||
package iface
|
||||
|
||||
// WGIface defines subset methods of interface required for router
|
||||
type WGIface interface {
|
||||
wgIfaceBase
|
||||
}
|
||||
22
client/internal/routemanager/iface/iface_common.go
Normal file
22
client/internal/routemanager/iface/iface_common.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package iface
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
type wgIfaceBase interface {
|
||||
AddAllowedIP(peerKey string, allowedIP string) error
|
||||
RemoveAllowedIP(peerKey string, allowedIP string) error
|
||||
|
||||
Name() string
|
||||
Address() iface.WGAddress
|
||||
ToInterface() *net.Interface
|
||||
IsUserspaceBind() bool
|
||||
GetFilter() device.PacketFilter
|
||||
GetDevice() *device.FilteredDevice
|
||||
GetStats(peerKey string) (configurer.WGStats, error)
|
||||
}
|
||||
7
client/internal/routemanager/iface/iface_windows.go
Normal file
7
client/internal/routemanager/iface/iface_windows.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package iface
|
||||
|
||||
// WGIface defines subset methods of interface required for router
|
||||
type WGIface interface {
|
||||
wgIfaceBase
|
||||
GetInterfaceGUIDString() (string, error)
|
||||
}
|
||||
@@ -15,13 +15,13 @@ import (
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
@@ -52,7 +52,7 @@ type ManagerConfig struct {
|
||||
Context context.Context
|
||||
PublicKey string
|
||||
DNSRouteInterval time.Duration
|
||||
WGInterface iface.IWGIface
|
||||
WGInterface iface.WGIface
|
||||
StatusRecorder *peer.Status
|
||||
RelayManager *relayClient.Manager
|
||||
InitialRoutes []*route.Route
|
||||
@@ -74,7 +74,7 @@ type DefaultManager struct {
|
||||
sysOps *systemops.SysOps
|
||||
statusRecorder *peer.Status
|
||||
relayMgr *relayClient.Manager
|
||||
wgInterface iface.IWGIface
|
||||
wgInterface iface.WGIface
|
||||
pubKey string
|
||||
notifier *notifier.Notifier
|
||||
routeRefCounter *refcounter.RouteRefCounter
|
||||
|
||||
@@ -7,8 +7,8 @@ import (
|
||||
"fmt"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
@@ -22,6 +22,6 @@ func (r serverRouter) updateRoutes(map[route.ID]*route.Route) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newServerRouter(context.Context, iface.IWGIface, firewall.Manager, *peer.Status) (*serverRouter, error) {
|
||||
func newServerRouter(context.Context, iface.WGIface, firewall.Manager, *peer.Status) (*serverRouter, error) {
|
||||
return nil, fmt.Errorf("server route not supported on this os")
|
||||
}
|
||||
|
||||
@@ -11,8 +11,8 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
@@ -22,11 +22,11 @@ type serverRouter struct {
|
||||
ctx context.Context
|
||||
routes map[route.ID]*route.Route
|
||||
firewall firewall.Manager
|
||||
wgInterface iface.IWGIface
|
||||
wgInterface iface.WGIface
|
||||
statusRecorder *peer.Status
|
||||
}
|
||||
|
||||
func newServerRouter(ctx context.Context, wgInterface iface.IWGIface, firewall firewall.Manager, statusRecorder *peer.Status) (*serverRouter, error) {
|
||||
func newServerRouter(ctx context.Context, wgInterface iface.WGIface, firewall firewall.Manager, statusRecorder *peer.Status) (*serverRouter, error) {
|
||||
return &serverRouter{
|
||||
ctx: ctx,
|
||||
routes: make(map[route.ID]*route.Route),
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -23,7 +23,7 @@ const (
|
||||
)
|
||||
|
||||
// Setup configures sysctl settings for RP filtering and source validation.
|
||||
func Setup(wgIface iface.IWGIface) (map[string]int, error) {
|
||||
func Setup(wgIface iface.WGIface) (map[string]int, error) {
|
||||
keys := map[string]int{}
|
||||
var result *multierror.Error
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
)
|
||||
@@ -19,7 +19,7 @@ type ExclusionCounter = refcounter.Counter[netip.Prefix, struct{}, Nexthop]
|
||||
|
||||
type SysOps struct {
|
||||
refCounter *ExclusionCounter
|
||||
wgInterface iface.IWGIface
|
||||
wgInterface iface.WGIface
|
||||
// prefixes is tracking all the current added prefixes im memory
|
||||
// (this is used in iOS as all route updates require a full table update)
|
||||
//nolint
|
||||
@@ -30,7 +30,7 @@ type SysOps struct {
|
||||
notifier *notifier.Notifier
|
||||
}
|
||||
|
||||
func NewSysOps(wgInterface iface.IWGIface, notifier *notifier.Notifier) *SysOps {
|
||||
func NewSysOps(wgInterface iface.WGIface, notifier *notifier.Notifier) *SysOps {
|
||||
return &SysOps{
|
||||
wgInterface: wgInterface,
|
||||
notifier: notifier,
|
||||
|
||||
@@ -16,8 +16,8 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||
@@ -149,7 +149,7 @@ func (r *SysOps) addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
|
||||
|
||||
// addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface.
|
||||
// If the next hop or interface is pointing to the VPN interface, it will return the initial values.
|
||||
func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf iface.IWGIface, initialNextHop Nexthop) (Nexthop, error) {
|
||||
func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf iface.WGIface, initialNextHop Nexthop) (Nexthop, error) {
|
||||
addr := prefix.Addr()
|
||||
switch {
|
||||
case addr.IsLoopback(),
|
||||
|
||||
Reference in New Issue
Block a user