Adding --external-ip-map and --dns-resolver-address and shorthand flags (#652)

Adding --external-ip-map and --dns-resolver-address to up command and shorthand option to global flags.

Refactor get and read config functions with new ConfigInput type.

updated cobra package to latest release.
This commit is contained in:
Maycon Santos
2023-01-17 19:16:50 +01:00
committed by GitHub
parent 12ae2e93fc
commit dcf6533ed5
21 changed files with 613 additions and 558 deletions

View File

@@ -16,6 +16,14 @@ import (
"google.golang.org/grpc/status"
)
// ManagementLegacyPort is the port that was used before by the Management gRPC server.
// It is used for backward compatibility now.
// NB: hardcoded from github.com/netbirdio/netbird/management/cmd to avoid import
const ManagementLegacyPort = 33073
var defaultInterfaceBlacklist = []string{iface.WgInterfaceDefault, "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts",
"Tailscale", "tailscale", "docker", "veth", "br-"}
var managementURLDefault *url.URL
func ManagementURLDefault() *url.URL {
@@ -32,10 +40,12 @@ func init() {
// ConfigInput carries configuration changes to the client
type ConfigInput struct {
ManagementURL string
AdminURL string
ConfigPath string
PreSharedKey *string
ManagementURL string
AdminURL string
ConfigPath string
PreSharedKey *string
NATExternalIPs []string
CustomDNSAddress []byte
}
// Config Configuration type
@@ -68,6 +78,8 @@ type Config struct {
// "12.34.56.78/10.1.2.3" => interface IP 10.1.2.3 will be mapped to external IP of 12.34.56.78
NATExternalIPs []string
// CustomDNSAddress sets the DNS resolver listening address in format ip:port
CustomDNSAddress string
}
// createNewConfig creates a new config generating a new Wireguard key and saving to file
@@ -84,6 +96,8 @@ func createNewConfig(input ConfigInput) (*Config, error) {
WgPort: iface.DefaultWgPort,
IFaceBlackList: []string{},
DisableIPv6Discovery: false,
NATExternalIPs: input.NATExternalIPs,
CustomDNSAddress: string(input.CustomDNSAddress),
}
if input.ManagementURL != "" {
URL, err := ParseURL("Management URL", input.ManagementURL)
@@ -107,8 +121,7 @@ func createNewConfig(input ConfigInput) (*Config, error) {
config.AdminURL = newURL
}
config.IFaceBlackList = []string{iface.WgInterfaceDefault, "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts",
"Tailscale", "tailscale", "docker", "veth", "br-"}
config.IFaceBlackList = defaultInterfaceBlacklist
err = util.WriteJson(input.ConfigPath, config)
if err != nil {
@@ -135,7 +148,7 @@ func ParseURL(serviceName, managementURL string) (*url.URL, error) {
return parsedMgmtURL, err
}
// ReadConfig reads existing config. In case provided managementURL is not empty overrides the read property
// ReadConfig reads existing configuration and update settings according to input configuration
func ReadConfig(input ConfigInput) (*Config, error) {
config := &Config{}
if _, err := os.Stat(input.ConfigPath); os.IsNotExist(err) {
@@ -176,6 +189,7 @@ func ReadConfig(input ConfigInput) (*Config, error) {
config.PreSharedKey = *input.PreSharedKey
refresh = true
}
if config.SSHKey == "" {
pem, err := ssh.GeneratePrivateKey(ssh.ED25519)
if err != nil {
@@ -189,6 +203,15 @@ func ReadConfig(input ConfigInput) (*Config, error) {
config.WgPort = iface.DefaultWgPort
refresh = true
}
if input.NATExternalIPs != nil && len(config.NATExternalIPs) != len(input.NATExternalIPs) {
config.NATExternalIPs = input.NATExternalIPs
refresh = true
}
if input.CustomDNSAddress != nil {
config.CustomDNSAddress = string(input.CustomDNSAddress)
refresh = true
}
if refresh {
// since we have new management URL, we need to update config file

View File

@@ -197,6 +197,7 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
WgPort: config.WgPort,
SSHKey: []byte(config.SSHKey),
NATExternalIPs: config.NATExternalIPs,
CustomDNSAddress: config.CustomDNSAddress,
}
if config.PreSharedKey != "" {
@@ -245,11 +246,6 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte)
return loginResp, nil
}
// ManagementLegacyPort is the port that was used before by the Management gRPC server.
// It is used for backward compatibility now.
// NB: hardcoded from github.com/netbirdio/netbird/management/cmd to avoid import
const ManagementLegacyPort = 33073
// UpdateOldManagementPort checks whether client can switch to the new Management port 443.
// If it can switch, then it updates the config and returns a new one. Otherwise, it returns the provided config.
// The check is performed only for the NetBird's managed version.

View File

@@ -45,6 +45,7 @@ type DefaultServer struct {
runtimePort int
runtimeIP string
previousConfigHash uint64
customAddress *netip.AddrPort
}
type registrationMap map[string]struct{}
@@ -55,7 +56,7 @@ type muxUpdate struct {
}
// NewDefaultServer returns a new dns server
func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface) (*DefaultServer, error) {
func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAddress string) (*DefaultServer, error) {
mux := dns.NewServeMux()
dnsServer := &dns.Server{
@@ -66,6 +67,16 @@ func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface) (*Default
ctx, stop := context.WithCancel(ctx)
var addrPort *netip.AddrPort
if customAddress != "" {
parsedAddrPort, err := netip.ParseAddrPort(customAddress)
if err != nil {
stop()
return nil, fmt.Errorf("unable to parse the custom dns address, got error: %s", err)
}
addrPort = &parsedAddrPort
}
defaultServer := &DefaultServer{
ctx: ctx,
stop: stop,
@@ -75,12 +86,14 @@ func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface) (*Default
localResolver: &localResolver{
registeredMap: make(registrationMap),
},
wgInterface: wgInterface,
runtimePort: defaultPort,
wgInterface: wgInterface,
runtimePort: defaultPort,
customAddress: addrPort,
}
hostmanager, err := newHostManager(wgInterface)
if err != nil {
stop()
return nil, err
}
defaultServer.hostManager = hostmanager
@@ -90,13 +103,19 @@ func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface) (*Default
// Start runs the listener in a go routine
func (s *DefaultServer) Start() {
ip, port, err := s.getFirstListenerAvailable()
if err != nil {
log.Error(err)
return
if s.customAddress != nil {
s.runtimeIP = s.customAddress.Addr().String()
s.runtimePort = int(s.customAddress.Port())
} else {
ip, port, err := s.getFirstListenerAvailable()
if err != nil {
log.Error(err)
return
}
s.runtimeIP = ip
s.runtimePort = port
}
s.runtimeIP = ip
s.runtimePort = port
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
log.Debugf("starting dns on %s", s.server.Addr)
@@ -105,7 +124,7 @@ func (s *DefaultServer) Start() {
s.setListenerStatus(true)
defer s.setListenerStatus(false)
err = s.server.ListenAndServe()
err := s.server.ListenAndServe()
if err != nil {
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.runtimePort, err)
}

View File

@@ -8,8 +8,6 @@ import (
"github.com/netbirdio/netbird/iface"
"net"
"net/netip"
"os"
"runtime"
"testing"
"time"
)
@@ -214,7 +212,7 @@ func TestUpdateDNSServer(t *testing.T) {
t.Log(err)
}
}()
dnsServer, err := NewDefaultServer(context.Background(), wgIface)
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "")
if err != nil {
t.Fatal(err)
}
@@ -265,74 +263,89 @@ func TestUpdateDNSServer(t *testing.T) {
}
func TestDNSServerStartStop(t *testing.T) {
dnsServer := getDefaultServerWithNoHostManager("127.0.0.1")
if runtime.GOOS == "windows" && os.Getenv("CI") == "true" {
// todo review why this test is not working only on github actions workflows
t.Skip("skipping test in Windows CI workflows.")
}
dnsServer.hostManager = newNoopHostMocker()
dnsServer.Start()
time.Sleep(100 * time.Millisecond)
if !dnsServer.listenerIsRunning {
t.Fatal("dns server listener is not running")
}
defer dnsServer.Stop()
err := dnsServer.localResolver.registerRecord(zoneRecords[0])
if err != nil {
t.Error(err)
}
dnsServer.dnsMux.Handle("netbird.cloud", dnsServer.localResolver)
resolver := &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
d := net.Dialer{
Timeout: time.Second * 5,
}
addr := fmt.Sprintf("%s:%d", dnsServer.runtimeIP, dnsServer.runtimePort)
conn, err := d.DialContext(ctx, network, addr)
if err != nil {
t.Log(err)
// retry test before exit, for slower systems
return d.DialContext(ctx, network, addr)
}
return conn, nil
testCases := []struct {
name string
addrPort string
}{
{
name: "Should Pass With Port Discovery",
},
{
name: "Should Pass With Custom Port",
addrPort: "127.0.0.1:3535",
},
}
ips, err := resolver.LookupHost(context.Background(), zoneRecords[0].Name)
if err != nil {
t.Fatalf("failed to connect to the server, error: %v", err)
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
dnsServer := getDefaultServerWithNoHostManager(t, testCase.addrPort)
t.Log(ips)
dnsServer.hostManager = newNoopHostMocker()
dnsServer.Start()
time.Sleep(100 * time.Millisecond)
if !dnsServer.listenerIsRunning {
t.Fatal("dns server listener is not running")
}
defer dnsServer.Stop()
err := dnsServer.localResolver.registerRecord(zoneRecords[0])
if err != nil {
t.Error(err)
}
if ips[0] != zoneRecords[0].RData {
t.Fatalf("got a different IP from the server: want %s, got %s", zoneRecords[0].RData, ips[0])
}
dnsServer.dnsMux.Handle("netbird.cloud", dnsServer.localResolver)
dnsServer.Stop()
ctx, cancel := context.WithTimeout(context.TODO(), time.Second*1)
defer cancel()
_, err = resolver.LookupHost(ctx, zoneRecords[0].Name)
if err == nil {
t.Fatalf("we should encounter an error when querying a stopped server")
resolver := &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
d := net.Dialer{
Timeout: time.Second * 5,
}
addr := fmt.Sprintf("%s:%d", dnsServer.runtimeIP, dnsServer.runtimePort)
conn, err := d.DialContext(ctx, network, addr)
if err != nil {
t.Log(err)
// retry test before exit, for slower systems
return d.DialContext(ctx, network, addr)
}
return conn, nil
},
}
ips, err := resolver.LookupHost(context.Background(), zoneRecords[0].Name)
if err != nil {
t.Fatalf("failed to connect to the server, error: %v", err)
}
if ips[0] != zoneRecords[0].RData {
t.Fatalf("got a different IP from the server: want %s, got %s", zoneRecords[0].RData, ips[0])
}
dnsServer.Stop()
ctx, cancel := context.WithTimeout(context.TODO(), time.Second*1)
defer cancel()
_, err = resolver.LookupHost(ctx, zoneRecords[0].Name)
if err == nil {
t.Fatalf("we should encounter an error when querying a stopped server")
}
})
}
}
func getDefaultServerWithNoHostManager(ip string) *DefaultServer {
func getDefaultServerWithNoHostManager(t *testing.T, addrPort string) *DefaultServer {
mux := dns.NewServeMux()
listenIP := defaultIP
if ip != "" {
listenIP = ip
var parsedAddrPort *netip.AddrPort
if addrPort != "" {
parsed, err := netip.ParseAddrPort(addrPort)
if err != nil {
t.Fatal(err)
}
parsedAddrPort = &parsed
}
dnsServer := &dns.Server{
Addr: fmt.Sprintf("%s:%d", ip, defaultPort),
Net: "udp",
Handler: mux,
UDPSize: 65535,
@@ -349,7 +362,6 @@ func getDefaultServerWithNoHostManager(ip string) *DefaultServer {
localResolver: &localResolver{
registeredMap: make(registrationMap),
},
runtimePort: defaultPort,
runtimeIP: listenIP,
customAddress: parsedAddrPort,
}
}

View File

@@ -70,6 +70,8 @@ type EngineConfig struct {
SSHKey []byte
NATExternalIPs []string
CustomDNSAddress string
}
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
@@ -261,7 +263,8 @@ func (e *Engine) Start() error {
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder)
if e.dnsServer == nil {
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface)
// todo fix custom address
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress)
if err != nil {
return err
}
@@ -964,6 +967,7 @@ func (e *Engine) parseNATExternalIPMappings() []string {
var external, internal string
var externalIP, internalIP net.IP
var err error
split := strings.Split(mapping, "/")
if len(split) > 2 {
log.Warnf("ignoring invalid external mapping '%s', too many delimiters", mapping)
@@ -988,7 +992,7 @@ func (e *Engine) parseNATExternalIPMappings() []string {
external = split[0]
externalIP = net.ParseIP(external)
if externalIP == nil {
log.Warnf("invalid external IP, ignoring external IP mapping '%s'", mapping)
log.Warnf("invalid external IP, %s, ignoring external IP mapping '%s'", external, mapping)
break
}
if externalIP != nil {

View File

@@ -12,6 +12,7 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/route"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"net"
"net/netip"
"os"
@@ -778,15 +779,13 @@ func TestEngine_MultiplePeers(t *testing.T) {
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
defer cancel()
sport := 10010
sigServer, err := startSignal(sport)
sigServer, signalAddr, err := startSignal()
if err != nil {
t.Fatal(err)
return
}
defer sigServer.Stop()
mport := 33081
mgmtServer, err := startManagement(mport, dir)
mgmtServer, mgmtAddr, err := startManagement(dir)
if err != nil {
t.Fatal(err)
return
@@ -804,7 +803,7 @@ func TestEngine_MultiplePeers(t *testing.T) {
for i := 0; i < numPeers; i++ {
j := i
go func() {
engine, err := createEngine(ctx, cancel, setupKey, j, mport, sport)
engine, err := createEngine(ctx, cancel, setupKey, j, mgmtAddr, signalAddr)
if err != nil {
wg.Done()
t.Errorf("unable to create the engine for peer %d with error %v", j, err)
@@ -870,16 +869,84 @@ loop:
}
}
func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mport int, sport int) (*Engine, error) {
func Test_ParseNATExternalIPMappings(t *testing.T) {
ifaceList, err := net.Interfaces()
if err != nil {
t.Fatalf("could get the interface list, got error: %s", err)
}
var testingIP string
var testingInterface string
for _, iface := range ifaceList {
addrList, err := iface.Addrs()
if err != nil {
t.Fatalf("could get the addr list, got error: %s", err)
}
for _, addr := range addrList {
prefix := netip.MustParsePrefix(addr.String())
if prefix.Addr().Is4() && !prefix.Addr().IsLoopback() {
testingIP = prefix.Addr().String()
testingInterface = iface.Name
}
}
}
testCases := []struct {
name string
inputMapList []string
inputBlacklistInterface []string
expectedOutput []string
}{
{
name: "Parse Valid List Should Be OK",
inputBlacklistInterface: defaultInterfaceBlacklist,
inputMapList: []string{"1.1.1.1", "8.8.8.8/" + testingInterface},
expectedOutput: []string{"1.1.1.1", "8.8.8.8/" + testingIP},
},
{
name: "Only Interface Name Should Return Nil",
inputBlacklistInterface: defaultInterfaceBlacklist,
inputMapList: []string{testingInterface},
expectedOutput: nil,
},
{
name: "Invalid IP Return Nil",
inputBlacklistInterface: defaultInterfaceBlacklist,
inputMapList: []string{"1.1.1.1000"},
expectedOutput: nil,
},
{
name: "Invalid Mapping Element Should return Nil",
inputBlacklistInterface: defaultInterfaceBlacklist,
inputMapList: []string{"1.1.1.1/10.10.10.1/eth0"},
expectedOutput: nil,
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
engine := &Engine{
config: &EngineConfig{
IFaceBlackList: testCase.inputBlacklistInterface,
NATExternalIPs: testCase.inputMapList,
},
}
parsedList := engine.parseNATExternalIPMappings()
require.ElementsMatchf(t, testCase.expectedOutput, parsedList, "elements of parsed list should match expected list")
})
}
}
func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
return nil, err
}
mgmtClient, err := mgmt.NewClient(ctx, fmt.Sprintf("localhost:%d", mport), key, false)
mgmtClient, err := mgmt.NewClient(ctx, mgmtAddr, key, false)
if err != nil {
return nil, err
}
signalClient, err := signal.NewClient(ctx, fmt.Sprintf("localhost:%d", sport), key, false)
signalClient, err := signal.NewClient(ctx, signalAddr, key, false)
if err != nil {
return nil, err
}
@@ -913,10 +980,10 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
return NewEngine(ctx, cancel, signalClient, mgmtClient, conf, nbstatus.NewRecorder()), nil
}
func startSignal(port int) (*grpc.Server, error) {
func startSignal() (*grpc.Server, string, error) {
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
lis, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
log.Fatalf("failed to listen: %v", err)
}
@@ -929,10 +996,10 @@ func startSignal(port int) (*grpc.Server, error) {
}
}()
return s, nil
return s, lis.Addr().String(), nil
}
func startManagement(port int, dataDir string) (*grpc.Server, error) {
func startManagement(dataDir string) (*grpc.Server, string, error) {
config := &server.Config{
Stuns: []*server.Host{},
TURNConfig: &server.TURNConfig{},
@@ -944,9 +1011,9 @@ func startManagement(port int, dataDir string) (*grpc.Server, error) {
HttpConfig: nil,
}
lis, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", port))
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
return nil, err
return nil, "", err
}
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
store, err := server.NewFileStore(config.Datadir)
@@ -956,17 +1023,17 @@ func startManagement(port int, dataDir string) (*grpc.Server, error) {
peersUpdateManager := server.NewPeersUpdateManager()
eventStore := &activity.InMemoryEventStore{}
if err != nil {
return nil, nil
return nil, "", nil
}
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "",
eventStore)
if err != nil {
return nil, err
return nil, "", err
}
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil)
if err != nil {
return nil, err
return nil, "", err
}
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
go func() {
@@ -975,5 +1042,5 @@ func startManagement(port int, dataDir string) (*grpc.Server, error) {
}
}()
return s, nil
return s, lis.Addr().String(), nil
}