mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
Mobile (#735)
Initial modification to support mobile client Export necessary interfaces for Android framework
This commit is contained in:
@@ -73,6 +73,25 @@ type Config struct {
|
||||
CustomDNSAddress string
|
||||
}
|
||||
|
||||
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
|
||||
func ReadConfig(configPath string) (*Config, error) {
|
||||
if configFileIsExists(configPath) {
|
||||
config := &Config{}
|
||||
if _, err := util.ReadJson(configPath, config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return config, nil
|
||||
}
|
||||
|
||||
cfg, err := createNewConfig(ConfigInput{ConfigPath: configPath})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = WriteOutConfig(configPath, cfg)
|
||||
return cfg, err
|
||||
}
|
||||
|
||||
// UpdateConfig update existing configuration according to input configuration and return with the configuration
|
||||
func UpdateConfig(input ConfigInput) (*Config, error) {
|
||||
if !configFileIsExists(input.ConfigPath) {
|
||||
@@ -86,7 +105,12 @@ func UpdateConfig(input ConfigInput) (*Config, error) {
|
||||
func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
|
||||
if !configFileIsExists(input.ConfigPath) {
|
||||
log.Infof("generating new config %s", input.ConfigPath)
|
||||
return createNewConfig(input)
|
||||
cfg, err := createNewConfig(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = WriteOutConfig(input.ConfigPath, cfg)
|
||||
return cfg, err
|
||||
}
|
||||
|
||||
if isPreSharedKeyHidden(input.PreSharedKey) {
|
||||
@@ -95,6 +119,16 @@ func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
|
||||
return update(input)
|
||||
}
|
||||
|
||||
// CreateInMemoryConfig generate a new config but do not write out it to the store
|
||||
func CreateInMemoryConfig(input ConfigInput) (*Config, error) {
|
||||
return createNewConfig(input)
|
||||
}
|
||||
|
||||
// WriteOutConfig write put the prepared config to the given path
|
||||
func WriteOutConfig(path string, config *Config) error {
|
||||
return util.WriteJson(path, config)
|
||||
}
|
||||
|
||||
// createNewConfig creates a new config generating a new Wireguard key and saving to file
|
||||
func createNewConfig(input ConfigInput) (*Config, error) {
|
||||
wgKey := generateKey()
|
||||
@@ -146,12 +180,6 @@ func createNewConfig(input ConfigInput) (*Config, error) {
|
||||
}
|
||||
|
||||
config.IFaceBlackList = defaultInterfaceBlacklist
|
||||
|
||||
err = util.WriteJson(input.ConfigPath, config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ import (
|
||||
)
|
||||
|
||||
// RunClient with main logic.
|
||||
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status) error {
|
||||
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter) error {
|
||||
backOff := &backoff.ExponentialBackOff{
|
||||
InitialInterval: time.Second,
|
||||
RandomizationFactor: 1,
|
||||
@@ -60,6 +60,8 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status)
|
||||
|
||||
statusRecorder.MarkManagementDisconnected()
|
||||
|
||||
statusRecorder.ClientStart()
|
||||
defer statusRecorder.ClientStop()
|
||||
operation := func() error {
|
||||
// if context cancelled we not start new backoff cycle
|
||||
select {
|
||||
@@ -144,7 +146,7 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status)
|
||||
|
||||
peerConfig := loginResp.GetPeerConfig()
|
||||
|
||||
engineConfig, err := createEngineConfig(myPrivateKey, config, peerConfig)
|
||||
engineConfig, err := createEngineConfig(myPrivateKey, config, peerConfig, tunAdapter)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return wrapErr(err)
|
||||
@@ -191,11 +193,12 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status)
|
||||
}
|
||||
|
||||
// createEngineConfig converts configuration received from Management Service to EngineConfig
|
||||
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) {
|
||||
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig, tunAdapter iface.TunAdapter) (*EngineConfig, error) {
|
||||
|
||||
engineConf := &EngineConfig{
|
||||
WgIfaceName: config.WgIface,
|
||||
WgAddr: peerConfig.Address,
|
||||
TunAdapter: tunAdapter,
|
||||
IFaceBlackList: config.IFaceBlackList,
|
||||
DisableIPv6Discovery: config.DisableIPv6Discovery,
|
||||
WgPrivateKey: key,
|
||||
|
||||
@@ -57,6 +57,7 @@ func GetDeviceAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmU
|
||||
return DeviceAuthorizationFlow{}, err
|
||||
}
|
||||
log.Debugf("connected to the Management service %s", mgmURL.String())
|
||||
|
||||
defer func() {
|
||||
err = mgmClient.Close()
|
||||
if err != nil {
|
||||
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
type registrationMap map[string]struct{}
|
||||
|
||||
type localResolver struct {
|
||||
registeredMap registrationMap
|
||||
records sync.Map
|
||||
|
||||
@@ -1,27 +1,6 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/mitchellh/hashstructure/v2"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultPort = 53
|
||||
customPort = 5053
|
||||
defaultIP = "127.0.0.1"
|
||||
customIP = "127.0.0.153"
|
||||
)
|
||||
import nbdns "github.com/netbirdio/netbird/dns"
|
||||
|
||||
// Server is a dns server interface
|
||||
type Server interface {
|
||||
@@ -29,444 +8,3 @@ type Server interface {
|
||||
Stop()
|
||||
UpdateDNSServer(serial uint64, update nbdns.Config) error
|
||||
}
|
||||
|
||||
// DefaultServer dns server object
|
||||
type DefaultServer struct {
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
upstreamCtxCancel context.CancelFunc
|
||||
mux sync.Mutex
|
||||
server *dns.Server
|
||||
dnsMux *dns.ServeMux
|
||||
dnsMuxMap registrationMap
|
||||
localResolver *localResolver
|
||||
wgInterface *iface.WGIface
|
||||
hostManager hostManager
|
||||
updateSerial uint64
|
||||
listenerIsRunning bool
|
||||
runtimePort int
|
||||
runtimeIP string
|
||||
previousConfigHash uint64
|
||||
currentConfig hostDNSConfig
|
||||
customAddress *netip.AddrPort
|
||||
}
|
||||
|
||||
type registrationMap map[string]struct{}
|
||||
|
||||
type muxUpdate struct {
|
||||
domain string
|
||||
handler dns.Handler
|
||||
}
|
||||
|
||||
// NewDefaultServer returns a new dns server
|
||||
func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAddress string) (*DefaultServer, error) {
|
||||
mux := dns.NewServeMux()
|
||||
|
||||
dnsServer := &dns.Server{
|
||||
Net: "udp",
|
||||
Handler: mux,
|
||||
UDPSize: 65535,
|
||||
}
|
||||
|
||||
ctx, stop := context.WithCancel(ctx)
|
||||
|
||||
var addrPort *netip.AddrPort
|
||||
if customAddress != "" {
|
||||
parsedAddrPort, err := netip.ParseAddrPort(customAddress)
|
||||
if err != nil {
|
||||
stop()
|
||||
return nil, fmt.Errorf("unable to parse the custom dns address, got error: %s", err)
|
||||
}
|
||||
addrPort = &parsedAddrPort
|
||||
}
|
||||
|
||||
defaultServer := &DefaultServer{
|
||||
ctx: ctx,
|
||||
ctxCancel: stop,
|
||||
server: dnsServer,
|
||||
dnsMux: mux,
|
||||
dnsMuxMap: make(registrationMap),
|
||||
localResolver: &localResolver{
|
||||
registeredMap: make(registrationMap),
|
||||
},
|
||||
wgInterface: wgInterface,
|
||||
runtimePort: defaultPort,
|
||||
customAddress: addrPort,
|
||||
}
|
||||
|
||||
hostmanager, err := newHostManager(wgInterface)
|
||||
if err != nil {
|
||||
stop()
|
||||
return nil, err
|
||||
}
|
||||
defaultServer.hostManager = hostmanager
|
||||
return defaultServer, err
|
||||
}
|
||||
|
||||
// Start runs the listener in a go routine
|
||||
func (s *DefaultServer) Start() {
|
||||
if s.customAddress != nil {
|
||||
s.runtimeIP = s.customAddress.Addr().String()
|
||||
s.runtimePort = int(s.customAddress.Port())
|
||||
} else {
|
||||
ip, port, err := s.getFirstListenerAvailable()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
s.runtimeIP = ip
|
||||
s.runtimePort = port
|
||||
}
|
||||
|
||||
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
|
||||
|
||||
log.Debugf("starting dns on %s", s.server.Addr)
|
||||
|
||||
go func() {
|
||||
s.setListenerStatus(true)
|
||||
defer s.setListenerStatus(false)
|
||||
|
||||
err := s.server.ListenAndServe()
|
||||
if err != nil {
|
||||
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.runtimePort, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *DefaultServer) getFirstListenerAvailable() (string, int, error) {
|
||||
ips := []string{defaultIP, customIP}
|
||||
if runtime.GOOS != "darwin" && s.wgInterface != nil {
|
||||
ips = append([]string{s.wgInterface.Address().IP.String()}, ips...)
|
||||
}
|
||||
ports := []int{defaultPort, customPort}
|
||||
for _, port := range ports {
|
||||
for _, ip := range ips {
|
||||
addrString := fmt.Sprintf("%s:%d", ip, port)
|
||||
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
|
||||
probeListener, err := net.ListenUDP("udp", udpAddr)
|
||||
if err == nil {
|
||||
err = probeListener.Close()
|
||||
if err != nil {
|
||||
log.Errorf("got an error closing the probe listener, error: %s", err)
|
||||
}
|
||||
return ip, port, nil
|
||||
}
|
||||
log.Warnf("binding dns on %s is not available, error: %s", addrString, err)
|
||||
}
|
||||
}
|
||||
return "", 0, fmt.Errorf("unable to find an unused ip and port combination. IPs tested: %v and ports %v", ips, ports)
|
||||
}
|
||||
|
||||
func (s *DefaultServer) setListenerStatus(running bool) {
|
||||
s.listenerIsRunning = running
|
||||
}
|
||||
|
||||
// Stop stops the server
|
||||
func (s *DefaultServer) Stop() {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
s.ctxCancel()
|
||||
|
||||
err := s.hostManager.restoreHostDNS()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
|
||||
err = s.stopListener()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DefaultServer) stopListener() error {
|
||||
if !s.listenerIsRunning {
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := s.server.ShutdownContext(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("stopping dns server listener returned an error: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateDNSServer processes an update received from the management service
|
||||
func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
log.Infof("not updating DNS server as context is closed")
|
||||
return s.ctx.Err()
|
||||
default:
|
||||
if serial < s.updateSerial {
|
||||
return fmt.Errorf("not applying dns update, error: "+
|
||||
"network update is %d behind the last applied update", s.updateSerial-serial)
|
||||
}
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
hash, err := hashstructure.Hash(update, hashstructure.FormatV2, &hashstructure.HashOptions{
|
||||
ZeroNil: true,
|
||||
IgnoreZeroValue: true,
|
||||
SlicesAsSets: true,
|
||||
UseStringer: true,
|
||||
})
|
||||
if err != nil {
|
||||
log.Errorf("unable to hash the dns configuration update, got error: %s", err)
|
||||
}
|
||||
|
||||
if s.previousConfigHash == hash {
|
||||
log.Debugf("not applying the dns configuration update as there is nothing new")
|
||||
s.updateSerial = serial
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.applyConfiguration(update); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.updateSerial = serial
|
||||
s.previousConfigHash = hash
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
// is the service should be disabled, we stop the listener
|
||||
// and proceed with a regular update to clean up the handlers and records
|
||||
if !update.ServiceEnable {
|
||||
err := s.stopListener()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
} else if !s.listenerIsRunning {
|
||||
s.Start()
|
||||
}
|
||||
|
||||
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
|
||||
if err != nil {
|
||||
return fmt.Errorf("not applying dns update, error: %v", err)
|
||||
}
|
||||
upstreamMuxUpdates, err := s.buildUpstreamHandlerUpdate(update.NameServerGroups)
|
||||
if err != nil {
|
||||
return fmt.Errorf("not applying dns update, error: %v", err)
|
||||
}
|
||||
|
||||
muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...)
|
||||
|
||||
s.updateMux(muxUpdates)
|
||||
s.updateLocalResolver(localRecords)
|
||||
s.currentConfig = dnsConfigToHostDNSConfig(update, s.runtimeIP, s.runtimePort)
|
||||
|
||||
if err = s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]muxUpdate, map[string]nbdns.SimpleRecord, error) {
|
||||
var muxUpdates []muxUpdate
|
||||
localRecords := make(map[string]nbdns.SimpleRecord, 0)
|
||||
|
||||
for _, customZone := range customZones {
|
||||
|
||||
if len(customZone.Records) == 0 {
|
||||
return nil, nil, fmt.Errorf("received an empty list of records")
|
||||
}
|
||||
|
||||
muxUpdates = append(muxUpdates, muxUpdate{
|
||||
domain: customZone.Domain,
|
||||
handler: s.localResolver,
|
||||
})
|
||||
|
||||
for _, record := range customZone.Records {
|
||||
var class uint16 = dns.ClassINET
|
||||
if record.Class != nbdns.DefaultClass {
|
||||
return nil, nil, fmt.Errorf("received an invalid class type: %s", record.Class)
|
||||
}
|
||||
key := buildRecordKey(record.Name, class, uint16(record.Type))
|
||||
localRecords[key] = record
|
||||
}
|
||||
}
|
||||
return muxUpdates, localRecords, nil
|
||||
}
|
||||
|
||||
func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]muxUpdate, error) {
|
||||
// clean up the previous upstream resolver
|
||||
if s.upstreamCtxCancel != nil {
|
||||
s.upstreamCtxCancel()
|
||||
}
|
||||
|
||||
var muxUpdates []muxUpdate
|
||||
for _, nsGroup := range nameServerGroups {
|
||||
if len(nsGroup.NameServers) == 0 {
|
||||
log.Warn("received a nameserver group with empty nameserver list")
|
||||
continue
|
||||
}
|
||||
|
||||
var ctx context.Context
|
||||
ctx, s.upstreamCtxCancel = context.WithCancel(s.ctx)
|
||||
|
||||
handler := newUpstreamResolver(ctx)
|
||||
for _, ns := range nsGroup.NameServers {
|
||||
if ns.NSType != nbdns.UDPNameServerType {
|
||||
log.Warnf("skiping nameserver %s with type %s, this peer supports only %s",
|
||||
ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String())
|
||||
continue
|
||||
}
|
||||
handler.upstreamServers = append(handler.upstreamServers, getNSHostPort(ns))
|
||||
}
|
||||
|
||||
if len(handler.upstreamServers) == 0 {
|
||||
log.Errorf("received a nameserver group with an invalid nameserver list")
|
||||
continue
|
||||
}
|
||||
|
||||
// when upstream fails to resolve domain several times over all it servers
|
||||
// it will calls this hook to exclude self from the configuration and
|
||||
// reapply DNS settings, but it not touch the original configuration and serial number
|
||||
// because it is temporal deactivation until next try
|
||||
//
|
||||
// after some period defined by upstream it trys to reactivate self by calling this hook
|
||||
// everything we need here is just to re-apply current configuration because it already
|
||||
// contains this upstream settings (temporal deactivation not removed it)
|
||||
handler.deactivate, handler.reactivate = s.upstreamCallbacks(nsGroup, handler)
|
||||
|
||||
if nsGroup.Primary {
|
||||
muxUpdates = append(muxUpdates, muxUpdate{
|
||||
domain: nbdns.RootZone,
|
||||
handler: handler,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
if len(nsGroup.Domains) == 0 {
|
||||
return nil, fmt.Errorf("received a non primary nameserver group with an empty domain list")
|
||||
}
|
||||
|
||||
for _, domain := range nsGroup.Domains {
|
||||
if domain == "" {
|
||||
return nil, fmt.Errorf("received a nameserver group with an empty domain element")
|
||||
}
|
||||
muxUpdates = append(muxUpdates, muxUpdate{
|
||||
domain: domain,
|
||||
handler: handler,
|
||||
})
|
||||
}
|
||||
}
|
||||
return muxUpdates, nil
|
||||
}
|
||||
|
||||
func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
|
||||
muxUpdateMap := make(registrationMap)
|
||||
|
||||
for _, update := range muxUpdates {
|
||||
s.registerMux(update.domain, update.handler)
|
||||
muxUpdateMap[update.domain] = struct{}{}
|
||||
}
|
||||
|
||||
for key := range s.dnsMuxMap {
|
||||
_, found := muxUpdateMap[key]
|
||||
if !found {
|
||||
s.deregisterMux(key)
|
||||
}
|
||||
}
|
||||
|
||||
s.dnsMuxMap = muxUpdateMap
|
||||
}
|
||||
|
||||
func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord) {
|
||||
for key := range s.localResolver.registeredMap {
|
||||
_, found := update[key]
|
||||
if !found {
|
||||
s.localResolver.deleteRecord(key)
|
||||
}
|
||||
}
|
||||
|
||||
updatedMap := make(registrationMap)
|
||||
for key, record := range update {
|
||||
err := s.localResolver.registerRecord(record)
|
||||
if err != nil {
|
||||
log.Warnf("got an error while registering the record (%s), error: %v", record.String(), err)
|
||||
}
|
||||
updatedMap[key] = struct{}{}
|
||||
}
|
||||
|
||||
s.localResolver.registeredMap = updatedMap
|
||||
}
|
||||
|
||||
func getNSHostPort(ns nbdns.NameServer) string {
|
||||
return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port)
|
||||
}
|
||||
|
||||
func (s *DefaultServer) registerMux(pattern string, handler dns.Handler) {
|
||||
s.dnsMux.Handle(pattern, handler)
|
||||
}
|
||||
|
||||
func (s *DefaultServer) deregisterMux(pattern string) {
|
||||
s.dnsMux.HandleRemove(pattern)
|
||||
}
|
||||
|
||||
// upstreamCallbacks returns two functions, the first one is used to deactivate
|
||||
// the upstream resolver from the configuration, the second one is used to
|
||||
// reactivate it. Not allowed to call reactivate before deactivate.
|
||||
func (s *DefaultServer) upstreamCallbacks(
|
||||
nsGroup *nbdns.NameServerGroup,
|
||||
handler dns.Handler,
|
||||
) (deactivate func(), reactivate func()) {
|
||||
var removeIndex map[string]int
|
||||
deactivate = func() {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
l := log.WithField("nameservers", nsGroup.NameServers)
|
||||
l.Info("temporary deactivate nameservers group due timeout")
|
||||
|
||||
removeIndex = make(map[string]int)
|
||||
for _, domain := range nsGroup.Domains {
|
||||
removeIndex[domain] = -1
|
||||
}
|
||||
if nsGroup.Primary {
|
||||
removeIndex[nbdns.RootZone] = -1
|
||||
s.currentConfig.routeAll = false
|
||||
}
|
||||
|
||||
for i, item := range s.currentConfig.domains {
|
||||
if _, found := removeIndex[item.domain]; found {
|
||||
s.currentConfig.domains[i].disabled = true
|
||||
s.deregisterMux(item.domain)
|
||||
removeIndex[item.domain] = i
|
||||
}
|
||||
}
|
||||
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
|
||||
l.WithError(err).Error("fail to apply nameserver deactivation on the host")
|
||||
}
|
||||
}
|
||||
reactivate = func() {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
for domain, i := range removeIndex {
|
||||
if i == -1 || i >= len(s.currentConfig.domains) || s.currentConfig.domains[i].domain != domain {
|
||||
continue
|
||||
}
|
||||
s.currentConfig.domains[i].disabled = false
|
||||
s.registerMux(domain, handler)
|
||||
}
|
||||
|
||||
l := log.WithField("nameservers", nsGroup.NameServers)
|
||||
l.Debug("reactivate temporary disabled nameserver group")
|
||||
|
||||
if nsGroup.Primary {
|
||||
s.currentConfig.routeAll = true
|
||||
}
|
||||
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
|
||||
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
32
client/internal/dns/server_android.go
Normal file
32
client/internal/dns/server_android.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
// DefaultServer dummy dns server
|
||||
type DefaultServer struct {
|
||||
}
|
||||
|
||||
// NewDefaultServer On Android the DNS feature is not supported yet
|
||||
func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAddress string) (*DefaultServer, error) {
|
||||
return &DefaultServer{}, nil
|
||||
}
|
||||
|
||||
// Start dummy implementation
|
||||
func (s DefaultServer) Start() {
|
||||
|
||||
}
|
||||
|
||||
// Stop dummy implementation
|
||||
func (s DefaultServer) Stop() {
|
||||
|
||||
}
|
||||
|
||||
// UpdateDNSServer dummy implementation
|
||||
func (s DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
|
||||
return nil
|
||||
}
|
||||
465
client/internal/dns/server_nonandroid.go
Normal file
465
client/internal/dns/server_nonandroid.go
Normal file
@@ -0,0 +1,465 @@
|
||||
//go:build !android
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/mitchellh/hashstructure/v2"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultPort = 53
|
||||
customPort = 5053
|
||||
defaultIP = "127.0.0.1"
|
||||
customIP = "127.0.0.153"
|
||||
)
|
||||
|
||||
// DefaultServer dns server object
|
||||
type DefaultServer struct {
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
upstreamCtxCancel context.CancelFunc
|
||||
mux sync.Mutex
|
||||
server *dns.Server
|
||||
dnsMux *dns.ServeMux
|
||||
dnsMuxMap registrationMap
|
||||
localResolver *localResolver
|
||||
wgInterface *iface.WGIface
|
||||
hostManager hostManager
|
||||
updateSerial uint64
|
||||
listenerIsRunning bool
|
||||
runtimePort int
|
||||
runtimeIP string
|
||||
previousConfigHash uint64
|
||||
currentConfig hostDNSConfig
|
||||
customAddress *netip.AddrPort
|
||||
}
|
||||
|
||||
type muxUpdate struct {
|
||||
domain string
|
||||
handler dns.Handler
|
||||
}
|
||||
|
||||
// NewDefaultServer returns a new dns server
|
||||
func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAddress string) (*DefaultServer, error) {
|
||||
mux := dns.NewServeMux()
|
||||
|
||||
dnsServer := &dns.Server{
|
||||
Net: "udp",
|
||||
Handler: mux,
|
||||
UDPSize: 65535,
|
||||
}
|
||||
|
||||
ctx, stop := context.WithCancel(ctx)
|
||||
|
||||
var addrPort *netip.AddrPort
|
||||
if customAddress != "" {
|
||||
parsedAddrPort, err := netip.ParseAddrPort(customAddress)
|
||||
if err != nil {
|
||||
stop()
|
||||
return nil, fmt.Errorf("unable to parse the custom dns address, got error: %s", err)
|
||||
}
|
||||
addrPort = &parsedAddrPort
|
||||
}
|
||||
|
||||
defaultServer := &DefaultServer{
|
||||
ctx: ctx,
|
||||
ctxCancel: stop,
|
||||
server: dnsServer,
|
||||
dnsMux: mux,
|
||||
dnsMuxMap: make(registrationMap),
|
||||
localResolver: &localResolver{
|
||||
registeredMap: make(registrationMap),
|
||||
},
|
||||
wgInterface: wgInterface,
|
||||
runtimePort: defaultPort,
|
||||
customAddress: addrPort,
|
||||
}
|
||||
|
||||
hostmanager, err := newHostManager(wgInterface)
|
||||
if err != nil {
|
||||
stop()
|
||||
return nil, err
|
||||
}
|
||||
defaultServer.hostManager = hostmanager
|
||||
return defaultServer, err
|
||||
}
|
||||
|
||||
// Start runs the listener in a go routine
|
||||
func (s *DefaultServer) Start() {
|
||||
if s.customAddress != nil {
|
||||
s.runtimeIP = s.customAddress.Addr().String()
|
||||
s.runtimePort = int(s.customAddress.Port())
|
||||
} else {
|
||||
ip, port, err := s.getFirstListenerAvailable()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
s.runtimeIP = ip
|
||||
s.runtimePort = port
|
||||
}
|
||||
|
||||
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
|
||||
|
||||
log.Debugf("starting dns on %s", s.server.Addr)
|
||||
|
||||
go func() {
|
||||
s.setListenerStatus(true)
|
||||
defer s.setListenerStatus(false)
|
||||
|
||||
err := s.server.ListenAndServe()
|
||||
if err != nil {
|
||||
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.runtimePort, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *DefaultServer) getFirstListenerAvailable() (string, int, error) {
|
||||
ips := []string{defaultIP, customIP}
|
||||
if runtime.GOOS != "darwin" && s.wgInterface != nil {
|
||||
ips = append([]string{s.wgInterface.Address().IP.String()}, ips...)
|
||||
}
|
||||
ports := []int{defaultPort, customPort}
|
||||
for _, port := range ports {
|
||||
for _, ip := range ips {
|
||||
addrString := fmt.Sprintf("%s:%d", ip, port)
|
||||
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
|
||||
probeListener, err := net.ListenUDP("udp", udpAddr)
|
||||
if err == nil {
|
||||
err = probeListener.Close()
|
||||
if err != nil {
|
||||
log.Errorf("got an error closing the probe listener, error: %s", err)
|
||||
}
|
||||
return ip, port, nil
|
||||
}
|
||||
log.Warnf("binding dns on %s is not available, error: %s", addrString, err)
|
||||
}
|
||||
}
|
||||
return "", 0, fmt.Errorf("unable to find an unused ip and port combination. IPs tested: %v and ports %v", ips, ports)
|
||||
}
|
||||
|
||||
func (s *DefaultServer) setListenerStatus(running bool) {
|
||||
s.listenerIsRunning = running
|
||||
}
|
||||
|
||||
// Stop stops the server
|
||||
func (s *DefaultServer) Stop() {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
s.ctxCancel()
|
||||
|
||||
err := s.hostManager.restoreHostDNS()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
|
||||
err = s.stopListener()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DefaultServer) stopListener() error {
|
||||
if !s.listenerIsRunning {
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := s.server.ShutdownContext(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("stopping dns server listener returned an error: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateDNSServer processes an update received from the management service
|
||||
func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
log.Infof("not updating DNS server as context is closed")
|
||||
return s.ctx.Err()
|
||||
default:
|
||||
if serial < s.updateSerial {
|
||||
return fmt.Errorf("not applying dns update, error: "+
|
||||
"network update is %d behind the last applied update", s.updateSerial-serial)
|
||||
}
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
hash, err := hashstructure.Hash(update, hashstructure.FormatV2, &hashstructure.HashOptions{
|
||||
ZeroNil: true,
|
||||
IgnoreZeroValue: true,
|
||||
SlicesAsSets: true,
|
||||
UseStringer: true,
|
||||
})
|
||||
if err != nil {
|
||||
log.Errorf("unable to hash the dns configuration update, got error: %s", err)
|
||||
}
|
||||
|
||||
if s.previousConfigHash == hash {
|
||||
log.Debugf("not applying the dns configuration update as there is nothing new")
|
||||
s.updateSerial = serial
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.applyConfiguration(update); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.updateSerial = serial
|
||||
s.previousConfigHash = hash
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
// is the service should be disabled, we stop the listener
|
||||
// and proceed with a regular update to clean up the handlers and records
|
||||
if !update.ServiceEnable {
|
||||
err := s.stopListener()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
} else if !s.listenerIsRunning {
|
||||
s.Start()
|
||||
}
|
||||
|
||||
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
|
||||
if err != nil {
|
||||
return fmt.Errorf("not applying dns update, error: %v", err)
|
||||
}
|
||||
upstreamMuxUpdates, err := s.buildUpstreamHandlerUpdate(update.NameServerGroups)
|
||||
if err != nil {
|
||||
return fmt.Errorf("not applying dns update, error: %v", err)
|
||||
}
|
||||
|
||||
muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...)
|
||||
|
||||
s.updateMux(muxUpdates)
|
||||
s.updateLocalResolver(localRecords)
|
||||
s.currentConfig = dnsConfigToHostDNSConfig(update, s.runtimeIP, s.runtimePort)
|
||||
|
||||
if err = s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]muxUpdate, map[string]nbdns.SimpleRecord, error) {
|
||||
var muxUpdates []muxUpdate
|
||||
localRecords := make(map[string]nbdns.SimpleRecord, 0)
|
||||
|
||||
for _, customZone := range customZones {
|
||||
|
||||
if len(customZone.Records) == 0 {
|
||||
return nil, nil, fmt.Errorf("received an empty list of records")
|
||||
}
|
||||
|
||||
muxUpdates = append(muxUpdates, muxUpdate{
|
||||
domain: customZone.Domain,
|
||||
handler: s.localResolver,
|
||||
})
|
||||
|
||||
for _, record := range customZone.Records {
|
||||
var class uint16 = dns.ClassINET
|
||||
if record.Class != nbdns.DefaultClass {
|
||||
return nil, nil, fmt.Errorf("received an invalid class type: %s", record.Class)
|
||||
}
|
||||
key := buildRecordKey(record.Name, class, uint16(record.Type))
|
||||
localRecords[key] = record
|
||||
}
|
||||
}
|
||||
return muxUpdates, localRecords, nil
|
||||
}
|
||||
|
||||
func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]muxUpdate, error) {
|
||||
// clean up the previous upstream resolver
|
||||
if s.upstreamCtxCancel != nil {
|
||||
s.upstreamCtxCancel()
|
||||
}
|
||||
|
||||
var muxUpdates []muxUpdate
|
||||
for _, nsGroup := range nameServerGroups {
|
||||
if len(nsGroup.NameServers) == 0 {
|
||||
log.Warn("received a nameserver group with empty nameserver list")
|
||||
continue
|
||||
}
|
||||
|
||||
var ctx context.Context
|
||||
ctx, s.upstreamCtxCancel = context.WithCancel(s.ctx)
|
||||
|
||||
handler := newUpstreamResolver(ctx)
|
||||
for _, ns := range nsGroup.NameServers {
|
||||
if ns.NSType != nbdns.UDPNameServerType {
|
||||
log.Warnf("skiping nameserver %s with type %s, this peer supports only %s",
|
||||
ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String())
|
||||
continue
|
||||
}
|
||||
handler.upstreamServers = append(handler.upstreamServers, getNSHostPort(ns))
|
||||
}
|
||||
|
||||
if len(handler.upstreamServers) == 0 {
|
||||
log.Errorf("received a nameserver group with an invalid nameserver list")
|
||||
continue
|
||||
}
|
||||
|
||||
// when upstream fails to resolve domain several times over all it servers
|
||||
// it will calls this hook to exclude self from the configuration and
|
||||
// reapply DNS settings, but it not touch the original configuration and serial number
|
||||
// because it is temporal deactivation until next try
|
||||
//
|
||||
// after some period defined by upstream it trys to reactivate self by calling this hook
|
||||
// everything we need here is just to re-apply current configuration because it already
|
||||
// contains this upstream settings (temporal deactivation not removed it)
|
||||
handler.deactivate, handler.reactivate = s.upstreamCallbacks(nsGroup, handler)
|
||||
|
||||
if nsGroup.Primary {
|
||||
muxUpdates = append(muxUpdates, muxUpdate{
|
||||
domain: nbdns.RootZone,
|
||||
handler: handler,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
if len(nsGroup.Domains) == 0 {
|
||||
return nil, fmt.Errorf("received a non primary nameserver group with an empty domain list")
|
||||
}
|
||||
|
||||
for _, domain := range nsGroup.Domains {
|
||||
if domain == "" {
|
||||
return nil, fmt.Errorf("received a nameserver group with an empty domain element")
|
||||
}
|
||||
muxUpdates = append(muxUpdates, muxUpdate{
|
||||
domain: domain,
|
||||
handler: handler,
|
||||
})
|
||||
}
|
||||
}
|
||||
return muxUpdates, nil
|
||||
}
|
||||
|
||||
func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
|
||||
muxUpdateMap := make(registrationMap)
|
||||
|
||||
for _, update := range muxUpdates {
|
||||
s.registerMux(update.domain, update.handler)
|
||||
muxUpdateMap[update.domain] = struct{}{}
|
||||
}
|
||||
|
||||
for key := range s.dnsMuxMap {
|
||||
_, found := muxUpdateMap[key]
|
||||
if !found {
|
||||
s.deregisterMux(key)
|
||||
}
|
||||
}
|
||||
|
||||
s.dnsMuxMap = muxUpdateMap
|
||||
}
|
||||
|
||||
func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord) {
|
||||
for key := range s.localResolver.registeredMap {
|
||||
_, found := update[key]
|
||||
if !found {
|
||||
s.localResolver.deleteRecord(key)
|
||||
}
|
||||
}
|
||||
|
||||
updatedMap := make(registrationMap)
|
||||
for key, record := range update {
|
||||
err := s.localResolver.registerRecord(record)
|
||||
if err != nil {
|
||||
log.Warnf("got an error while registering the record (%s), error: %v", record.String(), err)
|
||||
}
|
||||
updatedMap[key] = struct{}{}
|
||||
}
|
||||
|
||||
s.localResolver.registeredMap = updatedMap
|
||||
}
|
||||
|
||||
func getNSHostPort(ns nbdns.NameServer) string {
|
||||
return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port)
|
||||
}
|
||||
|
||||
func (s *DefaultServer) registerMux(pattern string, handler dns.Handler) {
|
||||
s.dnsMux.Handle(pattern, handler)
|
||||
}
|
||||
|
||||
func (s *DefaultServer) deregisterMux(pattern string) {
|
||||
s.dnsMux.HandleRemove(pattern)
|
||||
}
|
||||
|
||||
// upstreamCallbacks returns two functions, the first one is used to deactivate
|
||||
// the upstream resolver from the configuration, the second one is used to
|
||||
// reactivate it. Not allowed to call reactivate before deactivate.
|
||||
func (s *DefaultServer) upstreamCallbacks(
|
||||
nsGroup *nbdns.NameServerGroup,
|
||||
handler dns.Handler,
|
||||
) (deactivate func(), reactivate func()) {
|
||||
var removeIndex map[string]int
|
||||
deactivate = func() {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
l := log.WithField("nameservers", nsGroup.NameServers)
|
||||
l.Info("temporary deactivate nameservers group due timeout")
|
||||
|
||||
removeIndex = make(map[string]int)
|
||||
for _, domain := range nsGroup.Domains {
|
||||
removeIndex[domain] = -1
|
||||
}
|
||||
if nsGroup.Primary {
|
||||
removeIndex[nbdns.RootZone] = -1
|
||||
s.currentConfig.routeAll = false
|
||||
}
|
||||
|
||||
for i, item := range s.currentConfig.domains {
|
||||
if _, found := removeIndex[item.domain]; found {
|
||||
s.currentConfig.domains[i].disabled = true
|
||||
s.deregisterMux(item.domain)
|
||||
removeIndex[item.domain] = i
|
||||
}
|
||||
}
|
||||
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
|
||||
l.WithError(err).Error("fail to apply nameserver deactivation on the host")
|
||||
}
|
||||
}
|
||||
reactivate = func() {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
for domain, i := range removeIndex {
|
||||
if i == -1 || i >= len(s.currentConfig.domains) || s.currentConfig.domains[i].domain != domain {
|
||||
continue
|
||||
}
|
||||
s.currentConfig.domains[i].disabled = false
|
||||
s.registerMux(domain, handler)
|
||||
}
|
||||
|
||||
l := log.WithField("nameservers", nsGroup.NameServers)
|
||||
l.Debug("reactivate temporary disabled nameserver group")
|
||||
|
||||
if nsGroup.Primary {
|
||||
s.currentConfig.routeAll = true
|
||||
}
|
||||
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
|
||||
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -199,7 +199,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
|
||||
for n, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), iface.DefaultMTU)
|
||||
wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), iface.DefaultMTU, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -46,6 +46,8 @@ var ErrResetConnection = fmt.Errorf("reset connection")
|
||||
type EngineConfig struct {
|
||||
WgPort int
|
||||
WgIfaceName string
|
||||
// TunAdapter is option. It is necessary for mobile version.
|
||||
TunAdapter iface.TunAdapter
|
||||
|
||||
// WgAddr is a Wireguard local address (Netbird Network IP)
|
||||
WgAddr string
|
||||
@@ -173,7 +175,7 @@ func (e *Engine) Start() error {
|
||||
myPrivateKey := e.config.WgPrivateKey
|
||||
var err error
|
||||
|
||||
e.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU)
|
||||
e.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, e.config.TunAdapter)
|
||||
if err != nil {
|
||||
log.Errorf("failed creating wireguard interface instance %s: [%s]", wgIfaceName, err.Error())
|
||||
return err
|
||||
@@ -614,6 +616,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
if protoDNSConfig == nil {
|
||||
protoDNSConfig = &mgmProto.DNSConfig{}
|
||||
}
|
||||
|
||||
err = e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig))
|
||||
if err != nil {
|
||||
log.Errorf("failed to update dns server, err: %v", err)
|
||||
|
||||
@@ -207,7 +207,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
}, peer.NewRecorder("https://mgm"))
|
||||
engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", iface.DefaultMTU)
|
||||
engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", iface.DefaultMTU, nil)
|
||||
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), engine.wgInterface, engine.statusRecorder)
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||
@@ -549,7 +549,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
}, peer.NewRecorder("https://mgm"))
|
||||
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU)
|
||||
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, nil)
|
||||
assert.NoError(t, err, "shouldn't return error")
|
||||
input := struct {
|
||||
inputSerial uint64
|
||||
@@ -714,7 +714,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
}, peer.NewRecorder("https://mgm"))
|
||||
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU)
|
||||
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, nil)
|
||||
assert.NoError(t, err, "shouldn't return error")
|
||||
|
||||
mockRouteManager := &routemanager.MockManager{
|
||||
|
||||
@@ -2,37 +2,26 @@ package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
mgm "github.com/netbirdio/netbird/management/client"
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
mgm "github.com/netbirdio/netbird/management/client"
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
)
|
||||
|
||||
func Login(ctx context.Context, config *Config, setupKey string, jwtToken string) error {
|
||||
// validate our peer's Wireguard PRIVATE key
|
||||
myPrivateKey, err := wgtypes.ParseKey(config.PrivateKey)
|
||||
// IsLoginRequired check that the server is support SSO or not
|
||||
func IsLoginRequired(ctx context.Context, privateKey string, mgmURL *url.URL, sshKey string) (bool, error) {
|
||||
mgmClient, err := getMgmClient(ctx, privateKey, mgmURL)
|
||||
if err != nil {
|
||||
log.Errorf("failed parsing Wireguard key %s: [%s]", config.PrivateKey, err.Error())
|
||||
return err
|
||||
return false, err
|
||||
}
|
||||
|
||||
var mgmTlsEnabled bool
|
||||
if config.ManagementURL.Scheme == "https" {
|
||||
mgmTlsEnabled = true
|
||||
}
|
||||
|
||||
log.Debugf("connecting to the Management service %s", config.ManagementURL.String())
|
||||
mgmClient, err := mgm.NewClient(ctx, config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled)
|
||||
if err != nil {
|
||||
log.Errorf("failed connecting to the Management service %s %v", config.ManagementURL.String(), err)
|
||||
return err
|
||||
}
|
||||
log.Debugf("connected to the Management service %s", config.ManagementURL.String())
|
||||
defer func() {
|
||||
err = mgmClient.Close()
|
||||
if err != nil {
|
||||
@@ -42,40 +31,84 @@ func Login(ctx context.Context, config *Config, setupKey string, jwtToken string
|
||||
}
|
||||
}
|
||||
}()
|
||||
log.Debugf("connected to the Management service %s", mgmURL.String())
|
||||
|
||||
serverKey, err := mgmClient.GetServerPublicKey()
|
||||
pubSSHKey, err := ssh.GeneratePublicKey([]byte(sshKey))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
_, err = doMgmLogin(ctx, mgmClient, pubSSHKey)
|
||||
if isLoginNeeded(err) {
|
||||
return true, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Login or register the client
|
||||
func Login(ctx context.Context, config *Config, setupKey string, jwtToken string) error {
|
||||
mgmClient, err := getMgmClient(ctx, config.PrivateKey, config.ManagementURL)
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
err = mgmClient.Close()
|
||||
if err != nil {
|
||||
cStatus, ok := status.FromError(err)
|
||||
if !ok || ok && cStatus.Code() != codes.Canceled {
|
||||
log.Warnf("failed to close the Management service client, err: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
log.Debugf("connected to the Management service %s", config.ManagementURL.String())
|
||||
|
||||
pubSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = loginPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey)
|
||||
if err != nil {
|
||||
log.Errorf("failed logging-in peer on Management Service : %v", err)
|
||||
|
||||
serverKey, err := doMgmLogin(ctx, mgmClient, pubSSHKey)
|
||||
if isRegistrationNeeded(err) {
|
||||
log.Debugf("peer registration required")
|
||||
_, err = registerPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey)
|
||||
return err
|
||||
}
|
||||
log.Infof("peer has successfully logged-in to the Management service %s", config.ManagementURL.String())
|
||||
return nil
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// loginPeer attempts to login to Management Service. If peer wasn't registered, tries the registration flow.
|
||||
func loginPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.GrpcClient, setupKey string, jwtToken string, pubSSHKey []byte) (*mgmProto.LoginResponse, error) {
|
||||
sysInfo := system.GetInfo(ctx)
|
||||
loginResp, err := client.Login(serverPublicKey, sysInfo, pubSSHKey)
|
||||
func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm.GrpcClient, error) {
|
||||
// validate our peer's Wireguard PRIVATE key
|
||||
myPrivateKey, err := wgtypes.ParseKey(privateKey)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.PermissionDenied {
|
||||
log.Debugf("peer registration required")
|
||||
return registerPeer(ctx, serverPublicKey, client, setupKey, jwtToken, pubSSHKey)
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return loginResp, nil
|
||||
var mgmTlsEnabled bool
|
||||
if mgmURL.Scheme == "https" {
|
||||
mgmTlsEnabled = true
|
||||
}
|
||||
|
||||
log.Debugf("connecting to the Management service %s", mgmURL.String())
|
||||
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTlsEnabled)
|
||||
if err != nil {
|
||||
log.Errorf("failed connecting to the Management service %s %v", mgmURL.String(), err)
|
||||
return nil, err
|
||||
}
|
||||
return mgmClient, err
|
||||
}
|
||||
|
||||
func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte) (*wgtypes.Key, error) {
|
||||
serverKey, err := mgmClient.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sysInfo := system.GetInfo(ctx)
|
||||
_, err = mgmClient.Login(*serverKey, sysInfo, pubSSHKey)
|
||||
return serverKey, err
|
||||
}
|
||||
|
||||
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
|
||||
@@ -98,3 +131,31 @@ func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.
|
||||
|
||||
return loginResp, nil
|
||||
}
|
||||
|
||||
func isLoginNeeded(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
s, ok := status.FromError(err)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isRegistrationNeeded(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
s, ok := status.FromError(err)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if s.Code() == codes.PermissionDenied {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/pion/ice/v2"
|
||||
"github.com/pion/transport/v2/stdnet"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl"
|
||||
|
||||
@@ -161,7 +162,10 @@ func (conn *Conn) reCreateAgent() error {
|
||||
defer conn.mu.Unlock()
|
||||
|
||||
failedTimeout := 6 * time.Second
|
||||
var err error
|
||||
transportNet, err := stdnet.NewNet()
|
||||
if err != nil {
|
||||
log.Warnf("failed to create pion's stdnet: %s", err)
|
||||
}
|
||||
agentConfig := &ice.AgentConfig{
|
||||
MulticastDNSMode: ice.MulticastDNSModeDisabled,
|
||||
NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6},
|
||||
@@ -172,6 +176,7 @@ func (conn *Conn) reCreateAgent() error {
|
||||
UDPMux: conn.config.UDPMux,
|
||||
UDPMuxSrflx: conn.config.UDPMuxSrflx,
|
||||
NAT1To1IPs: conn.config.NATExternalIPs,
|
||||
Net: transportNet,
|
||||
}
|
||||
|
||||
if conn.config.DisableIPv6Discovery {
|
||||
|
||||
9
client/internal/peer/listener.go
Normal file
9
client/internal/peer/listener.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package peer
|
||||
|
||||
// Listener is a callback type about the NetBird network connection state
|
||||
type Listener interface {
|
||||
OnConnected()
|
||||
OnDisconnected()
|
||||
OnConnecting()
|
||||
OnPeersListChanged(int)
|
||||
}
|
||||
124
client/internal/peer/notifier.go
Normal file
124
client/internal/peer/notifier.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
const (
|
||||
stateDisconnected = iota
|
||||
stateConnected
|
||||
stateConnecting
|
||||
)
|
||||
|
||||
type notifier struct {
|
||||
serverStateLock sync.Mutex
|
||||
listenersLock sync.Mutex
|
||||
listeners map[Listener]struct{}
|
||||
currentServerState bool
|
||||
currentClientState bool
|
||||
lastNotification int
|
||||
}
|
||||
|
||||
func newNotifier() *notifier {
|
||||
return ¬ifier{
|
||||
listeners: make(map[Listener]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (n *notifier) addListener(listener Listener) {
|
||||
n.listenersLock.Lock()
|
||||
defer n.listenersLock.Unlock()
|
||||
|
||||
n.serverStateLock.Lock()
|
||||
go n.notifyListener(listener, n.lastNotification)
|
||||
n.serverStateLock.Unlock()
|
||||
n.listeners[listener] = struct{}{}
|
||||
}
|
||||
|
||||
func (n *notifier) removeListener(listener Listener) {
|
||||
n.listenersLock.Lock()
|
||||
defer n.listenersLock.Unlock()
|
||||
delete(n.listeners, listener)
|
||||
}
|
||||
|
||||
func (n *notifier) updateServerStates(mgmState bool, signalState bool) {
|
||||
n.serverStateLock.Lock()
|
||||
defer n.serverStateLock.Unlock()
|
||||
|
||||
var newState bool
|
||||
if mgmState && signalState {
|
||||
newState = true
|
||||
} else {
|
||||
newState = false
|
||||
}
|
||||
|
||||
if !n.isServerStateChanged(newState) {
|
||||
return
|
||||
}
|
||||
|
||||
n.currentServerState = newState
|
||||
n.lastNotification = n.calculateState(newState, n.currentClientState)
|
||||
|
||||
go n.notifyAll(n.lastNotification)
|
||||
}
|
||||
|
||||
func (n *notifier) clientStart() {
|
||||
n.serverStateLock.Lock()
|
||||
defer n.serverStateLock.Unlock()
|
||||
n.currentClientState = true
|
||||
n.lastNotification = n.calculateState(n.currentServerState, true)
|
||||
go n.notifyAll(n.lastNotification)
|
||||
}
|
||||
|
||||
func (n *notifier) clientStop() {
|
||||
n.serverStateLock.Lock()
|
||||
defer n.serverStateLock.Unlock()
|
||||
n.currentClientState = false
|
||||
n.lastNotification = n.calculateState(n.currentServerState, false)
|
||||
go n.notifyAll(n.lastNotification)
|
||||
}
|
||||
|
||||
func (n *notifier) isServerStateChanged(newState bool) bool {
|
||||
return n.currentServerState != newState
|
||||
}
|
||||
|
||||
func (n *notifier) notifyAll(state int) {
|
||||
n.listenersLock.Lock()
|
||||
defer n.listenersLock.Unlock()
|
||||
|
||||
for l := range n.listeners {
|
||||
n.notifyListener(l, state)
|
||||
}
|
||||
}
|
||||
|
||||
func (n *notifier) notifyListener(l Listener, state int) {
|
||||
switch state {
|
||||
case stateDisconnected:
|
||||
l.OnDisconnected()
|
||||
case stateConnected:
|
||||
l.OnConnected()
|
||||
case stateConnecting:
|
||||
l.OnConnecting()
|
||||
}
|
||||
}
|
||||
|
||||
func (n *notifier) calculateState(serverState bool, clientState bool) int {
|
||||
if serverState && clientState {
|
||||
return stateConnected
|
||||
}
|
||||
|
||||
if !clientState {
|
||||
return stateDisconnected
|
||||
}
|
||||
|
||||
return stateConnecting
|
||||
}
|
||||
|
||||
func (n *notifier) peerListChanged(numOfPeers int) {
|
||||
n.listenersLock.Lock()
|
||||
defer n.listenersLock.Unlock()
|
||||
|
||||
for l := range n.listeners {
|
||||
l.OnPeersListChanged(numOfPeers)
|
||||
}
|
||||
}
|
||||
32
client/internal/peer/notifier_test.go
Normal file
32
client/internal/peer/notifier_test.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_notifier_serverState(t *testing.T) {
|
||||
|
||||
type scenario struct {
|
||||
name string
|
||||
expected bool
|
||||
mgmState bool
|
||||
signalState bool
|
||||
}
|
||||
scenarios := []scenario{
|
||||
{"connected", true, true, true},
|
||||
{"mgm down", false, false, true},
|
||||
{"signal down", false, true, false},
|
||||
{"disconnected", false, false, false},
|
||||
}
|
||||
|
||||
for _, tt := range scenarios {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
n := newNotifier()
|
||||
n.updateServerStates(tt.mgmState, tt.signalState)
|
||||
if n.currentServerState != tt.expected {
|
||||
t.Errorf("invalid serverstate: %t, expected: %t", n.currentServerState, tt.expected)
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -58,6 +58,7 @@ type Status struct {
|
||||
offlinePeers []State
|
||||
mgmAddress string
|
||||
signalAddress string
|
||||
notifier *notifier
|
||||
}
|
||||
|
||||
// NewRecorder returns a new Status instance
|
||||
@@ -66,6 +67,7 @@ func NewRecorder(mgmAddress string) *Status {
|
||||
peers: make(map[string]State),
|
||||
changeNotify: make(map[string]chan struct{}),
|
||||
offlinePeers: make([]State, 0),
|
||||
notifier: newNotifier(),
|
||||
mgmAddress: mgmAddress,
|
||||
}
|
||||
}
|
||||
@@ -114,6 +116,7 @@ func (d *Status) RemovePeer(peerPubKey string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
d.notifyPeerListChanged()
|
||||
return errors.New("no peer with to remove")
|
||||
}
|
||||
|
||||
@@ -148,6 +151,7 @@ func (d *Status) UpdatePeerState(receivedState State) error {
|
||||
d.changeNotify[receivedState.PubKey] = nil
|
||||
}
|
||||
|
||||
d.notifyPeerListChanged()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -164,6 +168,7 @@ func (d *Status) UpdatePeerFQDN(peerPubKey, fqdn string) error {
|
||||
peerState.FQDN = fqdn
|
||||
d.peers[peerPubKey] = peerState
|
||||
|
||||
d.notifyPeerListChanged()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -199,6 +204,8 @@ func (d *Status) CleanLocalPeerState() {
|
||||
func (d *Status) MarkManagementDisconnected() {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
defer d.onConnectionChanged()
|
||||
|
||||
d.managementState = false
|
||||
}
|
||||
|
||||
@@ -206,7 +213,9 @@ func (d *Status) MarkManagementDisconnected() {
|
||||
func (d *Status) MarkManagementConnected() {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.managementState = true
|
||||
defer d.onConnectionChanged()
|
||||
|
||||
d.managementState = true
|
||||
}
|
||||
|
||||
// UpdateSignalAddress update the address of the signal server
|
||||
@@ -227,13 +236,17 @@ func (d *Status) UpdateManagementAddress(mgmAddress string) {
|
||||
func (d *Status) MarkSignalDisconnected() {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.signalState = false
|
||||
defer d.onConnectionChanged()
|
||||
|
||||
d.signalState = false
|
||||
}
|
||||
|
||||
// MarkSignalConnected sets SignalState to connected
|
||||
func (d *Status) MarkSignalConnected() {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
defer d.onConnectionChanged()
|
||||
|
||||
d.signalState = true
|
||||
}
|
||||
|
||||
@@ -262,3 +275,31 @@ func (d *Status) GetFullStatus() FullStatus {
|
||||
|
||||
return fullStatus
|
||||
}
|
||||
|
||||
// ClientStart will notify all listeners about the new service state
|
||||
func (d *Status) ClientStart() {
|
||||
d.notifier.clientStart()
|
||||
}
|
||||
|
||||
// ClientStop will notify all listeners about the new service state
|
||||
func (d *Status) ClientStop() {
|
||||
d.notifier.clientStop()
|
||||
}
|
||||
|
||||
// AddConnectionListener add a listener to the notifier
|
||||
func (d *Status) AddConnectionListener(listener Listener) {
|
||||
d.notifier.addListener(listener)
|
||||
}
|
||||
|
||||
// RemoveConnectionListener remove a listener from the notifier
|
||||
func (d *Status) RemoveConnectionListener(listener Listener) {
|
||||
d.notifier.removeListener(listener)
|
||||
}
|
||||
|
||||
func (d *Status) onConnectionChanged() {
|
||||
d.notifier.updateServerStates(d.managementState, d.signalState)
|
||||
}
|
||||
|
||||
func (d *Status) notifyPeerListChanged() {
|
||||
d.notifier.peerListChanged(len(d.peers))
|
||||
}
|
||||
|
||||
@@ -1,190 +1,9 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
import "github.com/netbirdio/netbird/route"
|
||||
|
||||
// Manager is a route manager interface
|
||||
type Manager interface {
|
||||
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error
|
||||
Stop()
|
||||
}
|
||||
|
||||
// DefaultManager is the default instance of a route manager
|
||||
type DefaultManager struct {
|
||||
ctx context.Context
|
||||
stop context.CancelFunc
|
||||
mux sync.Mutex
|
||||
clientNetworks map[string]*clientNetwork
|
||||
serverRoutes map[string]*route.Route
|
||||
serverRouter *serverRouter
|
||||
statusRecorder *peer.Status
|
||||
wgInterface *iface.WGIface
|
||||
pubKey string
|
||||
}
|
||||
|
||||
// NewManager returns a new route manager
|
||||
func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status) *DefaultManager {
|
||||
mCTX, cancel := context.WithCancel(ctx)
|
||||
return &DefaultManager{
|
||||
ctx: mCTX,
|
||||
stop: cancel,
|
||||
clientNetworks: make(map[string]*clientNetwork),
|
||||
serverRoutes: make(map[string]*route.Route),
|
||||
serverRouter: &serverRouter{
|
||||
routes: make(map[string]*route.Route),
|
||||
netForwardHistoryEnabled: isNetForwardHistoryEnabled(),
|
||||
firewall: NewFirewall(ctx),
|
||||
},
|
||||
statusRecorder: statusRecorder,
|
||||
wgInterface: wgInterface,
|
||||
pubKey: pubKey,
|
||||
}
|
||||
}
|
||||
|
||||
// Stop stops the manager watchers and clean firewall rules
|
||||
func (m *DefaultManager) Stop() {
|
||||
m.stop()
|
||||
m.serverRouter.firewall.CleanRoutingRules()
|
||||
}
|
||||
|
||||
func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) {
|
||||
// removing routes that do not exist as per the update from the Management service.
|
||||
for id, client := range m.clientNetworks {
|
||||
_, found := networks[id]
|
||||
if !found {
|
||||
log.Debugf("stopping client network watcher, %s", id)
|
||||
client.stop()
|
||||
delete(m.clientNetworks, id)
|
||||
}
|
||||
}
|
||||
|
||||
for id, routes := range networks {
|
||||
clientNetworkWatcher, found := m.clientNetworks[id]
|
||||
if !found {
|
||||
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.wgInterface, m.statusRecorder, routes[0].Network)
|
||||
m.clientNetworks[id] = clientNetworkWatcher
|
||||
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
||||
}
|
||||
update := routesUpdate{
|
||||
updateSerial: updateSerial,
|
||||
routes: routes,
|
||||
}
|
||||
|
||||
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(update)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *DefaultManager) updateServerRoutes(routesMap map[string]*route.Route) error {
|
||||
serverRoutesToRemove := make([]string, 0)
|
||||
|
||||
if len(routesMap) > 0 {
|
||||
err := m.serverRouter.firewall.RestoreOrCreateContainers()
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't initialize firewall containers, got err: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
for routeID := range m.serverRoutes {
|
||||
update, found := routesMap[routeID]
|
||||
if !found || !update.IsEqual(m.serverRoutes[routeID]) {
|
||||
serverRoutesToRemove = append(serverRoutesToRemove, routeID)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
for _, routeID := range serverRoutesToRemove {
|
||||
oldRoute := m.serverRoutes[routeID]
|
||||
err := m.removeFromServerNetwork(oldRoute)
|
||||
if err != nil {
|
||||
log.Errorf("unable to remove route id: %s, network %s, from server, got: %v",
|
||||
oldRoute.ID, oldRoute.Network, err)
|
||||
}
|
||||
delete(m.serverRoutes, routeID)
|
||||
}
|
||||
|
||||
for id, newRoute := range routesMap {
|
||||
_, found := m.serverRoutes[id]
|
||||
if found {
|
||||
continue
|
||||
}
|
||||
|
||||
err := m.addToServerNetwork(newRoute)
|
||||
if err != nil {
|
||||
log.Errorf("unable to add route %s from server, got: %v", newRoute.ID, err)
|
||||
continue
|
||||
}
|
||||
m.serverRoutes[id] = newRoute
|
||||
}
|
||||
|
||||
if len(m.serverRoutes) > 0 {
|
||||
err := enableIPForwarding()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateRoutes compares received routes with existing routes and remove, update or add them to the client and server maps
|
||||
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
log.Infof("not updating routes as context is closed")
|
||||
return m.ctx.Err()
|
||||
default:
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
newClientRoutesIDMap := make(map[string][]*route.Route)
|
||||
newServerRoutesMap := make(map[string]*route.Route)
|
||||
ownNetworkIDs := make(map[string]bool)
|
||||
|
||||
for _, newRoute := range newRoutes {
|
||||
networkID := route.GetHAUniqueID(newRoute)
|
||||
if newRoute.Peer == m.pubKey {
|
||||
ownNetworkIDs[networkID] = true
|
||||
// only linux is supported for now
|
||||
if runtime.GOOS != "linux" {
|
||||
log.Warnf("received a route to manage, but agent doesn't support router mode on %s OS", runtime.GOOS)
|
||||
continue
|
||||
}
|
||||
newServerRoutesMap[newRoute.ID] = newRoute
|
||||
}
|
||||
}
|
||||
|
||||
for _, newRoute := range newRoutes {
|
||||
networkID := route.GetHAUniqueID(newRoute)
|
||||
if !ownNetworkIDs[networkID] {
|
||||
// if prefix is too small, lets assume is a possible default route which is not yet supported
|
||||
// we skip this route management
|
||||
if newRoute.Network.Bits() < 7 {
|
||||
log.Errorf("this agent version: %s, doesn't support default routes, received %s, skiping this route",
|
||||
version.NetbirdVersion(), newRoute.Network)
|
||||
continue
|
||||
}
|
||||
newClientRoutesIDMap[networkID] = append(newClientRoutesIDMap[networkID], newRoute)
|
||||
}
|
||||
}
|
||||
|
||||
m.updateClientNetworks(updateSerial, newClientRoutesIDMap)
|
||||
|
||||
err := m.updateServerRoutes(newServerRoutesMap)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
31
client/internal/routemanager/manager_android.go
Normal file
31
client/internal/routemanager/manager_android.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
// DefaultManager dummy router manager for Android
|
||||
type DefaultManager struct {
|
||||
ctx context.Context
|
||||
serverRouter *serverRouter
|
||||
wgInterface *iface.WGIface
|
||||
}
|
||||
|
||||
// NewManager returns a new dummy route manager what doing nothing
|
||||
func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status) *DefaultManager {
|
||||
return &DefaultManager{}
|
||||
}
|
||||
|
||||
// UpdateRoutes ...
|
||||
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop ...
|
||||
func (m *DefaultManager) Stop() {
|
||||
|
||||
}
|
||||
186
client/internal/routemanager/manager_nonandroid.go
Normal file
186
client/internal/routemanager/manager_nonandroid.go
Normal file
@@ -0,0 +1,186 @@
|
||||
//go:build !android
|
||||
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
// DefaultManager is the default instance of a route manager
|
||||
type DefaultManager struct {
|
||||
ctx context.Context
|
||||
stop context.CancelFunc
|
||||
mux sync.Mutex
|
||||
clientNetworks map[string]*clientNetwork
|
||||
serverRoutes map[string]*route.Route
|
||||
serverRouter *serverRouter
|
||||
statusRecorder *peer.Status
|
||||
wgInterface *iface.WGIface
|
||||
pubKey string
|
||||
}
|
||||
|
||||
// NewManager returns a new route manager
|
||||
func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status) *DefaultManager {
|
||||
mCTX, cancel := context.WithCancel(ctx)
|
||||
return &DefaultManager{
|
||||
ctx: mCTX,
|
||||
stop: cancel,
|
||||
clientNetworks: make(map[string]*clientNetwork),
|
||||
serverRoutes: make(map[string]*route.Route),
|
||||
serverRouter: &serverRouter{
|
||||
routes: make(map[string]*route.Route),
|
||||
netForwardHistoryEnabled: isNetForwardHistoryEnabled(),
|
||||
firewall: NewFirewall(ctx),
|
||||
},
|
||||
statusRecorder: statusRecorder,
|
||||
wgInterface: wgInterface,
|
||||
pubKey: pubKey,
|
||||
}
|
||||
}
|
||||
|
||||
// Stop stops the manager watchers and clean firewall rules
|
||||
func (m *DefaultManager) Stop() {
|
||||
m.stop()
|
||||
m.serverRouter.firewall.CleanRoutingRules()
|
||||
}
|
||||
|
||||
func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) {
|
||||
// removing routes that do not exist as per the update from the Management service.
|
||||
for id, client := range m.clientNetworks {
|
||||
_, found := networks[id]
|
||||
if !found {
|
||||
log.Debugf("stopping client network watcher, %s", id)
|
||||
client.stop()
|
||||
delete(m.clientNetworks, id)
|
||||
}
|
||||
}
|
||||
|
||||
for id, routes := range networks {
|
||||
clientNetworkWatcher, found := m.clientNetworks[id]
|
||||
if !found {
|
||||
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.wgInterface, m.statusRecorder, routes[0].Network)
|
||||
m.clientNetworks[id] = clientNetworkWatcher
|
||||
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
||||
}
|
||||
update := routesUpdate{
|
||||
updateSerial: updateSerial,
|
||||
routes: routes,
|
||||
}
|
||||
|
||||
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(update)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *DefaultManager) updateServerRoutes(routesMap map[string]*route.Route) error {
|
||||
serverRoutesToRemove := make([]string, 0)
|
||||
|
||||
if len(routesMap) > 0 {
|
||||
err := m.serverRouter.firewall.RestoreOrCreateContainers()
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't initialize firewall containers, got err: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
for routeID := range m.serverRoutes {
|
||||
update, found := routesMap[routeID]
|
||||
if !found || !update.IsEqual(m.serverRoutes[routeID]) {
|
||||
serverRoutesToRemove = append(serverRoutesToRemove, routeID)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
for _, routeID := range serverRoutesToRemove {
|
||||
oldRoute := m.serverRoutes[routeID]
|
||||
err := m.removeFromServerNetwork(oldRoute)
|
||||
if err != nil {
|
||||
log.Errorf("unable to remove route id: %s, network %s, from server, got: %v",
|
||||
oldRoute.ID, oldRoute.Network, err)
|
||||
}
|
||||
delete(m.serverRoutes, routeID)
|
||||
}
|
||||
|
||||
for id, newRoute := range routesMap {
|
||||
_, found := m.serverRoutes[id]
|
||||
if found {
|
||||
continue
|
||||
}
|
||||
|
||||
err := m.addToServerNetwork(newRoute)
|
||||
if err != nil {
|
||||
log.Errorf("unable to add route %s from server, got: %v", newRoute.ID, err)
|
||||
continue
|
||||
}
|
||||
m.serverRoutes[id] = newRoute
|
||||
}
|
||||
|
||||
if len(m.serverRoutes) > 0 {
|
||||
err := enableIPForwarding()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateRoutes compares received routes with existing routes and remove, update or add them to the client and server maps
|
||||
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
log.Infof("not updating routes as context is closed")
|
||||
return m.ctx.Err()
|
||||
default:
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
newClientRoutesIDMap := make(map[string][]*route.Route)
|
||||
newServerRoutesMap := make(map[string]*route.Route)
|
||||
ownNetworkIDs := make(map[string]bool)
|
||||
|
||||
for _, newRoute := range newRoutes {
|
||||
networkID := route.GetHAUniqueID(newRoute)
|
||||
if newRoute.Peer == m.pubKey {
|
||||
ownNetworkIDs[networkID] = true
|
||||
// only linux is supported for now
|
||||
if runtime.GOOS != "linux" {
|
||||
log.Warnf("received a route to manage, but agent doesn't support router mode on %s OS", runtime.GOOS)
|
||||
continue
|
||||
}
|
||||
newServerRoutesMap[newRoute.ID] = newRoute
|
||||
}
|
||||
}
|
||||
|
||||
for _, newRoute := range newRoutes {
|
||||
networkID := route.GetHAUniqueID(newRoute)
|
||||
if !ownNetworkIDs[networkID] {
|
||||
// if prefix is too small, lets assume is a possible default route which is not yet supported
|
||||
// we skip this route management
|
||||
if newRoute.Network.Bits() < 7 {
|
||||
log.Errorf("this agent version: %s, doesn't support default routes, received %s, skiping this route",
|
||||
version.NetbirdVersion(), newRoute.Network)
|
||||
continue
|
||||
}
|
||||
newClientRoutesIDMap[networkID] = append(newClientRoutesIDMap[networkID], newRoute)
|
||||
}
|
||||
}
|
||||
|
||||
m.updateClientNetworks(updateSerial, newClientRoutesIDMap)
|
||||
|
||||
err := m.updateServerRoutes(newServerRoutesMap)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -391,7 +391,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
|
||||
for n, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", iface.DefaultMTU)
|
||||
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", iface.DefaultMTU, nil)
|
||||
require.NoError(t, err, "should create testing WGIface interface")
|
||||
defer wgInterface.Close()
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ func TestAddRemoveRoutes(t *testing.T) {
|
||||
|
||||
for n, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", iface.DefaultMTU)
|
||||
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", iface.DefaultMTU, nil)
|
||||
require.NoError(t, err, "should create testing WGIface interface")
|
||||
defer wgInterface.Close()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user