Compare commits

...

49 Commits

Author SHA1 Message Date
Maycon Santos
a814715ef8 Add resolvconf configurator for linux (#592) 2022-11-29 14:51:18 +01:00
Maycon Santos
4a30b66503 Check if system is our manager when resolvconf (#590)
Sometimes resolvconf will manage the /etc/resolv.conf file
And systemd-resolved still the DNS manager
2022-11-29 13:37:50 +01:00
Maycon Santos
ae500b63a7 User custom loopback address (#589)
We will probe a set of addresses and port
to define the one available for our DNS service

if none is available, we return an error
2022-11-29 11:49:18 +01:00
Maycon Santos
20a73e3e14 Sync peers FQDN (#584)
Use stdout and stderr log path only if on Linux and attempt to create the path

Update status system with FQDN fields and 
status command to display the domain names of remote and local peers

Set some DNS logs to tracing

update readme file
2022-11-26 13:29:50 +01:00
Misha Bragin
fcf7786a85 Disable route when removing peer (#582) 2022-11-25 18:11:07 +01:00
Maycon Santos
a78fd69f80 Feature/dns client configuration (#563)
Added host configurators for Linux, Windows, and macOS.

The host configurator will update the peer system configuration
 directing DNS queries according to its capabilities.

Some Linux distributions don't support split (match) DNS or custom ports,
 and that will be reported to our management system in another PR
2022-11-23 13:39:42 +01:00
Genteure
4bd5029e7b Enable IPv6 address discovery (#578)
Agents will use IPv6 when available for ICE negotiation
2022-11-23 11:03:29 +01:00
Tom Kunicki
f604956246 External NAT IP mapping support (#487)
* External NAT IP mapping support

* Ignore blacklisted interfaces, even if in user specified in  mapping
2022-11-23 08:42:12 +01:00
Misha Bragin
53c532bbb4 Fix interactive SSO login when creating account from a device (#575) 2022-11-22 12:37:36 +01:00
Misha Bragin
8b0a1bbae0 Display peers of a user that it has access to (#571)
If a user has a non-admin role, display all peers
that user's peers have access to when calling
/peers endpoint of the HTTP API.
2022-11-21 17:45:14 +01:00
Misha Bragin
e965d6c022 Fix CISPA note 2022-11-21 17:36:07 +01:00
Misha Bragin
11f8249eed Add CISPA note (#572) 2022-11-21 16:38:41 +01:00
Maycon Santos
d63a9ce4a7 Return peer's FQDN via API (#567)
Added a temp method to retrieve the dns domain
2022-11-21 11:14:42 +01:00
Maycon Santos
9cb66bdb5d Update last run time and active count (#568)
* Update last run time and active count

We will collect the active peer min and max versions

* Get UI client usage
2022-11-18 16:35:13 +01:00
Genteure
c8ace8bbbe Fix docker network interface filter (#564)
docker network address are assigned on network interfaces that start with "br-"
2022-11-15 22:07:58 +01:00
Misha Bragin
509d23c7cf Replace gRPC errors in business logic with internal ones (#558) 2022-11-11 20:36:45 +01:00
Misha Bragin
1db4027bea Remove docs typo 2022-11-10 10:48:00 +01:00
Misha Bragin
d4dbc322be Add ref to ICE in Readme 2022-11-10 10:46:40 +01:00
Misha Bragin
e19d5dca7f Refactor AddPeer to ensure consistency (#557) 2022-11-08 16:14:36 +01:00
Maycon Santos
157137e4ad Use a single way to generate network map (#550) 2022-11-08 11:38:40 +01:00
Maycon Santos
7d7e576775 Set report caller when info or higher (#555) 2022-11-08 10:56:13 +01:00
Misha Bragin
f37b43a542 Save Peer Status separately in the FileStore (#554)
Due to peer reconnects when restarting the Management service,
there are lots of SaveStore operations to update peer status.

Store.SavePeerStatus stores peer status separately and the
FileStore implementation stores it in memory.
2022-11-08 10:46:12 +01:00
Maycon Santos
7e262572a4 Move dns label generation to store (#552) 2022-11-08 10:31:34 +01:00
Misha Bragin
a768a0aa8a Always lock the store when getting an account (#551) 2022-11-07 19:09:22 +01:00
Misha Bragin
ed7ac81027 Introduce locking on the account level (#548) 2022-11-07 17:52:23 +01:00
Maycon Santos
1f845f466c Add account copy test (#549) 2022-11-07 17:37:28 +01:00
Maycon Santos
270f0e4ce8 Feature/dns protocol (#543)
Added DNS update protocol message

Added sync to clients

Update nameserver API with new fields

Added default NS groups

Added new dns-name flag for the management service append to peer DNS label
2022-11-07 15:38:21 +01:00
Misha Bragin
d0c6d88971 Simplified Store Interface (#545)
This PR simplifies Store and FileStore
by keeping just the Get and Save account methods.

The AccountManager operates mostly around
a single account, so it makes sense to fetch
the whole account object from the store.
2022-11-07 12:10:56 +01:00
Misha Bragin
4321b71984 Hide content based on user role (#541) 2022-11-05 10:24:50 +01:00
Maycon Santos
e8d82c1bd3 Feature/dns-server (#537)
Adding DNS server for client

Updated the API with new fields

Added custom zone object for peer's DNS resolution
2022-11-03 18:39:37 +01:00
Misha Bragin
6aa7a2c5e1 Hide setup key from non-admin users (#539) 2022-11-03 17:02:31 +01:00
Rui Lopes
2e0bf61e9a correctly set the windows application icon on windows (#535)
the icon format is not really supported, so this uses a png instead.

this closes https://github.com/netbirdio/netbird/issues/534.
2022-11-01 00:34:30 +01:00
Maycon Santos
126af9dffc Return gateway address if not nil (#533)
If the gateway address would be nil which is
the case on macOS, we return the preferredSrc

added tests for getExistingRIBRouteGateway function

update log message
2022-10-31 11:54:34 +01:00
Maycon Santos
4cdf2df660 Update sign pipeline version to 0.0.4 (#531)
This version has a fix for the
macOS UI client architecture
2022-10-31 11:03:42 +01:00
Maycon Santos
9a4c9aa286 Add active peers count per OS (#526)
* Add active peers count per OS

* increase iface tests timeout
2022-10-26 14:48:40 +02:00
Rui Lopes
5ed61700ff Set the application icon, settings window title and systray tooltip (#523) 2022-10-26 14:34:30 +02:00
Misha Bragin
84117a9fb7 Update WireGuard trademark note 2022-10-23 11:47:42 +02:00
Misha Bragin
92b612eba4 Update demo video link 2022-10-22 16:55:49 +02:00
Misha Bragin
aeeaa21eed Update README.md (#524) 2022-10-22 16:19:16 +02:00
Misha Bragin
d228cd0cb1 Remove release note 2022-10-22 15:10:09 +02:00
Misha Bragin
b41f36fccd Add gRPC metrics (#522) 2022-10-22 15:06:54 +02:00
Misha Bragin
d2cde4a040 Add IdP metrics (#521) 2022-10-22 13:29:39 +02:00
Misha Bragin
84879a356b Extract app metrics to a separate struct (#520) 2022-10-22 11:50:21 +02:00
Misha Bragin
ed2214f9a9 Add HTTP request/response totals to metrics (#519) 2022-10-22 10:07:13 +02:00
braginini
4388dcc20b Listen metrics on all interfaces 2022-10-21 16:50:06 +02:00
Misha Bragin
4f1f0df7d2 Add Open-telemetry support (#517)
This PR brings open-telemetry metrics to the
Management service.
The Management service exposes new HTTP endpoint
/metrics on 8081 port by default.
The port can be changed by specifying
--metrics-port PORT flag when starting the service.
2022-10-21 16:24:13 +02:00
Misha Bragin
08ddf04c5f Fix IdP tests (#516) 2022-10-19 18:36:10 +02:00
Misha Bragin
b5ee2174a8 Do not set wt_pending_invite when unnecessary (#515)
wt_pending_invite property is set for every user on IdP.
Avoid setting it when unnecessary.
2022-10-19 17:51:41 +02:00
Misha Bragin
7218a3d563 Management single account mode (#511) 2022-10-19 17:43:28 +02:00
116 changed files with 7600 additions and 2359 deletions

View File

@@ -25,7 +25,6 @@ jobs:
needs: pre needs: pre
runs-on: windows-latest runs-on: windows-latest
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v2 uses: actions/checkout@v2

View File

@@ -9,7 +9,7 @@ on:
pull_request: pull_request:
env: env:
SIGN_PIPE_VER: "v0.0.3" SIGN_PIPE_VER: "v0.0.4"
GORELEASER_VER: "v1.6.3" GORELEASER_VER: "v1.6.3"
jobs: jobs:

3
.gitignore vendored
View File

@@ -10,4 +10,5 @@ infrastructure_files/management.json
infrastructure_files/docker-compose.yml infrastructure_files/docker-compose.yml
*.syso *.syso
client/.distfiles/ client/.distfiles/
infrastructure_files/setup.env infrastructure_files/setup.env
.vscode

View File

@@ -1,6 +1,6 @@
<p align="center"> <p align="center">
<strong>:hatching_chick: New release! NetBird Easy SSH</strong>. <strong>:hatching_chick: New Release! DNS support.</strong>
<a href="https://github.com/netbirdio/netbird/releases/tag/v0.8.0"> <a href="https://github.com/netbirdio/netbird/releases">
Learn more Learn more
</a> </a>
</p> </p>
@@ -40,7 +40,7 @@
It requires zero configuration effort leaving behind the hassle of opening ports, complex firewall rules, VPN gateways, and so forth. It requires zero configuration effort leaving behind the hassle of opening ports, complex firewall rules, VPN gateways, and so forth.
NetBird creates an overlay peer-to-peer network connecting machines automatically regardless of their location (home, office, datacenter, container, cloud or edge environments) unifying virtual private network management experience. NetBird uses [NAT traversal techniques](https://en.wikipedia.org/wiki/Interactive_Connectivity_Establishment) to automatically create an overlay peer-to-peer network connecting machines regardless of location (home, office, data center, container, cloud, or edge environments), unifying virtual private network management experience.
**Key features:** **Key features:**
- \[x] Automatic IP allocation and network management with a Web UI ([separate repo](https://github.com/netbirdio/dashboard)) - \[x] Automatic IP allocation and network management with a Web UI ([separate repo](https://github.com/netbirdio/dashboard))
@@ -55,16 +55,15 @@ NetBird creates an overlay peer-to-peer network connecting machines automaticall
- \[x] Access Controls - groups & rules. - \[x] Access Controls - groups & rules.
- \[x] Remote SSH access without managing SSH keys. - \[x] Remote SSH access without managing SSH keys.
- \[x] Network Routes. - \[x] Network Routes.
- \[x] Private DNS.
**Coming soon:** **Coming soon:**
- \[ ] Private DNS.
- \[ ] Mobile clients. - \[ ] Mobile clients.
- \[ ] Network Activity Monitoring. - \[ ] Network Activity Monitoring.
### Secure peer-to-peer VPN with SSO and MFA in minutes ### Secure peer-to-peer VPN with SSO and MFA in minutes
<p float="left" align="middle">
<img src="docs/media/netbird-sso-mfa-demo.gif" width="800"/> https://user-images.githubusercontent.com/700848/197345890-2e2cded5-7b7a-436f-a444-94e80dd24f46.mov
</p>
**Note**: The `main` branch may be in an *unstable or even broken state* during development. **Note**: The `main` branch may be in an *unstable or even broken state* during development.
For stable versions, see [releases](https://github.com/netbirdio/netbird/releases). For stable versions, see [releases](https://github.com/netbirdio/netbird/releases).
@@ -100,9 +99,15 @@ See a complete [architecture overview](https://netbird.io/docs/overview/architec
### Community projects ### Community projects
- [NetBird on OpenWRT](https://github.com/messense/openwrt-netbird) - [NetBird on OpenWRT](https://github.com/messense/openwrt-netbird)
### Support acknowledgement
In November 2022, NetBird joined the [StartUpSecure program](https://www.forschung-it-sicherheit-kommunikationssysteme.de/foerderung/bekanntmachungen/startup-secure) sponsored by The Federal Ministry of Education and Research of The Federal Republic of Germany. Together with [CISPA Helmholtz Center for Information Security](https://cispa.de/en) NetBird brings the security best practices and simplicity to private networking.
![CISPA_Logo_BLACK_EN_RZ_RGB (1)](https://user-images.githubusercontent.com/700848/203091324-c6d311a0-22b5-4b05-a288-91cbc6cdcc46.png)
### Testimonials ### Testimonials
We use open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE (WebRTC)](https://github.com/pion/ice), and [Coturn](https://github.com/coturn/coturn). We very much appreciate the work these guys are doing and we'd greatly appreciate if you could support them in any way (e.g. giving a star or a contribution). We use open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE (WebRTC)](https://github.com/pion/ice), and [Coturn](https://github.com/coturn/coturn). We very much appreciate the work these guys are doing and we'd greatly appreciate if you could support them in any way (e.g. giving a star or a contribution).
### Legal ### Legal
[WireGuard](https://wireguard.com/) is a registered trademark of Jason A. Donenfeld. _WireGuard_ and the _WireGuard_ logo are [registered trademarks](https://www.wireguard.com/trademark-policy/) of Jason A. Donenfeld.

View File

@@ -32,6 +32,7 @@ func newSVCConfig() *service.Config {
Name: name, Name: name,
DisplayName: "Netbird", DisplayName: "Netbird",
Description: "A WireGuard-based mesh network that connects your devices into a single private network.", Description: "A WireGuard-based mesh network that connects your devices into a single private network.",
Option: make(service.KeyValue),
} }
} }

View File

@@ -2,6 +2,8 @@ package cmd
import ( import (
"context" "context"
"os"
"path/filepath"
"runtime" "runtime"
"github.com/spf13/cobra" "github.com/spf13/cobra"
@@ -32,13 +34,34 @@ var installCmd = &cobra.Command{
} }
if managementURL != "" { if managementURL != "" {
svcConfig.Arguments = append(svcConfig.Arguments, "--management-url") svcConfig.Arguments = append(svcConfig.Arguments, "--management-url", managementURL)
svcConfig.Arguments = append(svcConfig.Arguments, managementURL) }
if logFile != "console" {
svcConfig.Arguments = append(svcConfig.Arguments, "--log-file", logFile)
} }
if runtime.GOOS == "linux" { if runtime.GOOS == "linux" {
// Respected only by systemd systems // Respected only by systemd systems
svcConfig.Dependencies = []string{"After=network.target syslog.target"} svcConfig.Dependencies = []string{"After=network.target syslog.target"}
if logFile != "console" {
setStdLogPath := true
dir := filepath.Dir(logFile)
_, err := os.Stat(dir)
if err != nil {
err = os.MkdirAll(dir, 0750)
if err != nil {
setStdLogPath = false
}
}
if setStdLogPath {
svcConfig.Option["LogOutput"] = true
svcConfig.Option["LogDirectory"] = dir
}
}
} }
ctx, cancel := context.WithCancel(cmd.Context()) ctx, cancel := context.WithCancel(cmd.Context())

View File

@@ -122,6 +122,7 @@ func fromProtoFullStatus(pbFullStatus *proto.FullStatus) nbStatus.FullStatus {
fullStatus.LocalPeerState.IP = localPeerState.GetIP() fullStatus.LocalPeerState.IP = localPeerState.GetIP()
fullStatus.LocalPeerState.PubKey = localPeerState.GetPubKey() fullStatus.LocalPeerState.PubKey = localPeerState.GetPubKey()
fullStatus.LocalPeerState.KernelInterface = localPeerState.GetKernelInterface() fullStatus.LocalPeerState.KernelInterface = localPeerState.GetKernelInterface()
fullStatus.LocalPeerState.FQDN = localPeerState.GetFqdn()
var peersState []nbStatus.PeerState var peersState []nbStatus.PeerState
@@ -136,6 +137,7 @@ func fromProtoFullStatus(pbFullStatus *proto.FullStatus) nbStatus.FullStatus {
Direct: pbPeerState.GetDirect(), Direct: pbPeerState.GetDirect(),
LocalIceCandidateType: pbPeerState.GetLocalIceCandidateType(), LocalIceCandidateType: pbPeerState.GetLocalIceCandidateType(),
RemoteIceCandidateType: pbPeerState.GetRemoteIceCandidateType(), RemoteIceCandidateType: pbPeerState.GetRemoteIceCandidateType(),
FQDN: pbPeerState.GetFqdn(),
} }
peersState = append(peersState, peerState) peersState = append(peersState, peerState)
} }
@@ -196,6 +198,7 @@ func parseFullStatus(fullStatus nbStatus.FullStatus, printDetail bool, daemonSta
"%s"+ // daemon status "%s"+ // daemon status
"Management: %s%s\n"+ "Management: %s%s\n"+
"Signal: %s%s\n"+ "Signal: %s%s\n"+
"Domain: %s\n"+
"NetBird IP: %s\n"+ "NetBird IP: %s\n"+
"Interface type: %s\n"+ "Interface type: %s\n"+
"Peers count: %s\n", "Peers count: %s\n",
@@ -206,6 +209,7 @@ func parseFullStatus(fullStatus nbStatus.FullStatus, printDetail bool, daemonSta
managementStatusURL, managementStatusURL,
signalConnString, signalConnString,
signalStatusURL, signalStatusURL,
fullStatus.LocalPeerState.FQDN,
interfaceIP, interfaceIP,
interfaceTypeString, interfaceTypeString,
peersCountString, peersCountString,
@@ -266,7 +270,7 @@ func parsePeers(peers []nbStatus.PeerState, printDetail bool) (string, int) {
} }
peerString := fmt.Sprintf( peerString := fmt.Sprintf(
"\n Peer:\n"+ "\n %s:\n"+
" NetBird IP: %s\n"+ " NetBird IP: %s\n"+
" Public key: %s\n"+ " Public key: %s\n"+
" Status: %s\n"+ " Status: %s\n"+
@@ -275,6 +279,7 @@ func parsePeers(peers []nbStatus.PeerState, printDetail bool) (string, int) {
" Direct: %t\n"+ " Direct: %t\n"+
" ICE candidate (Local/Remote): %s/%s\n"+ " ICE candidate (Local/Remote): %s/%s\n"+
" Last connection update: %s\n", " Last connection update: %s\n",
peerState.FQDN,
peerState.IP, peerState.IP,
peerState.PubKey, peerState.PubKey,
peerState.ConnStatus, peerState.ConnStatus,

View File

@@ -62,18 +62,18 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
t.Fatal(err) t.Fatal(err)
} }
s := grpc.NewServer() s := grpc.NewServer()
store, err := mgmt.NewStore(config.Datadir) store, err := mgmt.NewFileStore(config.Datadir)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
peersUpdateManager := mgmt.NewPeersUpdateManager() peersUpdateManager := mgmt.NewPeersUpdateManager()
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil) accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig) turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager) mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -3,6 +3,9 @@ package internal
import ( import (
"context" "context"
"fmt" "fmt"
"net/url"
"os"
"github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
mgm "github.com/netbirdio/netbird/management/client" mgm "github.com/netbirdio/netbird/management/client"
@@ -11,8 +14,6 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"net/url"
"os"
) )
var managementURLDefault *url.URL var managementURLDefault *url.URL
@@ -32,15 +33,33 @@ func init() {
// Config Configuration type // Config Configuration type
type Config struct { type Config struct {
// Wireguard private key of local peer // Wireguard private key of local peer
PrivateKey string PrivateKey string
PreSharedKey string PreSharedKey string
ManagementURL *url.URL ManagementURL *url.URL
AdminURL *url.URL AdminURL *url.URL
WgIface string WgIface string
WgPort int WgPort int
IFaceBlackList []string IFaceBlackList []string
DisableIPv6Discovery bool
// SSHKey is a private SSH key in a PEM format // SSHKey is a private SSH key in a PEM format
SSHKey string SSHKey string
// ExternalIP mappings, if different than the host interface IP
//
// External IP must not be behind a CGNAT and port-forwarding for incoming UDP packets from WgPort on ExternalIP
// to WgPort on host interface IP must be present. This can take form of single port-forwarding rule, 1:1 DNAT
// mapping ExternalIP to host interface IP, or a NAT DMZ to host interface IP.
//
// A single mapping will take the form of: external[/internal]
// external (required): either the external IP address or "stun" to use STUN to determine the external IP address
// internal (optional): either the internal/interface IP address or an interface name
//
// examples:
// "12.34.56.78" => all interfaces IPs will be mapped to external IP of 12.34.56.78
// "12.34.56.78/eth0" => IPv4 assigned to interface eth0 will be mapped to external IP of 12.34.56.78
// "12.34.56.78/10.1.2.3" => interface IP 10.1.2.3 will be mapped to external IP of 12.34.56.78
NATExternalIPs []string
} }
// createNewConfig creates a new config generating a new Wireguard key and saving to file // createNewConfig creates a new config generating a new Wireguard key and saving to file
@@ -51,11 +70,12 @@ func createNewConfig(managementURL, adminURL, configPath, preSharedKey string) (
return nil, err return nil, err
} }
config := &Config{ config := &Config{
SSHKey: string(pem), SSHKey: string(pem),
PrivateKey: wgKey, PrivateKey: wgKey,
WgIface: iface.WgInterfaceDefault, WgIface: iface.WgInterfaceDefault,
WgPort: iface.DefaultWgPort, WgPort: iface.DefaultWgPort,
IFaceBlackList: []string{}, IFaceBlackList: []string{},
DisableIPv6Discovery: false,
} }
if managementURL != "" { if managementURL != "" {
URL, err := ParseURL("Management URL", managementURL) URL, err := ParseURL("Management URL", managementURL)
@@ -80,7 +100,7 @@ func createNewConfig(managementURL, adminURL, configPath, preSharedKey string) (
} }
config.IFaceBlackList = []string{iface.WgInterfaceDefault, "wt", "utun", "tun0", "zt", "ZeroTier", "utun", "wg", "ts", config.IFaceBlackList = []string{iface.WgInterfaceDefault, "wt", "utun", "tun0", "zt", "ZeroTier", "utun", "wg", "ts",
"Tailscale", "tailscale", "docker", "vet"} "Tailscale", "tailscale", "docker", "veth", "br-"}
err = util.WriteJson(configPath, config) err = util.WriteJson(configPath, config)
if err != nil { if err != nil {

View File

@@ -3,11 +3,12 @@ package internal
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/netbirdio/netbird/client/ssh"
nbStatus "github.com/netbirdio/netbird/client/status"
"strings" "strings"
"time" "time"
"github.com/netbirdio/netbird/client/ssh"
nbStatus "github.com/netbirdio/netbird/client/status"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
@@ -108,6 +109,7 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *nbStatus.Sta
IP: loginResp.GetPeerConfig().GetAddress(), IP: loginResp.GetPeerConfig().GetAddress(),
PubKey: myPrivateKey.PublicKey().String(), PubKey: myPrivateKey.PublicKey().String(),
KernelInterface: iface.WireguardModuleIsLoaded(), KernelInterface: iface.WireguardModuleIsLoaded(),
FQDN: loginResp.GetPeerConfig().GetFqdn(),
} }
statusRecorder.UpdateLocalPeerState(localPeerState) statusRecorder.UpdateLocalPeerState(localPeerState)
@@ -184,12 +186,14 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *nbStatus.Sta
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) { func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) {
engineConf := &EngineConfig{ engineConf := &EngineConfig{
WgIfaceName: config.WgIface, WgIfaceName: config.WgIface,
WgAddr: peerConfig.Address, WgAddr: peerConfig.Address,
IFaceBlackList: config.IFaceBlackList, IFaceBlackList: config.IFaceBlackList,
WgPrivateKey: key, DisableIPv6Discovery: config.DisableIPv6Discovery,
WgPort: config.WgPort, WgPrivateKey: key,
SSHKey: []byte(config.SSHKey), WgPort: config.WgPort,
SSHKey: []byte(config.SSHKey),
NATExternalIPs: config.NATExternalIPs,
} }
if config.PreSharedKey != "" { if config.PreSharedKey != "" {

View File

@@ -0,0 +1,41 @@
package dns
import (
"context"
"github.com/godbus/dbus/v5"
log "github.com/sirupsen/logrus"
"time"
)
const dbusDefaultFlag = 0
func isDbusListenerRunning(dest string, path dbus.ObjectPath) bool {
obj, closeConn, err := getDbusObject(dest, path)
if err != nil {
return false
}
defer closeConn()
ctx, cancel := context.WithTimeout(context.TODO(), 5*time.Second)
defer cancel()
err = obj.CallWithContext(ctx, "org.freedesktop.DBus.Peer.Ping", 0).Store()
return err == nil
}
func getDbusObject(dest string, path dbus.ObjectPath) (dbus.BusObject, func(), error) {
conn, err := dbus.SystemBus()
if err != nil {
return nil, nil, err
}
obj := conn.Object(dest, path)
closeFunc := func() {
closeErr := conn.Close()
if closeErr != nil {
log.Warnf("got an error closing dbus connection, err: %s", closeErr)
}
}
return obj, closeFunc, nil
}

View File

@@ -0,0 +1,154 @@
package dns
import (
"bytes"
"fmt"
log "github.com/sirupsen/logrus"
"os"
)
const (
fileGeneratedResolvConfContentHeader = "# Generated by NetBird"
fileGeneratedResolvConfSearchBeginContent = "search "
fileGeneratedResolvConfContentFormat = fileGeneratedResolvConfContentHeader +
"\n# If needed you can restore the original file by copying back %s\n\nnameserver %s\n" +
fileGeneratedResolvConfSearchBeginContent + "%s\n"
)
const (
fileDefaultResolvConfBackupLocation = defaultResolvConfPath + ".original.netbird"
fileMaxLineCharsLimit = 256
fileMaxNumberOfSearchDomains = 6
)
var fileSearchLineBeginCharCount = len(fileGeneratedResolvConfSearchBeginContent)
type fileConfigurator struct {
originalPerms os.FileMode
}
func newFileConfigurator() (hostManager, error) {
return &fileConfigurator{}, nil
}
func (f *fileConfigurator) applyDNSConfig(config hostDNSConfig) error {
backupFileExist := false
_, err := os.Stat(fileDefaultResolvConfBackupLocation)
if err == nil {
backupFileExist = true
}
if !config.routeAll {
if backupFileExist {
err = f.restore()
if err != nil {
return fmt.Errorf("unable to configure DNS for this peer using file manager without a Primary nameserver group. Restoring the original file return err: %s", err)
}
}
return fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured")
}
managerType, err := getOSDNSManagerType()
if err != nil {
return err
}
switch managerType {
case fileManager, netbirdManager:
if !backupFileExist {
err = f.backup()
if err != nil {
return fmt.Errorf("unable to backup the resolv.conf file")
}
}
default:
// todo improve this and maybe restart DNS manager from scratch
return fmt.Errorf("something happened and file manager is not your prefered host dns configurator, restart the agent")
}
var searchDomains string
appendedDomains := 0
for _, dConf := range config.domains {
if dConf.matchOnly {
continue
}
if appendedDomains >= fileMaxNumberOfSearchDomains {
// lets log all skipped domains
log.Infof("already appended %d domains to search list. Skipping append of %s domain", fileMaxNumberOfSearchDomains, dConf.domain)
continue
}
if fileSearchLineBeginCharCount+len(searchDomains) > fileMaxLineCharsLimit {
// lets log all skipped domains
log.Infof("search list line is larger than %d characters. Skipping append of %s domain", fileMaxLineCharsLimit, dConf.domain)
continue
}
searchDomains += " " + dConf.domain
appendedDomains++
}
content := fmt.Sprintf(fileGeneratedResolvConfContentFormat, fileDefaultResolvConfBackupLocation, config.serverIP, searchDomains)
err = writeDNSConfig(content, defaultResolvConfPath, f.originalPerms)
if err != nil {
err = f.restore()
if err != nil {
log.Errorf("attempt to restore default file failed with error: %s", err)
}
return err
}
log.Infof("created a NetBird managed %s file with your DNS settings. Added %d search domains. Search list: %s", defaultResolvConfPath, appendedDomains, searchDomains)
return nil
}
func (f *fileConfigurator) restoreHostDNS() error {
return f.restore()
}
func (f *fileConfigurator) backup() error {
stats, err := os.Stat(defaultResolvConfPath)
if err != nil {
return fmt.Errorf("got an error while checking stats for %s file. Error: %s", defaultResolvConfPath, err)
}
f.originalPerms = stats.Mode()
err = copyFile(defaultResolvConfPath, fileDefaultResolvConfBackupLocation)
if err != nil {
return fmt.Errorf("got error while backing up the %s file. Error: %s", defaultResolvConfPath, err)
}
return nil
}
func (f *fileConfigurator) restore() error {
err := copyFile(fileDefaultResolvConfBackupLocation, defaultResolvConfPath)
if err != nil {
return fmt.Errorf("got error while restoring the %s file from %s. Error: %s", defaultResolvConfPath, fileDefaultResolvConfBackupLocation, err)
}
return os.RemoveAll(fileDefaultResolvConfBackupLocation)
}
func writeDNSConfig(content, fileName string, permissions os.FileMode) error {
log.Debugf("creating managed file %s", fileName)
var buf bytes.Buffer
buf.WriteString(content)
err := os.WriteFile(fileName, buf.Bytes(), permissions)
if err != nil {
return fmt.Errorf("got an creating resolver file %s. Error: %s", fileName, err)
}
return nil
}
func copyFile(src, dest string) error {
stats, err := os.Stat(src)
if err != nil {
return fmt.Errorf("got an error while checking stats for %s file when copying it. Error: %s", src, err)
}
bytesRead, err := os.ReadFile(src)
if err != nil {
return fmt.Errorf("got an error while reading the file %s file for copy. Error: %s", src, err)
}
err = os.WriteFile(dest, bytesRead, stats.Mode())
if err != nil {
return fmt.Errorf("got an writing the destination file %s for copy. Error: %s", dest, err)
}
return nil
}

View File

@@ -0,0 +1,79 @@
package dns
import (
"fmt"
nbdns "github.com/netbirdio/netbird/dns"
"strings"
)
type hostManager interface {
applyDNSConfig(config hostDNSConfig) error
restoreHostDNS() error
}
type hostDNSConfig struct {
domains []domainConfig
routeAll bool
serverIP string
serverPort int
}
type domainConfig struct {
domain string
matchOnly bool
}
type mockHostConfigurator struct {
applyDNSConfigFunc func(config hostDNSConfig) error
restoreHostDNSFunc func() error
}
func (m *mockHostConfigurator) applyDNSConfig(config hostDNSConfig) error {
if m.applyDNSConfigFunc != nil {
return m.applyDNSConfigFunc(config)
}
return fmt.Errorf("method applyDNSSettings is not implemented")
}
func (m *mockHostConfigurator) restoreHostDNS() error {
if m.restoreHostDNSFunc != nil {
return m.restoreHostDNSFunc()
}
return fmt.Errorf("method restoreHostDNS is not implemented")
}
func newNoopHostMocker() hostManager {
return &mockHostConfigurator{
applyDNSConfigFunc: func(config hostDNSConfig) error { return nil },
restoreHostDNSFunc: func() error { return nil },
}
}
func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) hostDNSConfig {
config := hostDNSConfig{
routeAll: false,
serverIP: ip,
serverPort: port,
}
for _, nsConfig := range dnsConfig.NameServerGroups {
if nsConfig.Primary {
config.routeAll = true
}
for _, domain := range nsConfig.Domains {
config.domains = append(config.domains, domainConfig{
domain: strings.TrimSuffix(domain, "."),
matchOnly: true,
})
}
}
for _, customZone := range dnsConfig.CustomZones {
config.domains = append(config.domains, domainConfig{
domain: strings.TrimSuffix(customZone.Domain, "."),
matchOnly: false,
})
}
return config
}

View File

@@ -0,0 +1,259 @@
package dns
import (
"bufio"
"bytes"
"fmt"
"github.com/netbirdio/netbird/iface"
log "github.com/sirupsen/logrus"
"os/exec"
"strconv"
"strings"
)
const (
netbirdDNSStateKeyFormat = "State:/Network/Service/NetBird-%s/DNS"
globalIPv4State = "State:/Network/Global/IPv4"
primaryServiceSetupKeyFormat = "Setup:/Network/Service/%s/DNS"
keySupplementalMatchDomains = "SupplementalMatchDomains"
keySupplementalMatchDomainsNoSearch = "SupplementalMatchDomainsNoSearch"
keyServerAddresses = "ServerAddresses"
keyServerPort = "ServerPort"
arraySymbol = "* "
digitSymbol = "# "
scutilPath = "/usr/sbin/scutil"
searchSuffix = "Search"
matchSuffix = "Match"
)
type systemConfigurator struct {
// primaryServiceID primary interface in the system. AKA the interface with the default route
primaryServiceID string
createdKeys map[string]struct{}
}
func newHostManager(_ *iface.WGIface) (hostManager, error) {
return &systemConfigurator{
createdKeys: make(map[string]struct{}),
}, nil
}
func (s *systemConfigurator) applyDNSConfig(config hostDNSConfig) error {
var err error
if config.routeAll {
err = s.addDNSSetupForAll(config.serverIP, config.serverPort)
if err != nil {
return err
}
} else if s.primaryServiceID != "" {
err = s.removeKeyFromSystemConfig(getKeyWithInput(primaryServiceSetupKeyFormat, s.primaryServiceID))
if err != nil {
return err
}
s.primaryServiceID = ""
log.Infof("removed %s:%d as main DNS resolver for this peer", config.serverIP, config.serverPort)
}
var (
searchDomains []string
matchDomains []string
)
for _, dConf := range config.domains {
if dConf.matchOnly {
matchDomains = append(matchDomains, dConf.domain)
continue
}
searchDomains = append(searchDomains, dConf.domain)
}
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
if len(matchDomains) != 0 {
err = s.addMatchDomains(matchKey, strings.Join(matchDomains, " "), config.serverIP, config.serverPort)
} else {
log.Infof("removing match domains from the system")
err = s.removeKeyFromSystemConfig(matchKey)
}
if err != nil {
return err
}
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
if len(searchDomains) != 0 {
err = s.addSearchDomains(searchKey, strings.Join(searchDomains, " "), config.serverIP, config.serverPort)
} else {
log.Infof("removing search domains from the system")
err = s.removeKeyFromSystemConfig(searchKey)
}
if err != nil {
return err
}
return nil
}
func (s *systemConfigurator) restoreHostDNS() error {
lines := ""
for key := range s.createdKeys {
lines += buildRemoveKeyOperation(key)
keyType := "search"
if strings.Contains(key, matchSuffix) {
keyType = "match"
}
log.Infof("removing %s domains from system", keyType)
}
if s.primaryServiceID != "" {
lines += buildRemoveKeyOperation(getKeyWithInput(primaryServiceSetupKeyFormat, s.primaryServiceID))
log.Infof("restoring DNS resolver configuration for system")
}
_, err := runSystemConfigCommand(wrapCommand(lines))
if err != nil {
log.Errorf("got an error while cleaning the system configuration: %s", err)
return err
}
return nil
}
func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error {
line := buildRemoveKeyOperation(key)
_, err := runSystemConfigCommand(wrapCommand(line))
if err != nil {
return err
}
delete(s.createdKeys, key)
return nil
}
func (s *systemConfigurator) addSearchDomains(key, domains string, ip string, port int) error {
err := s.addDNSState(key, domains, ip, port, true)
if err != nil {
return err
}
log.Infof("added %d search domains to the state. Domain list: %s", len(strings.Split(domains, " ")), domains)
s.createdKeys[key] = struct{}{}
return nil
}
func (s *systemConfigurator) addMatchDomains(key, domains, dnsServer string, port int) error {
err := s.addDNSState(key, domains, dnsServer, port, false)
if err != nil {
return err
}
log.Infof("added %d match domains to the state. Domain list: %s", len(strings.Split(domains, " ")), domains)
s.createdKeys[key] = struct{}{}
return nil
}
func (s *systemConfigurator) addDNSState(state, domains, dnsServer string, port int, enableSearch bool) error {
noSearch := "1"
if enableSearch {
noSearch = "0"
}
lines := buildAddCommandLine(keySupplementalMatchDomains, arraySymbol+domains)
lines += buildAddCommandLine(keySupplementalMatchDomainsNoSearch, digitSymbol+noSearch)
lines += buildAddCommandLine(keyServerAddresses, arraySymbol+dnsServer)
lines += buildAddCommandLine(keyServerPort, digitSymbol+strconv.Itoa(port))
addDomainCommand := buildCreateStateWithOperation(state, lines)
stdinCommands := wrapCommand(addDomainCommand)
_, err := runSystemConfigCommand(stdinCommands)
if err != nil {
return fmt.Errorf("got error while applying state for domains %s, error: %s", domains, err)
}
return nil
}
func (s *systemConfigurator) addDNSSetupForAll(dnsServer string, port int) error {
primaryServiceKey := s.getPrimaryService()
if primaryServiceKey == "" {
return fmt.Errorf("couldn't find the primary service key")
}
err := s.addDNSSetup(getKeyWithInput(primaryServiceSetupKeyFormat, primaryServiceKey), dnsServer, port)
if err != nil {
return err
}
log.Infof("configured %s:%d as main DNS resolver for this peer", dnsServer, port)
s.primaryServiceID = primaryServiceKey
return nil
}
func (s *systemConfigurator) getPrimaryService() string {
line := buildCommandLine("show", globalIPv4State, "")
stdinCommands := wrapCommand(line)
b, err := runSystemConfigCommand(stdinCommands)
if err != nil {
log.Error("got error while sending the command: ", err)
return ""
}
scanner := bufio.NewScanner(bytes.NewReader(b))
for scanner.Scan() {
text := scanner.Text()
if strings.Contains(text, "PrimaryService") {
return strings.TrimSpace(strings.Split(text, ":")[1])
}
}
return ""
}
func (s *systemConfigurator) addDNSSetup(setupKey, dnsServer string, port int) error {
lines := buildAddCommandLine(keySupplementalMatchDomainsNoSearch, digitSymbol+strconv.Itoa(0))
lines += buildAddCommandLine(keyServerAddresses, arraySymbol+dnsServer)
lines += buildAddCommandLine(keyServerPort, digitSymbol+strconv.Itoa(port))
addDomainCommand := buildCreateStateWithOperation(setupKey, lines)
stdinCommands := wrapCommand(addDomainCommand)
_, err := runSystemConfigCommand(stdinCommands)
if err != nil {
return fmt.Errorf("got error while applying dns setup, error: %s", err)
}
return nil
}
func getKeyWithInput(format, key string) string {
return fmt.Sprintf(format, key)
}
func buildAddCommandLine(key, value string) string {
return buildCommandLine("d.add", key, value)
}
func buildCommandLine(action, key, value string) string {
return fmt.Sprintf("%s %s %s\n", action, key, value)
}
func wrapCommand(commands string) string {
return fmt.Sprintf("open\n%s\nquit\n", commands)
}
func buildRemoveKeyOperation(key string) string {
return fmt.Sprintf("remove %s\n", key)
}
func buildCreateStateWithOperation(state, commands string) string {
return buildWriteStateOperation("set", state, commands)
}
func buildWriteStateOperation(operation, state, commands string) string {
return fmt.Sprintf("d.init\n%s %s\n%s\nset %s\n", operation, state, commands, state)
}
func runSystemConfigCommand(command string) ([]byte, error) {
cmd := exec.Command(scutilPath)
cmd.Stdin = strings.NewReader(command)
out, err := cmd.Output()
if err != nil {
return nil, fmt.Errorf("got error while running system configuration command: \"%s\", error: %s", command, err)
}
return out, nil
}

View File

@@ -0,0 +1,87 @@
package dns
import (
"bufio"
"fmt"
"github.com/netbirdio/netbird/iface"
log "github.com/sirupsen/logrus"
"os"
"strings"
)
const (
defaultResolvConfPath = "/etc/resolv.conf"
)
const (
netbirdManager osManagerType = iota
fileManager
networkManager
systemdManager
resolvConfManager
)
type osManagerType int
func newHostManager(wgInterface *iface.WGIface) (hostManager, error) {
osManager, err := getOSDNSManagerType()
if err != nil {
return nil, err
}
log.Debugf("discovered mode is: %d", osManager)
switch osManager {
case networkManager:
return newNetworkManagerDbusConfigurator(wgInterface)
case systemdManager:
return newSystemdDbusConfigurator(wgInterface)
case resolvConfManager:
return newResolvConfConfigurator(wgInterface)
default:
return newFileConfigurator()
}
}
func getOSDNSManagerType() (osManagerType, error) {
file, err := os.Open(defaultResolvConfPath)
if err != nil {
return 0, fmt.Errorf("unable to open %s for checking owner, got error: %s", defaultResolvConfPath, err)
}
defer file.Close()
scanner := bufio.NewScanner(file)
for scanner.Scan() {
text := scanner.Text()
if len(text) == 0 {
continue
}
if text[0] != '#' {
return fileManager, nil
}
if strings.Contains(text, fileGeneratedResolvConfContentHeader) {
return netbirdManager, nil
}
if strings.Contains(text, "NetworkManager") && isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
log.Debugf("is nm running on supported v? %t", isNetworkManagerSupportedVersion())
return networkManager, nil
}
if strings.Contains(text, "systemd-resolved") && isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode) {
return systemdManager, nil
}
if strings.Contains(text, "resolvconf") {
if isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode) {
var value string
err = getSystemdDbusProperty(systemdDbusResolvConfModeProperty, &value)
if err == nil {
if value == systemdDbusResolvConfModeForeign {
return systemdManager, nil
}
}
log.Errorf("got an error while checking systemd resolv conf mode, error: %s", err)
}
return resolvConfManager, nil
}
}
return fileManager, nil
}

View File

@@ -0,0 +1,260 @@
package dns
import (
"fmt"
"github.com/netbirdio/netbird/iface"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows/registry"
"strings"
)
const (
dnsPolicyConfigMatchPath = "SYSTEM\\CurrentControlSet\\Services\\Dnscache\\Parameters\\DnsPolicyConfig\\NetBird-Match"
dnsPolicyConfigVersionKey = "Version"
dnsPolicyConfigVersionValue = 2
dnsPolicyConfigNameKey = "Name"
dnsPolicyConfigGenericDNSServersKey = "GenericDNSServers"
dnsPolicyConfigConfigOptionsKey = "ConfigOptions"
dnsPolicyConfigConfigOptionsValue = 0x8
)
const (
interfaceConfigPath = "SYSTEM\\CurrentControlSet\\Services\\Tcpip\\Parameters\\Interfaces"
interfaceConfigNameServerKey = "NameServer"
interfaceConfigSearchListKey = "SearchList"
tcpipParametersPath = "SYSTEM\\CurrentControlSet\\Services\\Tcpip\\Parameters"
)
type registryConfigurator struct {
guid string
routingAll bool
existingSearchDomains []string
}
func newHostManager(wgInterface *iface.WGIface) (hostManager, error) {
guid, err := wgInterface.GetInterfaceGUIDString()
if err != nil {
return nil, err
}
return &registryConfigurator{
guid: guid,
}, nil
}
func (r *registryConfigurator) applyDNSConfig(config hostDNSConfig) error {
var err error
if config.routeAll {
err = r.addDNSSetupForAll(config.serverIP)
if err != nil {
return err
}
} else if r.routingAll {
err = r.deleteInterfaceRegistryKeyProperty(interfaceConfigNameServerKey)
if err != nil {
return err
}
r.routingAll = false
log.Infof("removed %s as main DNS forwarder for this peer", config.serverIP)
}
var (
searchDomains []string
matchDomains []string
)
for _, dConf := range config.domains {
if !dConf.matchOnly {
searchDomains = append(searchDomains, dConf.domain)
}
matchDomains = append(matchDomains, "."+dConf.domain)
}
if len(matchDomains) != 0 {
err = r.addDNSMatchPolicy(matchDomains, config.serverIP)
} else {
err = removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath)
}
if err != nil {
return err
}
err = r.updateSearchDomains(searchDomains)
if err != nil {
return err
}
return nil
}
func (r *registryConfigurator) addDNSSetupForAll(ip string) error {
err := r.setInterfaceRegistryKeyStringValue(interfaceConfigNameServerKey, ip)
if err != nil {
return fmt.Errorf("adding dns setup for all failed with error: %s", err)
}
r.routingAll = true
log.Infof("configured %s:53 as main DNS forwarder for this peer", ip)
return nil
}
func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip string) error {
_, err := registry.OpenKey(registry.LOCAL_MACHINE, dnsPolicyConfigMatchPath, registry.QUERY_VALUE)
if err == nil {
err = registry.DeleteKey(registry.LOCAL_MACHINE, dnsPolicyConfigMatchPath)
if err != nil {
return fmt.Errorf("unable to remove existing key from registry, key: HKEY_LOCAL_MACHINE\\%s, error: %s", dnsPolicyConfigMatchPath, err)
}
}
regKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, dnsPolicyConfigMatchPath, registry.SET_VALUE)
if err != nil {
return fmt.Errorf("unable to create registry key, key: HKEY_LOCAL_MACHINE\\%s, error: %s", dnsPolicyConfigMatchPath, err)
}
err = regKey.SetDWordValue(dnsPolicyConfigVersionKey, dnsPolicyConfigVersionValue)
if err != nil {
return fmt.Errorf("unable to set registry value for %s, error: %s", dnsPolicyConfigVersionKey, err)
}
err = regKey.SetStringsValue(dnsPolicyConfigNameKey, domains)
if err != nil {
return fmt.Errorf("unable to set registry value for %s, error: %s", dnsPolicyConfigNameKey, err)
}
err = regKey.SetStringValue(dnsPolicyConfigGenericDNSServersKey, ip)
if err != nil {
return fmt.Errorf("unable to set registry value for %s, error: %s", dnsPolicyConfigGenericDNSServersKey, err)
}
err = regKey.SetDWordValue(dnsPolicyConfigConfigOptionsKey, dnsPolicyConfigConfigOptionsValue)
if err != nil {
return fmt.Errorf("unable to set registry value for %s, error: %s", dnsPolicyConfigConfigOptionsKey, err)
}
log.Infof("added %d match domains to the state. Domain list: %s", len(domains), domains)
return nil
}
func (r *registryConfigurator) restoreHostDNS() error {
err := removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath)
if err != nil {
log.Error(err)
}
return r.updateSearchDomains([]string{})
}
func (r *registryConfigurator) updateSearchDomains(domains []string) error {
value, err := getLocalMachineRegistryKeyStringValue(tcpipParametersPath, interfaceConfigSearchListKey)
if err != nil {
return fmt.Errorf("unable to get current search domains failed with error: %s", err)
}
valueList := strings.Split(value, ",")
setExisting := false
if len(r.existingSearchDomains) == 0 {
r.existingSearchDomains = valueList
setExisting = true
}
if len(domains) == 0 && setExisting {
log.Infof("added %d search domains to the registry. Domain list: %s", len(domains), domains)
return nil
}
newList := append(r.existingSearchDomains, domains...)
err = setLocalMachineRegistryKeyStringValue(tcpipParametersPath, interfaceConfigSearchListKey, strings.Join(newList, ","))
if err != nil {
return fmt.Errorf("adding search domain failed with error: %s", err)
}
log.Infof("updated the search domains in the registry with %d domains. Domain list: %s", len(domains), domains)
return nil
}
func (r *registryConfigurator) setInterfaceRegistryKeyStringValue(key, value string) error {
regKey, err := r.getInterfaceRegistryKey()
if err != nil {
return err
}
defer regKey.Close()
err = regKey.SetStringValue(key, value)
if err != nil {
return fmt.Errorf("applying key %s with value \"%s\" for interface failed with error: %s", key, value, err)
}
return nil
}
func (r *registryConfigurator) deleteInterfaceRegistryKeyProperty(propertyKey string) error {
regKey, err := r.getInterfaceRegistryKey()
if err != nil {
return err
}
defer regKey.Close()
err = regKey.DeleteValue(propertyKey)
if err != nil {
return fmt.Errorf("deleting registry key %s for interface failed with error: %s", propertyKey, err)
}
return nil
}
func (r *registryConfigurator) getInterfaceRegistryKey() (registry.Key, error) {
var regKey registry.Key
regKeyPath := interfaceConfigPath + "\\" + r.guid
regKey, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyPath, registry.SET_VALUE)
if err != nil {
return regKey, fmt.Errorf("unable to open the interface registry key, key: HKEY_LOCAL_MACHINE\\%s, error: %s", regKeyPath, err)
}
return regKey, nil
}
func removeRegistryKeyFromDNSPolicyConfig(regKeyPath string) error {
k, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyPath, registry.QUERY_VALUE)
if err == nil {
k.Close()
err = registry.DeleteKey(registry.LOCAL_MACHINE, regKeyPath)
if err != nil {
return fmt.Errorf("unable to remove existing key from registry, key: HKEY_LOCAL_MACHINE\\%s, error: %s", regKeyPath, err)
}
}
return nil
}
func getLocalMachineRegistryKeyStringValue(keyPath, key string) (string, error) {
regKey, err := registry.OpenKey(registry.LOCAL_MACHINE, keyPath, registry.QUERY_VALUE)
if err != nil {
return "", fmt.Errorf("unable to open existing key from registry, key path: HKEY_LOCAL_MACHINE\\%s, error: %s", keyPath, err)
}
defer regKey.Close()
val, _, err := regKey.GetStringValue(key)
if err != nil {
return "", fmt.Errorf("getting %s value for key path HKEY_LOCAL_MACHINE\\%s failed with error: %s", key, keyPath, err)
}
return val, nil
}
func setLocalMachineRegistryKeyStringValue(keyPath, key, value string) error {
regKey, err := registry.OpenKey(registry.LOCAL_MACHINE, keyPath, registry.SET_VALUE)
if err != nil {
return fmt.Errorf("unable to open existing key from registry, key path: HKEY_LOCAL_MACHINE\\%s, error: %s", keyPath, err)
}
defer regKey.Close()
err = regKey.SetStringValue(key, value)
if err != nil {
return fmt.Errorf("setting %s value %s for key path HKEY_LOCAL_MACHINE\\%s failed with error: %s", key, value, keyPath, err)
}
return nil
}

View File

@@ -0,0 +1,66 @@
package dns
import (
"fmt"
"github.com/miekg/dns"
nbdns "github.com/netbirdio/netbird/dns"
log "github.com/sirupsen/logrus"
"sync"
)
type localResolver struct {
registeredMap registrationMap
records sync.Map
}
// ServeDNS handles a DNS request
func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
log.Tracef("received question: %#v\n", r.Question[0])
replyMessage := &dns.Msg{}
replyMessage.SetReply(r)
replyMessage.RecursionAvailable = true
replyMessage.Rcode = dns.RcodeSuccess
response := d.lookupRecord(r)
if response != nil {
replyMessage.Answer = append(replyMessage.Answer, response)
}
err := w.WriteMsg(replyMessage)
if err != nil {
log.Debugf("got an error while writing the local resolver response, error: %v", err)
}
}
func (d *localResolver) lookupRecord(r *dns.Msg) dns.RR {
question := r.Question[0]
record, found := d.records.Load(buildRecordKey(question.Name, question.Qclass, question.Qtype))
if !found {
return nil
}
return record.(dns.RR)
}
func (d *localResolver) registerRecord(record nbdns.SimpleRecord) error {
fullRecord, err := dns.NewRR(record.String())
if err != nil {
return err
}
fullRecord.Header().Rdlength = record.Len()
header := fullRecord.Header()
d.records.Store(buildRecordKey(header.Name, header.Class, header.Rrtype), fullRecord)
return nil
}
func (d *localResolver) deleteRecord(recordKey string) {
d.records.Delete(dns.Fqdn(recordKey))
}
func buildRecordKey(name string, class, qType uint16) string {
key := fmt.Sprintf("%s_%d_%d", name, class, qType)
return key
}

View File

@@ -0,0 +1,86 @@
package dns
import (
"github.com/miekg/dns"
nbdns "github.com/netbirdio/netbird/dns"
"strings"
"testing"
)
func TestLocalResolver_ServeDNS(t *testing.T) {
recordA := nbdns.SimpleRecord{
Name: "peera.netbird.cloud.",
Type: 1,
Class: nbdns.DefaultClass,
TTL: 300,
RData: "1.2.3.4",
}
recordCNAME := nbdns.SimpleRecord{
Name: "peerb.netbird.cloud.",
Type: 5,
Class: nbdns.DefaultClass,
TTL: 300,
RData: "www.netbird.io",
}
testCases := []struct {
name string
inputRecord nbdns.SimpleRecord
inputMSG *dns.Msg
responseShouldBeNil bool
}{
{
name: "Should Resolve A Record",
inputRecord: recordA,
inputMSG: new(dns.Msg).SetQuestion(recordA.Name, dns.TypeA),
},
{
name: "Should Resolve CNAME Record",
inputRecord: recordCNAME,
inputMSG: new(dns.Msg).SetQuestion(recordCNAME.Name, dns.TypeCNAME),
},
{
name: "Should Not Write When Not Found A Record",
inputRecord: recordA,
inputMSG: new(dns.Msg).SetQuestion("not.found.com", dns.TypeA),
responseShouldBeNil: true,
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
resolver := &localResolver{
registeredMap: make(registrationMap),
}
_ = resolver.registerRecord(testCase.inputRecord)
var responseMSG *dns.Msg
responseWriter := &mockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
resolver.ServeDNS(responseWriter, testCase.inputMSG)
if responseMSG == nil || len(responseMSG.Answer) == 0 {
if testCase.responseShouldBeNil {
return
}
t.Fatalf("should write a response message")
}
answerString := responseMSG.Answer[0].String()
if !strings.Contains(answerString, testCase.inputRecord.Name) {
t.Fatalf("answer doesn't contain the same domain name: \nWant: %s\nGot:%s", testCase.name, answerString)
}
if !strings.Contains(answerString, dns.Type(testCase.inputRecord.Type).String()) {
t.Fatalf("answer doesn't contain the correct type: \nWant: %s\nGot:%s", dns.Type(testCase.inputRecord.Type).String(), answerString)
}
if !strings.Contains(answerString, testCase.inputRecord.RData) {
t.Fatalf("answer doesn't contain the same address: \nWant: %s\nGot:%s", testCase.inputRecord.RData, answerString)
}
})
}
}

View File

@@ -0,0 +1,35 @@
package dns
import (
"fmt"
nbdns "github.com/netbirdio/netbird/dns"
)
// MockServer is the mock instance of a dns server
type MockServer struct {
StartFunc func()
StopFunc func()
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
}
// Start mock implementation of Start from Server interface
func (m *MockServer) Start() {
if m.StartFunc != nil {
m.StartFunc()
}
}
// Stop mock implementation of Stop from Server interface
func (m *MockServer) Stop() {
if m.StopFunc != nil {
m.StopFunc()
}
}
// UpdateDNSServer mock implementation of UpdateDNSServer from Server interface
func (m *MockServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
if m.UpdateDNSServerFunc != nil {
return m.UpdateDNSServerFunc(serial, update)
}
return fmt.Errorf("method UpdateDNSServer is not implemented")
}

View File

@@ -0,0 +1,25 @@
package dns
import (
"github.com/miekg/dns"
"net"
)
type mockResponseWriter struct {
WriteMsgFunc func(m *dns.Msg) error
}
func (rw *mockResponseWriter) WriteMsg(m *dns.Msg) error {
if rw.WriteMsgFunc != nil {
return rw.WriteMsgFunc(m)
}
return nil
}
func (rw *mockResponseWriter) LocalAddr() net.Addr { return nil }
func (rw *mockResponseWriter) RemoteAddr() net.Addr { return nil }
func (rw *mockResponseWriter) Write([]byte) (int, error) { return 0, nil }
func (rw *mockResponseWriter) Close() error { return nil }
func (rw *mockResponseWriter) TsigStatus() error { return nil }
func (rw *mockResponseWriter) TsigTimersOnly(bool) {}
func (rw *mockResponseWriter) Hijack() {}

View File

@@ -0,0 +1,295 @@
package dns
import (
"context"
"encoding/binary"
"fmt"
"github.com/godbus/dbus/v5"
"github.com/hashicorp/go-version"
"github.com/miekg/dns"
"github.com/netbirdio/netbird/iface"
log "github.com/sirupsen/logrus"
"net/netip"
"regexp"
"time"
)
const (
networkManagerDest = "org.freedesktop.NetworkManager"
networkManagerDbusObjectNode = "/org/freedesktop/NetworkManager"
networkManagerDbusDNSManagerInterface = "org.freedesktop.NetworkManager.DnsManager"
networkManagerDbusDNSManagerObjectNode = networkManagerDbusObjectNode + "/DnsManager"
networkManagerDbusDNSManagerModeProperty = networkManagerDbusDNSManagerInterface + ".Mode"
networkManagerDbusDNSManagerRcManagerProperty = networkManagerDbusDNSManagerInterface + ".RcManager"
networkManagerDbusVersionProperty = "org.freedesktop.NetworkManager.Version"
networkManagerDbusGetDeviceByIPIfaceMethod = networkManagerDest + ".GetDeviceByIpIface"
networkManagerDbusDeviceInterface = "org.freedesktop.NetworkManager.Device"
networkManagerDbusDeviceGetAppliedConnectionMethod = networkManagerDbusDeviceInterface + ".GetAppliedConnection"
networkManagerDbusDeviceReapplyMethod = networkManagerDbusDeviceInterface + ".Reapply"
networkManagerDbusDeviceDeleteMethod = networkManagerDbusDeviceInterface + ".Delete"
networkManagerDbusDefaultBehaviorFlag networkManagerConfigBehavior = 0
networkManagerDbusIPv4Key = "ipv4"
networkManagerDbusIPv6Key = "ipv6"
networkManagerDbusDNSKey = "dns"
networkManagerDbusDNSSearchKey = "dns-search"
networkManagerDbusDNSPriorityKey = "dns-priority"
// dns priority doc https://wiki.gnome.org/Projects/NetworkManager/DNS
networkManagerDbusPrimaryDNSPriority int32 = -500
networkManagerDbusWithMatchDomainPriority int32 = 0
networkManagerDbusSearchDomainOnlyPriority int32 = 50
supportedNetworkManagerVersionConstraint = ">= 1.16, < 1.28"
)
type networkManagerDbusConfigurator struct {
dbusLinkObject dbus.ObjectPath
routingAll bool
}
// the types below are based on dbus specification, each field is mapped to a dbus type
// see https://dbus.freedesktop.org/doc/dbus-specification.html#basic-types for more details on dbus types
// see https://networkmanager.dev/docs/api/latest/gdbus-org.freedesktop.NetworkManager.Device.html on Network Manager input types
// networkManagerConnSettings maps to a (a{sa{sv}}) dbus output from GetAppliedConnection and input for Reapply methods
type networkManagerConnSettings map[string]map[string]dbus.Variant
// networkManagerConfigVersion maps to a (t) dbus output from GetAppliedConnection and input for Reapply methods
type networkManagerConfigVersion uint64
// networkManagerConfigBehavior maps to a (u) dbus input for GetAppliedConnection and Reapply methods
type networkManagerConfigBehavior uint32
// cleanDeprecatedSettings cleans deprecated settings that still returned by
// the GetAppliedConnection methods but can't be reApplied
func (s networkManagerConnSettings) cleanDeprecatedSettings() {
for _, key := range []string{"addresses", "routes"} {
delete(s[networkManagerDbusIPv4Key], key)
delete(s[networkManagerDbusIPv6Key], key)
}
}
func newNetworkManagerDbusConfigurator(wgInterface *iface.WGIface) (hostManager, error) {
obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode)
if err != nil {
return nil, err
}
defer closeConn()
var s string
err = obj.Call(networkManagerDbusGetDeviceByIPIfaceMethod, dbusDefaultFlag, wgInterface.GetName()).Store(&s)
if err != nil {
return nil, err
}
log.Debugf("got network manager dbus Link Object: %s from net interface %s", s, wgInterface.GetName())
return &networkManagerDbusConfigurator{
dbusLinkObject: dbus.ObjectPath(s),
}, nil
}
func (n *networkManagerDbusConfigurator) applyDNSConfig(config hostDNSConfig) error {
connSettings, configVersion, err := n.getAppliedConnectionSettings()
if err != nil {
return fmt.Errorf("got an error while retrieving the applied connection settings, error: %s", err)
}
connSettings.cleanDeprecatedSettings()
dnsIP := netip.MustParseAddr(config.serverIP)
convDNSIP := binary.LittleEndian.Uint32(dnsIP.AsSlice())
connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSKey] = dbus.MakeVariant([]uint32{convDNSIP})
var (
searchDomains []string
matchDomains []string
)
for _, dConf := range config.domains {
if dConf.matchOnly {
matchDomains = append(matchDomains, "~."+dns.Fqdn(dConf.domain))
continue
}
searchDomains = append(searchDomains, dns.Fqdn(dConf.domain))
}
newDomainList := append(searchDomains, matchDomains...)
priority := networkManagerDbusSearchDomainOnlyPriority
switch {
case config.routeAll:
priority = networkManagerDbusPrimaryDNSPriority
newDomainList = append(newDomainList, "~.")
if !n.routingAll {
log.Infof("configured %s:%d as main DNS forwarder for this peer", config.serverIP, config.serverPort)
}
case len(matchDomains) > 0:
priority = networkManagerDbusWithMatchDomainPriority
}
if priority != networkManagerDbusPrimaryDNSPriority && n.routingAll {
log.Infof("removing %s:%d as main DNS forwarder for this peer", config.serverIP, config.serverPort)
n.routingAll = false
}
connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSPriorityKey] = dbus.MakeVariant(priority)
connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSSearchKey] = dbus.MakeVariant(newDomainList)
log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains)
err = n.reApplyConnectionSettings(connSettings, configVersion)
if err != nil {
return fmt.Errorf("got an error while reapplying the connection with new settings, error: %s", err)
}
return nil
}
func (n *networkManagerDbusConfigurator) restoreHostDNS() error {
// once the interface is gone network manager cleans all config associated with it
return n.deleteConnectionSettings()
}
func (n *networkManagerDbusConfigurator) getAppliedConnectionSettings() (networkManagerConnSettings, networkManagerConfigVersion, error) {
obj, closeConn, err := getDbusObject(networkManagerDest, n.dbusLinkObject)
if err != nil {
return nil, 0, fmt.Errorf("got error while attempting to retrieve the applied connection settings, err: %s", err)
}
defer closeConn()
ctx, cancel := context.WithTimeout(context.TODO(), 5*time.Second)
defer cancel()
var (
connSettings networkManagerConnSettings
configVersion networkManagerConfigVersion
)
err = obj.CallWithContext(ctx, networkManagerDbusDeviceGetAppliedConnectionMethod, dbusDefaultFlag,
networkManagerDbusDefaultBehaviorFlag).Store(&connSettings, &configVersion)
if err != nil {
return nil, 0, fmt.Errorf("got error while calling GetAppliedConnection method with context, err: %s", err)
}
return connSettings, configVersion, nil
}
func (n *networkManagerDbusConfigurator) reApplyConnectionSettings(connSettings networkManagerConnSettings, configVersion networkManagerConfigVersion) error {
obj, closeConn, err := getDbusObject(networkManagerDest, n.dbusLinkObject)
if err != nil {
return fmt.Errorf("got error while attempting to retrieve the applied connection settings, err: %s", err)
}
defer closeConn()
ctx, cancel := context.WithTimeout(context.TODO(), 5*time.Second)
defer cancel()
err = obj.CallWithContext(ctx, networkManagerDbusDeviceReapplyMethod, dbusDefaultFlag,
connSettings, configVersion, networkManagerDbusDefaultBehaviorFlag).Store()
if err != nil {
return fmt.Errorf("got error while calling ReApply method with context, err: %s", err)
}
return nil
}
func (n *networkManagerDbusConfigurator) deleteConnectionSettings() error {
obj, closeConn, err := getDbusObject(networkManagerDest, n.dbusLinkObject)
if err != nil {
return fmt.Errorf("got error while attempting to retrieve the applied connection settings, err: %s", err)
}
defer closeConn()
ctx, cancel := context.WithTimeout(context.TODO(), 5*time.Second)
defer cancel()
err = obj.CallWithContext(ctx, networkManagerDbusDeviceDeleteMethod, dbusDefaultFlag).Store()
if err != nil {
return fmt.Errorf("got error while calling delete method with context, err: %s", err)
}
return nil
}
func isNetworkManagerSupported() bool {
return isNetworkManagerSupportedVersion() && isNetworkManagerSupportedMode()
}
func isNetworkManagerSupportedMode() bool {
var mode string
err := getNetworkManagerDNSProperty(networkManagerDbusDNSManagerModeProperty, &mode)
if err != nil {
log.Error(err)
return false
}
switch mode {
case "dnsmasq", "unbound", "systemd-resolved":
return true
default:
var rcManager string
err = getNetworkManagerDNSProperty(networkManagerDbusDNSManagerRcManagerProperty, &rcManager)
if err != nil {
log.Error(err)
return false
}
if rcManager == "unmanaged" {
return false
}
}
return true
}
func getNetworkManagerDNSProperty(property string, store any) error {
obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusDNSManagerObjectNode)
if err != nil {
return fmt.Errorf("got error while attempting to retrieve the network manager dns manager object, error: %s", err)
}
defer closeConn()
v, e := obj.GetProperty(property)
if e != nil {
return fmt.Errorf("got an error getting property %s: %v", property, e)
}
return v.Store(store)
}
func isNetworkManagerSupportedVersion() bool {
obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode)
if err != nil {
log.Errorf("got error while attempting to get the network manager object, err: %s", err)
return false
}
defer closeConn()
value, err := obj.GetProperty(networkManagerDbusVersionProperty)
if err != nil {
log.Errorf("unable to retrieve network manager mode, got error: %s", err)
return false
}
versionValue, err := parseVersion(value.Value().(string))
if err != nil {
return false
}
constraints, err := version.NewConstraint(supportedNetworkManagerVersionConstraint)
if err != nil {
return false
}
return constraints.Check(versionValue)
}
func parseVersion(inputVersion string) (*version.Version, error) {
reg, err := regexp.Compile(version.SemverRegexpRaw)
if err != nil {
return nil, err
}
if inputVersion == "" || !reg.MatchString(inputVersion) {
return nil, fmt.Errorf("couldn't parse the provided version: Not SemVer")
}
verObj, err := version.NewVersion(inputVersion)
if err != nil {
return nil, err
}
return verObj, nil
}

View File

@@ -0,0 +1,84 @@
package dns
import (
"fmt"
"github.com/netbirdio/netbird/iface"
log "github.com/sirupsen/logrus"
"os/exec"
"strings"
)
const resolvconfCommand = "resolvconf"
type resolvconf struct {
ifaceName string
}
func newResolvConfConfigurator(wgInterface *iface.WGIface) (hostManager, error) {
return &resolvconf{
ifaceName: wgInterface.GetName(),
}, nil
}
func (r *resolvconf) applyDNSConfig(config hostDNSConfig) error {
var err error
if !config.routeAll {
err = r.restoreHostDNS()
if err != nil {
log.Error(err)
}
return fmt.Errorf("unable to configure DNS for this peer using resolvconf manager without a nameserver group with all domains configured")
}
var searchDomains string
appendedDomains := 0
for _, dConf := range config.domains {
if dConf.matchOnly {
continue
}
if appendedDomains >= fileMaxNumberOfSearchDomains {
// lets log all skipped domains
log.Infof("already appended %d domains to search list. Skipping append of %s domain", fileMaxNumberOfSearchDomains, dConf.domain)
continue
}
if fileSearchLineBeginCharCount+len(searchDomains) > fileMaxLineCharsLimit {
// lets log all skipped domains
log.Infof("search list line is larger than %d characters. Skipping append of %s domain", fileMaxLineCharsLimit, dConf.domain)
continue
}
searchDomains += " " + dConf.domain
appendedDomains++
}
content := fmt.Sprintf(fileGeneratedResolvConfContentFormat, fileDefaultResolvConfBackupLocation, config.serverIP, searchDomains)
err = r.applyConfig(content)
if err != nil {
return err
}
log.Infof("added %d search domains. Search list: %s", appendedDomains, searchDomains)
return nil
}
func (r *resolvconf) restoreHostDNS() error {
cmd := exec.Command(resolvconfCommand, "-f", "-d", r.ifaceName)
_, err := cmd.Output()
if err != nil {
return fmt.Errorf("got an error while removing resolvconf configuration for %s interface, error: %s", r.ifaceName, err)
}
return nil
}
func (r *resolvconf) applyConfig(content string) error {
cmd := exec.Command(resolvconfCommand, "-x", "-a", r.ifaceName)
cmd.Stdin = strings.NewReader(content)
_, err := cmd.Output()
if err != nil {
return fmt.Errorf("got an error while appying resolvconf configuration for %s interface, error: %s", r.ifaceName, err)
}
return nil
}

View File

@@ -0,0 +1,348 @@
package dns
import (
"context"
"fmt"
"github.com/miekg/dns"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/iface"
log "github.com/sirupsen/logrus"
"net"
"net/netip"
"runtime"
"sync"
"time"
)
const (
defaultPort = 53
customPort = 5053
defaultIP = "127.0.0.1"
customIP = "127.0.0.153"
)
// Server is a dns server interface
type Server interface {
Start()
Stop()
UpdateDNSServer(serial uint64, update nbdns.Config) error
}
// DefaultServer dns server object
type DefaultServer struct {
ctx context.Context
stop context.CancelFunc
mux sync.Mutex
server *dns.Server
dnsMux *dns.ServeMux
dnsMuxMap registrationMap
localResolver *localResolver
wgInterface *iface.WGIface
hostManager hostManager
updateSerial uint64
listenerIsRunning bool
runtimePort int
runtimeIP string
}
type registrationMap map[string]struct{}
type muxUpdate struct {
domain string
handler dns.Handler
}
// NewDefaultServer returns a new dns server
func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface) (*DefaultServer, error) {
mux := dns.NewServeMux()
dnsServer := &dns.Server{
Net: "udp",
Handler: mux,
UDPSize: 65535,
}
ctx, stop := context.WithCancel(ctx)
defaultServer := &DefaultServer{
ctx: ctx,
stop: stop,
server: dnsServer,
dnsMux: mux,
dnsMuxMap: make(registrationMap),
localResolver: &localResolver{
registeredMap: make(registrationMap),
},
wgInterface: wgInterface,
runtimePort: defaultPort,
}
hostmanager, err := newHostManager(wgInterface)
if err != nil {
return nil, err
}
defaultServer.hostManager = hostmanager
return defaultServer, err
}
// Start runs the listener in a go routine
func (s *DefaultServer) Start() {
ip, port, err := s.getFirstListenerAvailable()
if err != nil {
log.Error(err)
return
}
s.runtimeIP = ip
s.runtimePort = port
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
log.Debugf("starting dns on %s", s.server.Addr)
go func() {
s.setListenerStatus(true)
defer s.setListenerStatus(false)
err = s.server.ListenAndServe()
if err != nil {
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.runtimePort, err)
}
}()
}
func (s *DefaultServer) getFirstListenerAvailable() (string, int, error) {
ips := []string{defaultIP, customIP}
if runtime.GOOS != "darwin" && s.wgInterface != nil {
ips = append([]string{s.wgInterface.GetAddress().IP.String()}, ips...)
}
ports := []int{defaultPort, customPort}
for _, port := range ports {
for _, ip := range ips {
addrString := fmt.Sprintf("%s:%d", ip, port)
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
probeListener, err := net.ListenUDP("udp", udpAddr)
if err == nil {
err = probeListener.Close()
if err != nil {
log.Errorf("got an error closing the probe listener, error: %s", err)
}
return ip, port, nil
}
log.Warnf("binding dns on %s is not available, error: %s", addrString, err)
}
}
return "", 0, fmt.Errorf("unable to find an unused ip and port combination. IPs tested: %v and ports %v", ips, ports)
}
func (s *DefaultServer) setListenerStatus(running bool) {
s.listenerIsRunning = running
}
// Stop stops the server
func (s *DefaultServer) Stop() {
s.mux.Lock()
defer s.mux.Unlock()
s.stop()
err := s.hostManager.restoreHostDNS()
if err != nil {
log.Error(err)
}
err = s.stopListener()
if err != nil {
log.Error(err)
}
}
func (s *DefaultServer) stopListener() error {
if !s.listenerIsRunning {
return nil
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err := s.server.ShutdownContext(ctx)
if err != nil {
return fmt.Errorf("stopping dns server listener returned an error: %v", err)
}
return nil
}
// UpdateDNSServer processes an update received from the management service
func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
select {
case <-s.ctx.Done():
log.Infof("not updating DNS server as context is closed")
return s.ctx.Err()
default:
if serial < s.updateSerial {
return fmt.Errorf("not applying dns update, error: "+
"network update is %d behind the last applied update", s.updateSerial-serial)
}
s.mux.Lock()
defer s.mux.Unlock()
// is the service should be disabled, we stop the listener
// and proceed with a regular update to clean up the handlers and records
if !update.ServiceEnable {
err := s.stopListener()
if err != nil {
log.Error(err)
}
} else if !s.listenerIsRunning {
s.Start()
}
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
if err != nil {
return fmt.Errorf("not applying dns update, error: %v", err)
}
upstreamMuxUpdates, err := s.buildUpstreamHandlerUpdate(update.NameServerGroups)
if err != nil {
return fmt.Errorf("not applying dns update, error: %v", err)
}
muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...)
s.updateMux(muxUpdates)
s.updateLocalResolver(localRecords)
err = s.hostManager.applyDNSConfig(dnsConfigToHostDNSConfig(update, s.runtimeIP, s.runtimePort))
if err != nil {
log.Error(err)
}
s.updateSerial = serial
return nil
}
}
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]muxUpdate, map[string]nbdns.SimpleRecord, error) {
var muxUpdates []muxUpdate
localRecords := make(map[string]nbdns.SimpleRecord, 0)
for _, customZone := range customZones {
if len(customZone.Records) == 0 {
return nil, nil, fmt.Errorf("received an empty list of records")
}
muxUpdates = append(muxUpdates, muxUpdate{
domain: customZone.Domain,
handler: s.localResolver,
})
for _, record := range customZone.Records {
var class uint16 = dns.ClassINET
if record.Class != nbdns.DefaultClass {
return nil, nil, fmt.Errorf("received an invalid class type: %s", record.Class)
}
key := buildRecordKey(record.Name, class, uint16(record.Type))
localRecords[key] = record
}
}
return muxUpdates, localRecords, nil
}
func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]muxUpdate, error) {
var muxUpdates []muxUpdate
for _, nsGroup := range nameServerGroups {
if len(nsGroup.NameServers) == 0 {
return nil, fmt.Errorf("received a nameserver group with empty nameserver list")
}
handler := &upstreamResolver{
parentCTX: s.ctx,
upstreamClient: &dns.Client{},
upstreamTimeout: defaultUpstreamTimeout,
}
for _, ns := range nsGroup.NameServers {
if ns.NSType != nbdns.UDPNameServerType {
log.Warnf("skiping nameserver %s with type %s, this peer supports only %s",
ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String())
continue
}
handler.upstreamServers = append(handler.upstreamServers, getNSHostPort(ns))
}
if len(handler.upstreamServers) == 0 {
log.Errorf("received a nameserver group with an invalid nameserver list")
continue
}
if nsGroup.Primary {
muxUpdates = append(muxUpdates, muxUpdate{
domain: nbdns.RootZone,
handler: handler,
})
continue
}
if len(nsGroup.Domains) == 0 {
return nil, fmt.Errorf("received a non primary nameserver group with an empty domain list")
}
for _, domain := range nsGroup.Domains {
if domain == "" {
return nil, fmt.Errorf("received a nameserver group with an empty domain element")
}
muxUpdates = append(muxUpdates, muxUpdate{
domain: domain,
handler: handler,
})
}
}
return muxUpdates, nil
}
func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
muxUpdateMap := make(registrationMap)
for _, update := range muxUpdates {
s.registerMux(update.domain, update.handler)
muxUpdateMap[update.domain] = struct{}{}
}
for key := range s.dnsMuxMap {
_, found := muxUpdateMap[key]
if !found {
s.deregisterMux(key)
}
}
s.dnsMuxMap = muxUpdateMap
}
func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord) {
for key := range s.localResolver.registeredMap {
_, found := update[key]
if !found {
s.localResolver.deleteRecord(key)
}
}
updatedMap := make(registrationMap)
for key, record := range update {
err := s.localResolver.registerRecord(record)
if err != nil {
log.Warnf("got an error while registering the record (%s), error: %v", record.String(), err)
}
updatedMap[key] = struct{}{}
}
s.localResolver.registeredMap = updatedMap
}
func getNSHostPort(ns nbdns.NameServer) string {
return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port)
}
func (s *DefaultServer) registerMux(pattern string, handler dns.Handler) {
s.dnsMux.Handle(pattern, handler)
}
func (s *DefaultServer) deregisterMux(pattern string) {
s.dnsMux.HandleRemove(pattern)
}

View File

@@ -0,0 +1,320 @@
package dns
import (
"context"
"fmt"
"github.com/miekg/dns"
nbdns "github.com/netbirdio/netbird/dns"
"net"
"net/netip"
"os"
"runtime"
"testing"
"time"
)
var zoneRecords = []nbdns.SimpleRecord{
{
Name: "peera.netbird.cloud",
Type: 1,
Class: nbdns.DefaultClass,
TTL: 300,
RData: "1.2.3.4",
},
}
func TestUpdateDNSServer(t *testing.T) {
nameServers := []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
{
IP: netip.MustParseAddr("8.8.4.4"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
}
testCases := []struct {
name string
initUpstreamMap registrationMap
initLocalMap registrationMap
initSerial uint64
inputSerial uint64
inputUpdate nbdns.Config
shouldFail bool
expectedUpstreamMap registrationMap
expectedLocalMap registrationMap
}{
{
name: "Initial Config Should Succeed",
initLocalMap: make(registrationMap),
initUpstreamMap: make(registrationMap),
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
Domains: []string{"netbird.io"},
NameServers: nameServers,
},
{
NameServers: nameServers,
Primary: true,
},
},
},
expectedUpstreamMap: registrationMap{"netbird.io": struct{}{}, "netbird.cloud": struct{}{}, nbdns.RootZone: struct{}{}},
expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}},
},
{
name: "New Config Should Succeed",
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
initUpstreamMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}},
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
Domains: []string{"netbird.io"},
NameServers: nameServers,
},
},
},
expectedUpstreamMap: registrationMap{"netbird.io": struct{}{}, "netbird.cloud": struct{}{}},
expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}},
},
{
name: "Smaller Config Serial Should Be Skipped",
initLocalMap: make(registrationMap),
initUpstreamMap: make(registrationMap),
initSerial: 2,
inputSerial: 1,
shouldFail: true,
},
{
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
initLocalMap: make(registrationMap),
initUpstreamMap: make(registrationMap),
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
NameServers: nameServers,
},
},
},
shouldFail: true,
},
{
name: "Invalid NS Group Nameservers list Should Fail",
initLocalMap: make(registrationMap),
initUpstreamMap: make(registrationMap),
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
NameServers: nameServers,
},
},
},
shouldFail: true,
},
{
name: "Invalid Custom Zone Records list Should Fail",
initLocalMap: make(registrationMap),
initUpstreamMap: make(registrationMap),
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
NameServers: nameServers,
Primary: true,
},
},
},
shouldFail: true,
},
{
name: "Empty Config Should Succeed and Clean Maps",
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
initUpstreamMap: registrationMap{zoneRecords[0].Name: struct{}{}},
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{ServiceEnable: true},
expectedUpstreamMap: make(registrationMap),
expectedLocalMap: make(registrationMap),
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
dnsServer := getDefaultServerWithNoHostManager("127.0.0.1")
dnsServer.hostManager = newNoopHostMocker()
dnsServer.dnsMuxMap = testCase.initUpstreamMap
dnsServer.localResolver.registeredMap = testCase.initLocalMap
dnsServer.updateSerial = testCase.initSerial
// pretend we are running
dnsServer.listenerIsRunning = true
err := dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
if err != nil {
if testCase.shouldFail {
return
}
t.Fatalf("update dns server should not fail, got error: %v", err)
}
if len(dnsServer.dnsMuxMap) != len(testCase.expectedUpstreamMap) {
t.Fatalf("update upstream failed, map size is different than expected, want %d, got %d", len(testCase.expectedUpstreamMap), len(dnsServer.dnsMuxMap))
}
for key := range testCase.expectedUpstreamMap {
_, found := dnsServer.dnsMuxMap[key]
if !found {
t.Fatalf("update upstream failed, key %s was not found in the dnsMuxMap: %#v", key, dnsServer.dnsMuxMap)
}
}
if len(dnsServer.localResolver.registeredMap) != len(testCase.expectedLocalMap) {
t.Fatalf("update local failed, registered map size is different than expected, want %d, got %d", len(testCase.expectedLocalMap), len(dnsServer.localResolver.registeredMap))
}
for key := range testCase.expectedLocalMap {
_, found := dnsServer.localResolver.registeredMap[key]
if !found {
t.Fatalf("update local failed, key %s was not found in the localResolver.registeredMap: %#v", key, dnsServer.localResolver.registeredMap)
}
}
})
}
}
func TestDNSServerStartStop(t *testing.T) {
dnsServer := getDefaultServerWithNoHostManager("127.0.0.1")
if runtime.GOOS == "windows" && os.Getenv("CI") == "true" {
// todo review why this test is not working only on github actions workflows
t.Skip("skipping test in Windows CI workflows.")
}
dnsServer.hostManager = newNoopHostMocker()
dnsServer.Start()
err := dnsServer.localResolver.registerRecord(zoneRecords[0])
if err != nil {
t.Error(err)
}
dnsServer.dnsMux.Handle("netbird.cloud", dnsServer.localResolver)
resolver := &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
d := net.Dialer{
Timeout: time.Second * 5,
}
addr := fmt.Sprintf("%s:%d", dnsServer.runtimeIP, dnsServer.runtimePort)
conn, err := d.DialContext(ctx, network, addr)
if err != nil {
t.Log(err)
// retry test before exit, for slower systems
return d.DialContext(ctx, network, addr)
}
return conn, nil
},
}
ips, err := resolver.LookupHost(context.Background(), zoneRecords[0].Name)
if err != nil {
t.Fatalf("failed to connect to the server, error: %v", err)
}
t.Log(ips)
if ips[0] != zoneRecords[0].RData {
t.Fatalf("got a different IP from the server: want %s, got %s", zoneRecords[0].RData, ips[0])
}
dnsServer.Stop()
ctx, cancel := context.WithTimeout(context.TODO(), time.Second*1)
defer cancel()
_, err = resolver.LookupHost(ctx, zoneRecords[0].Name)
if err == nil {
t.Fatalf("we should encounter an error when querying a stopped server")
}
}
func getDefaultServerWithNoHostManager(ip string) *DefaultServer {
mux := dns.NewServeMux()
listenIP := defaultIP
if ip != "" {
listenIP = ip
}
dnsServer := &dns.Server{
Addr: fmt.Sprintf("%s:%d", ip, defaultPort),
Net: "udp",
Handler: mux,
UDPSize: 65535,
}
ctx, stop := context.WithCancel(context.TODO())
return &DefaultServer{
ctx: ctx,
stop: stop,
server: dnsServer,
dnsMux: mux,
dnsMuxMap: make(registrationMap),
localResolver: &localResolver{
registeredMap: make(registrationMap),
},
runtimePort: defaultPort,
runtimeIP: listenIP,
}
}

View File

@@ -0,0 +1,202 @@
package dns
import (
"context"
"fmt"
"github.com/godbus/dbus/v5"
"github.com/miekg/dns"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/iface"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
"net"
"net/netip"
"time"
)
const (
systemdDbusManagerInterface = "org.freedesktop.resolve1.Manager"
systemdResolvedDest = "org.freedesktop.resolve1"
systemdDbusObjectNode = "/org/freedesktop/resolve1"
systemdDbusGetLinkMethod = systemdDbusManagerInterface + ".GetLink"
systemdDbusFlushCachesMethod = systemdDbusManagerInterface + ".FlushCaches"
systemdDbusResolvConfModeProperty = systemdDbusManagerInterface + ".ResolvConfMode"
systemdDbusLinkInterface = "org.freedesktop.resolve1.Link"
systemdDbusRevertMethodSuffix = systemdDbusLinkInterface + ".Revert"
systemdDbusSetDNSMethodSuffix = systemdDbusLinkInterface + ".SetDNS"
systemdDbusSetDefaultRouteMethodSuffix = systemdDbusLinkInterface + ".SetDefaultRoute"
systemdDbusSetDomainsMethodSuffix = systemdDbusLinkInterface + ".SetDomains"
systemdDbusResolvConfModeForeign = "foreign"
)
type systemdDbusConfigurator struct {
dbusLinkObject dbus.ObjectPath
routingAll bool
}
// the types below are based on dbus specification, each field is mapped to a dbus type
// see https://dbus.freedesktop.org/doc/dbus-specification.html#basic-types for more details on dbus types
// see https://www.freedesktop.org/software/systemd/man/org.freedesktop.resolve1.html on resolve1 input types
// systemdDbusDNSInput maps to a (iay) dbus input for SetDNS method
type systemdDbusDNSInput struct {
Family int32
Address []byte
}
// systemdDbusLinkDomainsInput maps to a (sb) dbus input for SetDomains method
type systemdDbusLinkDomainsInput struct {
Domain string
MatchOnly bool
}
func newSystemdDbusConfigurator(wgInterface *iface.WGIface) (hostManager, error) {
iface, err := net.InterfaceByName(wgInterface.GetName())
if err != nil {
return nil, err
}
obj, closeConn, err := getDbusObject(systemdResolvedDest, systemdDbusObjectNode)
if err != nil {
return nil, err
}
defer closeConn()
var s string
err = obj.Call(systemdDbusGetLinkMethod, dbusDefaultFlag, iface.Index).Store(&s)
if err != nil {
return nil, err
}
log.Debugf("got dbus Link interface: %s from net interface %s and index %d", s, iface.Name, iface.Index)
return &systemdDbusConfigurator{
dbusLinkObject: dbus.ObjectPath(s),
}, nil
}
func (s *systemdDbusConfigurator) applyDNSConfig(config hostDNSConfig) error {
parsedIP := netip.MustParseAddr(config.serverIP).As4()
defaultLinkInput := systemdDbusDNSInput{
Family: unix.AF_INET,
Address: parsedIP[:],
}
err := s.callLinkMethod(systemdDbusSetDNSMethodSuffix, []systemdDbusDNSInput{defaultLinkInput})
if err != nil {
return fmt.Errorf("setting the interface DNS server %s:%d failed with error: %s", config.serverIP, config.serverPort, err)
}
var (
searchDomains []string
matchDomains []string
domainsInput []systemdDbusLinkDomainsInput
)
for _, dConf := range config.domains {
domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{
Domain: dns.Fqdn(dConf.domain),
MatchOnly: dConf.matchOnly,
})
if dConf.matchOnly {
matchDomains = append(matchDomains, dConf.domain)
continue
}
searchDomains = append(searchDomains, dConf.domain)
}
if config.routeAll {
log.Infof("configured %s:%d as main DNS forwarder for this peer", config.serverIP, config.serverPort)
err = s.callLinkMethod(systemdDbusSetDefaultRouteMethodSuffix, true)
if err != nil {
return fmt.Errorf("setting link as default dns router, failed with error: %s", err)
}
domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{
Domain: nbdns.RootZone,
MatchOnly: true,
})
s.routingAll = true
} else if s.routingAll {
log.Infof("removing %s:%d as main DNS forwarder for this peer", config.serverIP, config.serverPort)
}
log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains)
err = s.setDomainsForInterface(domainsInput)
if err != nil {
log.Error(err)
}
return nil
}
func (s *systemdDbusConfigurator) setDomainsForInterface(domainsInput []systemdDbusLinkDomainsInput) error {
err := s.callLinkMethod(systemdDbusSetDomainsMethodSuffix, domainsInput)
if err != nil {
return fmt.Errorf("setting domains configuration failed with error: %s", err)
}
return s.flushCaches()
}
func (s *systemdDbusConfigurator) restoreHostDNS() error {
log.Infof("reverting link settings and flushing cache")
if !isDbusListenerRunning(systemdResolvedDest, s.dbusLinkObject) {
return nil
}
err := s.callLinkMethod(systemdDbusRevertMethodSuffix, nil)
if err != nil {
return fmt.Errorf("unable to revert link configuration, got error: %s", err)
}
return s.flushCaches()
}
func (s *systemdDbusConfigurator) flushCaches() error {
obj, closeConn, err := getDbusObject(systemdResolvedDest, systemdDbusObjectNode)
if err != nil {
return fmt.Errorf("got error while attempting to retrieve the object %s, err: %s", systemdDbusObjectNode, err)
}
defer closeConn()
ctx, cancel := context.WithTimeout(context.TODO(), 5*time.Second)
defer cancel()
err = obj.CallWithContext(ctx, systemdDbusFlushCachesMethod, dbusDefaultFlag).Store()
if err != nil {
return fmt.Errorf("got error while calling the FlushCaches method with context, err: %s", err)
}
return nil
}
func (s *systemdDbusConfigurator) callLinkMethod(method string, value any) error {
obj, closeConn, err := getDbusObject(systemdResolvedDest, s.dbusLinkObject)
if err != nil {
return fmt.Errorf("got error while attempting to retrieve the object, err: %s", err)
}
defer closeConn()
ctx, cancel := context.WithTimeout(context.TODO(), 5*time.Second)
defer cancel()
if value != nil {
err = obj.CallWithContext(ctx, method, dbusDefaultFlag, value).Store()
} else {
err = obj.CallWithContext(ctx, method, dbusDefaultFlag).Store()
}
if err != nil {
return fmt.Errorf("got error while calling command with context, err: %s", err)
}
return nil
}
func getSystemdDbusProperty(property string, store any) error {
obj, closeConn, err := getDbusObject(systemdResolvedDest, systemdDbusObjectNode)
if err != nil {
return fmt.Errorf("got error while attempting to retrieve the systemd dns manager object, error: %s", err)
}
defer closeConn()
v, e := obj.GetProperty(property)
if e != nil {
return fmt.Errorf("got an error getting property %s: %v", property, e)
}
return v.Store(store)
}

View File

@@ -0,0 +1,67 @@
package dns
import (
"context"
"errors"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"net"
"time"
)
const defaultUpstreamTimeout = 15 * time.Second
type upstreamResolver struct {
parentCTX context.Context
upstreamClient *dns.Client
upstreamServers []string
upstreamTimeout time.Duration
}
// ServeDNS handles a DNS request
func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
log.Tracef("received an upstream question: %#v", r.Question[0])
select {
case <-u.parentCTX.Done():
return
default:
}
for _, upstream := range u.upstreamServers {
ctx, cancel := context.WithTimeout(u.parentCTX, u.upstreamTimeout)
rm, t, err := u.upstreamClient.ExchangeContext(ctx, r, upstream)
cancel()
if err != nil {
if err == context.DeadlineExceeded || isTimeout(err) {
log.Warnf("got an error while connecting to upstream %s, error: %v", upstream, err)
continue
}
log.Errorf("got an error while querying the upstream %s, error: %v", upstream, err)
return
}
log.Tracef("took %s to query the upstream %s", t, upstream)
err = w.WriteMsg(rm)
if err != nil {
log.Errorf("got an error while writing the upstream resolver response, error: %v", err)
}
return
}
log.Errorf("all queries to the upstream nameservers failed with timeout")
}
// isTimeout returns true if the given error is a network timeout error.
//
// Copied from k8s.io/apimachinery/pkg/util/net.IsTimeout
func isTimeout(err error) bool {
var neterr net.Error
if errors.As(err, &neterr) {
return neterr != nil && neterr.Timeout()
}
return false
}

View File

@@ -0,0 +1,110 @@
package dns
import (
"context"
"github.com/miekg/dns"
"strings"
"testing"
"time"
)
func TestUpstreamResolver_ServeDNS(t *testing.T) {
testCases := []struct {
name string
inputMSG *dns.Msg
responseShouldBeNil bool
InputServers []string
timeout time.Duration
cancelCTX bool
expectedAnswer string
}{
{
name: "Should Resolve A Record",
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
InputServers: []string{"8.8.8.8:53", "8.8.4.4:53"},
timeout: defaultUpstreamTimeout,
expectedAnswer: "1.1.1.1",
},
{
name: "Should Resolve If First Upstream Times Out",
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
InputServers: []string{"8.0.0.0:53", "8.8.4.4:53"},
timeout: 2 * time.Second,
expectedAnswer: "1.1.1.1",
},
{
name: "Should Not Resolve If Can't Connect To Both Servers",
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
InputServers: []string{"8.0.0.0:53", "8.0.0.1:53"},
timeout: 200 * time.Millisecond,
responseShouldBeNil: true,
},
{
name: "Should Not Resolve If Parent Context Is Canceled",
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
InputServers: []string{"8.0.0.0:53", "8.8.4.4:53"},
cancelCTX: true,
timeout: defaultUpstreamTimeout,
responseShouldBeNil: true,
},
//{
// name: "Should Resolve CNAME Record",
// inputMSG: new(dns.Msg).SetQuestion("one.one.one.one", dns.TypeCNAME),
//},
//{
// name: "Should Not Write When Not Found A Record",
// inputMSG: new(dns.Msg).SetQuestion("not.found.com", dns.TypeA),
// responseShouldBeNil: true,
//},
}
// should resolve if first upstream times out
// should not write when both fails
// should not resolve if parent context is canceled
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO())
resolver := &upstreamResolver{
parentCTX: ctx,
upstreamClient: &dns.Client{},
upstreamServers: testCase.InputServers,
upstreamTimeout: testCase.timeout,
}
if testCase.cancelCTX {
cancel()
} else {
defer cancel()
}
var responseMSG *dns.Msg
responseWriter := &mockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
resolver.ServeDNS(responseWriter, testCase.inputMSG)
if responseMSG == nil {
if testCase.responseShouldBeNil {
return
}
t.Fatalf("should write a response message")
}
foundAnswer := false
for _, answer := range responseMSG.Answer {
if strings.Contains(answer.String(), testCase.expectedAnswer) {
foundAnswer = true
break
}
}
if !foundAnswer {
t.Errorf("couldn't find the required answer, %s, in the dns response", testCase.expectedAnswer)
}
})
}
}

View File

@@ -3,18 +3,22 @@ package internal
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/netbirdio/netbird/client/internal/routemanager"
nbssh "github.com/netbirdio/netbird/client/ssh"
nbstatus "github.com/netbirdio/netbird/client/status"
"github.com/netbirdio/netbird/route"
"math/rand" "math/rand"
"net" "net"
"net/netip"
"reflect" "reflect"
"runtime" "runtime"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/routemanager"
nbssh "github.com/netbirdio/netbird/client/ssh"
nbstatus "github.com/netbirdio/netbird/client/status"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/proxy" "github.com/netbirdio/netbird/client/internal/proxy"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
@@ -51,7 +55,8 @@ type EngineConfig struct {
WgPrivateKey wgtypes.Key WgPrivateKey wgtypes.Key
// IFaceBlackList is a list of network interfaces to ignore when discovering connection candidates (ICE related) // IFaceBlackList is a list of network interfaces to ignore when discovering connection candidates (ICE related)
IFaceBlackList []string IFaceBlackList []string
DisableIPv6Discovery bool
PreSharedKey *wgtypes.Key PreSharedKey *wgtypes.Key
@@ -63,6 +68,8 @@ type EngineConfig struct {
// SSHKey is a private SSH key in a PEM format // SSHKey is a private SSH key in a PEM format
SSHKey []byte SSHKey []byte
NATExternalIPs []string
} }
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers. // Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
@@ -103,6 +110,8 @@ type Engine struct {
statusRecorder *nbstatus.Status statusRecorder *nbstatus.Status
routeManager routemanager.Manager routeManager routemanager.Manager
dnsServer dns.Server
} }
// Peer is an instance of the Connection Peer // Peer is an instance of the Connection Peer
@@ -190,6 +199,10 @@ func (e *Engine) Stop() error {
e.routeManager.Stop() e.routeManager.Stop()
} }
if e.dnsServer != nil {
e.dnsServer.Stop()
}
log.Infof("stopped Netbird Engine") log.Infof("stopped Netbird Engine")
return nil return nil
@@ -213,13 +226,18 @@ func (e *Engine) Start() error {
return err return err
} }
e.udpMuxConn, err = net.ListenUDP("udp4", &net.UDPAddr{Port: e.config.UDPMuxPort}) networkName := "udp"
if e.config.DisableIPv6Discovery {
networkName = "udp4"
}
e.udpMuxConn, err = net.ListenUDP(networkName, &net.UDPAddr{Port: e.config.UDPMuxPort})
if err != nil { if err != nil {
log.Errorf("failed listening on UDP port %d: [%s]", e.config.UDPMuxPort, err.Error()) log.Errorf("failed listening on UDP port %d: [%s]", e.config.UDPMuxPort, err.Error())
return err return err
} }
e.udpMuxConnSrflx, err = net.ListenUDP("udp4", &net.UDPAddr{Port: e.config.UDPMuxSrflxPort}) e.udpMuxConnSrflx, err = net.ListenUDP(networkName, &net.UDPAddr{Port: e.config.UDPMuxSrflxPort})
if err != nil { if err != nil {
log.Errorf("failed listening on UDP port %d: [%s]", e.config.UDPMuxSrflxPort, err.Error()) log.Errorf("failed listening on UDP port %d: [%s]", e.config.UDPMuxSrflxPort, err.Error())
return err return err
@@ -242,6 +260,14 @@ func (e *Engine) Start() error {
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder) e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder)
if e.dnsServer == nil {
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface)
if err != nil {
return err
}
e.dnsServer = dnsServer
}
e.receiveSignalEvents() e.receiveSignalEvents()
e.receiveManagementEvents() e.receiveManagementEvents()
@@ -255,9 +281,15 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
// first, check if peers have been modified // first, check if peers have been modified
var modified []*mgmProto.RemotePeerConfig var modified []*mgmProto.RemotePeerConfig
for _, p := range peersUpdate { for _, p := range peersUpdate {
if peerConn, ok := e.peerConns[p.GetWgPubKey()]; ok { peerPubKey := p.GetWgPubKey()
if peerConn, ok := e.peerConns[peerPubKey]; ok {
if peerConn.GetConf().ProxyConfig.AllowedIps != strings.Join(p.AllowedIps, ",") { if peerConn.GetConf().ProxyConfig.AllowedIps != strings.Join(p.AllowedIps, ",") {
modified = append(modified, p) modified = append(modified, p)
continue
}
err := e.statusRecorder.UpdatePeerFQDN(peerPubKey, p.GetFqdn())
if err != nil {
log.Warnf("error updating peer's %s fqdn in the status recorder, got error: %v", peerPubKey, err)
} }
} }
} }
@@ -517,6 +549,13 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
} }
} }
e.statusRecorder.UpdateLocalPeerState(nbstatus.LocalPeerState{
IP: e.config.WgAddr,
PubKey: e.config.WgPrivateKey.PublicKey().String(),
KernelInterface: iface.WireguardModuleIsLoaded(),
FQDN: conf.GetFqdn(),
})
return nil return nil
} }
@@ -638,6 +677,15 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
log.Errorf("failed to update routes, err: %v", err) log.Errorf("failed to update routes, err: %v", err)
} }
protoDNSConfig := networkMap.GetDNSConfig()
if protoDNSConfig == nil {
protoDNSConfig = &mgmProto.DNSConfig{}
}
err = e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig))
if err != nil {
log.Errorf("failed to update dns server, err: %v", err)
}
e.networkSerial = serial e.networkSerial = serial
return nil return nil
} }
@@ -660,6 +708,48 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
return routes return routes
} }
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig) nbdns.Config {
dnsUpdate := nbdns.Config{
ServiceEnable: protoDNSConfig.GetServiceEnable(),
CustomZones: make([]nbdns.CustomZone, 0),
NameServerGroups: make([]*nbdns.NameServerGroup, 0),
}
for _, zone := range protoDNSConfig.GetCustomZones() {
dnsZone := nbdns.CustomZone{
Domain: zone.GetDomain(),
}
for _, record := range zone.Records {
dnsRecord := nbdns.SimpleRecord{
Name: record.GetName(),
Type: int(record.GetType()),
Class: record.GetClass(),
TTL: int(record.GetTTL()),
RData: record.GetRData(),
}
dnsZone.Records = append(dnsZone.Records, dnsRecord)
}
dnsUpdate.CustomZones = append(dnsUpdate.CustomZones, dnsZone)
}
for _, nsGroup := range protoDNSConfig.GetNameServerGroups() {
dnsNSGroup := &nbdns.NameServerGroup{
Primary: nsGroup.GetPrimary(),
Domains: nsGroup.GetDomains(),
}
for _, ns := range nsGroup.GetNameServers() {
dnsNS := nbdns.NameServer{
IP: netip.MustParseAddr(ns.GetIP()),
NSType: nbdns.NameServerType(ns.GetNSType()),
Port: int(ns.GetPort()),
}
dnsNSGroup.NameServers = append(dnsNSGroup.NameServers, dnsNS)
}
dnsUpdate.NameServerGroups = append(dnsUpdate.NameServerGroups, dnsNSGroup)
}
return dnsUpdate
}
// addNewPeers adds peers that were not know before but arrived from the Management service with the update // addNewPeers adds peers that were not know before but arrived from the Management service with the update
func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig) error { func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
for _, p := range peersUpdate { for _, p := range peersUpdate {
@@ -689,6 +779,10 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
go e.connWorker(conn, peerKey) go e.connWorker(conn, peerKey)
} }
err := e.statusRecorder.UpdatePeerFQDN(peerKey, peerConfig.Fqdn)
if err != nil {
log.Warnf("error updating peer's %s fqdn in the status recorder, got error: %v", peerKey, err)
}
return nil return nil
} }
@@ -755,15 +849,17 @@ func (e Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, er
// randomize connection timeout // randomize connection timeout
timeout := time.Duration(rand.Intn(PeerConnectionTimeoutMax-PeerConnectionTimeoutMin)+PeerConnectionTimeoutMin) * time.Millisecond timeout := time.Duration(rand.Intn(PeerConnectionTimeoutMax-PeerConnectionTimeoutMin)+PeerConnectionTimeoutMin) * time.Millisecond
config := peer.ConnConfig{ config := peer.ConnConfig{
Key: pubKey, Key: pubKey,
LocalKey: e.config.WgPrivateKey.PublicKey().String(), LocalKey: e.config.WgPrivateKey.PublicKey().String(),
StunTurn: stunTurn, StunTurn: stunTurn,
InterfaceBlackList: e.config.IFaceBlackList, InterfaceBlackList: e.config.IFaceBlackList,
Timeout: timeout, DisableIPv6Discovery: e.config.DisableIPv6Discovery,
UDPMux: e.udpMux, Timeout: timeout,
UDPMuxSrflx: e.udpMuxSrflx, UDPMux: e.udpMux,
ProxyConfig: proxyConfig, UDPMuxSrflx: e.udpMuxSrflx,
LocalWgPort: e.config.WgPort, ProxyConfig: proxyConfig,
LocalWgPort: e.config.WgPort,
NATExternalIPs: e.parseNATExternalIPMappings(),
} }
peerConn, err := peer.NewConn(config, e.statusRecorder) peerConn, err := peer.NewConn(config, e.statusRecorder)
@@ -857,3 +953,77 @@ func (e *Engine) receiveSignalEvents() {
e.signal.WaitStreamConnected() e.signal.WaitStreamConnected()
} }
func (e *Engine) parseNATExternalIPMappings() []string {
var mappedIPs []string
var ignoredIFaces = make(map[string]interface{})
for _, iFace := range e.config.IFaceBlackList {
ignoredIFaces[iFace] = nil
}
for _, mapping := range e.config.NATExternalIPs {
var external, internal string
var externalIP, internalIP net.IP
var err error
split := strings.Split(mapping, "/")
if len(split) > 2 {
log.Warnf("ignoring invalid external mapping '%s', too many delimiters", mapping)
break
}
if len(split) > 1 {
internal = split[1]
internalIP = net.ParseIP(internal)
if internalIP == nil {
// not a properly formatted IP address, maybe it's interface name?
if _, present := ignoredIFaces[internal]; present {
log.Warnf("internal interface '%s' in blacklist, ignoring external mapping '%s'", internal, mapping)
break
}
internalIP, err = findIPFromInterfaceName(internal)
if err != nil {
log.Warnf("error finding interface IP for interface '%s', ignoring external mapping '%s': %v", internal, mapping, err)
break
}
}
}
external = split[0]
externalIP = net.ParseIP(external)
if externalIP == nil {
log.Warnf("invalid external IP, ignoring external IP mapping '%s'", mapping)
break
}
if externalIP != nil {
mappedIP := externalIP.String()
if internalIP != nil {
mappedIP = mappedIP + "/" + internalIP.String()
}
mappedIPs = append(mappedIPs, mappedIP)
log.Infof("parsed external IP mapping of '%s' as '%s'", mapping, mappedIP)
}
}
if len(mappedIPs) != len(e.config.NATExternalIPs) {
log.Warnf("one or more external IP mappings failed to parse, ignoring all mappings")
return nil
}
return mappedIPs
}
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
iface, err := net.InterfaceByName(ifaceName)
if err != nil {
return nil, err
}
return findIPFromInterface(iface)
}
func findIPFromInterface(iface *net.Interface) (net.IP, error) {
ifaceAddrs, err := iface.Addrs()
if err != nil {
return nil, err
}
for _, addr := range ifaceAddrs {
if ipv4Addr := addr.(*net.IPNet).IP.To4(); ipv4Addr != nil {
return ipv4Addr, nil
}
}
return nil, fmt.Errorf("interface %s don't have an ipv4 address", iface.Name)
}

View File

@@ -3,9 +3,11 @@ package internal
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/ssh"
nbstatus "github.com/netbirdio/netbird/client/status" nbstatus "github.com/netbirdio/netbird/client/status"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -69,6 +71,10 @@ func TestEngine_SSH(t *testing.T) {
WgPort: 33100, WgPort: 33100,
}, nbstatus.NewRecorder()) }, nbstatus.NewRecorder())
engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
}
var sshKeysAdded []string var sshKeysAdded []string
var sshPeersRemoved []string var sshPeersRemoved []string
@@ -200,6 +206,9 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
}, nbstatus.NewRecorder()) }, nbstatus.NewRecorder())
engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", iface.DefaultMTU) engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", iface.DefaultMTU)
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), engine.wgInterface, engine.statusRecorder) engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), engine.wgInterface, engine.statusRecorder)
engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
}
type testCase struct { type testCase struct {
name string name string
@@ -380,6 +389,10 @@ func TestEngine_Sync(t *testing.T) {
WgPort: 33100, WgPort: 33100,
}, nbstatus.NewRecorder()) }, nbstatus.NewRecorder())
engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
}
defer func() { defer func() {
err := engine.Stop() err := engine.Stop()
if err != nil { if err != nil {
@@ -440,7 +453,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
expectedSerial uint64 expectedSerial uint64
}{ }{
{ {
name: "Routes Update Should Be Passed To Manager", name: "Routes Config Should Be Passed To Manager",
networkMap: &mgmtProto.NetworkMap{ networkMap: &mgmtProto.NetworkMap{
Serial: 1, Serial: 1,
PeerConfig: nil, PeerConfig: nil,
@@ -486,7 +499,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
expectedSerial: 1, expectedSerial: 1,
}, },
{ {
name: "Empty Routes Update Should Be Passed", name: "Empty Routes Config Should Be Passed",
networkMap: &mgmtProto.NetworkMap{ networkMap: &mgmtProto.NetworkMap{
Serial: 1, Serial: 1,
PeerConfig: nil, PeerConfig: nil,
@@ -549,6 +562,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
} }
engine.routeManager = mockRouteManager engine.routeManager = mockRouteManager
engine.dnsServer = &dns.MockServer{}
defer func() { defer func() {
exitErr := engine.Stop() exitErr := engine.Stop()
@@ -566,6 +580,183 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
} }
} }
func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
testCases := []struct {
name string
inputErr error
networkMap *mgmtProto.NetworkMap
expectedZonesLen int
expectedZones []nbdns.CustomZone
expectedNSGroupsLen int
expectedNSGroups []*nbdns.NameServerGroup
expectedSerial uint64
}{
{
name: "DNS Config Should Be Passed To DNS Server",
networkMap: &mgmtProto.NetworkMap{
Serial: 1,
PeerConfig: nil,
RemotePeersIsEmpty: false,
Routes: nil,
DNSConfig: &mgmtProto.DNSConfig{
ServiceEnable: true,
CustomZones: []*mgmtProto.CustomZone{
{
Domain: "netbird.cloud.",
Records: []*mgmtProto.SimpleRecord{
{
Name: "peer-a.netbird.cloud.",
Type: 1,
Class: nbdns.DefaultClass,
TTL: 300,
RData: "100.64.0.1",
},
},
},
},
NameServerGroups: []*mgmtProto.NameServerGroup{
{
Primary: true,
NameServers: []*mgmtProto.NameServer{
{
IP: "8.8.8.8",
NSType: 1,
Port: 53,
},
},
},
},
},
},
expectedZonesLen: 1,
expectedZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud.",
Records: []nbdns.SimpleRecord{
{
Name: "peer-a.netbird.cloud.",
Type: 1,
Class: nbdns.DefaultClass,
TTL: 300,
RData: "100.64.0.1",
},
},
},
},
expectedNSGroupsLen: 1,
expectedNSGroups: []*nbdns.NameServerGroup{
{
Primary: true,
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
NSType: 1,
Port: 53,
},
},
},
},
expectedSerial: 1,
},
{
name: "Empty DNS Config Should Be OK",
networkMap: &mgmtProto.NetworkMap{
Serial: 1,
PeerConfig: nil,
RemotePeersIsEmpty: false,
Routes: nil,
DNSConfig: nil,
},
expectedZonesLen: 0,
expectedZones: []nbdns.CustomZone{},
expectedNSGroupsLen: 0,
expectedNSGroups: []*nbdns.NameServerGroup{},
expectedSerial: 1,
},
{
name: "Error Shouldn't Break Engine",
inputErr: fmt.Errorf("mocking error"),
networkMap: &mgmtProto.NetworkMap{
Serial: 1,
PeerConfig: nil,
RemotePeersIsEmpty: false,
Routes: nil,
},
expectedZonesLen: 0,
expectedZones: []nbdns.CustomZone{},
expectedNSGroupsLen: 0,
expectedNSGroups: []*nbdns.NameServerGroup{},
expectedSerial: 1,
},
}
for n, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
// test setup
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Fatal(err)
return
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
wgIfaceName := fmt.Sprintf("utun%d", 104+n)
wgAddr := fmt.Sprintf("100.66.%d.1/24", n)
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, &EngineConfig{
WgIfaceName: wgIfaceName,
WgAddr: wgAddr,
WgPrivateKey: key,
WgPort: 33100,
}, nbstatus.NewRecorder())
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU)
assert.NoError(t, err, "shouldn't return error")
mockRouteManager := &routemanager.MockManager{
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) error {
return nil
},
}
engine.routeManager = mockRouteManager
input := struct {
inputSerial uint64
inputNSGroups []*nbdns.NameServerGroup
inputZones []nbdns.CustomZone
}{}
mockDNSServer := &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error {
input.inputSerial = serial
input.inputZones = update.CustomZones
input.inputNSGroups = update.NameServerGroups
return testCase.inputErr
},
}
engine.dnsServer = mockDNSServer
defer func() {
exitErr := engine.Stop()
if exitErr != nil {
return
}
}()
err = engine.updateNetworkMap(testCase.networkMap)
assert.NoError(t, err, "shouldn't return error")
assert.Equal(t, testCase.expectedSerial, input.inputSerial, "serial should match")
assert.Len(t, input.inputNSGroups, testCase.expectedZonesLen, "zones len should match")
assert.Equal(t, testCase.expectedZones, input.inputZones, "custom zones should match")
assert.Len(t, input.inputNSGroups, testCase.expectedNSGroupsLen, "ns groups len should match")
assert.Equal(t, testCase.expectedNSGroups, input.inputNSGroups, "ns groups should match")
})
}
}
func TestEngine_MultiplePeers(t *testing.T) { func TestEngine_MultiplePeers(t *testing.T) {
// log.SetLevel(log.DebugLevel) // log.SetLevel(log.DebugLevel)
@@ -618,6 +809,7 @@ func TestEngine_MultiplePeers(t *testing.T) {
t.Errorf("unable to create the engine for peer %d with error %v", j, err) t.Errorf("unable to create the engine for peer %d with error %v", j, err)
return return
} }
engine.dnsServer = &dns.MockServer{}
mu.Lock() mu.Lock()
defer mu.Unlock() defer mu.Unlock()
err = engine.Start() err = engine.Start()
@@ -756,17 +948,17 @@ func startManagement(port int, dataDir string) (*grpc.Server, error) {
return nil, err return nil, err
} }
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
store, err := server.NewStore(config.Datadir) store, err := server.NewFileStore(config.Datadir)
if err != nil { if err != nil {
log.Fatalf("failed creating a store: %s: %v", config.Datadir, err) log.Fatalf("failed creating a store: %s: %v", config.Datadir, err)
} }
peersUpdateManager := server.NewPeersUpdateManager() peersUpdateManager := server.NewPeersUpdateManager()
accountManager, err := server.BuildManager(store, peersUpdateManager, nil) accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "")
if err != nil { if err != nil {
return nil, err return nil, err
} }
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig) turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager) mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -2,18 +2,18 @@ package peer
import ( import (
"context" "context"
nbStatus "github.com/netbirdio/netbird/client/status"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/iface"
"golang.zx2c4.com/wireguard/wgctrl"
"net" "net"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/netbirdio/netbird/client/internal/proxy" "github.com/netbirdio/netbird/client/internal/proxy"
nbStatus "github.com/netbirdio/netbird/client/status"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/iface"
"github.com/pion/ice/v2" "github.com/pion/ice/v2"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl"
) )
// ConnConfig is a peer Connection configuration // ConnConfig is a peer Connection configuration
@@ -29,7 +29,8 @@ type ConnConfig struct {
// InterfaceBlackList is a list of machine interfaces that should be filtered out by ICE Candidate gathering // InterfaceBlackList is a list of machine interfaces that should be filtered out by ICE Candidate gathering
// (e.g. if eth0 is in the list, host candidate of this interface won't be used) // (e.g. if eth0 is in the list, host candidate of this interface won't be used)
InterfaceBlackList []string InterfaceBlackList []string
DisableIPv6Discovery bool
Timeout time.Duration Timeout time.Duration
@@ -39,6 +40,8 @@ type ConnConfig struct {
UDPMuxSrflx ice.UniversalUDPMux UDPMuxSrflx ice.UniversalUDPMux
LocalWgPort int LocalWgPort int
NATExternalIPs []string
} }
// OfferAnswer represents a session establishment offer or answer // OfferAnswer represents a session establishment offer or answer
@@ -143,16 +146,24 @@ func (conn *Conn) reCreateAgent() error {
failedTimeout := 6 * time.Second failedTimeout := 6 * time.Second
var err error var err error
conn.agent, err = ice.NewAgent(&ice.AgentConfig{ agentConfig := &ice.AgentConfig{
MulticastDNSMode: ice.MulticastDNSModeDisabled, MulticastDNSMode: ice.MulticastDNSModeDisabled,
NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4}, NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6},
Urls: conn.config.StunTurn, Urls: conn.config.StunTurn,
CandidateTypes: []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay}, CandidateTypes: []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay},
FailedTimeout: &failedTimeout, FailedTimeout: &failedTimeout,
InterfaceFilter: interfaceFilter(conn.config.InterfaceBlackList), InterfaceFilter: interfaceFilter(conn.config.InterfaceBlackList),
UDPMux: conn.config.UDPMux, UDPMux: conn.config.UDPMux,
UDPMuxSrflx: conn.config.UDPMuxSrflx, UDPMuxSrflx: conn.config.UDPMuxSrflx,
}) NAT1To1IPs: conn.config.NATExternalIPs,
}
if conn.config.DisableIPv6Discovery {
agentConfig.NetworkTypes = []ice.NetworkType{ice.NetworkTypeUDP4}
}
conn.agent, err = ice.NewAgent(agentConfig)
if err != nil { if err != nil {
return err return err
} }
@@ -284,7 +295,7 @@ func (conn *Conn) Open() error {
host, _, _ := net.SplitHostPort(remoteConn.LocalAddr().String()) host, _, _ := net.SplitHostPort(remoteConn.LocalAddr().String())
rhost, _, _ := net.SplitHostPort(remoteConn.RemoteAddr().String()) rhost, _, _ := net.SplitHostPort(remoteConn.RemoteAddr().String())
// direct Wireguard connection // direct Wireguard connection
log.Infof("directly connected to peer %s [laddr <-> raddr] [%s:%d <-> %s:%d]", conn.config.Key, host, iface.DefaultWgPort, rhost, iface.DefaultWgPort) log.Infof("directly connected to peer %s [laddr <-> raddr] [%s:%d <-> %s:%d]", conn.config.Key, host, conn.config.LocalWgPort, rhost, remoteWgPort)
} else { } else {
log.Infof("connected to peer %s [laddr <-> raddr] [%s <-> %s]", conn.config.Key, remoteConn.LocalAddr().String(), remoteConn.RemoteAddr().String()) log.Infof("connected to peer %s [laddr <-> raddr] [%s <-> %s]", conn.config.Key, remoteConn.LocalAddr().String(), remoteConn.RemoteAddr().String())
} }
@@ -448,6 +459,7 @@ func (conn *Conn) SetSignalCandidate(handler func(candidate ice.Candidate) error
// and then signals them to the remote peer // and then signals them to the remote peer
func (conn *Conn) onICECandidate(candidate ice.Candidate) { func (conn *Conn) onICECandidate(candidate ice.Candidate) {
if candidate != nil { if candidate != nil {
// TODO: reported port is incorrect for CandidateTypeHost, makes understanding ICE use via logs confusing as port is ignored
log.Debugf("discovered local candidate %s", candidate.String()) log.Debugf("discovered local candidate %s", candidate.String())
go func() { go func() {
err := conn.signalCandidate(candidate) err := conn.signalCandidate(candidate)

View File

@@ -21,7 +21,7 @@ func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error {
} }
if prefixGateway != nil && !prefixGateway.Equal(gateway) { if prefixGateway != nil && !prefixGateway.Equal(gateway) {
log.Warnf("route for network %s already exist and is pointing to the gateway: %s, won't add another one", prefix, prefixGateway) log.Warnf("skipping adding a new route for network %s because it already exists and is pointing to the non default gateway: %s", prefix, prefixGateway)
return nil return nil
} }
return addToRouteTable(prefix, addr) return addToRouteTable(prefix, addr)
@@ -45,11 +45,14 @@ func getExistingRIBRouteGateway(prefix netip.Prefix) (net.IP, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
_, _, localGatewayAddress, err := r.Route(prefix.Addr().AsSlice()) _, gateway, preferredSrc, err := r.Route(prefix.Addr().AsSlice())
if err != nil { if err != nil {
log.Errorf("getting routes returned an error: %v", err) log.Errorf("getting routes returned an error: %v", err)
return nil, errRouteNotFound return nil, errRouteNotFound
} }
if gateway == nil {
return preferredSrc, nil
}
return localGatewayAddress, nil return gateway, nil
} }

View File

@@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"net"
"net/netip" "net/netip"
"testing" "testing"
) )
@@ -66,3 +67,45 @@ func TestAddRemoveRoutes(t *testing.T) {
}) })
} }
} }
func TestGetExistingRIBRouteGateway(t *testing.T) {
gateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0"))
if err != nil {
t.Fatal("shouldn't return error when fetching the gateway: ", err)
}
if gateway == nil {
t.Fatal("should return a gateway")
}
addresses, err := net.InterfaceAddrs()
if err != nil {
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
}
var testingIP string
var testingPrefix netip.Prefix
for _, address := range addresses {
if address.Network() != "ip+net" {
continue
}
prefix := netip.MustParsePrefix(address.String())
if !prefix.Addr().IsLoopback() && prefix.Addr().Is4() {
testingIP = prefix.Addr().String()
testingPrefix = prefix.Masked()
break
}
}
localIP, err := getExistingRIBRouteGateway(testingPrefix)
if err != nil {
t.Fatal("shouldn't return error: ", err)
}
if localIP == nil {
t.Fatal("should return a gateway for local network")
}
if localIP.String() == gateway.String() {
t.Fatal("local ip should not match with gateway IP")
}
if localIP.String() != testingIP {
t.Fatalf("local ip should match with testing IP: want %s got %s", testingIP, localIP.String())
}
}

View File

@@ -1,16 +1,16 @@
// Code generated by protoc-gen-go. DO NOT EDIT. // Code generated by protoc-gen-go. DO NOT EDIT.
// versions: // versions:
// protoc-gen-go v1.26.0 // protoc-gen-go v1.26.0
// protoc v3.12.4 // protoc v3.21.9
// source: daemon.proto // source: daemon.proto
package proto package proto
import ( import (
_ "github.com/golang/protobuf/protoc-gen-go/descriptor"
timestamp "github.com/golang/protobuf/ptypes/timestamp"
protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl" protoimpl "google.golang.org/protobuf/runtime/protoimpl"
_ "google.golang.org/protobuf/types/descriptorpb"
timestamppb "google.golang.org/protobuf/types/known/timestamppb"
reflect "reflect" reflect "reflect"
sync "sync" sync "sync"
) )
@@ -645,14 +645,15 @@ type PeerState struct {
sizeCache protoimpl.SizeCache sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields unknownFields protoimpl.UnknownFields
IP string `protobuf:"bytes,1,opt,name=IP,proto3" json:"IP,omitempty"` IP string `protobuf:"bytes,1,opt,name=IP,proto3" json:"IP,omitempty"`
PubKey string `protobuf:"bytes,2,opt,name=pubKey,proto3" json:"pubKey,omitempty"` PubKey string `protobuf:"bytes,2,opt,name=pubKey,proto3" json:"pubKey,omitempty"`
ConnStatus string `protobuf:"bytes,3,opt,name=connStatus,proto3" json:"connStatus,omitempty"` ConnStatus string `protobuf:"bytes,3,opt,name=connStatus,proto3" json:"connStatus,omitempty"`
ConnStatusUpdate *timestamp.Timestamp `protobuf:"bytes,4,opt,name=connStatusUpdate,proto3" json:"connStatusUpdate,omitempty"` ConnStatusUpdate *timestamppb.Timestamp `protobuf:"bytes,4,opt,name=connStatusUpdate,proto3" json:"connStatusUpdate,omitempty"`
Relayed bool `protobuf:"varint,5,opt,name=relayed,proto3" json:"relayed,omitempty"` Relayed bool `protobuf:"varint,5,opt,name=relayed,proto3" json:"relayed,omitempty"`
Direct bool `protobuf:"varint,6,opt,name=direct,proto3" json:"direct,omitempty"` Direct bool `protobuf:"varint,6,opt,name=direct,proto3" json:"direct,omitempty"`
LocalIceCandidateType string `protobuf:"bytes,7,opt,name=localIceCandidateType,proto3" json:"localIceCandidateType,omitempty"` LocalIceCandidateType string `protobuf:"bytes,7,opt,name=localIceCandidateType,proto3" json:"localIceCandidateType,omitempty"`
RemoteIceCandidateType string `protobuf:"bytes,8,opt,name=remoteIceCandidateType,proto3" json:"remoteIceCandidateType,omitempty"` RemoteIceCandidateType string `protobuf:"bytes,8,opt,name=remoteIceCandidateType,proto3" json:"remoteIceCandidateType,omitempty"`
Fqdn string `protobuf:"bytes,9,opt,name=fqdn,proto3" json:"fqdn,omitempty"`
} }
func (x *PeerState) Reset() { func (x *PeerState) Reset() {
@@ -708,7 +709,7 @@ func (x *PeerState) GetConnStatus() string {
return "" return ""
} }
func (x *PeerState) GetConnStatusUpdate() *timestamp.Timestamp { func (x *PeerState) GetConnStatusUpdate() *timestamppb.Timestamp {
if x != nil { if x != nil {
return x.ConnStatusUpdate return x.ConnStatusUpdate
} }
@@ -743,6 +744,13 @@ func (x *PeerState) GetRemoteIceCandidateType() string {
return "" return ""
} }
func (x *PeerState) GetFqdn() string {
if x != nil {
return x.Fqdn
}
return ""
}
// LocalPeerState contains the latest state of the local peer // LocalPeerState contains the latest state of the local peer
type LocalPeerState struct { type LocalPeerState struct {
state protoimpl.MessageState state protoimpl.MessageState
@@ -752,6 +760,7 @@ type LocalPeerState struct {
IP string `protobuf:"bytes,1,opt,name=IP,proto3" json:"IP,omitempty"` IP string `protobuf:"bytes,1,opt,name=IP,proto3" json:"IP,omitempty"`
PubKey string `protobuf:"bytes,2,opt,name=pubKey,proto3" json:"pubKey,omitempty"` PubKey string `protobuf:"bytes,2,opt,name=pubKey,proto3" json:"pubKey,omitempty"`
KernelInterface bool `protobuf:"varint,3,opt,name=kernelInterface,proto3" json:"kernelInterface,omitempty"` KernelInterface bool `protobuf:"varint,3,opt,name=kernelInterface,proto3" json:"kernelInterface,omitempty"`
Fqdn string `protobuf:"bytes,4,opt,name=fqdn,proto3" json:"fqdn,omitempty"`
} }
func (x *LocalPeerState) Reset() { func (x *LocalPeerState) Reset() {
@@ -807,6 +816,13 @@ func (x *LocalPeerState) GetKernelInterface() bool {
return false return false
} }
func (x *LocalPeerState) GetFqdn() string {
if x != nil {
return x.Fqdn
}
return ""
}
// SignalState contains the latest state of a signal connection // SignalState contains the latest state of a signal connection
type SignalState struct { type SignalState struct {
state protoimpl.MessageState state protoimpl.MessageState
@@ -1053,7 +1069,7 @@ var file_daemon_proto_rawDesc = []byte{
0x61, 0x72, 0x65, 0x64, 0x4b, 0x65, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x70, 0x61, 0x72, 0x65, 0x64, 0x4b, 0x65, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x70,
0x72, 0x65, 0x53, 0x68, 0x61, 0x72, 0x65, 0x64, 0x4b, 0x65, 0x79, 0x12, 0x1a, 0x0a, 0x08, 0x61, 0x72, 0x65, 0x53, 0x68, 0x61, 0x72, 0x65, 0x64, 0x4b, 0x65, 0x79, 0x12, 0x1a, 0x0a, 0x08, 0x61,
0x64, 0x6d, 0x69, 0x6e, 0x55, 0x52, 0x4c, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x55, 0x52, 0x4c, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x61,
0x64, 0x6d, 0x69, 0x6e, 0x55, 0x52, 0x4c, 0x22, 0xbb, 0x02, 0x0a, 0x09, 0x50, 0x65, 0x65, 0x72, 0x64, 0x6d, 0x69, 0x6e, 0x55, 0x52, 0x4c, 0x22, 0xcf, 0x02, 0x0a, 0x09, 0x50, 0x65, 0x65, 0x72,
0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28,
0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18,
0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x1e, 0x0a, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x1e, 0x0a,
@@ -1073,61 +1089,64 @@ var file_daemon_proto_rawDesc = []byte{
0x16, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x16, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64,
0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x18, 0x08, 0x20, 0x01, 0x28, 0x09, 0x52, 0x16, 0x72, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x18, 0x08, 0x20, 0x01, 0x28, 0x09, 0x52, 0x16, 0x72,
0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74,
0x65, 0x54, 0x79, 0x70, 0x65, 0x22, 0x62, 0x0a, 0x0e, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x09, 0x20,
0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x22, 0x76, 0x0a, 0x0e, 0x4c, 0x6f, 0x63,
0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49,
0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x70,
0x28, 0x0a, 0x0f, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x70, 0x75, 0x62,
0x63, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0f, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x4b, 0x65, 0x79, 0x12, 0x28, 0x0a, 0x0f, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x49, 0x6e, 0x74,
0x49, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x22, 0x3d, 0x0a, 0x0b, 0x53, 0x69, 0x67, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0f, 0x6b, 0x65,
0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, 0x4c, 0x18, 0x72, 0x6e, 0x65, 0x6c, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x12, 0x12, 0x0a,
0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x4c, 0x12, 0x1c, 0x0a, 0x09, 0x63, 0x6f, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64,
0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x63, 0x6e, 0x22, 0x3d, 0x0a, 0x0b, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65,
0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x22, 0x41, 0x0a, 0x0f, 0x4d, 0x61, 0x6e, 0x61, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, 0x4c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x55,
0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, 0x4c, 0x12, 0x1c, 0x0a, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x18,
0x52, 0x4c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x4c, 0x12, 0x1c, 0x0a, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64,
0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x22, 0x41, 0x0a, 0x0f, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74,
0x52, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x22, 0xef, 0x01, 0x0a, 0x0a, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, 0x4c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09,
0x46, 0x75, 0x6c, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x41, 0x0a, 0x0f, 0x6d, 0x61, 0x52, 0x03, 0x55, 0x52, 0x4c, 0x12, 0x1c, 0x0a, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74,
0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x01, 0x20, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63,
0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4d, 0x61, 0x6e, 0x74, 0x65, 0x64, 0x22, 0xef, 0x01, 0x0a, 0x0a, 0x46, 0x75, 0x6c, 0x6c, 0x53, 0x74, 0x61, 0x74,
0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0f, 0x6d, 0x61, 0x75, 0x73, 0x12, 0x41, 0x0a, 0x0f, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74,
0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x35, 0x0a, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x64, 0x61,
0x0b, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x02, 0x20, 0x01, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53,
0x28, 0x0b, 0x32, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x69, 0x67, 0x6e, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0f, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74,
0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0b, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x35, 0x0a, 0x0b, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53,
0x74, 0x61, 0x74, 0x65, 0x12, 0x3e, 0x0a, 0x0e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x74, 0x61, 0x74, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x13, 0x2e, 0x64, 0x61, 0x65,
0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x64, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52,
0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x0b, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x3e, 0x0a, 0x0e,
0x74, 0x61, 0x74, 0x65, 0x52, 0x0e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x03,
0x74, 0x61, 0x74, 0x65, 0x12, 0x27, 0x0a, 0x05, 0x70, 0x65, 0x65, 0x72, 0x73, 0x18, 0x04, 0x20, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f,
0x03, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x50, 0x65, 0x65, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0e, 0x6c, 0x6f,
0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x05, 0x70, 0x65, 0x65, 0x72, 0x73, 0x32, 0xf7, 0x02, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x27, 0x0a, 0x05,
0x0a, 0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x70, 0x65, 0x65, 0x72, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x64, 0x61,
0x36, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x05,
0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x70, 0x65, 0x65, 0x72, 0x73, 0x32, 0xf7, 0x02, 0x0a, 0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e,
0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x36, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e,
0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, 0x12, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52,
0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e,
0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12,
0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x4b, 0x0a, 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12,
0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f,
0x73, 0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55, 0x70, 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64,
0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67,
0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02,
0x65, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, 0x55, 0x70, 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65,
0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55,
0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x70, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53,
0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53,
0x0a, 0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64,
0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70,
0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13,
0x65, 0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75,
0x12, 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x65, 0x73, 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77,
0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47,
0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f,
0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65,
0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43,
0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42,
0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f,
0x33,
} }
var ( var (
@@ -1144,24 +1163,24 @@ func file_daemon_proto_rawDescGZIP() []byte {
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 17) var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 17)
var file_daemon_proto_goTypes = []interface{}{ var file_daemon_proto_goTypes = []interface{}{
(*LoginRequest)(nil), // 0: daemon.LoginRequest (*LoginRequest)(nil), // 0: daemon.LoginRequest
(*LoginResponse)(nil), // 1: daemon.LoginResponse (*LoginResponse)(nil), // 1: daemon.LoginResponse
(*WaitSSOLoginRequest)(nil), // 2: daemon.WaitSSOLoginRequest (*WaitSSOLoginRequest)(nil), // 2: daemon.WaitSSOLoginRequest
(*WaitSSOLoginResponse)(nil), // 3: daemon.WaitSSOLoginResponse (*WaitSSOLoginResponse)(nil), // 3: daemon.WaitSSOLoginResponse
(*UpRequest)(nil), // 4: daemon.UpRequest (*UpRequest)(nil), // 4: daemon.UpRequest
(*UpResponse)(nil), // 5: daemon.UpResponse (*UpResponse)(nil), // 5: daemon.UpResponse
(*StatusRequest)(nil), // 6: daemon.StatusRequest (*StatusRequest)(nil), // 6: daemon.StatusRequest
(*StatusResponse)(nil), // 7: daemon.StatusResponse (*StatusResponse)(nil), // 7: daemon.StatusResponse
(*DownRequest)(nil), // 8: daemon.DownRequest (*DownRequest)(nil), // 8: daemon.DownRequest
(*DownResponse)(nil), // 9: daemon.DownResponse (*DownResponse)(nil), // 9: daemon.DownResponse
(*GetConfigRequest)(nil), // 10: daemon.GetConfigRequest (*GetConfigRequest)(nil), // 10: daemon.GetConfigRequest
(*GetConfigResponse)(nil), // 11: daemon.GetConfigResponse (*GetConfigResponse)(nil), // 11: daemon.GetConfigResponse
(*PeerState)(nil), // 12: daemon.PeerState (*PeerState)(nil), // 12: daemon.PeerState
(*LocalPeerState)(nil), // 13: daemon.LocalPeerState (*LocalPeerState)(nil), // 13: daemon.LocalPeerState
(*SignalState)(nil), // 14: daemon.SignalState (*SignalState)(nil), // 14: daemon.SignalState
(*ManagementState)(nil), // 15: daemon.ManagementState (*ManagementState)(nil), // 15: daemon.ManagementState
(*FullStatus)(nil), // 16: daemon.FullStatus (*FullStatus)(nil), // 16: daemon.FullStatus
(*timestamp.Timestamp)(nil), // 17: google.protobuf.Timestamp (*timestamppb.Timestamp)(nil), // 17: google.protobuf.Timestamp
} }
var file_daemon_proto_depIdxs = []int32{ var file_daemon_proto_depIdxs = []int32{
16, // 0: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus 16, // 0: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus

View File

@@ -105,6 +105,7 @@ message PeerState {
bool direct = 6; bool direct = 6;
string localIceCandidateType = 7; string localIceCandidateType = 7;
string remoteIceCandidateType =8; string remoteIceCandidateType =8;
string fqdn = 9;
} }
// LocalPeerState contains the latest state of the local peer // LocalPeerState contains the latest state of the local peer
@@ -112,6 +113,7 @@ message LocalPeerState {
string IP = 1; string IP = 1;
string pubKey = 2; string pubKey = 2;
bool kernelInterface =3; bool kernelInterface =3;
string fqdn = 4;
} }
// SignalState contains the latest state of a signal connection // SignalState contains the latest state of a signal connection

View File

@@ -1,4 +1,17 @@
#!/bin/bash #!/bin/bash
set -e
if ! which realpath > /dev/null 2>&1
then
echo realpath is not installed
echo run: brew install coreutils
exit 1
fi
old_pwd=$(pwd)
script_path=$(dirname $(realpath "$0"))
cd "$script_path"
go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.26 go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.26
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.1 go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.1
protoc -I proto/ proto/daemon.proto --go_out=. --go-grpc_out=. protoc -I ./ ./daemon.proto --go_out=../ --go-grpc_out=../
cd "$old_pwd"

View File

@@ -475,6 +475,7 @@ func toProtoFullStatus(fullStatus nbStatus.FullStatus) *proto.FullStatus {
pbFullStatus.LocalPeerState.IP = fullStatus.LocalPeerState.IP pbFullStatus.LocalPeerState.IP = fullStatus.LocalPeerState.IP
pbFullStatus.LocalPeerState.PubKey = fullStatus.LocalPeerState.PubKey pbFullStatus.LocalPeerState.PubKey = fullStatus.LocalPeerState.PubKey
pbFullStatus.LocalPeerState.KernelInterface = fullStatus.LocalPeerState.KernelInterface pbFullStatus.LocalPeerState.KernelInterface = fullStatus.LocalPeerState.KernelInterface
pbFullStatus.LocalPeerState.Fqdn = fullStatus.LocalPeerState.FQDN
for _, peerState := range fullStatus.Peers { for _, peerState := range fullStatus.Peers {
pbPeerState := &proto.PeerState{ pbPeerState := &proto.PeerState{
@@ -486,6 +487,7 @@ func toProtoFullStatus(fullStatus nbStatus.FullStatus) *proto.FullStatus {
Direct: peerState.Direct, Direct: peerState.Direct,
LocalIceCandidateType: peerState.LocalIceCandidateType, LocalIceCandidateType: peerState.LocalIceCandidateType,
RemoteIceCandidateType: peerState.RemoteIceCandidateType, RemoteIceCandidateType: peerState.RemoteIceCandidateType,
Fqdn: peerState.FQDN,
} }
pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState) pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState)
} }

View File

@@ -10,6 +10,7 @@ import (
type PeerState struct { type PeerState struct {
IP string IP string
PubKey string PubKey string
FQDN string
ConnStatus string ConnStatus string
ConnStatusUpdate time.Time ConnStatusUpdate time.Time
Relayed bool Relayed bool
@@ -23,6 +24,7 @@ type LocalPeerState struct {
IP string IP string
PubKey string PubKey string
KernelInterface bool KernelInterface bool
FQDN string
} }
// SignalState contains the latest state of a signal connection // SignalState contains the latest state of a signal connection
@@ -136,6 +138,22 @@ func (d *Status) UpdatePeerState(receivedState PeerState) error {
return nil return nil
} }
// UpdatePeerFQDN update peer's state fqdn only
func (d *Status) UpdatePeerFQDN(peerPubKey, fqdn string) error {
d.mux.Lock()
defer d.mux.Unlock()
peerState, ok := d.peers[peerPubKey]
if !ok {
return errors.New("peer doesn't exist")
}
peerState.FQDN = fqdn
d.peers[peerPubKey] = peerState
return nil
}
// GetPeerStateChangeNotifier returns a change notifier channel for a peer // GetPeerStateChangeNotifier returns a change notifier channel for a peer
func (d *Status) GetPeerStateChangeNotifier(peer string) <-chan struct{} { func (d *Status) GetPeerStateChangeNotifier(peer string) <-chan struct{} {
d.mux.Lock() d.mux.Lock()

View File

@@ -54,6 +54,24 @@ func TestUpdatePeerState(t *testing.T) {
assert.Equal(t, ip, state.IP, "ip should be equal") assert.Equal(t, ip, state.IP, "ip should be equal")
} }
func TestStatus_UpdatePeerFQDN(t *testing.T) {
key := "abc"
fqdn := "peer-a.netbird.local"
status := NewRecorder()
peerState := PeerState{
PubKey: key,
}
status.peers[key] = peerState
err := status.UpdatePeerFQDN(key, fqdn)
assert.NoError(t, err, "shouldn't return error")
state, exists := status.peers[key]
assert.True(t, exists, "state should be found")
assert.Equal(t, fqdn, state.FQDN, "fqdn should be equal")
}
func TestGetPeerStateChangeNotifierLogic(t *testing.T) { func TestGetPeerStateChangeNotifierLogic(t *testing.T) {
key := "abc" key := "abc"
ip := "10.10.10.10" ip := "10.10.10.10"

View File

@@ -8,7 +8,6 @@ import (
"context" "context"
"flag" "flag"
"fmt" "fmt"
"github.com/netbirdio/netbird/client/system"
"os" "os"
"os/exec" "os/exec"
"path" "path"
@@ -18,6 +17,8 @@ import (
"syscall" "syscall"
"time" "time"
"github.com/netbirdio/netbird/client/system"
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
_ "embed" _ "embed"
@@ -61,6 +62,8 @@ func main() {
flag.Parse() flag.Parse()
a := app.New() a := app.New()
a.SetIcon(fyne.NewStaticResource("netbird", iconDisconnectedPNG))
client := newServiceClient(daemonAddr, a, showSettings) client := newServiceClient(daemonAddr, a, showSettings)
if showSettings { if showSettings {
a.Run() a.Run()
@@ -113,7 +116,7 @@ type serviceClient struct {
iLogFile *widget.Entry iLogFile *widget.Entry
iPreSharedKey *widget.Entry iPreSharedKey *widget.Entry
// observable settings over correspondign iMngURL and iPreSharedKey values. // observable settings over corresponding iMngURL and iPreSharedKey values.
managementURL string managementURL string
preSharedKey string preSharedKey string
adminURL string adminURL string
@@ -121,7 +124,7 @@ type serviceClient struct {
// newServiceClient instance constructor // newServiceClient instance constructor
// //
// This constructor olso build UI elements for settings window. // This constructor also builds the UI elements for the settings window.
func newServiceClient(addr string, a fyne.App, showSettings bool) *serviceClient { func newServiceClient(addr string, a fyne.App, showSettings bool) *serviceClient {
s := &serviceClient{ s := &serviceClient{
ctx: context.Background(), ctx: context.Background(),
@@ -149,7 +152,7 @@ func newServiceClient(addr string, a fyne.App, showSettings bool) *serviceClient
func (s *serviceClient) showUIElements() { func (s *serviceClient) showUIElements() {
// add settings window UI elements. // add settings window UI elements.
s.wSettings = s.app.NewWindow("Settings") s.wSettings = s.app.NewWindow("NetBird Settings")
s.iMngURL = widget.NewEntry() s.iMngURL = widget.NewEntry()
s.iAdminURL = widget.NewEntry() s.iAdminURL = widget.NewEntry()
s.iConfigFile = widget.NewEntry() s.iConfigFile = widget.NewEntry()
@@ -326,11 +329,13 @@ func (s *serviceClient) updateStatus() error {
if status.Status == string(internal.StatusConnected) && !s.mUp.Disabled() { if status.Status == string(internal.StatusConnected) && !s.mUp.Disabled() {
systray.SetIcon(s.icConnected) systray.SetIcon(s.icConnected)
systray.SetTooltip("NetBird (Connected)")
s.mStatus.SetTitle("Connected") s.mStatus.SetTitle("Connected")
s.mUp.Disable() s.mUp.Disable()
s.mDown.Enable() s.mDown.Enable()
} else if status.Status != string(internal.StatusConnected) && s.mUp.Disabled() { } else if status.Status != string(internal.StatusConnected) && s.mUp.Disabled() {
systray.SetIcon(s.icDisconnected) systray.SetIcon(s.icDisconnected)
systray.SetTooltip("NetBird (Disconnected)")
s.mStatus.SetTitle("Disconnected") s.mStatus.SetTitle("Disconnected")
s.mDown.Disable() s.mDown.Disable()
s.mUp.Enable() s.mUp.Enable()
@@ -355,6 +360,7 @@ func (s *serviceClient) updateStatus() error {
func (s *serviceClient) onTrayReady() { func (s *serviceClient) onTrayReady() {
systray.SetIcon(s.icDisconnected) systray.SetIcon(s.icDisconnected)
systray.SetTooltip("NetBird")
// setup systray menu items // setup systray menu items
s.mStatus = systray.AddMenuItem("Disconnected", "Disconnected") s.mStatus = systray.AddMenuItem("Disconnected", "Disconnected")

View File

@@ -2,5 +2,108 @@
// to parse and normalize dns records and configuration // to parse and normalize dns records and configuration
package dns package dns
// DefaultDNSPort well-known port number import (
const DefaultDNSPort = 53 "fmt"
"github.com/miekg/dns"
"golang.org/x/net/idna"
"net"
"regexp"
"strings"
)
const (
// DefaultDNSPort well-known port number
DefaultDNSPort = 53
// RootZone is a string representation of the root zone
RootZone = "."
// DefaultClass is the class supported by the system
DefaultClass = "IN"
)
const invalidHostLabel = "[^a-zA-Z0-9-]+"
// Config represents a dns configuration that is exchanged between management and peers
type Config struct {
// ServiceEnable indicates if the service should be enabled
ServiceEnable bool
// NameServerGroups contains a list of nameserver group
NameServerGroups []*NameServerGroup
// CustomZones contains a list of custom zone
CustomZones []CustomZone
}
// CustomZone represents a custom zone to be resolved by the dns server
type CustomZone struct {
// Domain is the zone's domain
Domain string
// Records custom zone records
Records []SimpleRecord
}
// SimpleRecord provides a simple DNS record specification for CNAME, A and AAAA records
type SimpleRecord struct {
// Name domain name
Name string
// Type of record, 1 for A, 5 for CNAME, 28 for AAAA. see https://pkg.go.dev/github.com/miekg/dns@v1.1.41#pkg-constants
Type int
// Class dns class, currently use the DefaultClass for all records
Class string
// TTL time-to-live for the record
TTL int
// RData is the actual value resolved in a dns query
RData string
}
// String returns a string of the simple record formatted as:
// <Name> <TTL> <Class> <Type> <RDATA>
func (s SimpleRecord) String() string {
fqdn := dns.Fqdn(s.Name)
return fmt.Sprintf("%s %d %s %s %s", fqdn, s.TTL, s.Class, dns.Type(s.Type).String(), s.RData)
}
// Len returns the length of the RData field, based on its type
func (s SimpleRecord) Len() uint16 {
emptyString := s.RData == ""
switch s.Type {
case 1:
if emptyString {
return 0
}
return net.IPv4len
case 5:
if emptyString || s.RData == "." {
return 1
}
return uint16(len(s.RData) + 1)
case 28:
if emptyString {
return 0
}
return net.IPv6len
default:
return 0
}
}
// GetParsedDomainLabel returns a domain label with max 59 characters,
// parsed for old Hosts.txt requirements, and converted to ASCII and lowercase
func GetParsedDomainLabel(name string) (string, error) {
labels := dns.SplitDomainName(name)
if len(labels) == 0 {
return "", fmt.Errorf("got empty label list for name \"%s\"", name)
}
rawLabel := labels[0]
ascii, err := idna.Punycode.ToASCII(rawLabel)
if err != nil {
return "", fmt.Errorf("unable to convert host lavel to ASCII, error: %v", err)
}
invalidHostMatcher := regexp.MustCompile(invalidHostLabel)
validHost := strings.ToLower(invalidHostMatcher.ReplaceAllString(ascii, "-"))
if len(validHost) > 58 {
validHost = validHost[:59]
}
return validHost, nil
}

View File

@@ -9,8 +9,6 @@ import (
) )
const ( const (
// MaxGroupNameChar maximum group name size
MaxGroupNameChar = 40
// InvalidNameServerType invalid nameserver type // InvalidNameServerType invalid nameserver type
InvalidNameServerType NameServerType = iota InvalidNameServerType NameServerType = iota
// UDPNameServerType udp nameserver type // UDPNameServerType udp nameserver type
@@ -18,6 +16,8 @@ const (
) )
const ( const (
// MaxGroupNameChar maximum group name size
MaxGroupNameChar = 40
// InvalidNameServerTypeString invalid nameserver type as string // InvalidNameServerTypeString invalid nameserver type as string
InvalidNameServerTypeString = "invalid" InvalidNameServerTypeString = "invalid"
// UDPNameServerTypeString udp nameserver type as string // UDPNameServerTypeString udp nameserver type as string
@@ -59,6 +59,10 @@ type NameServerGroup struct {
NameServers []NameServer NameServers []NameServer
// Groups list of peer group IDs to distribute the nameservers information // Groups list of peer group IDs to distribute the nameservers information
Groups []string Groups []string
// Primary indicates that the nameserver group is the primary resolver for any dns query
Primary bool
// Domains indicate the dns query domains to use with this nameserver group
Domains []string
// Enabled group status // Enabled group status
Enabled bool Enabled bool
} }
@@ -128,6 +132,8 @@ func (g *NameServerGroup) Copy() *NameServerGroup {
NameServers: g.NameServers, NameServers: g.NameServers,
Groups: g.Groups, Groups: g.Groups,
Enabled: g.Enabled, Enabled: g.Enabled,
Primary: g.Primary,
Domains: g.Domains,
} }
} }
@@ -136,8 +142,10 @@ func (g *NameServerGroup) IsEqual(other *NameServerGroup) bool {
return other.ID == g.ID && return other.ID == g.ID &&
other.Name == g.Name && other.Name == g.Name &&
other.Description == g.Description && other.Description == g.Description &&
other.Primary == g.Primary &&
compareNameServerList(g.NameServers, other.NameServers) && compareNameServerList(g.NameServers, other.NameServers) &&
compareGroupsList(g.Groups, other.Groups) compareGroupsList(g.Groups, other.Groups) &&
compareGroupsList(g.Domains, other.Domains)
} }
func compareNameServerList(list, other []NameServer) bool { func compareNameServerList(list, other []NameServer) bool {

Binary file not shown.

Before

Width:  |  Height:  |  Size: 887 KiB

26
go.mod
View File

@@ -18,12 +18,12 @@ require (
github.com/spf13/pflag v1.0.5 github.com/spf13/pflag v1.0.5
github.com/vishvananda/netlink v1.1.0 github.com/vishvananda/netlink v1.1.0
golang.org/x/crypto v0.0.0-20220513210258-46612604a0f9 golang.org/x/crypto v0.0.0-20220513210258-46612604a0f9
golang.org/x/sys v0.0.0-20220622161953-175b2fd9d664 golang.org/x/sys v0.0.0-20220919091848-fb04ddd9f9c8
golang.zx2c4.com/wireguard v0.0.0-20211209221555-9c9e7e272434 golang.zx2c4.com/wireguard v0.0.0-20211209221555-9c9e7e272434
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211215182854-7a385b3431de golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211215182854-7a385b3431de
golang.zx2c4.com/wireguard/windows v0.5.1 golang.zx2c4.com/wireguard/windows v0.5.1
google.golang.org/grpc v1.43.0 google.golang.org/grpc v1.43.0
google.golang.org/protobuf v1.28.0 google.golang.org/protobuf v1.28.1
gopkg.in/natefinch/lumberjack.v2 v2.0.0 gopkg.in/natefinch/lumberjack.v2 v2.0.0
) )
@@ -35,13 +35,20 @@ require (
github.com/eko/gocache/v3 v3.1.1 github.com/eko/gocache/v3 v3.1.1
github.com/getlantern/systray v1.2.1 github.com/getlantern/systray v1.2.1
github.com/gliderlabs/ssh v0.3.4 github.com/gliderlabs/ssh v0.3.4
github.com/godbus/dbus/v5 v5.1.0
github.com/google/nftables v0.0.0-20220808154552-2eca00135732 github.com/google/nftables v0.0.0-20220808154552-2eca00135732
github.com/hashicorp/go-version v1.6.0
github.com/libp2p/go-netroute v0.2.0 github.com/libp2p/go-netroute v0.2.0
github.com/magiconair/properties v1.8.5 github.com/magiconair/properties v1.8.5
github.com/miekg/dns v1.1.41
github.com/patrickmn/go-cache v2.1.0+incompatible github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/prometheus/client_golang v1.13.0
github.com/rs/xid v1.3.0 github.com/rs/xid v1.3.0
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
github.com/stretchr/testify v1.8.0 github.com/stretchr/testify v1.8.0
go.opentelemetry.io/otel/exporters/prometheus v0.33.0
go.opentelemetry.io/otel/metric v0.33.0
go.opentelemetry.io/otel/sdk/metric v0.33.0
golang.org/x/net v0.0.0-20220630215102-69896b714898 golang.org/x/net v0.0.0-20220630215102-69896b714898
golang.org/x/term v0.0.0-20220526004731-065cf7ba2467 golang.org/x/term v0.0.0-20220526004731-065cf7ba2467
) )
@@ -65,11 +72,12 @@ require (
github.com/getlantern/ops v0.0.0-20190325191751-d70cb0d6f85f // indirect github.com/getlantern/ops v0.0.0-20190325191751-d70cb0d6f85f // indirect
github.com/go-gl/gl v0.0.0-20210813123233-e4099ee2221f // indirect github.com/go-gl/gl v0.0.0-20210813123233-e4099ee2221f // indirect
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20211024062804-40e447a793be // indirect github.com/go-gl/glfw/v3.3/glfw v0.0.0-20211024062804-40e447a793be // indirect
github.com/go-logr/logr v1.2.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-redis/redis/v8 v8.11.5 // indirect github.com/go-redis/redis/v8 v8.11.5 // indirect
github.com/go-stack/stack v1.8.0 // indirect github.com/go-stack/stack v1.8.0 // indirect
github.com/godbus/dbus/v5 v5.0.4 // indirect
github.com/goki/freetype v0.0.0-20181231101311-fa8a33aabaff // indirect github.com/goki/freetype v0.0.0-20181231101311-fa8a33aabaff // indirect
github.com/google/go-cmp v0.5.7 // indirect github.com/google/go-cmp v0.5.9 // indirect
github.com/google/gopacket v1.1.19 // indirect github.com/google/gopacket v1.1.19 // indirect
github.com/inconshreveable/mousetrap v1.0.0 // indirect github.com/inconshreveable/mousetrap v1.0.0 // indirect
github.com/josharian/native v0.0.0-20200817173448-b6b71def0850 // indirect github.com/josharian/native v0.0.0-20200817173448-b6b71def0850 // indirect
@@ -89,20 +97,22 @@ require (
github.com/pion/turn/v2 v2.0.8 // indirect github.com/pion/turn/v2 v2.0.8 // indirect
github.com/pion/udp v0.1.1 // indirect github.com/pion/udp v0.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_golang v1.12.2 // indirect
github.com/prometheus/client_model v0.2.0 // indirect github.com/prometheus/client_model v0.2.0 // indirect
github.com/prometheus/common v0.33.0 // indirect github.com/prometheus/common v0.37.0 // indirect
github.com/prometheus/procfs v0.7.3 // indirect github.com/prometheus/procfs v0.8.0 // indirect
github.com/rogpeppe/go-internal v1.8.0 // indirect github.com/rogpeppe/go-internal v1.8.0 // indirect
github.com/spf13/cast v1.5.0 // indirect github.com/spf13/cast v1.5.0 // indirect
github.com/srwiley/oksvg v0.0.0-20200311192757-870daf9aa564 // indirect github.com/srwiley/oksvg v0.0.0-20200311192757-870daf9aa564 // indirect
github.com/srwiley/rasterx v0.0.0-20200120212402-85cb7272f5e9 // indirect github.com/srwiley/rasterx v0.0.0-20200120212402-85cb7272f5e9 // indirect
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df // indirect github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df // indirect
github.com/yuin/goldmark v1.4.1 // indirect github.com/yuin/goldmark v1.4.1 // indirect
go.opentelemetry.io/otel v1.11.1 // indirect
go.opentelemetry.io/otel/sdk v1.11.1 // indirect
go.opentelemetry.io/otel/trace v1.11.1 // indirect
golang.org/x/exp v0.0.0-20220518171630-0b5c67f07fdf // indirect golang.org/x/exp v0.0.0-20220518171630-0b5c67f07fdf // indirect
golang.org/x/image v0.0.0-20200430140353-33d19683fad8 // indirect golang.org/x/image v0.0.0-20200430140353-33d19683fad8 // indirect
golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3 // indirect golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3 // indirect
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f // indirect
golang.org/x/text v0.3.8-0.20211105212822-18b340fc7af2 // indirect golang.org/x/text v0.3.8-0.20211105212822-18b340fc7af2 // indirect
golang.org/x/tools v0.1.10 // indirect golang.org/x/tools v0.1.10 // indirect
golang.org/x/xerrors v0.0.0-20220411194840-2f41105eb62f // indirect golang.org/x/xerrors v0.0.0-20220411194840-2f41105eb62f // indirect

49
go.sum
View File

@@ -202,6 +202,11 @@ github.com/go-logfmt/logfmt v0.5.1/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KE
github.com/go-logr/logr v0.1.0/go.mod h1:ixOQHD9gLJUVQQ2ZOR7zLEifBX6tGkNJF4QyIY7sIas= github.com/go-logr/logr v0.1.0/go.mod h1:ixOQHD9gLJUVQQ2ZOR7zLEifBX6tGkNJF4QyIY7sIas=
github.com/go-logr/logr v0.2.0/go.mod h1:z6/tIYblkpsD+a4lm/fGIIU9mZ+XfAiaFtq7xTgseGU= github.com/go-logr/logr v0.2.0/go.mod h1:z6/tIYblkpsD+a4lm/fGIIU9mZ+XfAiaFtq7xTgseGU=
github.com/go-logr/logr v1.2.0/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.2.0/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0=
github.com/go-logr/logr v1.2.3/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/go-openapi/jsonpointer v0.0.0-20160704185906-46af16f9f7b1/go.mod h1:+35s3my2LFTysnkMfxsJBAMHj/DoqoB9knIWoYG/Vk0= github.com/go-openapi/jsonpointer v0.0.0-20160704185906-46af16f9f7b1/go.mod h1:+35s3my2LFTysnkMfxsJBAMHj/DoqoB9knIWoYG/Vk0=
github.com/go-openapi/jsonpointer v0.19.3/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= github.com/go-openapi/jsonpointer v0.19.3/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg=
@@ -218,8 +223,9 @@ github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq
github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk= github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk=
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE=
github.com/godbus/dbus/v5 v5.0.4 h1:9349emZab16e7zQvpmsbtjc18ykshndd8y2PG3sgJbA=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk=
github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
github.com/gogo/protobuf v1.2.2-0.20190723190241-65acae22fc9d/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o= github.com/gogo/protobuf v1.2.2-0.20190723190241-65acae22fc9d/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
@@ -278,8 +284,8 @@ github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.7 h1:81/ik6ipDQS2aGcBfIN5dHDB36BwrStyeAQquSYCV4o= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v0.0.0-20161122191042-44d81051d367/go.mod h1:HP5RmnzzSNb993RKQDq4+1A4ia9nllfqcQFTQJedwGI= github.com/google/gofuzz v0.0.0-20161122191042-44d81051d367/go.mod h1:HP5RmnzzSNb993RKQDq4+1A4ia9nllfqcQFTQJedwGI=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/gofuzz v1.1.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.1.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
@@ -342,6 +348,8 @@ github.com/hashicorp/go-sockaddr v1.0.0/go.mod h1:7Xibr9yA9JjQq1JpNB2Vw7kxv8xerX
github.com/hashicorp/go-syslog v1.0.0/go.mod h1:qPfqrKkXGihmCqbJM2mZgkZGvKG1dFdvsLplgctolz4= github.com/hashicorp/go-syslog v1.0.0/go.mod h1:qPfqrKkXGihmCqbJM2mZgkZGvKG1dFdvsLplgctolz4=
github.com/hashicorp/go-uuid v1.0.0/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/go-uuid v1.0.0/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
github.com/hashicorp/go-uuid v1.0.1/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/go-uuid v1.0.1/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
github.com/hashicorp/go-version v1.6.0 h1:feTTfFNnjP967rlCxM/I9g701jU+RN74YKx2mOkIeek=
github.com/hashicorp/go-version v1.6.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA=
github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4=
@@ -450,6 +458,7 @@ github.com/mdlayher/socket v0.0.0-20211102153432-57e3fa563ecb h1:2dC7L10LmTqlyMV
github.com/mdlayher/socket v0.0.0-20211102153432-57e3fa563ecb/go.mod h1:nFZ1EtZYK8Gi/k6QNu7z7CgO20i/4ExeQswwWuPmG/g= github.com/mdlayher/socket v0.0.0-20211102153432-57e3fa563ecb/go.mod h1:nFZ1EtZYK8Gi/k6QNu7z7CgO20i/4ExeQswwWuPmG/g=
github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg=
github.com/miekg/dns v1.1.26/go.mod h1:bPDLeHnStXmXAq1m/Ch/hvfNHr14JKNPMBo3VZKjuso= github.com/miekg/dns v1.1.26/go.mod h1:bPDLeHnStXmXAq1m/Ch/hvfNHr14JKNPMBo3VZKjuso=
github.com/miekg/dns v1.1.41 h1:WMszZWJG0XmzbK9FEmzH2TVcqYzFesusSIB41b8KHxY=
github.com/miekg/dns v1.1.41/go.mod h1:p6aan82bvRIyn+zDIv9xYNUpwa73JcSh9BKwknJysuI= github.com/miekg/dns v1.1.41/go.mod h1:p6aan82bvRIyn+zDIv9xYNUpwa73JcSh9BKwknJysuI=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws= github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc= github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc=
@@ -542,8 +551,8 @@ github.com/prometheus/client_golang v1.4.0/go.mod h1:e9GMxYsXl05ICDXkRhurwBS4Q3O
github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M= github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M=
github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0= github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0=
github.com/prometheus/client_golang v1.12.1/go.mod h1:3Z9XVyYiZYEO+YQWt3RD2R3jrbd179Rt297l4aS6nDY= github.com/prometheus/client_golang v1.12.1/go.mod h1:3Z9XVyYiZYEO+YQWt3RD2R3jrbd179Rt297l4aS6nDY=
github.com/prometheus/client_golang v1.12.2 h1:51L9cDoUHVrXx4zWYlcLQIZ+d+VXHgqnYKkIuq4g/34= github.com/prometheus/client_golang v1.13.0 h1:b71QUfeo5M8gq2+evJdTPfZhYMAU0uKPkyPJ7TPsloU=
github.com/prometheus/client_golang v1.12.2/go.mod h1:3Z9XVyYiZYEO+YQWt3RD2R3jrbd179Rt297l4aS6nDY= github.com/prometheus/client_golang v1.13.0/go.mod h1:vTeo+zgvILHsnnj/39Ou/1fPN5nJFOEMgftOUOmlvYQ=
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
@@ -554,15 +563,16 @@ github.com/prometheus/common v0.9.1/go.mod h1:yhUN8i9wzaXS3w1O07YhxHEBxD+W35wd8b
github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo= github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo=
github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc= github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc=
github.com/prometheus/common v0.32.1/go.mod h1:vu+V0TpY+O6vW9J44gczi3Ap/oXXR10b+M/gUGO4Hls= github.com/prometheus/common v0.32.1/go.mod h1:vu+V0TpY+O6vW9J44gczi3Ap/oXXR10b+M/gUGO4Hls=
github.com/prometheus/common v0.33.0 h1:rHgav/0a6+uYgGdNt3jwz8FNSesO/Hsang3O0T9A5SE= github.com/prometheus/common v0.37.0 h1:ccBbHCgIiT9uSoFY0vX8H3zsNR5eLt17/RQLUvn8pXE=
github.com/prometheus/common v0.33.0/go.mod h1:gB3sOl7P0TvJabZpLY5uQMpUqRCPPCyRLCZYc7JZTNE= github.com/prometheus/common v0.37.0/go.mod h1:phzohg0JFMnBEFGxTDbfu3QyL5GI8gTQJFhYO5B3mfA=
github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA=
github.com/prometheus/procfs v0.0.8/go.mod h1:7Qr8sr6344vo1JqZ6HhLceV9o3AJ1Ff+GxbHq6oeK9A= github.com/prometheus/procfs v0.0.8/go.mod h1:7Qr8sr6344vo1JqZ6HhLceV9o3AJ1Ff+GxbHq6oeK9A=
github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU= github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU=
github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
github.com/prometheus/procfs v0.7.3 h1:4jVXhlkAyzOScmCkXBTOLRLTz8EeU+eyjrwB/EPq0VU=
github.com/prometheus/procfs v0.7.3/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= github.com/prometheus/procfs v0.7.3/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
github.com/prometheus/procfs v0.8.0 h1:ODq8ZFEaYeCaZOJlZZdJA2AbQR98dSHSM1KW/You5mo=
github.com/prometheus/procfs v0.8.0/go.mod h1:z7EfXMXOkbkqb9IINtpCn86r/to3BnA0uaxHdg830/4=
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8= github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
@@ -648,6 +658,18 @@ go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk= go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk=
go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E= go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E=
go.opentelemetry.io/otel v1.11.1 h1:4WLLAmcfkmDk2ukNXJyq3/kiz/3UzCaYq6PskJsaou4=
go.opentelemetry.io/otel v1.11.1/go.mod h1:1nNhXBbWSD0nsL38H6btgnFN2k4i0sNLHNNMZMSbUGE=
go.opentelemetry.io/otel/exporters/prometheus v0.33.0 h1:xXhPj7SLKWU5/Zd4Hxmd+X1C4jdmvc0Xy+kvjFx2z60=
go.opentelemetry.io/otel/exporters/prometheus v0.33.0/go.mod h1:ZSmYfKdYWEdSDBB4njLBIwTf4AU2JNsH3n2quVQDebI=
go.opentelemetry.io/otel/metric v0.33.0 h1:xQAyl7uGEYvrLAiV/09iTJlp1pZnQ9Wl793qbVvED1E=
go.opentelemetry.io/otel/metric v0.33.0/go.mod h1:QlTYc+EnYNq/M2mNk1qDDMRLpqCOj2f/r5c7Fd5FYaI=
go.opentelemetry.io/otel/sdk v1.11.1 h1:F7KmQgoHljhUuJyA+9BiU+EkJfyX5nVVF4wyzWZpKxs=
go.opentelemetry.io/otel/sdk v1.11.1/go.mod h1:/l3FE4SupHJ12TduVjUkZtlfFqDCQJlOlithYrdktys=
go.opentelemetry.io/otel/sdk/metric v0.33.0 h1:oTqyWfksgKoJmbrs2q7O7ahkJzt+Ipekihf8vhpa9qo=
go.opentelemetry.io/otel/sdk/metric v0.33.0/go.mod h1:xdypMeA21JBOvjjzDUtD0kzIcHO/SPez+a8HOzJPGp0=
go.opentelemetry.io/otel/trace v1.11.1 h1:ofxdnzsNrGBYXbP7t7zpUK281+go5rF7dvdIZXF8gdQ=
go.opentelemetry.io/otel/trace v1.11.1/go.mod h1:f/Q9G7vzk5u91PhbmKbg1Qn0rzH1LJ4vbPHFGkTPtOk=
go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI= go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI=
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU=
@@ -809,8 +831,9 @@ golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f h1:Ax0t5p6N38Ga0dThY21weqDEyz2oklo4IvDkpigvkD8=
golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20170830134202-bb24a47a89ea/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20170830134202-bb24a47a89ea/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@@ -915,8 +938,8 @@ golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220608164250-635b8c9b7f68/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220608164250-635b8c9b7f68/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220622161953-175b2fd9d664 h1:wEZYwx+kK+KlZ0hpvP2Ls1Xr4+RWnlzGFwPP0aiDjIU= golang.org/x/sys v0.0.0-20220919091848-fb04ddd9f9c8 h1:h+EGohizhe9XlX18rfpa8k8RAc5XyaeamM+0VHRd4lc=
golang.org/x/sys v0.0.0-20220622161953-175b2fd9d664/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220919091848-fb04ddd9f9c8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.0.0-20220526004731-065cf7ba2467 h1:CBpWXWQpIRjzmkkA+M7q9Fqnwd2mZr3AFqexg8YTfoM= golang.org/x/term v0.0.0-20220526004731-065cf7ba2467 h1:CBpWXWQpIRjzmkkA+M7q9Fqnwd2mZr3AFqexg8YTfoM=
@@ -1162,8 +1185,8 @@ google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlba
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw= google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w=
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

View File

@@ -73,13 +73,15 @@ func parseAddress(address string) (WGAddress, error) {
func (w *WGIface) Close() error { func (w *WGIface) Close() error {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
if w.Interface == nil {
return nil
}
err := w.Interface.Close() err := w.Interface.Close()
if err != nil { if err != nil {
return err return err
} }
if runtime.GOOS == "darwin" { if runtime.GOOS != "windows" {
sockPath := "/var/run/wireguard/" + w.Name + ".sock" sockPath := "/var/run/wireguard/" + w.Name + ".sock"
if _, statErr := os.Stat(sockPath); statErr == nil { if _, statErr := os.Stat(sockPath); statErr == nil {
statErr = os.Remove(sockPath) statErr = os.Remove(sockPath)

View File

@@ -89,7 +89,6 @@ func getIfaceAddrs(ifaceName string) ([]net.Addr, error) {
return addrs, nil return addrs, nil
} }
//
func Test_CreateInterface(t *testing.T) { func Test_CreateInterface(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+1) ifaceName := fmt.Sprintf("utun%d", WgIntNumber+1)
wgIP := "10.99.99.1/32" wgIP := "10.99.99.1/32"
@@ -369,8 +368,8 @@ func Test_ConnectPeers(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
// todo: investigate why in some tests execution we need 30s
timeout := 10 * time.Second timeout := 30 * time.Second
timeoutChannel := time.After(timeout) timeoutChannel := time.After(timeout)
for { for {
select { select {

View File

@@ -75,3 +75,8 @@ func (w *WGIface) UpdateAddr(newAddr string) error {
w.Address = addr w.Address = addr
return w.assignAddr() return w.assignAddr()
} }
// GetInterfaceGUIDString returns an interface GUID. This is useful on Windows only
func (w *WGIface) GetInterfaceGUIDString() (string, error) {
return "", nil
}

View File

@@ -58,6 +58,20 @@ func (w *WGIface) UpdateAddr(newAddr string) error {
return w.assignAddr(luid) return w.assignAddr(luid)
} }
// GetInterfaceGUIDString returns an interface GUID string
func (w *WGIface) GetInterfaceGUIDString() (string, error) {
if w.Interface == nil {
return "", fmt.Errorf("interface has not been initialized yet")
}
windowsDevice := w.Interface.(*driver.Adapter)
luid := windowsDevice.LUID()
guid, err := luid.GUID()
if err != nil {
return "", err
}
return guid.String(), nil
}
// WireguardModuleIsLoaded check if we can load wireguard mod (linux only) // WireguardModuleIsLoaded check if we can load wireguard mod (linux only)
func WireguardModuleIsLoaded() bool { func WireguardModuleIsLoaded() bool {
return false return false

View File

@@ -214,8 +214,9 @@ func isBuiltinModule(name string) (bool, error) {
} }
// /proc/modules // /proc/modules
// name | memory size | reference count | references | state: <Live|Loading|Unloading> //
// macvlan 28672 1 macvtap, Live 0x0000000000000000 // name | memory size | reference count | references | state: <Live|Loading|Unloading>
// macvlan 28672 1 macvtap, Live 0x0000000000000000
func moduleStatus(name string) (status, error) { func moduleStatus(name string) (status, error) {
state := unknown state := unknown
f, err := os.Open("/proc/modules") f, err := os.Open("/proc/modules")

View File

@@ -10,6 +10,8 @@ NETBIRD_MGMT_API_ENDPOINT=https://$NETBIRD_DOMAIN:$NETBIRD_MGMT_API_PORT
NETBIRD_MGMT_API_CERT_FILE="/etc/letsencrypt/live/$NETBIRD_DOMAIN/fullchain.pem" NETBIRD_MGMT_API_CERT_FILE="/etc/letsencrypt/live/$NETBIRD_DOMAIN/fullchain.pem"
# Management Certficate key file path. # Management Certficate key file path.
NETBIRD_MGMT_API_CERT_KEY_FILE="/etc/letsencrypt/live/$NETBIRD_DOMAIN/privkey.pem" NETBIRD_MGMT_API_CERT_KEY_FILE="/etc/letsencrypt/live/$NETBIRD_DOMAIN/privkey.pem"
# By default Management single account mode is enabled and domain set to $NETBIRD_DOMAIN, you may want to set this to your user's email domain
NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN=$NETBIRD_DOMAIN
# Turn credentials # Turn credentials

View File

@@ -48,7 +48,7 @@ services:
# # port and command for Let's Encrypt validation without dashboard container # # port and command for Let's Encrypt validation without dashboard container
# - 443:443 # - 443:443
# command: ["--letsencrypt-domain", "$NETBIRD_DOMAIN", "--log-file", "console"] # command: ["--letsencrypt-domain", "$NETBIRD_DOMAIN", "--log-file", "console"]
command: ["--port", "443", "--log-file", "console", "--disable-anonymous-metrics=$NETBIRD_DISABLE_ANONYMOUS_METRICS"] command: ["--port", "443", "--log-file", "console", "--disable-anonymous-metrics=$NETBIRD_DISABLE_ANONYMOUS_METRICS", "--single-account-mode-domain=$NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN"]
# Coturn # Coturn
coturn: coturn:
image: coturn/coturn image: coturn/coturn

View File

@@ -49,18 +49,18 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
t.Fatal(err) t.Fatal(err)
} }
s := grpc.NewServer() s := grpc.NewServer()
store, err := mgmt.NewStore(config.Datadir) store, err := mgmt.NewFileStore(config.Datadir)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
peersUpdateManager := mgmt.NewPeersUpdateManager() peersUpdateManager := mgmt.NewPeersUpdateManager()
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil) accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig) turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager) mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -0,0 +1,18 @@
package cmd
const (
defaultMgmtDataDir = "/var/lib/netbird/"
defaultMgmtConfigDir = "/etc/netbird"
defaultLogDir = "/var/log/netbird"
oldDefaultMgmtDataDir = "/var/lib/wiretrustee/"
oldDefaultMgmtConfigDir = "/etc/wiretrustee"
oldDefaultLogDir = "/var/log/wiretrustee"
defaultMgmtConfig = defaultMgmtConfigDir + "/management.json"
defaultLogFile = defaultLogDir + "/management.log"
oldDefaultMgmtConfig = oldDefaultMgmtConfigDir + "/management.json"
oldDefaultLogFile = oldDefaultLogDir + "/management.log"
defaultSingleAccModeDomain = "netbird.selfhosted"
)

View File

@@ -8,8 +8,10 @@ import (
"flag" "flag"
"fmt" "fmt"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/miekg/dns"
httpapi "github.com/netbirdio/netbird/management/server/http" httpapi "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/metrics" "github.com/netbirdio/netbird/management/server/metrics"
"github.com/netbirdio/netbird/management/server/telemetry"
"golang.org/x/crypto/acme/autocert" "golang.org/x/crypto/acme/autocert"
"golang.org/x/net/http2" "golang.org/x/net/http2"
"golang.org/x/net/http2/h2c" "golang.org/x/net/http2/h2c"
@@ -41,11 +43,13 @@ import (
const ManagementLegacyPort = 33073 const ManagementLegacyPort = 33073
var ( var (
mgmtPort int mgmtPort int
mgmtLetsencryptDomain string mgmtMetricsPort int
certFile string mgmtLetsencryptDomain string
certKey string mgmtSingleAccModeDomain string
config *server.Config certFile string
certKey string
config *server.Config
kaep = keepalive.EnforcementPolicy{ kaep = keepalive.EnforcementPolicy{
MinTime: 15 * time.Second, MinTime: 15 * time.Second,
@@ -86,6 +90,11 @@ var (
} }
} }
_, valid := dns.IsDomainName(dnsDomain)
if !valid || len(dnsDomain) > 192 {
return fmt.Errorf("failed parsing the provided dns-domain. Valid status: %t, Lenght: %d", valid, len(dnsDomain))
}
return nil return nil
}, },
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
@@ -107,21 +116,33 @@ var (
} }
} }
store, err := server.NewStore(config.Datadir) store, err := server.NewFileStore(config.Datadir)
if err != nil { if err != nil {
return fmt.Errorf("failed creating Store: %s: %v", config.Datadir, err) return fmt.Errorf("failed creating Store: %s: %v", config.Datadir, err)
} }
peersUpdateManager := server.NewPeersUpdateManager() peersUpdateManager := server.NewPeersUpdateManager()
appMetrics, err := telemetry.NewDefaultAppMetrics(cmd.Context())
if err != nil {
return err
}
err = appMetrics.Expose(mgmtMetricsPort, "/metrics")
if err != nil {
return err
}
var idpManager idp.Manager var idpManager idp.Manager
if config.IdpManagerConfig != nil { if config.IdpManagerConfig != nil {
idpManager, err = idp.NewManager(*config.IdpManagerConfig) idpManager, err = idp.NewManager(*config.IdpManagerConfig, appMetrics)
if err != nil { if err != nil {
return fmt.Errorf("failed retrieving a new idp manager with err: %v", err) return fmt.Errorf("failed retrieving a new idp manager with err: %v", err)
} }
} }
accountManager, err := server.BuildManager(store, peersUpdateManager, idpManager) if disableSingleAccMode {
mgmtSingleAccModeDomain = ""
}
accountManager, err := server.BuildManager(store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain, dnsDomain)
if err != nil { if err != nil {
return fmt.Errorf("failed to build default manager: %v", err) return fmt.Errorf("failed to build default manager: %v", err)
} }
@@ -151,14 +172,14 @@ var (
tlsEnabled = true tlsEnabled = true
} }
httpAPIHandler, err := httpapi.APIHandler(accountManager, httpAPIHandler, err := httpapi.APIHandler(accountManager, config.HttpConfig.AuthIssuer,
config.HttpConfig.AuthIssuer, config.HttpConfig.AuthAudience, config.HttpConfig.AuthKeysLocation) config.HttpConfig.AuthAudience, config.HttpConfig.AuthKeysLocation, appMetrics)
if err != nil { if err != nil {
return fmt.Errorf("failed creating HTTP API handler: %v", err) return fmt.Errorf("failed creating HTTP API handler: %v", err)
} }
gRPCAPIHandler := grpc.NewServer(gRPCOpts...) gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
srv, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager) srv, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, appMetrics)
if err != nil { if err != nil {
return fmt.Errorf("failed creating gRPC API handler: %v", err) return fmt.Errorf("failed creating gRPC API handler: %v", err)
} }
@@ -170,8 +191,6 @@ var (
return err return err
} }
fmt.Println("metrics ", disableMetrics)
if !disableMetrics { if !disableMetrics {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
@@ -225,11 +244,13 @@ var (
SetupCloseHandler() SetupCloseHandler()
<-stopCh <-stopCh
_ = appMetrics.Close()
_ = listener.Close() _ = listener.Close()
if certManager != nil { if certManager != nil {
_ = certManager.Listener().Close() _ = certManager.Listener().Close()
} }
gRPCAPIHandler.Stop() gRPCAPIHandler.Stop()
_ = store.Close()
log.Infof("stopped Management Service") log.Infof("stopped Management Service")
return nil return nil
@@ -427,7 +448,7 @@ func loadTLSConfig(certFile string, certKey string) (*tls.Config, error) {
return nil, err return nil, err
} }
// Create the credentials and return it // NewDefaultAppMetrics the credentials and return it
config := &tls.Config{ config := &tls.Config{
Certificates: []tls.Certificate{serverCert}, Certificates: []tls.Certificate{serverCert},
ClientAuth: tls.NoClientCert, ClientAuth: tls.NoClientCert,

View File

@@ -13,21 +13,13 @@ const (
) )
var ( var (
defaultMgmtConfigDir string dnsDomain string
defaultMgmtDataDir string mgmtDataDir string
defaultMgmtConfig string mgmtConfig string
defaultLogDir string logLevel string
defaultLogFile string logFile string
oldDefaultMgmtConfigDir string disableMetrics bool
oldDefaultMgmtDataDir string disableSingleAccMode bool
oldDefaultMgmtConfig string
oldDefaultLogDir string
oldDefaultLogFile string
mgmtDataDir string
mgmtConfig string
logLevel string
logFile string
disableMetrics bool
rootCmd = &cobra.Command{ rootCmd = &cobra.Command{
Use: "netbird-mgmt", Use: "netbird-mgmt",
@@ -46,28 +38,17 @@ func Execute() error {
func init() { func init() {
stopCh = make(chan int) stopCh = make(chan int)
defaultMgmtDataDir = "/var/lib/netbird/"
defaultMgmtConfigDir = "/etc/netbird"
defaultLogDir = "/var/log/netbird"
oldDefaultMgmtDataDir = "/var/lib/wiretrustee/"
oldDefaultMgmtConfigDir = "/etc/wiretrustee"
oldDefaultLogDir = "/var/log/wiretrustee"
defaultMgmtConfig = defaultMgmtConfigDir + "/management.json"
defaultLogFile = defaultLogDir + "/management.log"
oldDefaultMgmtConfig = oldDefaultMgmtConfigDir + "/management.json"
oldDefaultLogFile = oldDefaultLogDir + "/management.log"
mgmtCmd.Flags().IntVar(&mgmtPort, "port", 80, "server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise") mgmtCmd.Flags().IntVar(&mgmtPort, "port", 80, "server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise")
mgmtCmd.Flags().IntVar(&mgmtMetricsPort, "metrics-port", 8081, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics")
mgmtCmd.Flags().StringVar(&mgmtDataDir, "datadir", defaultMgmtDataDir, "server data directory location") mgmtCmd.Flags().StringVar(&mgmtDataDir, "datadir", defaultMgmtDataDir, "server data directory location")
mgmtCmd.Flags().StringVar(&mgmtConfig, "config", defaultMgmtConfig, "Netbird config file location. Config params specified via command line (e.g. datadir) have a precedence over configuration from this file") mgmtCmd.Flags().StringVar(&mgmtConfig, "config", defaultMgmtConfig, "Netbird config file location. Config params specified via command line (e.g. datadir) have a precedence over configuration from this file")
mgmtCmd.Flags().StringVar(&mgmtLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS") mgmtCmd.Flags().StringVar(&mgmtLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
mgmtCmd.Flags().StringVar(&mgmtSingleAccModeDomain, "single-account-mode-domain", defaultSingleAccModeDomain, "Enables single account mode. This means that all the users will be under the same account grouped by the specified domain. If the installation has more than one account, the property is ineffective. Enabled by default with the default domain "+defaultSingleAccModeDomain)
mgmtCmd.Flags().BoolVar(&disableSingleAccMode, "disable-single-account-mode", false, "If set to true, disables single account mode. The --single-account-mode-domain property will be ignored and every new user will have a separate NetBird account.")
mgmtCmd.Flags().StringVar(&certFile, "cert-file", "", "Location of your SSL certificate. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect") mgmtCmd.Flags().StringVar(&certFile, "cert-file", "", "Location of your SSL certificate. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect")
mgmtCmd.Flags().StringVar(&certKey, "cert-key", "", "Location of your SSL certificate private key. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect") mgmtCmd.Flags().StringVar(&certKey, "cert-key", "", "Location of your SSL certificate private key. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect")
mgmtCmd.Flags().BoolVar(&disableMetrics, "disable-anonymous-metrics", false, "disables push of anonymous usage metrics to NetBird") mgmtCmd.Flags().BoolVar(&disableMetrics, "disable-anonymous-metrics", false, "disables push of anonymous usage metrics to NetBird")
mgmtCmd.Flags().StringVar(&dnsDomain, "dns-domain", defaultSingleAccModeDomain, fmt.Sprintf("Domain used for peer resolution. This is appended to the peer's name, e.g. pi-server. %s. Max lenght is 192 characters to allow appending to a peer name with up to 63 characters.", defaultSingleAccModeDomain))
rootCmd.MarkFlagRequired("config") //nolint rootCmd.MarkFlagRequired("config") //nolint
rootCmd.PersistentFlags().StringVar(&logLevel, "log-level", "info", "") rootCmd.PersistentFlags().StringVar(&logLevel, "log-level", "info", "")

View File

@@ -1,15 +1,15 @@
// Code generated by protoc-gen-go. DO NOT EDIT. // Code generated by protoc-gen-go. DO NOT EDIT.
// versions: // versions:
// protoc-gen-go v1.26.0 // protoc-gen-go v1.26.0
// protoc v3.12.4 // protoc v3.21.9
// source: management.proto // source: management.proto
package proto package proto
import ( import (
timestamp "github.com/golang/protobuf/ptypes/timestamp"
protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl" protoimpl "google.golang.org/protobuf/runtime/protoimpl"
timestamppb "google.golang.org/protobuf/types/known/timestamppb"
reflect "reflect" reflect "reflect"
sync "sync" sync "sync"
) )
@@ -611,7 +611,7 @@ type ServerKeyResponse struct {
// Server's Wireguard public key // Server's Wireguard public key
Key string `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"` Key string `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"`
// Key expiration timestamp after which the key should be fetched again by the client // Key expiration timestamp after which the key should be fetched again by the client
ExpiresAt *timestamp.Timestamp `protobuf:"bytes,2,opt,name=expiresAt,proto3" json:"expiresAt,omitempty"` ExpiresAt *timestamppb.Timestamp `protobuf:"bytes,2,opt,name=expiresAt,proto3" json:"expiresAt,omitempty"`
// Version of the Wiretrustee Management Service protocol // Version of the Wiretrustee Management Service protocol
Version int32 `protobuf:"varint,3,opt,name=version,proto3" json:"version,omitempty"` Version int32 `protobuf:"varint,3,opt,name=version,proto3" json:"version,omitempty"`
} }
@@ -655,7 +655,7 @@ func (x *ServerKeyResponse) GetKey() string {
return "" return ""
} }
func (x *ServerKeyResponse) GetExpiresAt() *timestamp.Timestamp { func (x *ServerKeyResponse) GetExpiresAt() *timestamppb.Timestamp {
if x != nil { if x != nil {
return x.ExpiresAt return x.ExpiresAt
} }
@@ -909,6 +909,8 @@ type PeerConfig struct {
Dns string `protobuf:"bytes,2,opt,name=dns,proto3" json:"dns,omitempty"` Dns string `protobuf:"bytes,2,opt,name=dns,proto3" json:"dns,omitempty"`
// SSHConfig of the peer. // SSHConfig of the peer.
SshConfig *SSHConfig `protobuf:"bytes,3,opt,name=sshConfig,proto3" json:"sshConfig,omitempty"` SshConfig *SSHConfig `protobuf:"bytes,3,opt,name=sshConfig,proto3" json:"sshConfig,omitempty"`
// Peer fully qualified domain name
Fqdn string `protobuf:"bytes,4,opt,name=fqdn,proto3" json:"fqdn,omitempty"`
} }
func (x *PeerConfig) Reset() { func (x *PeerConfig) Reset() {
@@ -964,6 +966,13 @@ func (x *PeerConfig) GetSshConfig() *SSHConfig {
return nil return nil
} }
func (x *PeerConfig) GetFqdn() string {
if x != nil {
return x.Fqdn
}
return ""
}
// NetworkMap represents a network state of the peer with the corresponding configuration parameters to establish peer-to-peer connections // NetworkMap represents a network state of the peer with the corresponding configuration parameters to establish peer-to-peer connections
type NetworkMap struct { type NetworkMap struct {
state protoimpl.MessageState state protoimpl.MessageState
@@ -982,6 +991,8 @@ type NetworkMap struct {
RemotePeersIsEmpty bool `protobuf:"varint,4,opt,name=remotePeersIsEmpty,proto3" json:"remotePeersIsEmpty,omitempty"` RemotePeersIsEmpty bool `protobuf:"varint,4,opt,name=remotePeersIsEmpty,proto3" json:"remotePeersIsEmpty,omitempty"`
// List of routes to be applied // List of routes to be applied
Routes []*Route `protobuf:"bytes,5,rep,name=Routes,proto3" json:"Routes,omitempty"` Routes []*Route `protobuf:"bytes,5,rep,name=Routes,proto3" json:"Routes,omitempty"`
// DNS config to be applied
DNSConfig *DNSConfig `protobuf:"bytes,6,opt,name=DNSConfig,proto3" json:"DNSConfig,omitempty"`
} }
func (x *NetworkMap) Reset() { func (x *NetworkMap) Reset() {
@@ -1051,6 +1062,13 @@ func (x *NetworkMap) GetRoutes() []*Route {
return nil return nil
} }
func (x *NetworkMap) GetDNSConfig() *DNSConfig {
if x != nil {
return x.DNSConfig
}
return nil
}
// RemotePeerConfig represents a configuration of a remote peer. // RemotePeerConfig represents a configuration of a remote peer.
// The properties are used to configure Wireguard Peers sections // The properties are used to configure Wireguard Peers sections
type RemotePeerConfig struct { type RemotePeerConfig struct {
@@ -1064,6 +1082,8 @@ type RemotePeerConfig struct {
AllowedIps []string `protobuf:"bytes,2,rep,name=allowedIps,proto3" json:"allowedIps,omitempty"` AllowedIps []string `protobuf:"bytes,2,rep,name=allowedIps,proto3" json:"allowedIps,omitempty"`
// SSHConfig is a SSH config of the remote peer. SSHConfig.sshPubKey should be ignored because peer knows it's SSH key. // SSHConfig is a SSH config of the remote peer. SSHConfig.sshPubKey should be ignored because peer knows it's SSH key.
SshConfig *SSHConfig `protobuf:"bytes,3,opt,name=sshConfig,proto3" json:"sshConfig,omitempty"` SshConfig *SSHConfig `protobuf:"bytes,3,opt,name=sshConfig,proto3" json:"sshConfig,omitempty"`
// Peer fully qualified domain name
Fqdn string `protobuf:"bytes,4,opt,name=fqdn,proto3" json:"fqdn,omitempty"`
} }
func (x *RemotePeerConfig) Reset() { func (x *RemotePeerConfig) Reset() {
@@ -1119,6 +1139,13 @@ func (x *RemotePeerConfig) GetSshConfig() *SSHConfig {
return nil return nil
} }
func (x *RemotePeerConfig) GetFqdn() string {
if x != nil {
return x.Fqdn
}
return ""
}
// SSHConfig represents SSH configurations of a peer. // SSHConfig represents SSH configurations of a peer.
type SSHConfig struct { type SSHConfig struct {
state protoimpl.MessageState state protoimpl.MessageState
@@ -1467,6 +1494,334 @@ func (x *Route) GetNetID() string {
return "" return ""
} }
// DNSConfig represents a dns.Update
type DNSConfig struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
ServiceEnable bool `protobuf:"varint,1,opt,name=ServiceEnable,proto3" json:"ServiceEnable,omitempty"`
NameServerGroups []*NameServerGroup `protobuf:"bytes,2,rep,name=NameServerGroups,proto3" json:"NameServerGroups,omitempty"`
CustomZones []*CustomZone `protobuf:"bytes,3,rep,name=CustomZones,proto3" json:"CustomZones,omitempty"`
}
func (x *DNSConfig) Reset() {
*x = DNSConfig{}
if protoimpl.UnsafeEnabled {
mi := &file_management_proto_msgTypes[20]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *DNSConfig) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*DNSConfig) ProtoMessage() {}
func (x *DNSConfig) ProtoReflect() protoreflect.Message {
mi := &file_management_proto_msgTypes[20]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use DNSConfig.ProtoReflect.Descriptor instead.
func (*DNSConfig) Descriptor() ([]byte, []int) {
return file_management_proto_rawDescGZIP(), []int{20}
}
func (x *DNSConfig) GetServiceEnable() bool {
if x != nil {
return x.ServiceEnable
}
return false
}
func (x *DNSConfig) GetNameServerGroups() []*NameServerGroup {
if x != nil {
return x.NameServerGroups
}
return nil
}
func (x *DNSConfig) GetCustomZones() []*CustomZone {
if x != nil {
return x.CustomZones
}
return nil
}
// CustomZone represents a dns.CustomZone
type CustomZone struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Domain string `protobuf:"bytes,1,opt,name=Domain,proto3" json:"Domain,omitempty"`
Records []*SimpleRecord `protobuf:"bytes,2,rep,name=Records,proto3" json:"Records,omitempty"`
}
func (x *CustomZone) Reset() {
*x = CustomZone{}
if protoimpl.UnsafeEnabled {
mi := &file_management_proto_msgTypes[21]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *CustomZone) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*CustomZone) ProtoMessage() {}
func (x *CustomZone) ProtoReflect() protoreflect.Message {
mi := &file_management_proto_msgTypes[21]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use CustomZone.ProtoReflect.Descriptor instead.
func (*CustomZone) Descriptor() ([]byte, []int) {
return file_management_proto_rawDescGZIP(), []int{21}
}
func (x *CustomZone) GetDomain() string {
if x != nil {
return x.Domain
}
return ""
}
func (x *CustomZone) GetRecords() []*SimpleRecord {
if x != nil {
return x.Records
}
return nil
}
// SimpleRecord represents a dns.SimpleRecord
type SimpleRecord struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Name string `protobuf:"bytes,1,opt,name=Name,proto3" json:"Name,omitempty"`
Type int64 `protobuf:"varint,2,opt,name=Type,proto3" json:"Type,omitempty"`
Class string `protobuf:"bytes,3,opt,name=Class,proto3" json:"Class,omitempty"`
TTL int64 `protobuf:"varint,4,opt,name=TTL,proto3" json:"TTL,omitempty"`
RData string `protobuf:"bytes,5,opt,name=RData,proto3" json:"RData,omitempty"`
}
func (x *SimpleRecord) Reset() {
*x = SimpleRecord{}
if protoimpl.UnsafeEnabled {
mi := &file_management_proto_msgTypes[22]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *SimpleRecord) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*SimpleRecord) ProtoMessage() {}
func (x *SimpleRecord) ProtoReflect() protoreflect.Message {
mi := &file_management_proto_msgTypes[22]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use SimpleRecord.ProtoReflect.Descriptor instead.
func (*SimpleRecord) Descriptor() ([]byte, []int) {
return file_management_proto_rawDescGZIP(), []int{22}
}
func (x *SimpleRecord) GetName() string {
if x != nil {
return x.Name
}
return ""
}
func (x *SimpleRecord) GetType() int64 {
if x != nil {
return x.Type
}
return 0
}
func (x *SimpleRecord) GetClass() string {
if x != nil {
return x.Class
}
return ""
}
func (x *SimpleRecord) GetTTL() int64 {
if x != nil {
return x.TTL
}
return 0
}
func (x *SimpleRecord) GetRData() string {
if x != nil {
return x.RData
}
return ""
}
// NameServerGroup represents a dns.NameServerGroup
type NameServerGroup struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
NameServers []*NameServer `protobuf:"bytes,1,rep,name=NameServers,proto3" json:"NameServers,omitempty"`
Primary bool `protobuf:"varint,2,opt,name=Primary,proto3" json:"Primary,omitempty"`
Domains []string `protobuf:"bytes,3,rep,name=Domains,proto3" json:"Domains,omitempty"`
}
func (x *NameServerGroup) Reset() {
*x = NameServerGroup{}
if protoimpl.UnsafeEnabled {
mi := &file_management_proto_msgTypes[23]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *NameServerGroup) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*NameServerGroup) ProtoMessage() {}
func (x *NameServerGroup) ProtoReflect() protoreflect.Message {
mi := &file_management_proto_msgTypes[23]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use NameServerGroup.ProtoReflect.Descriptor instead.
func (*NameServerGroup) Descriptor() ([]byte, []int) {
return file_management_proto_rawDescGZIP(), []int{23}
}
func (x *NameServerGroup) GetNameServers() []*NameServer {
if x != nil {
return x.NameServers
}
return nil
}
func (x *NameServerGroup) GetPrimary() bool {
if x != nil {
return x.Primary
}
return false
}
func (x *NameServerGroup) GetDomains() []string {
if x != nil {
return x.Domains
}
return nil
}
// NameServer represents a dns.NameServer
type NameServer struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
IP string `protobuf:"bytes,1,opt,name=IP,proto3" json:"IP,omitempty"`
NSType int64 `protobuf:"varint,2,opt,name=NSType,proto3" json:"NSType,omitempty"`
Port int64 `protobuf:"varint,3,opt,name=Port,proto3" json:"Port,omitempty"`
}
func (x *NameServer) Reset() {
*x = NameServer{}
if protoimpl.UnsafeEnabled {
mi := &file_management_proto_msgTypes[24]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *NameServer) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*NameServer) ProtoMessage() {}
func (x *NameServer) ProtoReflect() protoreflect.Message {
mi := &file_management_proto_msgTypes[24]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use NameServer.ProtoReflect.Descriptor instead.
func (*NameServer) Descriptor() ([]byte, []int) {
return file_management_proto_rawDescGZIP(), []int{24}
}
func (x *NameServer) GetIP() string {
if x != nil {
return x.IP
}
return ""
}
func (x *NameServer) GetNSType() int64 {
if x != nil {
return x.NSType
}
return 0
}
func (x *NameServer) GetPort() int64 {
if x != nil {
return x.Port
}
return 0
}
var File_management_proto protoreflect.FileDescriptor var File_management_proto protoreflect.FileDescriptor
var file_management_proto_rawDesc = []byte{ var file_management_proto_rawDesc = []byte{
@@ -1576,82 +1931,125 @@ var file_management_proto_rawDesc = []byte{
0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x75, 0x73, 0x65, 0x72, 0x18, 0x02, 0x20, 0x01, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x75, 0x73, 0x65, 0x72, 0x18, 0x02, 0x20, 0x01,
0x28, 0x09, 0x52, 0x04, 0x75, 0x73, 0x65, 0x72, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x61, 0x73, 0x73, 0x28, 0x09, 0x52, 0x04, 0x75, 0x73, 0x65, 0x72, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x61, 0x73, 0x73,
0x77, 0x6f, 0x72, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, 0x61, 0x73, 0x73,
0x77, 0x6f, 0x72, 0x64, 0x22, 0x6d, 0x0a, 0x0a, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x77, 0x6f, 0x72, 0x64, 0x22, 0x81, 0x01, 0x0a, 0x0a, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e,
0x69, 0x67, 0x12, 0x18, 0x0a, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x01, 0x20, 0x66, 0x69, 0x67, 0x12, 0x18, 0x0a, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x01,
0x01, 0x28, 0x09, 0x52, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x10, 0x0a,
0x64, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x64, 0x6e, 0x73, 0x12, 0x33, 0x03, 0x64, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x64, 0x6e, 0x73, 0x12,
0x0a, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x33, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x03, 0x20, 0x01,
0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e,
0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f,
0x66, 0x69, 0x67, 0x22, 0xf7, 0x01, 0x0a, 0x0a, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x04, 0x20, 0x01,
0x61, 0x70, 0x12, 0x16, 0x0a, 0x06, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x22, 0xac, 0x02, 0x0a, 0x0a, 0x4e, 0x65, 0x74,
0x28, 0x04, 0x52, 0x06, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x12, 0x36, 0x0a, 0x0a, 0x70, 0x65, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x12, 0x16, 0x0a, 0x06, 0x53, 0x65, 0x72, 0x69, 0x61,
0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x04, 0x52, 0x06, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x12,
0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x36, 0x0a, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20,
0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74,
0x69, 0x67, 0x12, 0x3e, 0x0a, 0x0b, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0a, 0x70, 0x65, 0x65,
0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x3e, 0x0a, 0x0b, 0x72, 0x65, 0x6d, 0x6f, 0x74,
0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x6d,
0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0b, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65,
0x72, 0x73, 0x12, 0x2e, 0x0a, 0x12, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0b, 0x72, 0x65, 0x6d, 0x6f,
0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x12, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x12, 0x2e, 0x0a, 0x12, 0x72, 0x65, 0x6d, 0x6f, 0x74,
0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, 0x04, 0x20,
0x74, 0x79, 0x12, 0x29, 0x0a, 0x06, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x05, 0x20, 0x03, 0x01, 0x28, 0x08, 0x52, 0x12, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73,
0x28, 0x0b, 0x32, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x29, 0x0a, 0x06, 0x52, 0x6f, 0x75, 0x74, 0x65,
0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x06, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x22, 0x83, 0x01, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65,
0x0a, 0x10, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x06, 0x52, 0x6f, 0x75, 0x74,
0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x65, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18,
0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x1e, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65,
0x0a, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x6e, 0x74, 0x2e, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x44, 0x4e,
0x28, 0x09, 0x52, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, 0x73, 0x12, 0x33, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0x97, 0x01, 0x0a, 0x10, 0x52, 0x65, 0x6d, 0x6f,
0x0a, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08,
0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08,
0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x1e, 0x0a, 0x0a, 0x61, 0x6c, 0x6c, 0x6f,
0x66, 0x69, 0x67, 0x22, 0x49, 0x0a, 0x09, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x77, 0x65, 0x64, 0x49, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0a, 0x61, 0x6c,
0x12, 0x1e, 0x0a, 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x43,
0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61,
0x12, 0x1c, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66,
0x01, 0x28, 0x0c, 0x52, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x22, 0x20, 0x69, 0x67, 0x52, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a,
0x0a, 0x1e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64,
0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x6e, 0x22, 0x49, 0x0a, 0x09, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1e,
0x22, 0xbf, 0x01, 0x0a, 0x17, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x0a, 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01,
0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x48, 0x0a, 0x08, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x1c,
0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x2c, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28,
0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x44, 0x65, 0x76, 0x69, 0x0c, 0x52, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x22, 0x20, 0x0a, 0x1e,
0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74,
0x6c, 0x6f, 0x77, 0x2e, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x52, 0x08, 0x50, 0x72, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0xbf,
0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x42, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x01, 0x0a, 0x17, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69,
0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x48, 0x0a, 0x08, 0x50, 0x72,
0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x2c, 0x2e, 0x6d,
0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65,
0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0x16, 0x0a, 0x08, 0x70, 0x72, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f,
0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x0a, 0x0a, 0x06, 0x48, 0x4f, 0x53, 0x54, 0x45, 0x44, 0x77, 0x2e, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x76,
0x10, 0x00, 0x22, 0xda, 0x01, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x69, 0x64, 0x65, 0x72, 0x12, 0x42, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72,
0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d,
0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64,
0x44, 0x12, 0x22, 0x0a, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64,
0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0x16, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x76,
0x65, 0x63, 0x72, 0x65, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x69, 0x64, 0x65, 0x72, 0x12, 0x0a, 0x0a, 0x06, 0x48, 0x4f, 0x53, 0x54, 0x45, 0x44, 0x10, 0x00,
0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x1a, 0x0a, 0x22, 0xda, 0x01, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e,
0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44, 0x18,
0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x2e, 0x0a, 0x12, 0x44, 0x65, 0x76, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44, 0x12,
0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x22, 0x0a, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x18,
0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63,
0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x24, 0x0a, 0x0d, 0x54, 0x6f, 0x6b, 0x72, 0x65, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x03, 0x20,
0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x41,
0x52, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x22, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x41,
0xb5, 0x01, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x44, 0x18, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x2e, 0x0a, 0x12, 0x44, 0x65, 0x76, 0x69, 0x63,
0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x4e, 0x65, 0x74, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x05, 0x20,
0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x4e, 0x65, 0x74, 0x77, 0x01, 0x28, 0x09, 0x52, 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45,
0x6f, 0x72, 0x6b, 0x12, 0x20, 0x0a, 0x0b, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x24, 0x0a, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e,
0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d,
0x6b, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x65, 0x65, 0x72, 0x18, 0x04, 0x20, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x22, 0xb5, 0x01,
0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x65, 0x65, 0x72, 0x12, 0x16, 0x0a, 0x06, 0x4d, 0x65, 0x74, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x44, 0x18, 0x01, 0x20,
0x72, 0x69, 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f,
0x63, 0x12, 0x1e, 0x0a, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, 0x61, 0x64, 0x65, 0x18, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72,
0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, 0x61, 0x64, 0x6b, 0x12, 0x20, 0x0a, 0x0b, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65,
0x65, 0x12, 0x14, 0x0a, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x54,
0x52, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x32, 0xf7, 0x02, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x65, 0x65, 0x72, 0x18, 0x04, 0x20, 0x01, 0x28,
0x09, 0x52, 0x04, 0x50, 0x65, 0x65, 0x72, 0x12, 0x16, 0x0a, 0x06, 0x4d, 0x65, 0x74, 0x72, 0x69,
0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x12,
0x1e, 0x0a, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, 0x61, 0x64, 0x65, 0x18, 0x06, 0x20,
0x01, 0x28, 0x08, 0x52, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, 0x61, 0x64, 0x65, 0x12,
0x14, 0x0a, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05,
0x4e, 0x65, 0x74, 0x49, 0x44, 0x22, 0xb4, 0x01, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e,
0x66, 0x69, 0x67, 0x12, 0x24, 0x0a, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e,
0x61, 0x62, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x53, 0x65, 0x72, 0x76,
0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x47, 0x0a, 0x10, 0x4e, 0x61, 0x6d,
0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18, 0x02, 0x20,
0x03, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74,
0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70,
0x52, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75,
0x70, 0x73, 0x12, 0x38, 0x0a, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65,
0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65,
0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x52,
0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x22, 0x58, 0x0a, 0x0a,
0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f,
0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61,
0x69, 0x6e, 0x12, 0x32, 0x0a, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x18, 0x02, 0x20,
0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74,
0x2e, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, 0x07, 0x52,
0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x22, 0x74, 0x0a, 0x0c, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65,
0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01,
0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x79,
0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x14,
0x0a, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x43,
0x6c, 0x61, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x54, 0x54, 0x4c, 0x18, 0x04, 0x20, 0x01, 0x28,
0x03, 0x52, 0x03, 0x54, 0x54, 0x4c, 0x12, 0x14, 0x0a, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x18,
0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x22, 0x7f, 0x0a, 0x0f,
0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12,
0x38, 0x0a, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01,
0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e,
0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x52, 0x0b, 0x4e, 0x61,
0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x50, 0x72, 0x69,
0x6d, 0x61, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x50, 0x72, 0x69, 0x6d,
0x61, 0x72, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x03,
0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x22, 0x48, 0x0a,
0x0a, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x0e, 0x0a, 0x02, 0x49,
0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x4e,
0x53, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4e, 0x53, 0x54,
0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28,
0x03, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x32, 0xf7, 0x02, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61,
0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a,
0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d,
0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73,
@@ -1692,7 +2090,7 @@ func file_management_proto_rawDescGZIP() []byte {
} }
var file_management_proto_enumTypes = make([]protoimpl.EnumInfo, 2) var file_management_proto_enumTypes = make([]protoimpl.EnumInfo, 2)
var file_management_proto_msgTypes = make([]protoimpl.MessageInfo, 20) var file_management_proto_msgTypes = make([]protoimpl.MessageInfo, 25)
var file_management_proto_goTypes = []interface{}{ var file_management_proto_goTypes = []interface{}{
(HostConfig_Protocol)(0), // 0: management.HostConfig.Protocol (HostConfig_Protocol)(0), // 0: management.HostConfig.Protocol
(DeviceAuthorizationFlowProvider)(0), // 1: management.DeviceAuthorizationFlow.provider (DeviceAuthorizationFlowProvider)(0), // 1: management.DeviceAuthorizationFlow.provider
@@ -1716,7 +2114,12 @@ var file_management_proto_goTypes = []interface{}{
(*DeviceAuthorizationFlow)(nil), // 19: management.DeviceAuthorizationFlow (*DeviceAuthorizationFlow)(nil), // 19: management.DeviceAuthorizationFlow
(*ProviderConfig)(nil), // 20: management.ProviderConfig (*ProviderConfig)(nil), // 20: management.ProviderConfig
(*Route)(nil), // 21: management.Route (*Route)(nil), // 21: management.Route
(*timestamp.Timestamp)(nil), // 22: google.protobuf.Timestamp (*DNSConfig)(nil), // 22: management.DNSConfig
(*CustomZone)(nil), // 23: management.CustomZone
(*SimpleRecord)(nil), // 24: management.SimpleRecord
(*NameServerGroup)(nil), // 25: management.NameServerGroup
(*NameServer)(nil), // 26: management.NameServer
(*timestamppb.Timestamp)(nil), // 27: google.protobuf.Timestamp
} }
var file_management_proto_depIdxs = []int32{ var file_management_proto_depIdxs = []int32{
11, // 0: management.SyncResponse.wiretrusteeConfig:type_name -> management.WiretrusteeConfig 11, // 0: management.SyncResponse.wiretrusteeConfig:type_name -> management.WiretrusteeConfig
@@ -1727,7 +2130,7 @@ var file_management_proto_depIdxs = []int32{
6, // 5: management.LoginRequest.peerKeys:type_name -> management.PeerKeys 6, // 5: management.LoginRequest.peerKeys:type_name -> management.PeerKeys
11, // 6: management.LoginResponse.wiretrusteeConfig:type_name -> management.WiretrusteeConfig 11, // 6: management.LoginResponse.wiretrusteeConfig:type_name -> management.WiretrusteeConfig
14, // 7: management.LoginResponse.peerConfig:type_name -> management.PeerConfig 14, // 7: management.LoginResponse.peerConfig:type_name -> management.PeerConfig
22, // 8: management.ServerKeyResponse.expiresAt:type_name -> google.protobuf.Timestamp 27, // 8: management.ServerKeyResponse.expiresAt:type_name -> google.protobuf.Timestamp
12, // 9: management.WiretrusteeConfig.stuns:type_name -> management.HostConfig 12, // 9: management.WiretrusteeConfig.stuns:type_name -> management.HostConfig
13, // 10: management.WiretrusteeConfig.turns:type_name -> management.ProtectedHostConfig 13, // 10: management.WiretrusteeConfig.turns:type_name -> management.ProtectedHostConfig
12, // 11: management.WiretrusteeConfig.signal:type_name -> management.HostConfig 12, // 11: management.WiretrusteeConfig.signal:type_name -> management.HostConfig
@@ -1737,24 +2140,29 @@ var file_management_proto_depIdxs = []int32{
14, // 15: management.NetworkMap.peerConfig:type_name -> management.PeerConfig 14, // 15: management.NetworkMap.peerConfig:type_name -> management.PeerConfig
16, // 16: management.NetworkMap.remotePeers:type_name -> management.RemotePeerConfig 16, // 16: management.NetworkMap.remotePeers:type_name -> management.RemotePeerConfig
21, // 17: management.NetworkMap.Routes:type_name -> management.Route 21, // 17: management.NetworkMap.Routes:type_name -> management.Route
17, // 18: management.RemotePeerConfig.sshConfig:type_name -> management.SSHConfig 22, // 18: management.NetworkMap.DNSConfig:type_name -> management.DNSConfig
1, // 19: management.DeviceAuthorizationFlow.Provider:type_name -> management.DeviceAuthorizationFlow.provider 17, // 19: management.RemotePeerConfig.sshConfig:type_name -> management.SSHConfig
20, // 20: management.DeviceAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig 1, // 20: management.DeviceAuthorizationFlow.Provider:type_name -> management.DeviceAuthorizationFlow.provider
2, // 21: management.ManagementService.Login:input_type -> management.EncryptedMessage 20, // 21: management.DeviceAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig
2, // 22: management.ManagementService.Sync:input_type -> management.EncryptedMessage 25, // 22: management.DNSConfig.NameServerGroups:type_name -> management.NameServerGroup
10, // 23: management.ManagementService.GetServerKey:input_type -> management.Empty 23, // 23: management.DNSConfig.CustomZones:type_name -> management.CustomZone
10, // 24: management.ManagementService.isHealthy:input_type -> management.Empty 24, // 24: management.CustomZone.Records:type_name -> management.SimpleRecord
2, // 25: management.ManagementService.GetDeviceAuthorizationFlow:input_type -> management.EncryptedMessage 26, // 25: management.NameServerGroup.NameServers:type_name -> management.NameServer
2, // 26: management.ManagementService.Login:output_type -> management.EncryptedMessage 2, // 26: management.ManagementService.Login:input_type -> management.EncryptedMessage
2, // 27: management.ManagementService.Sync:output_type -> management.EncryptedMessage 2, // 27: management.ManagementService.Sync:input_type -> management.EncryptedMessage
9, // 28: management.ManagementService.GetServerKey:output_type -> management.ServerKeyResponse 10, // 28: management.ManagementService.GetServerKey:input_type -> management.Empty
10, // 29: management.ManagementService.isHealthy:output_type -> management.Empty 10, // 29: management.ManagementService.isHealthy:input_type -> management.Empty
2, // 30: management.ManagementService.GetDeviceAuthorizationFlow:output_type -> management.EncryptedMessage 2, // 30: management.ManagementService.GetDeviceAuthorizationFlow:input_type -> management.EncryptedMessage
26, // [26:31] is the sub-list for method output_type 2, // 31: management.ManagementService.Login:output_type -> management.EncryptedMessage
21, // [21:26] is the sub-list for method input_type 2, // 32: management.ManagementService.Sync:output_type -> management.EncryptedMessage
21, // [21:21] is the sub-list for extension type_name 9, // 33: management.ManagementService.GetServerKey:output_type -> management.ServerKeyResponse
21, // [21:21] is the sub-list for extension extendee 10, // 34: management.ManagementService.isHealthy:output_type -> management.Empty
0, // [0:21] is the sub-list for field type_name 2, // 35: management.ManagementService.GetDeviceAuthorizationFlow:output_type -> management.EncryptedMessage
31, // [31:36] is the sub-list for method output_type
26, // [26:31] is the sub-list for method input_type
26, // [26:26] is the sub-list for extension type_name
26, // [26:26] is the sub-list for extension extendee
0, // [0:26] is the sub-list for field type_name
} }
func init() { file_management_proto_init() } func init() { file_management_proto_init() }
@@ -2003,6 +2411,66 @@ func file_management_proto_init() {
return nil return nil
} }
} }
file_management_proto_msgTypes[20].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*DNSConfig); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_management_proto_msgTypes[21].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*CustomZone); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_management_proto_msgTypes[22].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*SimpleRecord); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_management_proto_msgTypes[23].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*NameServerGroup); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_management_proto_msgTypes[24].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*NameServer); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
} }
type x struct{} type x struct{}
out := protoimpl.TypeBuilder{ out := protoimpl.TypeBuilder{
@@ -2010,7 +2478,7 @@ func file_management_proto_init() {
GoPackagePath: reflect.TypeOf(x{}).PkgPath(), GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_management_proto_rawDesc, RawDescriptor: file_management_proto_rawDesc,
NumEnums: 2, NumEnums: 2,
NumMessages: 20, NumMessages: 25,
NumExtensions: 0, NumExtensions: 0,
NumServices: 1, NumServices: 1,
}, },

View File

@@ -158,6 +158,8 @@ message PeerConfig {
// SSHConfig of the peer. // SSHConfig of the peer.
SSHConfig sshConfig = 3; SSHConfig sshConfig = 3;
// Peer fully qualified domain name
string fqdn = 4;
} }
// NetworkMap represents a network state of the peer with the corresponding configuration parameters to establish peer-to-peer connections // NetworkMap represents a network state of the peer with the corresponding configuration parameters to establish peer-to-peer connections
@@ -178,6 +180,9 @@ message NetworkMap {
// List of routes to be applied // List of routes to be applied
repeated Route Routes = 5; repeated Route Routes = 5;
// DNS config to be applied
DNSConfig DNSConfig = 6;
} }
// RemotePeerConfig represents a configuration of a remote peer. // RemotePeerConfig represents a configuration of a remote peer.
@@ -193,6 +198,9 @@ message RemotePeerConfig {
// SSHConfig is a SSH config of the remote peer. SSHConfig.sshPubKey should be ignored because peer knows it's SSH key. // SSHConfig is a SSH config of the remote peer. SSHConfig.sshPubKey should be ignored because peer knows it's SSH key.
SSHConfig sshConfig = 3; SSHConfig sshConfig = 3;
// Peer fully qualified domain name
string fqdn = 4;
} }
// SSHConfig represents SSH configurations of a peer. // SSHConfig represents SSH configurations of a peer.
@@ -246,4 +254,40 @@ message Route {
int64 Metric = 5; int64 Metric = 5;
bool Masquerade = 6; bool Masquerade = 6;
string NetID = 7; string NetID = 7;
}
// DNSConfig represents a dns.Update
message DNSConfig {
bool ServiceEnable = 1;
repeated NameServerGroup NameServerGroups = 2;
repeated CustomZone CustomZones = 3;
}
// CustomZone represents a dns.CustomZone
message CustomZone {
string Domain = 1;
repeated SimpleRecord Records = 2;
}
// SimpleRecord represents a dns.SimpleRecord
message SimpleRecord {
string Name = 1;
int64 Type = 2;
string Class = 3;
int64 TTL = 4;
string RData = 5;
}
// NameServerGroup represents a dns.NameServerGroup
message NameServerGroup {
repeated NameServer NameServers = 1;
bool Primary = 2;
repeated string Domains = 3;
}
// NameServer represents a dns.NameServer
message NameServer {
string IP = 1;
int64 NSType = 2;
int64 Port = 3;
} }

View File

@@ -8,13 +8,14 @@ import (
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
gocache "github.com/patrickmn/go-cache" gocache "github.com/patrickmn/go-cache"
"github.com/rs/xid" "github.com/rs/xid"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"math/rand" "math/rand"
"net"
"net/netip"
"reflect" "reflect"
"regexp" "regexp"
"strings" "strings"
@@ -37,7 +38,6 @@ func cacheEntryExpiration() time.Duration {
type AccountManager interface { type AccountManager interface {
GetOrCreateAccountByUser(userId, domain string) (*Account, error) GetOrCreateAccountByUser(userId, domain string) (*Account, error)
GetAccountByUser(userId string) (*Account, error)
CreateSetupKey( CreateSetupKey(
accountId string, accountId string,
keyName string, keyName string,
@@ -47,26 +47,25 @@ type AccountManager interface {
) (*SetupKey, error) ) (*SetupKey, error)
SaveSetupKey(accountID string, key *SetupKey) (*SetupKey, error) SaveSetupKey(accountID string, key *SetupKey) (*SetupKey, error)
CreateUser(accountID string, key *UserInfo) (*UserInfo, error) CreateUser(accountID string, key *UserInfo) (*UserInfo, error)
ListSetupKeys(accountID string) ([]*SetupKey, error) ListSetupKeys(accountID, userID string) ([]*SetupKey, error)
SaveUser(accountID string, key *User) (*UserInfo, error) SaveUser(accountID string, key *User) (*UserInfo, error)
GetSetupKey(accountID, keyID string) (*SetupKey, error) GetSetupKey(accountID, userID, keyID string) (*SetupKey, error)
GetAccountById(accountId string) (*Account, error) GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error)
GetAccountByUserOrAccountId(userId, accountId, domain string) (*Account, error) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error)
GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, error)
IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error)
AccountExists(accountId string) (*bool, error) AccountExists(accountId string) (*bool, error)
GetPeer(peerKey string) (*Peer, error) GetPeer(peerKey string) (*Peer, error)
GetPeers(accountID, userID string) ([]*Peer, error)
MarkPeerConnected(peerKey string, connected bool) error MarkPeerConnected(peerKey string, connected bool) error
RenamePeer(accountId string, peerKey string, newName string) (*Peer, error)
DeletePeer(accountId string, peerKey string) (*Peer, error) DeletePeer(accountId string, peerKey string) (*Peer, error)
GetPeerByIP(accountId string, peerIP string) (*Peer, error) GetPeerByIP(accountId string, peerIP string) (*Peer, error)
UpdatePeer(accountID string, peer *Peer) (*Peer, error) UpdatePeer(accountID string, peer *Peer) (*Peer, error)
GetNetworkMap(peerKey string) (*NetworkMap, error) GetNetworkMap(peerKey string) (*NetworkMap, error)
GetPeerNetwork(peerKey string) (*Network, error) GetPeerNetwork(peerKey string) (*Network, error)
AddPeer(setupKey string, userId string, peer *Peer) (*Peer, error) AddPeer(setupKey, userID string, peer *Peer) (*Peer, error)
UpdatePeerMeta(peerKey string, meta PeerSystemMeta) error UpdatePeerMeta(peerKey string, meta PeerSystemMeta) error
UpdatePeerSSHKey(peerKey string, sshKey string) error UpdatePeerSSHKey(peerKey string, sshKey string) error
GetUsersFromAccount(accountId string) ([]*UserInfo, error) GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error)
GetGroup(accountId, groupID string) (*Group, error) GetGroup(accountId, groupID string) (*Group, error)
SaveGroup(accountId string, group *Group) error SaveGroup(accountId string, group *Group) error
UpdateGroup(accountID string, groupID string, operations []GroupUpdateOperation) (*Group, error) UpdateGroup(accountID string, groupID string, operations []GroupUpdateOperation) (*Group, error)
@@ -75,29 +74,28 @@ type AccountManager interface {
GroupAddPeer(accountId, groupID, peerKey string) error GroupAddPeer(accountId, groupID, peerKey string) error
GroupDeletePeer(accountId, groupID, peerKey string) error GroupDeletePeer(accountId, groupID, peerKey string) error
GroupListPeers(accountId, groupID string) ([]*Peer, error) GroupListPeers(accountId, groupID string) ([]*Peer, error)
GetRule(accountId, ruleID string) (*Rule, error) GetRule(accountID, ruleID, userID string) (*Rule, error)
SaveRule(accountID string, rule *Rule) error SaveRule(accountID string, rule *Rule) error
UpdateRule(accountID string, ruleID string, operations []RuleUpdateOperation) (*Rule, error) UpdateRule(accountID string, ruleID string, operations []RuleUpdateOperation) (*Rule, error)
DeleteRule(accountId, ruleID string) error DeleteRule(accountId, ruleID string) error
ListRules(accountId string) ([]*Rule, error) ListRules(accountID, userID string) ([]*Rule, error)
GetRoute(accountID, routeID string) (*route.Route, error) GetRoute(accountID, routeID, userID string) (*route.Route, error)
CreateRoute(accountID string, prefix, peer, description, netID string, masquerade bool, metric int, enabled bool) (*route.Route, error) CreateRoute(accountID string, prefix, peer, description, netID string, masquerade bool, metric int, enabled bool) (*route.Route, error)
SaveRoute(accountID string, route *route.Route) error SaveRoute(accountID string, route *route.Route) error
UpdateRoute(accountID string, routeID string, operations []RouteUpdateOperation) (*route.Route, error) UpdateRoute(accountID string, routeID string, operations []RouteUpdateOperation) (*route.Route, error)
DeleteRoute(accountID, routeID string) error DeleteRoute(accountID, routeID string) error
ListRoutes(accountID string) ([]*route.Route, error) ListRoutes(accountID, userID string) ([]*route.Route, error)
GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error)
CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, enabled bool) (*nbdns.NameServerGroup, error) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool) (*nbdns.NameServerGroup, error)
SaveNameServerGroup(accountID string, nsGroupToSave *nbdns.NameServerGroup) error SaveNameServerGroup(accountID string, nsGroupToSave *nbdns.NameServerGroup) error
UpdateNameServerGroup(accountID, nsGroupID string, operations []NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) UpdateNameServerGroup(accountID, nsGroupID string, operations []NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error)
DeleteNameServerGroup(accountID, nsGroupID string) error DeleteNameServerGroup(accountID, nsGroupID string) error
ListNameServerGroups(accountID string) ([]*nbdns.NameServerGroup, error) ListNameServerGroups(accountID string) ([]*nbdns.NameServerGroup, error)
GetDNSDomain() string
} }
type DefaultAccountManager struct { type DefaultAccountManager struct {
Store Store Store Store
// mux to synchronise account operations (e.g. generating Peer IP address inside the Network)
mux sync.Mutex
// cacheMux and cacheLoading helps to make sure that only a single cache reload runs at a time per accountID // cacheMux and cacheLoading helps to make sure that only a single cache reload runs at a time per accountID
cacheMux sync.Mutex cacheMux sync.Mutex
// cacheLoading keeps the accountIDs that are currently reloading. The accountID has to be removed once cache has been reloaded // cacheLoading keeps the accountIDs that are currently reloading. The accountID has to be removed once cache has been reloaded
@@ -106,6 +104,15 @@ type DefaultAccountManager struct {
idpManager idp.Manager idpManager idp.Manager
cacheManager cache.CacheInterface[[]*idp.UserData] cacheManager cache.CacheInterface[[]*idp.UserData]
ctx context.Context ctx context.Context
// singleAccountMode indicates whether the instance has a single account.
// If true, then every new user will end up under the same account.
// This value will be set to false if management service has more than one account.
singleAccountMode bool
// singleAccountModeDomain is a domain to use in singleAccountMode setup
singleAccountModeDomain string
// dnsDomain is used for peer resolution. This is appended to the peer's name
dnsDomain string
} }
// Account represents a unique account of the system // Account represents a unique account of the system
@@ -135,6 +142,195 @@ type UserInfo struct {
Status string `json:"-"` Status string `json:"-"`
} }
// GetPeersRoutes returns all active routes of provided peers
func (a *Account) GetPeersRoutes(givenPeers []*Peer) []*route.Route {
//TODO Peer.ID migration: we will need to replace search by Peer.ID here
routes := make([]*route.Route, 0)
for _, peer := range givenPeers {
peerRoutes := a.GetPeerRoutes(peer.Key)
activeRoutes := make([]*route.Route, 0)
for _, pr := range peerRoutes {
if pr.Enabled {
activeRoutes = append(activeRoutes, pr)
}
}
if len(activeRoutes) > 0 {
routes = append(routes, activeRoutes...)
}
}
return routes
}
// GetPeerRoutes returns a list of routes of a given peer
func (a *Account) GetPeerRoutes(peerPubKey string) []*route.Route {
//TODO Peer.ID migration: we will need to replace search by Peer.ID here
var routes []*route.Route
for _, r := range a.Routes {
if r.Peer == peerPubKey {
routes = append(routes, r)
continue
}
}
return routes
}
// GetRoutesByPrefix return list of routes by account and route prefix
func (a *Account) GetRoutesByPrefix(prefix netip.Prefix) []*route.Route {
var routes []*route.Route
for _, r := range a.Routes {
if r.Network.String() == prefix.String() {
routes = append(routes, r)
}
}
return routes
}
// GetPeerRules returns a list of source or destination rules of a given peer.
func (a *Account) GetPeerRules(peerPubKey string) (srcRules []*Rule, dstRules []*Rule) {
// Rules are group based so there is no direct access to peers.
// First, find all groups that the given peer belongs to
peerGroups := make(map[string]struct{})
for s, group := range a.Groups {
for _, peer := range group.Peers {
if peerPubKey == peer {
peerGroups[s] = struct{}{}
break
}
}
}
// Second, find all rules that have discovered source and destination groups
srcRulesMap := make(map[string]*Rule)
dstRulesMap := make(map[string]*Rule)
for _, rule := range a.Rules {
for _, g := range rule.Source {
if _, ok := peerGroups[g]; ok && srcRulesMap[rule.ID] == nil {
srcRules = append(srcRules, rule)
srcRulesMap[rule.ID] = rule
}
}
for _, g := range rule.Destination {
if _, ok := peerGroups[g]; ok && dstRulesMap[rule.ID] == nil {
dstRules = append(dstRules, rule)
dstRulesMap[rule.ID] = rule
}
}
}
return srcRules, dstRules
}
// GetPeers returns a list of all Account peers
func (a *Account) GetPeers() []*Peer {
var peers []*Peer
for _, peer := range a.Peers {
peers = append(peers, peer)
}
return peers
}
// UpdatePeer saves new or replaces existing peer
func (a *Account) UpdatePeer(update *Peer) {
//TODO Peer.ID migration: we will need to replace search by Peer.ID here
a.Peers[update.Key] = update
}
// DeletePeer deletes peer from the account cleaning up all the references
func (a *Account) DeletePeer(peerPubKey string) {
// TODO Peer.ID migration: we will need to replace search by Peer.ID here
// delete peer from groups
for _, g := range a.Groups {
for i, pk := range g.Peers {
if pk == peerPubKey {
g.Peers = append(g.Peers[:i], g.Peers[i+1:]...)
break
}
}
}
for _, r := range a.Routes {
if r.Peer == peerPubKey {
r.Enabled = false
r.Peer = ""
}
}
delete(a.Peers, peerPubKey)
a.Network.IncSerial()
}
// FindPeerByPubKey looks for a Peer by provided WireGuard public key in the Account or returns error if it wasn't found.
// It will return an object copy of the peer.
func (a *Account) FindPeerByPubKey(peerPubKey string) (*Peer, error) {
for _, peer := range a.Peers {
if peer.Key == peerPubKey {
return peer.Copy(), nil
}
}
return nil, status.Errorf(status.NotFound, "peer with the public key %s not found", peerPubKey)
}
// FindUser looks for a given user in the Account or returns error if user wasn't found.
func (a *Account) FindUser(userID string) (*User, error) {
user := a.Users[userID]
if user == nil {
return nil, status.Errorf(status.NotFound, "user %s not found", userID)
}
return user, nil
}
// FindSetupKey looks for a given SetupKey in the Account or returns error if it wasn't found.
func (a *Account) FindSetupKey(setupKey string) (*SetupKey, error) {
key := a.SetupKeys[setupKey]
if key == nil {
return nil, status.Errorf(status.NotFound, "setup key not found")
}
return key, nil
}
func (a *Account) getUserGroups(userID string) ([]string, error) {
user, err := a.FindUser(userID)
if err != nil {
return nil, err
}
return user.AutoGroups, nil
}
func (a *Account) getSetupKeyGroups(setupKey string) ([]string, error) {
key, err := a.FindSetupKey(setupKey)
if err != nil {
return nil, err
}
return key.AutoGroups, nil
}
func (a *Account) getTakenIPs() []net.IP {
var takenIps []net.IP
for _, existingPeer := range a.Peers {
takenIps = append(takenIps, existingPeer.IP)
}
return takenIps
}
func (a *Account) getPeerDNSLabels() lookupMap {
existingLabels := make(lookupMap)
for _, peer := range a.Peers {
if peer.DNSLabel != "" {
existingLabels[peer.DNSLabel] = struct{}{}
}
}
return existingLabels
}
func (a *Account) Copy() *Account { func (a *Account) Copy() *Account {
peers := map[string]*Peer{} peers := map[string]*Peer{}
for id, peer := range a.Peers { for id, peer := range a.Peers {
@@ -172,16 +368,19 @@ func (a *Account) Copy() *Account {
} }
return &Account{ return &Account{
Id: a.Id, Id: a.Id,
CreatedBy: a.CreatedBy, CreatedBy: a.CreatedBy,
SetupKeys: setupKeys, Domain: a.Domain,
Network: a.Network.Copy(), DomainCategory: a.DomainCategory,
Peers: peers, IsDomainPrimaryAccount: a.IsDomainPrimaryAccount,
Users: users, SetupKeys: setupKeys,
Groups: groups, Network: a.Network.Copy(),
Rules: rules, Peers: peers,
Routes: routes, Users: users,
NameServerGroups: nsGroups, Groups: groups,
Rules: rules,
Routes: routes,
NameServerGroups: nsGroups,
} }
} }
@@ -195,36 +394,51 @@ func (a *Account) GetGroupAll() (*Group, error) {
} }
// BuildManager creates a new DefaultAccountManager with a provided Store // BuildManager creates a new DefaultAccountManager with a provided Store
func BuildManager( func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager,
store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager, singleAccountModeDomain string, dnsDomain string) (*DefaultAccountManager, error) {
) (*DefaultAccountManager, error) {
am := &DefaultAccountManager{ am := &DefaultAccountManager{
Store: store, Store: store,
mux: sync.Mutex{},
peersUpdateManager: peersUpdateManager, peersUpdateManager: peersUpdateManager,
idpManager: idpManager, idpManager: idpManager,
ctx: context.Background(), ctx: context.Background(),
cacheMux: sync.Mutex{}, cacheMux: sync.Mutex{},
cacheLoading: map[string]chan struct{}{}, cacheLoading: map[string]chan struct{}{},
dnsDomain: dnsDomain,
}
allAccounts := store.GetAllAccounts()
// enable single account mode only if configured by user and number of existing accounts is not grater than 1
am.singleAccountMode = singleAccountModeDomain != "" && len(allAccounts) <= 1
if am.singleAccountMode {
am.singleAccountModeDomain = singleAccountModeDomain
log.Infof("single account mode enabled, accounts number %d", len(allAccounts))
} else {
log.Infof("single account mode disabled, accounts number %d", len(allAccounts))
} }
// if account has not default group // if account doesn't have a default group
// we create 'all' group and add all peers into it // we create 'all' group and add all peers into it
// also we create default rule with source as destination // also we create default rule with source as destination
for _, account := range store.GetAllAccounts() { for _, account := range allAccounts {
shouldSave := false
_, err := account.GetGroupAll() _, err := account.GetGroupAll()
if err != nil { if err != nil {
addAllGroup(account) addAllGroup(account)
if err := store.SaveAccount(account); err != nil { shouldSave = true
}
if shouldSave {
err = store.SaveAccount(account)
if err != nil {
return nil, err return nil, err
} }
} }
} }
gocacheClient := gocache.New(CacheExpirationMax, 30*time.Minute) goCacheClient := gocache.New(CacheExpirationMax, 30*time.Minute)
gocacheStore := cacheStore.NewGoCache(gocacheClient) goCacheStore := cacheStore.NewGoCache(goCacheClient)
am.cacheManager = cache.NewLoadable[[]*idp.UserData](am.loadAccount, cache.New[[]*idp.UserData](gocacheStore)) am.cacheManager = cache.NewLoadable[[]*idp.UserData](am.loadAccount, cache.New[[]*idp.UserData](goCacheStore))
if !isNil(am.idpManager) { if !isNil(am.idpManager) {
go func() { go func() {
@@ -252,14 +466,14 @@ func (am *DefaultAccountManager) newAccount(userID, domain string) (*Account, er
if err == nil { if err == nil {
log.Warnf("an account with ID already exists, retrying...") log.Warnf("an account with ID already exists, retrying...")
continue continue
} else if statusErr.Code() == codes.NotFound { } else if statusErr.Type() == status.NotFound {
return newAccountWithId(accountId, userID, domain), nil return newAccountWithId(accountId, userID, domain), nil
} else { } else {
return nil, err return nil, err
} }
} }
return nil, status.Errorf(codes.Internal, "error while creating new account") return nil, status.Errorf(status.Internal, "error while creating new account")
} }
func (am *DefaultAccountManager) warmupIDPCache() error { func (am *DefaultAccountManager) warmupIDPCache() error {
@@ -278,39 +492,24 @@ func (am *DefaultAccountManager) warmupIDPCache() error {
return nil return nil
} }
// GetAccountById returns an existing account using its ID or error (NotFound) if doesn't exist // GetAccountByUserOrAccountID looks for an account by user or accountID, if no account is provided and
func (am *DefaultAccountManager) GetAccountById(accountId string) (*Account, error) { // userID doesn't have an account associated with it, one account is created
am.mux.Lock() func (am *DefaultAccountManager) GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error) {
defer am.mux.Unlock() if accountID != "" {
return am.Store.GetAccount(accountID)
account, err := am.Store.GetAccount(accountId) } else if userID != "" {
if err != nil { account, err := am.GetOrCreateAccountByUser(userID, domain)
return nil, status.Errorf(codes.NotFound, "account not found")
}
return account, nil
}
// GetAccountByUserOrAccountId look for an account by user or account Id, if no account is provided and
// user id doesn't have an account associated with it, one account is created
func (am *DefaultAccountManager) GetAccountByUserOrAccountId(
userId, accountId, domain string,
) (*Account, error) {
if accountId != "" {
return am.GetAccountById(accountId)
} else if userId != "" {
account, err := am.GetOrCreateAccountByUser(userId, domain)
if err != nil { if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found using user id: %s", userId) return nil, status.Errorf(status.NotFound, "account not found using user id: %s", userID)
} }
err = am.addAccountIDToIDPAppMeta(userId, account) err = am.addAccountIDToIDPAppMeta(userID, account)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return account, nil return account, nil
} }
return nil, status.Errorf(codes.NotFound, "no valid user or account Id provided") return nil, status.Errorf(status.NotFound, "no valid user or account Id provided")
} }
func isNil(i idp.Manager) bool { func isNil(i idp.Manager) bool {
@@ -340,11 +539,7 @@ func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(userID string, account
} }
if err != nil { if err != nil {
return status.Errorf( return status.Errorf(status.Internal, "updating user's app metadata failed with: %v", err)
codes.Internal,
"updating user's app metadata failed with: %v",
err,
)
} }
// refresh cache to reflect the update // refresh cache to reflect the update
_, err = am.refreshCache(account.Id) _, err = am.refreshCache(account.Id)
@@ -471,11 +666,8 @@ func (am *DefaultAccountManager) lookupCache(accountUsers map[string]struct{}, a
} }
// updateAccountDomainAttributes updates the account domain attributes and then, saves the account // updateAccountDomainAttributes updates the account domain attributes and then, saves the account
func (am *DefaultAccountManager) updateAccountDomainAttributes( func (am *DefaultAccountManager) updateAccountDomainAttributes(account *Account, claims jwtclaims.AuthorizationClaims,
account *Account, primaryDomain bool) error {
claims jwtclaims.AuthorizationClaims,
primaryDomain bool,
) error {
account.IsDomainPrimaryAccount = primaryDomain account.IsDomainPrimaryAccount = primaryDomain
lowerDomain := strings.ToLower(claims.Domain) lowerDomain := strings.ToLower(claims.Domain)
@@ -490,7 +682,7 @@ func (am *DefaultAccountManager) updateAccountDomainAttributes(
err := am.Store.SaveAccount(account) err := am.Store.SaveAccount(account)
if err != nil { if err != nil {
return status.Errorf(codes.Internal, "failed saving updated account") return err
} }
return nil return nil
} }
@@ -532,10 +724,7 @@ func (am *DefaultAccountManager) handleExistingUserAccount(
// handleNewUserAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account, // handleNewUserAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account,
// otherwise it will create a new account and make it primary account for the domain. // otherwise it will create a new account and make it primary account for the domain.
func (am *DefaultAccountManager) handleNewUserAccount( func (am *DefaultAccountManager) handleNewUserAccount(domainAcc *Account, claims jwtclaims.AuthorizationClaims) (*Account, error) {
domainAcc *Account,
claims jwtclaims.AuthorizationClaims,
) (*Account, error) {
var ( var (
account *Account account *Account
err error err error
@@ -547,7 +736,7 @@ func (am *DefaultAccountManager) handleNewUserAccount(
account.Users[claims.UserId] = NewRegularUser(claims.UserId) account.Users[claims.UserId] = NewRegularUser(claims.UserId)
err = am.Store.SaveAccount(account) err = am.Store.SaveAccount(account)
if err != nil { if err != nil {
return nil, status.Errorf(codes.Internal, "failed saving updated account") return nil, err
} }
} else { } else {
account, err = am.newAccount(claims.UserId, lowerDomain) account, err = am.newAccount(claims.UserId, lowerDomain)
@@ -582,10 +771,10 @@ func (am *DefaultAccountManager) redeemInvite(account *Account, userID string) e
} }
if user == nil { if user == nil {
return status.Errorf(codes.NotFound, "user %s not found in the IdP", userID) return status.Errorf(status.NotFound, "user %s not found in the IdP", userID)
} }
if user.AppMetadata.WTPendingInvite { if user.AppMetadata.WTPendingInvite != nil && *user.AppMetadata.WTPendingInvite {
log.Infof("redeeming invite for user %s account %s", userID, account.Id) log.Infof("redeeming invite for user %s account %s", userID, account.Id)
// User has already logged in, meaning that IdP should have set wt_pending_invite to false. // User has already logged in, meaning that IdP should have set wt_pending_invite to false.
// Our job is to just reload cache. // Our job is to just reload cache.
@@ -603,18 +792,33 @@ func (am *DefaultAccountManager) redeemInvite(account *Account, userID string) e
} }
// GetAccountFromToken returns an account associated with this token // GetAccountFromToken returns an account associated with this token
func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, error) { func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error) {
if am.singleAccountMode && am.singleAccountModeDomain != "" {
// This section is mostly related to self-hosted installations.
// We override incoming domain claims to group users under a single account.
claims.Domain = am.singleAccountModeDomain
claims.DomainCategory = PrivateCategory
log.Infof("overriding JWT Domain and DomainCategory claims since single account mode is enabled")
}
account, err := am.getAccountWithAuthorizationClaims(claims) account, err := am.getAccountWithAuthorizationClaims(claims)
if err != nil { if err != nil {
return nil, err return nil, nil, err
}
user := account.Users[claims.UserId]
if user == nil {
// this is not really possible because we got an account by user ID
return nil, nil, status.Errorf(status.NotFound, "user %s not found", claims.UserId)
} }
err = am.redeemInvite(account, claims.UserId) err = am.redeemInvite(account, claims.UserId)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
return account, nil return account, user, nil
} }
// getAccountWithAuthorizationClaims retrievs an account using JWT Claims. // getAccountWithAuthorizationClaims retrievs an account using JWT Claims.
@@ -634,15 +838,13 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat
// Existing user + Existing account + Existing Indexed Domain -> Nothing changes // Existing user + Existing account + Existing Indexed Domain -> Nothing changes
// //
// Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain) // Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain)
func (am *DefaultAccountManager) getAccountWithAuthorizationClaims( func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtclaims.AuthorizationClaims) (*Account, error) {
claims jwtclaims.AuthorizationClaims,
) (*Account, error) {
// if Account ID is part of the claims // if Account ID is part of the claims
// it means that we've already classified the domain and user has an account // it means that we've already classified the domain and user has an account
if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) { if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) {
return am.GetAccountByUserOrAccountId(claims.UserId, claims.AccountId, claims.Domain) return am.GetAccountByUserOrAccountID(claims.UserId, claims.AccountId, claims.Domain)
} else if claims.AccountId != "" { } else if claims.AccountId != "" {
accountFromID, err := am.GetAccountById(claims.AccountId) accountFromID, err := am.Store.GetAccount(claims.AccountId)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -654,24 +856,27 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(
} }
} }
am.mux.Lock() unlock := am.Store.AcquireGlobalLock()
defer am.mux.Unlock() defer unlock()
// We checked if the domain has a primary account already // We checked if the domain has a primary account already
domainAccount, err := am.Store.GetAccountByPrivateDomain(claims.Domain) domainAccount, err := am.Store.GetAccountByPrivateDomain(claims.Domain)
accStatus, _ := status.FromError(err) if err != nil {
if accStatus.Code() != codes.OK && accStatus.Code() != codes.NotFound { // if NotFound we are good to continue, otherwise return error
return nil, err e, ok := status.FromError(err)
if !ok || e.Type() != status.NotFound {
return nil, err
}
} }
account, err := am.Store.GetUserAccount(claims.UserId) account, err := am.Store.GetAccountByUser(claims.UserId)
if err == nil { if err == nil {
err = am.handleExistingUserAccount(account, domainAccount, claims) err = am.handleExistingUserAccount(account, domainAccount, claims)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return account, nil return account, nil
} else if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound { } else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
return am.handleNewUserAccount(domainAccount, claims) return am.handleNewUserAccount(domainAccount, claims)
} else { } else {
// other error // other error
@@ -685,14 +890,15 @@ func isDomainValid(domain string) bool {
} }
// AccountExists checks whether account exists (returns true) or not (returns false) // AccountExists checks whether account exists (returns true) or not (returns false)
func (am *DefaultAccountManager) AccountExists(accountId string) (*bool, error) { func (am *DefaultAccountManager) AccountExists(accountID string) (*bool, error) {
am.mux.Lock()
defer am.mux.Unlock() unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
var res bool var res bool
_, err := am.Store.GetAccount(accountId) _, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound { if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
res = false res = false
return &res, nil return &res, nil
} else { } else {
@@ -704,6 +910,11 @@ func (am *DefaultAccountManager) AccountExists(accountId string) (*bool, error)
return &res, nil return &res, nil
} }
// GetDNSDomain returns the configured dnsDomain
func (am *DefaultAccountManager) GetDNSDomain() string {
return am.dnsDomain
}
// addAllGroup to account object if it doesn't exists // addAllGroup to account object if it doesn't exists
func addAllGroup(account *Account) { func addAllGroup(account *Account) {
if len(account.Groups) == 0 { if len(account.Groups) == 0 {
@@ -761,15 +972,6 @@ func newAccountWithId(accountId, userId, domain string) *Account {
return acc return acc
} }
func getAccountSetupKeyByKey(acc *Account, key string) *SetupKey {
for _, k := range acc.SetupKeys {
if key == k.Key {
return k
}
}
return nil
}
func removeFromList(inputList []string, toRemove []string) []string { func removeFromList(inputList []string, toRemove []string) []string {
toRemoveMap := make(map[string]struct{}) toRemoveMap := make(map[string]struct{})
for _, item := range toRemove { for _, item := range toRemove {

View File

@@ -1,7 +1,11 @@
package server package server
import ( import (
"fmt"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/route"
"net" "net"
"reflect"
"sync" "sync"
"testing" "testing"
@@ -117,7 +121,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
t.Fatalf("expected to create an account for a user %s", userId) t.Fatalf("expected to create an account for a user %s", userId)
} }
account, err = manager.GetAccountByUser(userId) account, err = manager.Store.GetAccountByUser(userId)
if err != nil { if err != nil {
t.Errorf("expected to get existing account after creation, no account was found for a user %s", userId) t.Errorf("expected to get existing account after creation, no account was found for a user %s", userId)
} }
@@ -298,7 +302,7 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
manager, err := createManager(t) manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
initAccount, err := manager.GetAccountByUserOrAccountId(testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain) initAccount, err := manager.GetAccountByUserOrAccountID(testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain)
require.NoError(t, err, "create init user failed") require.NoError(t, err, "create init user failed")
if testCase.inputUpdateAttrs { if testCase.inputUpdateAttrs {
@@ -310,7 +314,7 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
testCase.inputClaims.AccountId = initAccount.Id testCase.inputClaims.AccountId = initAccount.Id
} }
account, err := manager.GetAccountFromToken(testCase.inputClaims) account, _, err := manager.GetAccountFromToken(testCase.inputClaims)
require.NoError(t, err, "support function failed") require.NoError(t, err, "support function failed")
verifyNewAccountHasDefaultFields(t, account, testCase.expectedCreatedBy, testCase.inputClaims.Domain, testCase.expectedUsers) verifyNewAccountHasDefaultFields(t, account, testCase.expectedCreatedBy, testCase.inputClaims.Domain, testCase.expectedUsers)
verifyCanAddPeerToAccount(t, manager, account, testCase.expectedCreatedBy) verifyCanAddPeerToAccount(t, manager, account, testCase.expectedCreatedBy)
@@ -341,7 +345,7 @@ func TestAccountManager_PrivateAccount(t *testing.T) {
t.Fatalf("expected to create an account for a user %s", userId) t.Fatalf("expected to create an account for a user %s", userId)
} }
account, err = manager.GetAccountByUser(userId) account, err = manager.Store.GetAccountByUser(userId)
if err != nil { if err != nil {
t.Errorf("expected to get existing account after creation, no account was found for a user %s", userId) t.Errorf("expected to get existing account after creation, no account was found for a user %s", userId)
} }
@@ -397,7 +401,7 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
userId := "test_user" userId := "test_user"
account, err := manager.GetAccountByUserOrAccountId(userId, "", "") account, err := manager.GetAccountByUserOrAccountID(userId, "", "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -407,12 +411,12 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
accountId := account.Id accountId := account.Id
_, err = manager.GetAccountByUserOrAccountId("", accountId, "") _, err = manager.GetAccountByUserOrAccountID("", accountId, "")
if err != nil { if err != nil {
t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", accountId) t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", accountId)
} }
_, err = manager.GetAccountByUserOrAccountId("", "", "") _, err = manager.GetAccountByUserOrAccountID("", "", "")
if err == nil { if err == nil {
t.Errorf("expected an error when user and account IDs are empty") t.Errorf("expected an error when user and account IDs are empty")
} }
@@ -466,7 +470,7 @@ func TestAccountManager_GetAccount(t *testing.T) {
} }
// AddAccount has been already tested so we can assume it is correct and compare results // AddAccount has been already tested so we can assume it is correct and compare results
getAccount, err := manager.GetAccountById(expectedId) getAccount, err := manager.Store.GetAccount(account.Id)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
@@ -536,7 +540,7 @@ func TestAccountManager_AddPeer(t *testing.T) {
return return
} }
account, err = manager.GetAccountById(account.Id) account, err = manager.Store.GetAccount(account.Id)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
@@ -598,7 +602,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
return return
} }
account, err = manager.GetAccountById(account.Id) account, err = manager.Store.GetAccount(account.Id)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
@@ -676,7 +680,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
peer2 := getPeer() peer2 := getPeer()
peer3 := getPeer() peer3 := getPeer()
account, err = manager.GetAccountById(account.Id) account, err = manager.Store.GetAccount(account.Id)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
@@ -844,7 +848,7 @@ func TestAccountManager_DeletePeer(t *testing.T) {
return return
} }
account, err = manager.GetAccountById(account.Id) account, err = manager.Store.GetAccount(account.Id)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
@@ -874,7 +878,7 @@ func TestGetUsersFromAccount(t *testing.T) {
account.Users[user.Id] = user account.Users[user.Id] = user
} }
userInfos, err := manager.GetUsersFromAccount(accountId) userInfos, err := manager.GetUsersFromAccount(accountId, "1")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -957,17 +961,257 @@ func TestAccountManager_UpdatePeerMeta(t *testing.T) {
assert.Equal(t, newMeta, p.Meta) assert.Equal(t, newMeta, p.Meta)
} }
func TestAccount_GetPeerRules(t *testing.T) {
groups := map[string]*Group{
"group_1": {
ID: "group_1",
Name: "group_1",
Peers: []string{"peer-1", "peer-2"},
},
"group_2": {
ID: "group_2",
Name: "group_2",
Peers: []string{"peer-2", "peer-3"},
},
"group_3": {
ID: "group_3",
Name: "group_3",
Peers: []string{"peer-4"},
},
"group_4": {
ID: "group_4",
Name: "group_4",
Peers: []string{"peer-1"},
},
"group_5": {
ID: "group_5",
Name: "group_5",
Peers: []string{"peer-1"},
},
}
rules := map[string]*Rule{
"rule-1": {
ID: "rule-1",
Name: "rule-1",
Description: "rule-1",
Disabled: false,
Source: []string{"group_1", "group_5"},
Destination: []string{"group_2"},
Flow: 0,
},
"rule-2": {
ID: "rule-2",
Name: "rule-2",
Description: "rule-2",
Disabled: false,
Source: []string{"group_1"},
Destination: []string{"group_1"},
Flow: 0,
},
"rule-3": {
ID: "rule-3",
Name: "rule-3",
Description: "rule-3",
Disabled: false,
Source: []string{"group_3"},
Destination: []string{"group_3"},
Flow: 0,
},
}
account := &Account{
Groups: groups,
Rules: rules,
}
srcRules, dstRules := account.GetPeerRules("peer-1")
assert.Equal(t, 2, len(srcRules))
assert.Equal(t, 1, len(dstRules))
}
func TestFileStore_GetRoutesByPrefix(t *testing.T) {
_, prefix, err := route.ParseNetwork("192.168.64.0/24")
if err != nil {
t.Fatal(err)
}
account := &Account{
Routes: map[string]*route.Route{
"route-1": {
ID: "route-1",
Network: prefix,
NetID: "network-1",
Description: "network-1",
Peer: "peer-1",
NetworkType: 0,
Masquerade: false,
Metric: 999,
Enabled: true,
},
"route-2": {
ID: "route-2",
Network: prefix,
NetID: "network-1",
Description: "network-1",
Peer: "peer-2",
NetworkType: 0,
Masquerade: false,
Metric: 999,
Enabled: true,
},
},
}
routes := account.GetRoutesByPrefix(prefix)
assert.Len(t, routes, 2)
routeIDs := make(map[string]struct{}, 2)
for _, r := range routes {
routeIDs[r.ID] = struct{}{}
}
assert.Contains(t, routeIDs, "route-1")
assert.Contains(t, routeIDs, "route-2")
}
func TestAccount_GetPeersRoutes(t *testing.T) {
_, prefix, err := route.ParseNetwork("192.168.64.0/24")
if err != nil {
t.Fatal(err)
}
account := &Account{
Peers: map[string]*Peer{
"peer-1": {Key: "peer-1"}, "peer-2": {Key: "peer-2"}, "peer-3": {Key: "peer-1"},
},
Routes: map[string]*route.Route{
"route-1": {
ID: "route-1",
Network: prefix,
NetID: "network-1",
Description: "network-1",
Peer: "peer-1",
NetworkType: 0,
Masquerade: false,
Metric: 999,
Enabled: true,
},
"route-2": {
ID: "route-2",
Network: prefix,
NetID: "network-1",
Description: "network-1",
Peer: "peer-2",
NetworkType: 0,
Masquerade: false,
Metric: 999,
Enabled: true,
},
},
}
routes := account.GetPeersRoutes([]*Peer{{Key: "peer-1"}, {Key: "peer-2"}, {Key: "non-existing-peer"}})
assert.Len(t, routes, 2)
routeIDs := make(map[string]struct{}, 2)
for _, r := range routes {
routeIDs[r.ID] = struct{}{}
}
assert.Contains(t, routeIDs, "route-1")
assert.Contains(t, routeIDs, "route-2")
}
func TestAccount_Copy(t *testing.T) {
account := &Account{
Id: "account1",
CreatedBy: "tester",
Domain: "test.com",
DomainCategory: "public",
IsDomainPrimaryAccount: true,
SetupKeys: map[string]*SetupKey{
"setup1": {
Id: "setup1",
AutoGroups: []string{"group1"},
},
},
Network: &Network{
Id: "net1",
},
Peers: map[string]*Peer{
"peer1": {
Key: "key1",
},
},
Users: map[string]*User{
"user1": {
Id: "user1",
Role: UserRoleAdmin,
AutoGroups: []string{"group1"},
},
},
Groups: map[string]*Group{
"group1": {
ID: "group1",
},
},
Rules: map[string]*Rule{
"rule1": {
ID: "rule1",
},
},
Routes: map[string]*route.Route{
"route1": {
ID: "route1",
},
},
NameServerGroups: map[string]*nbdns.NameServerGroup{
"nsGroup1": {
ID: "nsGroup1",
},
},
}
err := hasNilField(account)
if err != nil {
t.Fatal(err)
}
accountCopy := account.Copy()
assert.Equal(t, account, accountCopy, "account copy returned a different value than expected")
}
// hasNilField validates pointers, maps and slices if they are nil
func hasNilField(x interface{}) error {
rv := reflect.ValueOf(x)
rv = rv.Elem()
for i := 0; i < rv.NumField(); i++ {
if f := rv.Field(i); f.IsValid() {
k := f.Kind()
switch k {
case reflect.Ptr:
if f.IsNil() {
return fmt.Errorf("field %s is nil", f.String())
}
case reflect.Map, reflect.Slice:
if f.Len() == 0 || f.IsNil() {
return fmt.Errorf("field %s is nil", f.String())
}
}
}
}
return nil
}
func createManager(t *testing.T) (*DefaultAccountManager, error) { func createManager(t *testing.T) (*DefaultAccountManager, error) {
store, err := createStore(t) store, err := createStore(t)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return BuildManager(store, NewPeersUpdateManager(), nil) return BuildManager(store, NewPeersUpdateManager(), nil, "", "")
} }
func createStore(t *testing.T) (Store, error) { func createStore(t *testing.T) (Store, error) {
dataDir := t.TempDir() dataDir := t.TempDir()
store, err := NewStore(dataDir) store, err := NewFileStore(dataDir)
if err != nil { if err != nil {
return nil, err return nil, err
} }

152
management/server/dns.go Normal file
View File

@@ -0,0 +1,152 @@
package server
import (
"fmt"
"github.com/miekg/dns"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/proto"
log "github.com/sirupsen/logrus"
"strconv"
)
type lookupMap map[string]struct{}
const defaultTTL = 300
func toProtocolDNSConfig(update nbdns.Config) *proto.DNSConfig {
protoUpdate := &proto.DNSConfig{ServiceEnable: update.ServiceEnable}
for _, zone := range update.CustomZones {
protoZone := &proto.CustomZone{Domain: zone.Domain}
for _, record := range zone.Records {
protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{
Name: record.Name,
Type: int64(record.Type),
Class: record.Class,
TTL: int64(record.TTL),
RData: record.RData,
})
}
protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone)
}
for _, nsGroup := range update.NameServerGroups {
protoGroup := &proto.NameServerGroup{
Primary: nsGroup.Primary,
Domains: nsGroup.Domains,
}
for _, ns := range nsGroup.NameServers {
protoNS := &proto.NameServer{
IP: ns.IP.String(),
Port: int64(ns.Port),
NSType: int64(ns.NSType),
}
protoGroup.NameServers = append(protoGroup.NameServers, protoNS)
}
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup)
}
return protoUpdate
}
func getPeersCustomZone(account *Account, dnsDomain string) nbdns.CustomZone {
if dnsDomain == "" {
log.Errorf("no dns domain is set, returning empty zone")
return nbdns.CustomZone{}
}
customZone := nbdns.CustomZone{
Domain: dns.Fqdn(dnsDomain),
}
for _, peer := range account.Peers {
if peer.DNSLabel == "" {
log.Errorf("found a peer with empty dns label. It was probably caused by a invalid character in its name. Peer Name: %s", peer.Name)
continue
}
customZone.Records = append(customZone.Records, nbdns.SimpleRecord{
Name: dns.Fqdn(peer.DNSLabel + "." + dnsDomain),
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: defaultTTL,
RData: peer.IP.String(),
})
}
return customZone
}
func getPeerNSGroups(account *Account, peerID string) []*nbdns.NameServerGroup {
groupList := make(lookupMap)
for groupID, group := range account.Groups {
for _, id := range group.Peers {
if id == peerID {
groupList[groupID] = struct{}{}
break
}
}
}
var peerNSGroups []*nbdns.NameServerGroup
for _, nsGroup := range account.NameServerGroups {
if !nsGroup.Enabled {
continue
}
for _, gID := range nsGroup.Groups {
_, found := groupList[gID]
if found {
peerNSGroups = append(peerNSGroups, nsGroup.Copy())
break
}
}
}
return peerNSGroups
}
func addPeerLabelsToAccount(account *Account, peerLabels lookupMap) {
for _, peer := range account.Peers {
label, err := getPeerHostLabel(peer.Name, peerLabels)
if err != nil {
log.Errorf("got an error while generating a peer host label. Peer name %s, error: %v. Trying with the peer's meta hostname", peer.Name, err)
label, err = getPeerHostLabel(peer.Meta.Hostname, peerLabels)
if err != nil {
log.Errorf("got another error while generating a peer host label with hostname. Peer hostname %s, error: %v. Skiping", peer.Meta.Hostname, err)
continue
}
}
peer.DNSLabel = label
peerLabels[label] = struct{}{}
}
}
func getPeerHostLabel(name string, peerLabels lookupMap) (string, error) {
label, err := nbdns.GetParsedDomainLabel(name)
if err != nil {
return "", err
}
uniqueLabel := getUniqueHostLabel(label, peerLabels)
if uniqueLabel == "" {
return "", fmt.Errorf("couldn't find a unique valid label for %s, parsed label %s", name, label)
}
return uniqueLabel, nil
}
// getUniqueHostLabel look for a unique host label, and if doesn't find add a suffix up to 999
func getUniqueHostLabel(name string, peerLabels lookupMap) string {
_, found := peerLabels[name]
if !found {
return name
}
for i := 1; i < 1000; i++ {
nameWithSuffix := name + "-" + strconv.Itoa(i)
_, found = peerLabels[nameWithSuffix]
if !found {
return nameWithSuffix
}
}
return ""
}

View File

@@ -1,52 +0,0 @@
package server
import (
"fmt"
)
const (
// UserAlreadyExists indicates that user already exists
UserAlreadyExists ErrorType = 1
// AccountNotFound indicates that specified account hasn't been found
AccountNotFound ErrorType = iota
// PreconditionFailed indicates that some pre-condition for the operation hasn't been fulfilled
PreconditionFailed ErrorType = iota
)
// ErrorType is a type of the Error
type ErrorType int32
// Error is an internal error
type Error struct {
errorType ErrorType
message string
}
// Type returns the Type of the error
func (e *Error) Type() ErrorType {
return e.errorType
}
// Error is an error string
func (e *Error) Error() string {
return e.message
}
// Errorf returns Error(errorType, fmt.Sprintf(format, a...)).
func Errorf(errorType ErrorType, format string, a ...interface{}) error {
return &Error{
errorType: errorType,
message: fmt.Sprintf(format, a...),
}
}
// FromError returns Error, true if the provided error is of type of Error. nil, false otherwise
func FromError(err error) (s *Error, ok bool) {
if err == nil {
return nil, true
}
if e, ok := err.(*Error); ok {
return e, true
}
return nil, false
}

View File

@@ -1,16 +1,13 @@
package server package server
import ( import (
"fmt" "github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/route" log "github.com/sirupsen/logrus"
"net/netip"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
"sync" "sync"
"time"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
@@ -21,29 +18,29 @@ const storeFileName = "store.json"
// FileStore represents an account storage backed by a file persisted to disk // FileStore represents an account storage backed by a file persisted to disk
type FileStore struct { type FileStore struct {
Accounts map[string]*Account Accounts map[string]*Account
SetupKeyId2AccountId map[string]string `json:"-"` SetupKeyID2AccountID map[string]string `json:"-"`
PeerKeyId2AccountId map[string]string `json:"-"` PeerKeyID2AccountID map[string]string `json:"-"`
UserId2AccountId map[string]string `json:"-"` UserID2AccountID map[string]string `json:"-"`
PrivateDomain2AccountId map[string]string `json:"-"` PrivateDomain2AccountID map[string]string `json:"-"`
PeerKeyId2SrcRulesId map[string]map[string]struct{} `json:"-"`
PeerKeyId2DstRulesId map[string]map[string]struct{} `json:"-"`
PeerKeyID2RouteIDs map[string]map[string]struct{} `json:"-"`
AccountPrefix2RouteIDs map[string]map[string][]string `json:"-"`
InstallationID string InstallationID string
// mutex to synchronise Store read/write operations // mutex to synchronise Store read/write operations
mux sync.Mutex `json:"-"` mux sync.Mutex `json:"-"`
storeFile string `json:"-"` storeFile string `json:"-"`
// sync.Mutex indexed by accountID
accountLocks sync.Map `json:"-"`
globalAccountLock sync.Mutex `json:"-"`
} }
type StoredAccount struct{} type StoredAccount struct{}
// NewStore restores a store from the file located in the datadir // NewFileStore restores a store from the file located in the datadir
func NewStore(dataDir string) (*FileStore, error) { func NewFileStore(dataDir string) (*FileStore, error) {
return restore(filepath.Join(dataDir, storeFileName)) return restore(filepath.Join(dataDir, storeFileName))
} }
// restore restores the state of the store from the file. // restore the state of the store from the file.
// Creates a new empty store file if doesn't exist // Creates a new empty store file if doesn't exist
func restore(file string) (*FileStore, error) { func restore(file string) (*FileStore, error) {
if _, err := os.Stat(file); os.IsNotExist(err) { if _, err := os.Stat(file); os.IsNotExist(err) {
@@ -51,14 +48,11 @@ func restore(file string) (*FileStore, error) {
s := &FileStore{ s := &FileStore{
Accounts: make(map[string]*Account), Accounts: make(map[string]*Account),
mux: sync.Mutex{}, mux: sync.Mutex{},
SetupKeyId2AccountId: make(map[string]string), globalAccountLock: sync.Mutex{},
PeerKeyId2AccountId: make(map[string]string), SetupKeyID2AccountID: make(map[string]string),
UserId2AccountId: make(map[string]string), PeerKeyID2AccountID: make(map[string]string),
PrivateDomain2AccountId: make(map[string]string), UserID2AccountID: make(map[string]string),
PeerKeyId2SrcRulesId: make(map[string]map[string]struct{}), PrivateDomain2AccountID: make(map[string]string),
PeerKeyID2RouteIDs: make(map[string]map[string]struct{}),
PeerKeyId2DstRulesId: make(map[string]map[string]struct{}),
AccountPrefix2RouteIDs: make(map[string]map[string][]string),
storeFile: file, storeFile: file,
} }
@@ -77,287 +71,113 @@ func restore(file string) (*FileStore, error) {
store := read.(*FileStore) store := read.(*FileStore)
store.storeFile = file store.storeFile = file
store.SetupKeyId2AccountId = make(map[string]string) store.SetupKeyID2AccountID = make(map[string]string)
store.PeerKeyId2AccountId = make(map[string]string) store.PeerKeyID2AccountID = make(map[string]string)
store.UserId2AccountId = make(map[string]string) store.UserID2AccountID = make(map[string]string)
store.PrivateDomain2AccountId = make(map[string]string) store.PrivateDomain2AccountID = make(map[string]string)
store.PeerKeyId2SrcRulesId = make(map[string]map[string]struct{})
store.PeerKeyId2DstRulesId = make(map[string]map[string]struct{})
store.PeerKeyID2RouteIDs = make(map[string]map[string]struct{})
store.AccountPrefix2RouteIDs = make(map[string]map[string][]string)
for accountId, account := range store.Accounts { for accountID, account := range store.Accounts {
for setupKeyId := range account.SetupKeys { for setupKeyId := range account.SetupKeys {
store.SetupKeyId2AccountId[strings.ToUpper(setupKeyId)] = accountId store.SetupKeyID2AccountID[strings.ToUpper(setupKeyId)] = accountID
}
for _, rule := range account.Rules {
for _, groupID := range rule.Source {
if group, ok := account.Groups[groupID]; ok {
for _, peerID := range group.Peers {
rules := store.PeerKeyId2SrcRulesId[peerID]
if rules == nil {
rules = map[string]struct{}{}
store.PeerKeyId2SrcRulesId[peerID] = rules
}
rules[rule.ID] = struct{}{}
}
}
}
for _, groupID := range rule.Destination {
if group, ok := account.Groups[groupID]; ok {
for _, peerID := range group.Peers {
rules := store.PeerKeyId2DstRulesId[peerID]
if rules == nil {
rules = map[string]struct{}{}
store.PeerKeyId2DstRulesId[peerID] = rules
}
rules[rule.ID] = struct{}{}
}
}
}
} }
for _, peer := range account.Peers { for _, peer := range account.Peers {
store.PeerKeyId2AccountId[peer.Key] = accountId store.PeerKeyID2AccountID[peer.Key] = accountID
// reset all peers to status = Disconnected
if peer.Status != nil && peer.Status.Connected {
peer.Status.Connected = false
}
} }
for _, user := range account.Users { for _, user := range account.Users {
store.UserId2AccountId[user.Id] = accountId store.UserID2AccountID[user.Id] = accountID
} }
for _, user := range account.Users { for _, user := range account.Users {
store.UserId2AccountId[user.Id] = accountId store.UserID2AccountID[user.Id] = accountID
}
for _, route := range account.Routes {
if route.Peer == "" {
continue
}
if store.PeerKeyID2RouteIDs[route.Peer] == nil {
store.PeerKeyID2RouteIDs[route.Peer] = make(map[string]struct{})
}
store.PeerKeyID2RouteIDs[route.Peer][route.ID] = struct{}{}
if store.AccountPrefix2RouteIDs[account.Id] == nil {
store.AccountPrefix2RouteIDs[account.Id] = make(map[string][]string)
}
if _, ok := store.AccountPrefix2RouteIDs[account.Id][route.Network.String()]; !ok {
store.AccountPrefix2RouteIDs[account.Id][route.Network.String()] = make([]string, 0)
}
store.AccountPrefix2RouteIDs[account.Id][route.Network.String()] = append(
store.AccountPrefix2RouteIDs[account.Id][route.Network.String()],
route.ID,
)
} }
if account.Domain != "" && account.DomainCategory == PrivateCategory && if account.Domain != "" && account.DomainCategory == PrivateCategory &&
account.IsDomainPrimaryAccount { account.IsDomainPrimaryAccount {
store.PrivateDomain2AccountId[account.Domain] = accountId store.PrivateDomain2AccountID[account.Domain] = accountID
} }
// for data migration. Can be removed once most base will be with labels
existingLabels := account.getPeerDNSLabels()
if len(existingLabels) != len(account.Peers) {
addPeerLabelsToAccount(account, existingLabels)
}
}
// we need this persist to apply changes we made to account.Peers (we set them to Disconnected)
err = store.persist(store.storeFile)
if err != nil {
return nil, err
} }
return store, nil return store, nil
} }
// persist persists account data to a file // persist account data to a file
// It is recommended to call it with locking FileStore.mux // It is recommended to call it with locking FileStore.mux
func (s *FileStore) persist(file string) error { func (s *FileStore) persist(file string) error {
return util.WriteJson(file, s) return util.WriteJson(file, s)
} }
// SavePeer saves updated peer // AcquireGlobalLock acquires global lock across all the accounts and returns a function that releases the lock
func (s *FileStore) SavePeer(accountId string, peer *Peer) error { func (s *FileStore) AcquireGlobalLock() (unlock func()) {
s.mux.Lock() log.Debugf("acquiring global lock")
defer s.mux.Unlock() start := time.Now()
s.globalAccountLock.Lock()
account, err := s.GetAccount(accountId) unlock = func() {
if err != nil { s.globalAccountLock.Unlock()
return err log.Debugf("released global lock in %v", time.Since(start))
} }
// if it is new peer, add it to default 'All' group return unlock
allGroup, err := account.GetGroupAll()
if err != nil {
return err
}
ind := -1
for i, pid := range allGroup.Peers {
if pid == peer.Key {
ind = i
break
}
}
if ind < 0 {
allGroup.Peers = append(allGroup.Peers, peer.Key)
}
account.Peers[peer.Key] = peer
return s.persist(s.storeFile)
} }
// DeletePeer deletes peer from the Store // AcquireAccountLock acquires account lock and returns a function that releases the lock
func (s *FileStore) DeletePeer(accountId string, peerKey string) (*Peer, error) { func (s *FileStore) AcquireAccountLock(accountID string) (unlock func()) {
s.mux.Lock() log.Debugf("acquiring lock for account %s", accountID)
defer s.mux.Unlock() start := time.Now()
value, _ := s.accountLocks.LoadOrStore(accountID, &sync.Mutex{})
mtx := value.(*sync.Mutex)
mtx.Lock()
account, err := s.GetAccount(accountId) unlock = func() {
if err != nil { mtx.Unlock()
return nil, err log.Debugf("released lock for account %s in %v", accountID, time.Since(start))
} }
peer := account.Peers[peerKey] return unlock
if peer == nil {
return nil, status.Errorf(codes.NotFound, "peer not found")
}
peerRoutes := s.PeerKeyID2RouteIDs[peerKey]
delete(account.Peers, peerKey)
delete(s.PeerKeyId2AccountId, peerKey)
delete(s.PeerKeyId2DstRulesId, peerKey)
delete(s.PeerKeyId2SrcRulesId, peerKey)
delete(s.PeerKeyID2RouteIDs, peerKey)
// cleanup groups
for _, g := range account.Groups {
var peers []string
for _, p := range g.Peers {
if p != peerKey {
peers = append(peers, p)
}
}
g.Peers = peers
}
for routeID := range peerRoutes {
account.Routes[routeID].Enabled = false
account.Routes[routeID].Peer = ""
}
err = s.persist(s.storeFile)
if err != nil {
return nil, err
}
return peer, nil
} }
// GetPeer returns a peer from a Store
func (s *FileStore) GetPeer(peerKey string) (*Peer, error) {
s.mux.Lock()
defer s.mux.Unlock()
accountId, accountIdFound := s.PeerKeyId2AccountId[peerKey]
if !accountIdFound {
return nil, status.Errorf(codes.NotFound, "peer not found")
}
account, err := s.GetAccount(accountId)
if err != nil {
return nil, err
}
if peer, ok := account.Peers[peerKey]; ok {
return peer, nil
}
return nil, status.Errorf(codes.NotFound, "peer not found")
}
// SaveAccount updates an existing account or adds a new one
func (s *FileStore) SaveAccount(account *Account) error { func (s *FileStore) SaveAccount(account *Account) error {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
accountCopy := account.Copy()
// todo will override, handle existing keys // todo will override, handle existing keys
s.Accounts[account.Id] = account s.Accounts[accountCopy.Id] = accountCopy
// todo check that account.Id and keyId are not exist already // todo check that account.Id and keyId are not exist already
// because if keyId exists for other accounts this can be bad // because if keyId exists for other accounts this can be bad
for keyId := range account.SetupKeys { for keyID := range accountCopy.SetupKeys {
s.SetupKeyId2AccountId[strings.ToUpper(keyId)] = account.Id s.SetupKeyID2AccountID[strings.ToUpper(keyID)] = accountCopy.Id
} }
// enforce peer to account index and delete peer to route indexes for rebuild // enforce peer to account index and delete peer to route indexes for rebuild
for _, peer := range account.Peers { for _, peer := range accountCopy.Peers {
s.PeerKeyId2AccountId[peer.Key] = account.Id s.PeerKeyID2AccountID[peer.Key] = accountCopy.Id
delete(s.PeerKeyID2RouteIDs, peer.Key)
} }
delete(s.AccountPrefix2RouteIDs, account.Id) for _, user := range accountCopy.Users {
s.UserID2AccountID[user.Id] = accountCopy.Id
// remove all peers related to account from rules indexes
cleanIDs := make([]string, 0)
for key := range s.PeerKeyId2SrcRulesId {
if accountID, ok := s.PeerKeyId2AccountId[key]; ok && accountID == account.Id {
cleanIDs = append(cleanIDs, key)
}
}
for _, key := range cleanIDs {
delete(s.PeerKeyId2SrcRulesId, key)
}
cleanIDs = cleanIDs[:0]
for key := range s.PeerKeyId2DstRulesId {
if accountID, ok := s.PeerKeyId2AccountId[key]; ok && accountID == account.Id {
cleanIDs = append(cleanIDs, key)
}
}
for _, key := range cleanIDs {
delete(s.PeerKeyId2DstRulesId, key)
} }
// rebuild rule indexes if accountCopy.DomainCategory == PrivateCategory && accountCopy.IsDomainPrimaryAccount {
for _, rule := range account.Rules { s.PrivateDomain2AccountID[accountCopy.Domain] = accountCopy.Id
for _, gid := range rule.Source {
g, ok := account.Groups[gid]
if !ok {
break
}
for _, pid := range g.Peers {
rules := s.PeerKeyId2SrcRulesId[pid]
if rules == nil {
rules = map[string]struct{}{}
s.PeerKeyId2SrcRulesId[pid] = rules
}
rules[rule.ID] = struct{}{}
}
}
for _, gid := range rule.Destination {
g, ok := account.Groups[gid]
if !ok {
break
}
for _, pid := range g.Peers {
rules := s.PeerKeyId2DstRulesId[pid]
if rules == nil {
rules = map[string]struct{}{}
s.PeerKeyId2DstRulesId[pid] = rules
}
rules[rule.ID] = struct{}{}
}
}
}
for _, route := range account.Routes {
if route.Peer == "" {
continue
}
if s.PeerKeyID2RouteIDs[route.Peer] == nil {
s.PeerKeyID2RouteIDs[route.Peer] = make(map[string]struct{})
}
s.PeerKeyID2RouteIDs[route.Peer][route.ID] = struct{}{}
if s.AccountPrefix2RouteIDs[account.Id] == nil {
s.AccountPrefix2RouteIDs[account.Id] = make(map[string][]string)
}
if _, ok := s.AccountPrefix2RouteIDs[account.Id][route.Network.String()]; !ok {
s.AccountPrefix2RouteIDs[account.Id][route.Network.String()] = make([]string, 0)
}
s.AccountPrefix2RouteIDs[account.Id][route.Network.String()] = append(
s.AccountPrefix2RouteIDs[account.Id][route.Network.String()],
route.ID,
)
}
for _, user := range account.Users {
s.UserId2AccountId[user.Id] = account.Id
}
if account.DomainCategory == PrivateCategory && account.IsDomainPrimaryAccount {
s.PrivateDomain2AccountId[account.Domain] = account.Id
} }
return s.persist(s.storeFile) return s.persist(s.storeFile)
@@ -365,53 +185,38 @@ func (s *FileStore) SaveAccount(account *Account) error {
// GetAccountByPrivateDomain returns account by private domain // GetAccountByPrivateDomain returns account by private domain
func (s *FileStore) GetAccountByPrivateDomain(domain string) (*Account, error) { func (s *FileStore) GetAccountByPrivateDomain(domain string) (*Account, error) {
accountId, accountIdFound := s.PrivateDomain2AccountId[strings.ToLower(domain)] s.mux.Lock()
if !accountIdFound { defer s.mux.Unlock()
return nil, status.Errorf(
codes.NotFound, accountID, accountIDFound := s.PrivateDomain2AccountID[strings.ToLower(domain)]
"provided domain is not registered or is not private", if !accountIDFound {
) return nil, status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private")
} }
account, err := s.GetAccount(accountId) account, err := s.getAccount(accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return account, nil return account.Copy(), nil
} }
// GetAccountBySetupKey returns account by setup key id // GetAccountBySetupKey returns account by setup key id
func (s *FileStore) GetAccountBySetupKey(setupKey string) (*Account, error) { func (s *FileStore) GetAccountBySetupKey(setupKey string) (*Account, error) {
accountId, accountIdFound := s.SetupKeyId2AccountId[strings.ToUpper(setupKey)]
if !accountIdFound {
return nil, status.Errorf(codes.NotFound, "provided setup key doesn't exists")
}
account, err := s.GetAccount(accountId)
if err != nil {
return nil, err
}
return account, nil
}
// GetAccountPeers returns account peers
func (s *FileStore) GetAccountPeers(accountId string) ([]*Peer, error) {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
account, err := s.GetAccount(accountId) accountID, accountIDFound := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)]
if !accountIDFound {
return nil, status.Errorf(status.NotFound, "account not found: provided setup key doesn't exists")
}
account, err := s.getAccount(accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var peers []*Peer return account.Copy(), nil
for _, peer := range account.Peers {
peers = append(peers, peer)
}
return peers, nil
} }
// GetAllAccounts returns all accounts // GetAllAccounts returns all accounts
@@ -425,149 +230,63 @@ func (s *FileStore) GetAllAccounts() (all []*Account) {
return all return all
} }
// GetAccount returns an account for id // getAccount returns a reference to the Account. Should not return a copy.
func (s *FileStore) GetAccount(accountId string) (*Account, error) { func (s *FileStore) getAccount(accountID string) (*Account, error) {
account, accountFound := s.Accounts[accountId] account, accountFound := s.Accounts[accountID]
if !accountFound { if !accountFound {
return nil, status.Errorf(codes.NotFound, "account not found") return nil, status.Errorf(status.NotFound, "account not found")
} }
return account, nil return account, nil
} }
// GetUserAccount returns a user account // GetAccount returns an account for ID
func (s *FileStore) GetUserAccount(userId string) (*Account, error) { func (s *FileStore) GetAccount(accountID string) (*Account, error) {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
accountId, accountIdFound := s.UserId2AccountId[userId] account, err := s.getAccount(accountID)
if !accountIdFound {
return nil, status.Errorf(codes.NotFound, "account not found")
}
return s.GetAccount(accountId)
}
func (s *FileStore) getPeerAccount(peerKey string) (*Account, error) {
accountId, accountIdFound := s.PeerKeyId2AccountId[peerKey]
if !accountIdFound {
return nil, status.Errorf(codes.NotFound, "Provided peer key doesn't exists %s", peerKey)
}
return s.GetAccount(accountId)
}
// GetPeerAccount returns user account if exists
func (s *FileStore) GetPeerAccount(peerKey string) (*Account, error) {
s.mux.Lock()
defer s.mux.Unlock()
return s.getPeerAccount(peerKey)
}
// GetPeerSrcRules return list of source rules for peer
func (s *FileStore) GetPeerSrcRules(accountId, peerKey string) ([]*Rule, error) {
s.mux.Lock()
defer s.mux.Unlock()
account, err := s.GetAccount(accountId)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ruleIDs, ok := s.PeerKeyId2SrcRulesId[peerKey] return account.Copy(), nil
if !ok {
return nil, fmt.Errorf("no rules for peer: %v", ruleIDs)
}
rules := []*Rule{}
for id := range ruleIDs {
rule, ok := account.Rules[id]
if ok {
rules = append(rules, rule)
}
}
return rules, nil
} }
// GetPeerDstRules return list of destination rules for peer // GetAccountByUser returns a user account
func (s *FileStore) GetPeerDstRules(accountId, peerKey string) ([]*Rule, error) { func (s *FileStore) GetAccountByUser(userID string) (*Account, error) {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
account, err := s.GetAccount(accountId) accountID, accountIDFound := s.UserID2AccountID[userID]
if !accountIDFound {
return nil, status.Errorf(status.NotFound, "account not found")
}
account, err := s.getAccount(accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ruleIDs, ok := s.PeerKeyId2DstRulesId[peerKey] return account.Copy(), nil
if !ok {
return nil, fmt.Errorf("no rules for peer: %v", ruleIDs)
}
rules := []*Rule{}
for id := range ruleIDs {
rule, ok := account.Rules[id]
if ok {
rules = append(rules, rule)
}
}
return rules, nil
} }
// GetPeerRoutes return list of routes for peer // GetAccountByPeerPubKey returns an account for a given peer WireGuard public key
func (s *FileStore) GetPeerRoutes(peerKey string) ([]*route.Route, error) { func (s *FileStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
account, err := s.getPeerAccount(peerKey) accountID, accountIDFound := s.PeerKeyID2AccountID[peerKey]
if !accountIDFound {
return nil, status.Errorf(status.NotFound, "provided peer key doesn't exists %s", peerKey)
}
account, err := s.getAccount(accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var routes []*route.Route return account.Copy(), nil
routeIDs, ok := s.PeerKeyID2RouteIDs[peerKey]
if !ok {
return routes, nil
}
for id := range routeIDs {
route, found := account.Routes[id]
if found {
routes = append(routes, route)
}
}
return routes, nil
}
// GetRoutesByPrefix return list of routes by account and route prefix
func (s *FileStore) GetRoutesByPrefix(accountID string, prefix netip.Prefix) ([]*route.Route, error) {
s.mux.Lock()
defer s.mux.Unlock()
account, err := s.GetAccount(accountID)
if err != nil {
return nil, err
}
routeIDs, ok := s.AccountPrefix2RouteIDs[accountID][prefix.String()]
if !ok {
return nil, status.Errorf(codes.NotFound, "no routes for prefix: %v", prefix.String())
}
var routes []*route.Route
for _, id := range routeIDs {
route, found := account.Routes[id]
if found {
routes = append(routes, route)
}
}
return routes, nil
} }
// GetInstallationID returns the installation ID from the store // GetInstallationID returns the installation ID from the store
@@ -576,11 +295,42 @@ func (s *FileStore) GetInstallationID() string {
} }
// SaveInstallationID saves the installation ID // SaveInstallationID saves the installation ID
func (s *FileStore) SaveInstallationID(id string) error { func (s *FileStore) SaveInstallationID(ID string) error {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
s.InstallationID = id s.InstallationID = ID
return s.persist(s.storeFile)
}
// SavePeerStatus stores the PeerStatus in memory. It doesn't attempt to persist data to speed up things.
// PeerStatus will be saved eventually when some other changes occur.
func (s *FileStore) SavePeerStatus(accountID, peerKey string, peerStatus PeerStatus) error {
s.mux.Lock()
defer s.mux.Unlock()
account, err := s.getAccount(accountID)
if err != nil {
return err
}
peer := account.Peers[peerKey]
if peer == nil {
return status.Errorf(status.NotFound, "peer %s not found", peerKey)
}
peer.Status = &peerStatus
return nil
}
// Close the FileStore persisting data to disk
func (s *FileStore) Close() error {
s.mux.Lock()
defer s.mux.Unlock()
log.Infof("closing FileStore")
return s.persist(s.storeFile) return s.persist(s.storeFile)
} }

View File

@@ -2,6 +2,7 @@ package server
import ( import (
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"net" "net"
"path/filepath" "path/filepath"
@@ -9,6 +10,10 @@ import (
"time" "time"
) )
type accounts struct {
Accounts map[string]*Account
}
func TestNewStore(t *testing.T) { func TestNewStore(t *testing.T) {
store := newStore(t) store := newStore(t)
@@ -16,16 +21,16 @@ func TestNewStore(t *testing.T) {
t.Errorf("expected to create a new empty Accounts map when creating a new FileStore") t.Errorf("expected to create a new empty Accounts map when creating a new FileStore")
} }
if store.SetupKeyId2AccountId == nil || len(store.SetupKeyId2AccountId) != 0 { if store.SetupKeyID2AccountID == nil || len(store.SetupKeyID2AccountID) != 0 {
t.Errorf("expected to create a new empty SetupKeyId2AccountId map when creating a new FileStore") t.Errorf("expected to create a new empty SetupKeyID2AccountID map when creating a new FileStore")
} }
if store.PeerKeyId2AccountId == nil || len(store.PeerKeyId2AccountId) != 0 { if store.PeerKeyID2AccountID == nil || len(store.PeerKeyID2AccountID) != 0 {
t.Errorf("expected to create a new empty PeerKeyId2AccountId map when creating a new FileStore") t.Errorf("expected to create a new empty PeerKeyID2AccountID map when creating a new FileStore")
} }
if store.UserId2AccountId == nil || len(store.UserId2AccountId) != 0 { if store.UserID2AccountID == nil || len(store.UserID2AccountID) != 0 {
t.Errorf("expected to create a new empty UserId2AccountId map when creating a new FileStore") t.Errorf("expected to create a new empty UserID2AccountID map when creating a new FileStore")
} }
} }
@@ -55,16 +60,16 @@ func TestSaveAccount(t *testing.T) {
t.Errorf("expecting Account to be stored after SaveAccount()") t.Errorf("expecting Account to be stored after SaveAccount()")
} }
if store.PeerKeyId2AccountId["peerkey"] == "" { if store.PeerKeyID2AccountID["peerkey"] == "" {
t.Errorf("expecting PeerKeyId2AccountId index updated after SaveAccount()") t.Errorf("expecting PeerKeyID2AccountID index updated after SaveAccount()")
} }
if store.UserId2AccountId["testuser"] == "" { if store.UserID2AccountID["testuser"] == "" {
t.Errorf("expecting UserId2AccountId index updated after SaveAccount()") t.Errorf("expecting UserID2AccountID index updated after SaveAccount()")
} }
if store.SetupKeyId2AccountId[setupKey.Key] == "" { if store.SetupKeyID2AccountID[setupKey.Key] == "" {
t.Errorf("expecting SetupKeyId2AccountId index updated after SaveAccount()") t.Errorf("expecting SetupKeyID2AccountID index updated after SaveAccount()")
} }
} }
@@ -88,7 +93,7 @@ func TestStore(t *testing.T) {
return return
} }
restored, err := NewStore(store.storeFile) restored, err := NewFileStore(store.storeFile)
if err != nil { if err != nil {
return return
} }
@@ -124,7 +129,7 @@ func TestRestore(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
store, err := NewStore(storeDir) store, err := NewFileStore(storeDir)
if err != nil { if err != nil {
return return
} }
@@ -141,11 +146,11 @@ func TestRestore(t *testing.T) {
require.NotNil(t, account.SetupKeys["A2C8E62B-38F5-4553-B31E-DD66C696CEBB"], "failed to restore a FileStore file - missing Account SetupKey A2C8E62B-38F5-4553-B31E-DD66C696CEBB") require.NotNil(t, account.SetupKeys["A2C8E62B-38F5-4553-B31E-DD66C696CEBB"], "failed to restore a FileStore file - missing Account SetupKey A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
require.Len(t, store.UserId2AccountId, 2, "failed to restore a FileStore wrong UserId2AccountId mapping length") require.Len(t, store.UserID2AccountID, 2, "failed to restore a FileStore wrong UserID2AccountID mapping length")
require.Len(t, store.SetupKeyId2AccountId, 1, "failed to restore a FileStore wrong SetupKeyId2AccountId mapping length") require.Len(t, store.SetupKeyID2AccountID, 1, "failed to restore a FileStore wrong SetupKeyID2AccountID mapping length")
require.Len(t, store.PrivateDomain2AccountId, 1, "failed to restore a FileStore wrong PrivateDomain2AccountId mapping length") require.Len(t, store.PrivateDomain2AccountID, 1, "failed to restore a FileStore wrong PrivateDomain2AccountID mapping length")
} }
func TestGetAccountByPrivateDomain(t *testing.T) { func TestGetAccountByPrivateDomain(t *testing.T) {
@@ -156,7 +161,7 @@ func TestGetAccountByPrivateDomain(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
store, err := NewStore(storeDir) store, err := NewFileStore(storeDir)
if err != nil { if err != nil {
return return
} }
@@ -171,8 +176,101 @@ func TestGetAccountByPrivateDomain(t *testing.T) {
require.Error(t, err, "should return error on domain lookup") require.Error(t, err, "should return error on domain lookup")
} }
func TestFileStore_GetAccount(t *testing.T) {
storeDir := t.TempDir()
storeFile := filepath.Join(storeDir, "store.json")
err := util.CopyFileContents("testdata/store.json", storeFile)
if err != nil {
t.Fatal(err)
}
accounts := &accounts{}
_, err = util.ReadJson(storeFile, accounts)
if err != nil {
t.Fatal(err)
}
store, err := NewFileStore(storeDir)
if err != nil {
t.Fatal(err)
}
expected := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"]
if expected == nil {
t.Fatalf("expected account doesn't exist")
}
account, err := store.GetAccount(expected.Id)
if err != nil {
t.Fatal(err)
}
assert.Equal(t, expected.IsDomainPrimaryAccount, account.IsDomainPrimaryAccount)
assert.Equal(t, expected.DomainCategory, account.DomainCategory)
assert.Equal(t, expected.Domain, account.Domain)
assert.Equal(t, expected.CreatedBy, account.CreatedBy)
assert.Equal(t, expected.Network.Id, account.Network.Id)
assert.Len(t, account.Peers, len(expected.Peers))
assert.Len(t, account.Users, len(expected.Users))
assert.Len(t, account.SetupKeys, len(expected.SetupKeys))
assert.Len(t, account.Rules, len(expected.Rules))
assert.Len(t, account.Routes, len(expected.Routes))
assert.Len(t, account.NameServerGroups, len(expected.NameServerGroups))
}
func TestFileStore_SavePeerStatus(t *testing.T) {
storeDir := t.TempDir()
err := util.CopyFileContents("testdata/store.json", filepath.Join(storeDir, "store.json"))
if err != nil {
t.Fatal(err)
}
store, err := NewFileStore(storeDir)
if err != nil {
return
}
account, err := store.getAccount("bf1c8084-ba50-4ce7-9439-34653001fc3b")
if err != nil {
t.Fatal(err)
}
// save status of non-existing peer
newStatus := PeerStatus{Connected: true, LastSeen: time.Now()}
err = store.SavePeerStatus(account.Id, "non-existing-peer", newStatus)
assert.Error(t, err)
// save new status of existing peer
account.Peers["testpeer"] = &Peer{
Key: "peerkey",
SetupKey: "peerkeysetupkey",
IP: net.IP{127, 0, 0, 1},
Meta: PeerSystemMeta{},
Name: "peer name",
Status: &PeerStatus{Connected: false, LastSeen: time.Now()},
}
err = store.SaveAccount(account)
if err != nil {
t.Fatal(err)
}
err = store.SavePeerStatus(account.Id, "testpeer", newStatus)
if err != nil {
t.Fatal(err)
}
account, err = store.getAccount(account.Id)
if err != nil {
t.Fatal(err)
}
actual := account.Peers["testpeer"].Status
assert.Equal(t, newStatus, *actual)
}
func newStore(t *testing.T) *FileStore { func newStore(t *testing.T) *FileStore {
store, err := NewStore(t.TempDir()) store, err := NewFileStore(t.TempDir())
if err != nil { if err != nil {
t.Errorf("failed creating a new store") t.Errorf("failed creating a new store")
} }

View File

@@ -1,9 +1,6 @@
package server package server
import ( import "github.com/netbirdio/netbird/management/server/status"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
// Group of the peers for ACL // Group of the peers for ACL
type Group struct { type Group struct {
@@ -47,12 +44,13 @@ func (g *Group) Copy() *Group {
// GetGroup object of the peers // GetGroup object of the peers
func (am *DefaultAccountManager) GetGroup(accountID, groupID string) (*Group, error) { func (am *DefaultAccountManager) GetGroup(accountID, groupID string) (*Group, error) {
am.mux.Lock()
defer am.mux.Unlock() unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found") return nil, err
} }
group, ok := account.Groups[groupID] group, ok := account.Groups[groupID]
@@ -60,17 +58,18 @@ func (am *DefaultAccountManager) GetGroup(accountID, groupID string) (*Group, er
return group, nil return group, nil
} }
return nil, status.Errorf(codes.NotFound, "group with ID %s not found", groupID) return nil, status.Errorf(status.NotFound, "group with ID %s not found", groupID)
} }
// SaveGroup object of the peers // SaveGroup object of the peers
func (am *DefaultAccountManager) SaveGroup(accountID string, group *Group) error { func (am *DefaultAccountManager) SaveGroup(accountID string, group *Group) error {
am.mux.Lock()
defer am.mux.Unlock() unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return status.Errorf(codes.NotFound, "account not found") return err
} }
account.Groups[group.ID] = group account.Groups[group.ID] = group
@@ -86,17 +85,18 @@ func (am *DefaultAccountManager) SaveGroup(accountID string, group *Group) error
// UpdateGroup updates a group using a list of operations // UpdateGroup updates a group using a list of operations
func (am *DefaultAccountManager) UpdateGroup(accountID string, func (am *DefaultAccountManager) UpdateGroup(accountID string,
groupID string, operations []GroupUpdateOperation) (*Group, error) { groupID string, operations []GroupUpdateOperation) (*Group, error) {
am.mux.Lock()
defer am.mux.Unlock() unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found") return nil, err
} }
groupToUpdate, ok := account.Groups[groupID] groupToUpdate, ok := account.Groups[groupID]
if !ok { if !ok {
return nil, status.Errorf(codes.NotFound, "group %s no longer exists", groupID) return nil, status.Errorf(status.NotFound, "group with ID %s no longer exists", groupID)
} }
group := groupToUpdate.Copy() group := groupToUpdate.Copy()
@@ -127,7 +127,7 @@ func (am *DefaultAccountManager) UpdateGroup(accountID string,
err = am.updateAccountPeers(account) err = am.updateAccountPeers(account)
if err != nil { if err != nil {
return nil, status.Errorf(codes.Internal, "failed to update account peers") return nil, err
} }
return group, nil return group, nil
@@ -135,12 +135,13 @@ func (am *DefaultAccountManager) UpdateGroup(accountID string,
// DeleteGroup object of the peers // DeleteGroup object of the peers
func (am *DefaultAccountManager) DeleteGroup(accountID, groupID string) error { func (am *DefaultAccountManager) DeleteGroup(accountID, groupID string) error {
am.mux.Lock()
defer am.mux.Unlock() unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return status.Errorf(codes.NotFound, "account not found") return err
} }
delete(account.Groups, groupID) delete(account.Groups, groupID)
@@ -155,12 +156,13 @@ func (am *DefaultAccountManager) DeleteGroup(accountID, groupID string) error {
// ListGroups objects of the peers // ListGroups objects of the peers
func (am *DefaultAccountManager) ListGroups(accountID string) ([]*Group, error) { func (am *DefaultAccountManager) ListGroups(accountID string) ([]*Group, error) {
am.mux.Lock()
defer am.mux.Unlock() unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found") return nil, err
} }
groups := make([]*Group, 0, len(account.Groups)) groups := make([]*Group, 0, len(account.Groups))
@@ -173,17 +175,18 @@ func (am *DefaultAccountManager) ListGroups(accountID string) ([]*Group, error)
// GroupAddPeer appends peer to the group // GroupAddPeer appends peer to the group
func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerKey string) error { func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerKey string) error {
am.mux.Lock()
defer am.mux.Unlock() unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return status.Errorf(codes.NotFound, "account not found") return err
} }
group, ok := account.Groups[groupID] group, ok := account.Groups[groupID]
if !ok { if !ok {
return status.Errorf(codes.NotFound, "group with ID %s not found", groupID) return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
} }
add := true add := true
@@ -207,17 +210,18 @@ func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerKey string
// GroupDeletePeer removes peer from the group // GroupDeletePeer removes peer from the group
func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerKey string) error { func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerKey string) error {
am.mux.Lock()
defer am.mux.Unlock() unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return status.Errorf(codes.NotFound, "account not found") return err
} }
group, ok := account.Groups[groupID] group, ok := account.Groups[groupID]
if !ok { if !ok {
return status.Errorf(codes.NotFound, "group with ID %s not found", groupID) return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
} }
account.Network.IncSerial() account.Network.IncSerial()
@@ -225,7 +229,7 @@ func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerKey str
if itemID == peerKey { if itemID == peerKey {
group.Peers = append(group.Peers[:i], group.Peers[i+1:]...) group.Peers = append(group.Peers[:i], group.Peers[i+1:]...)
if err := am.Store.SaveAccount(account); err != nil { if err := am.Store.SaveAccount(account); err != nil {
return status.Errorf(codes.Internal, "can't save account") return err
} }
} }
} }
@@ -235,17 +239,18 @@ func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerKey str
// GroupListPeers returns list of the peers from the group // GroupListPeers returns list of the peers from the group
func (am *DefaultAccountManager) GroupListPeers(accountID, groupID string) ([]*Peer, error) { func (am *DefaultAccountManager) GroupListPeers(accountID, groupID string) ([]*Peer, error) {
am.mux.Lock()
defer am.mux.Unlock() unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found") return nil, status.Errorf(status.NotFound, "account not found")
} }
group, ok := account.Groups[groupID] group, ok := account.Groups[groupID]
if !ok { if !ok {
return nil, status.Errorf(codes.NotFound, "group with ID %s not found", groupID) return nil, status.Errorf(status.NotFound, "group with ID %s not found", groupID)
} }
peers := make([]*Peer, 0, len(account.Groups)) peers := make([]*Peer, 0, len(account.Groups))

View File

@@ -3,7 +3,7 @@ package server
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/management/server/telemetry"
gPeer "google.golang.org/grpc/peer" gPeer "google.golang.org/grpc/peer"
"strings" "strings"
"time" "time"
@@ -14,6 +14,7 @@ import (
"github.com/golang/protobuf/ptypes/timestamp" "github.com/golang/protobuf/ptypes/timestamp"
"github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/proto"
internalStatus "github.com/netbirdio/netbird/management/server/status"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
@@ -30,10 +31,12 @@ type GRPCServer struct {
config *Config config *Config
turnCredentialsManager TURNCredentialsManager turnCredentialsManager TURNCredentialsManager
jwtMiddleware *middleware.JWTMiddleware jwtMiddleware *middleware.JWTMiddleware
appMetrics telemetry.AppMetrics
} }
// NewServer creates a new Management server // NewServer creates a new Management server
func NewServer(config *Config, accountManager AccountManager, peersUpdateManager *PeersUpdateManager, turnCredentialsManager TURNCredentialsManager) (*GRPCServer, error) { func NewServer(config *Config, accountManager AccountManager, peersUpdateManager *PeersUpdateManager,
turnCredentialsManager TURNCredentialsManager, appMetrics telemetry.AppMetrics) (*GRPCServer, error) {
key, err := wgtypes.GeneratePrivateKey() key, err := wgtypes.GeneratePrivateKey()
if err != nil { if err != nil {
return nil, err return nil, err
@@ -53,6 +56,16 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager
log.Debug("unable to use http config to create new jwt middleware") log.Debug("unable to use http config to create new jwt middleware")
} }
if appMetrics != nil {
// update gauge based on number of connected peers which is equal to open gRPC streams
err = appMetrics.GRPCMetrics().RegisterConnectedStreams(func() int64 {
return int64(len(peersUpdateManager.peerChannels))
})
if err != nil {
return nil, err
}
}
return &GRPCServer{ return &GRPCServer{
wgKey: key, wgKey: key,
// peerKey -> event channel // peerKey -> event channel
@@ -61,11 +74,15 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager
config: config, config: config,
turnCredentialsManager: turnCredentialsManager, turnCredentialsManager: turnCredentialsManager,
jwtMiddleware: jwtMiddleware, jwtMiddleware: jwtMiddleware,
appMetrics: appMetrics,
}, nil }, nil
} }
func (s *GRPCServer) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.ServerKeyResponse, error) { func (s *GRPCServer) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.ServerKeyResponse, error) {
// todo introduce something more meaningful with the key expiration/rotation // todo introduce something more meaningful with the key expiration/rotation
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountGetKeyRequest()
}
now := time.Now().Add(24 * time.Hour) now := time.Now().Add(24 * time.Hour)
secs := int64(now.Second()) secs := int64(now.Second())
nanos := int32(now.Nanosecond()) nanos := int32(now.Nanosecond())
@@ -80,6 +97,9 @@ func (s *GRPCServer) GetServerKey(ctx context.Context, req *proto.Empty) (*proto
// Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and // Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and
// notifies the connected peer of any updates (e.g. new peers under the same account) // notifies the connected peer of any updates (e.g. new peers under the same account)
func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error { func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error {
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountSyncRequest()
}
p, ok := gRPCPeer.FromContext(srv.Context()) p, ok := gRPCPeer.FromContext(srv.Context())
if ok { if ok {
log.Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, p.Addr.String()) log.Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, p.Addr.String())
@@ -166,7 +186,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
func (s *GRPCServer) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest) (*Peer, error) { func (s *GRPCServer) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest) (*Peer, error) {
var ( var (
reqSetupKey string reqSetupKey string
userId string userID string
) )
if req.GetJwtToken() != "" { if req.GetJwtToken() != "" {
@@ -181,16 +201,16 @@ func (s *GRPCServer) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest)
return nil, status.Errorf(codes.Internal, "invalid jwt token, err: %v", err) return nil, status.Errorf(codes.Internal, "invalid jwt token, err: %v", err)
} }
claims := jwtclaims.ExtractClaimsWithToken(token, s.config.HttpConfig.AuthAudience) claims := jwtclaims.ExtractClaimsWithToken(token, s.config.HttpConfig.AuthAudience)
_, err = s.accountManager.GetAccountFromToken(claims) userID = claims.UserId
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
_, _, err = s.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
return nil, status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err) return nil, status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err)
} }
userId = claims.UserId
} else { } else {
log.Debugln("using setup key to register peer") log.Debugln("using setup key to register peer")
reqSetupKey = req.GetSetupKey() reqSetupKey = req.GetSetupKey()
userId = "" userID = ""
} }
meta := req.GetMeta() meta := req.GetMeta()
@@ -203,7 +223,7 @@ func (s *GRPCServer) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest)
sshKey = req.GetPeerKeys().GetSshPubKey() sshKey = req.GetPeerKeys().GetSshPubKey()
} }
peer, err := s.accountManager.AddPeer(reqSetupKey, userId, &Peer{ peer, err := s.accountManager.AddPeer(reqSetupKey, userID, &Peer{
Key: peerKey.String(), Key: peerKey.String(),
Name: meta.GetHostname(), Name: meta.GetHostname(),
SSHKey: string(sshKey), SSHKey: string(sshKey),
@@ -219,13 +239,16 @@ func (s *GRPCServer) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest)
}, },
}) })
if err != nil { if err != nil {
s, ok := status.FromError(err) if e, ok := internalStatus.FromError(err); ok {
if ok { switch e.Type() {
if s.Code() == codes.FailedPrecondition || s.Code() == codes.OutOfRange { case internalStatus.PreconditionFailed:
return nil, err return nil, status.Errorf(codes.FailedPrecondition, e.Message)
case internalStatus.NotFound:
return nil, status.Errorf(codes.NotFound, e.Message)
default:
} }
} }
return nil, status.Errorf(codes.NotFound, "provided setup key doesn't exists") return nil, status.Errorf(codes.Internal, "failed registering new peer")
} }
// todo move to DefaultAccountManager the code below // todo move to DefaultAccountManager the code below
@@ -233,17 +256,14 @@ func (s *GRPCServer) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest)
if err != nil { if err != nil {
return nil, status.Errorf(codes.Internal, "unable to fetch network map after registering peer, error: %v", err) return nil, status.Errorf(codes.Internal, "unable to fetch network map after registering peer, error: %v", err)
} }
// notify other peers of our registration // notify other peers of our registration
for _, remotePeer := range networkMap.Peers { for _, remotePeer := range networkMap.Peers {
// exclude notified peer and add ourselves remotePeerNetworkMap, err := s.accountManager.GetNetworkMap(remotePeer.Key)
peersToSend := []*Peer{peer} if err != nil {
for _, p := range networkMap.Peers { return nil, status.Errorf(codes.Internal, "unable to fetch network map after registering peer, error: %v", err)
if remotePeer.Key != p.Key {
peersToSend = append(peersToSend, p)
}
} }
update := toSyncResponse(s.config, remotePeer, peersToSend, networkMap.Routes, nil, networkMap.Network.CurrentSerial(), networkMap.Network)
update := toSyncResponse(s.config, remotePeer, nil, remotePeerNetworkMap, s.accountManager.GetDNSDomain())
err = s.peersUpdateManager.SendUpdate(remotePeer.Key, &UpdateMessage{Update: update}) err = s.peersUpdateManager.SendUpdate(remotePeer.Key, &UpdateMessage{Update: update})
if err != nil { if err != nil {
// todo rethink if we should keep this return // todo rethink if we should keep this return
@@ -259,6 +279,9 @@ func (s *GRPCServer) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest)
// In case it isn't, the endpoint checks whether setup key is provided within the request and tries to register a peer. // In case it isn't, the endpoint checks whether setup key is provided within the request and tries to register a peer.
// In case of the successful registration login is also successful // In case of the successful registration login is also successful
func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountLoginRequest()
}
p, ok := gRPCPeer.FromContext(ctx) p, ok := gRPCPeer.FromContext(ctx)
if ok { if ok {
log.Debugf("Login request from peer [%s] [%s]", req.WgPubKey, p.Addr.String()) log.Debugf("Login request from peer [%s] [%s]", req.WgPubKey, p.Addr.String())
@@ -278,7 +301,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
peer, err := s.accountManager.GetPeer(peerKey.String()) peer, err := s.accountManager.GetPeer(peerKey.String())
if err != nil { if err != nil {
if errStatus, ok := status.FromError(err); ok && errStatus.Code() == codes.NotFound { if errStatus, ok := internalStatus.FromError(err); ok && errStatus.Type() == internalStatus.NotFound {
// peer doesn't exist -> check if setup key was provided // peer doesn't exist -> check if setup key was provided
if loginReq.GetJwtToken() == "" && loginReq.GetSetupKey() == "" { if loginReq.GetJwtToken() == "" && loginReq.GetSetupKey() == "" {
// absent setup key or jwt -> permission denied // absent setup key or jwt -> permission denied
@@ -338,7 +361,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
// if peer has reached this point then it has logged in // if peer has reached this point then it has logged in
loginResp := &proto.LoginResponse{ loginResp := &proto.LoginResponse{
WiretrusteeConfig: toWiretrusteeConfig(s.config, nil), WiretrusteeConfig: toWiretrusteeConfig(s.config, nil),
PeerConfig: toPeerConfig(peer, network), PeerConfig: toPeerConfig(peer, network, s.accountManager.GetDNSDomain()),
} }
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp) encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp)
if err != nil { if err != nil {
@@ -364,12 +387,14 @@ func ToResponseProto(configProto Protocol) proto.HostConfig_Protocol {
case TCP: case TCP:
return proto.HostConfig_TCP return proto.HostConfig_TCP
default: default:
// mbragin: todo something better?
panic(fmt.Errorf("unexpected config protocol type %v", configProto)) panic(fmt.Errorf("unexpected config protocol type %v", configProto))
} }
} }
func toWiretrusteeConfig(config *Config, turnCredentials *TURNCredentials) *proto.WiretrusteeConfig { func toWiretrusteeConfig(config *Config, turnCredentials *TURNCredentials) *proto.WiretrusteeConfig {
if config == nil {
return nil
}
var stuns []*proto.HostConfig var stuns []*proto.HostConfig
for _, stun := range config.Stuns { for _, stun := range config.Stuns {
stuns = append(stuns, &proto.HostConfig{ stuns = append(stuns, &proto.HostConfig{
@@ -408,34 +433,46 @@ func toWiretrusteeConfig(config *Config, turnCredentials *TURNCredentials) *prot
} }
} }
func toPeerConfig(peer *Peer, network *Network) *proto.PeerConfig { func toPeerConfig(peer *Peer, network *Network, dnsName string) *proto.PeerConfig {
netmask, _ := network.Net.Mask.Size() netmask, _ := network.Net.Mask.Size()
fqdn := ""
if dnsName != "" {
fqdn = peer.DNSLabel + "." + dnsName
}
return &proto.PeerConfig{ return &proto.PeerConfig{
Address: fmt.Sprintf("%s/%d", peer.IP.String(), netmask), // take it from the network Address: fmt.Sprintf("%s/%d", peer.IP.String(), netmask), // take it from the network
SshConfig: &proto.SSHConfig{SshEnabled: peer.SSHEnabled}, SshConfig: &proto.SSHConfig{SshEnabled: peer.SSHEnabled},
Fqdn: fqdn,
} }
} }
func toRemotePeerConfig(peers []*Peer) []*proto.RemotePeerConfig { func toRemotePeerConfig(peers []*Peer, dnsName string) []*proto.RemotePeerConfig {
remotePeers := []*proto.RemotePeerConfig{} remotePeers := []*proto.RemotePeerConfig{}
for _, rPeer := range peers { for _, rPeer := range peers {
fqdn := ""
if dnsName != "" {
fqdn = rPeer.DNSLabel + "." + dnsName
}
remotePeers = append(remotePeers, &proto.RemotePeerConfig{ remotePeers = append(remotePeers, &proto.RemotePeerConfig{
WgPubKey: rPeer.Key, WgPubKey: rPeer.Key,
AllowedIps: []string{fmt.Sprintf(AllowedIPsFormat, rPeer.IP)}, AllowedIps: []string{fmt.Sprintf(AllowedIPsFormat, rPeer.IP)},
SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)}, SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
Fqdn: fqdn,
}) })
} }
return remotePeers return remotePeers
} }
func toSyncResponse(config *Config, peer *Peer, peers []*Peer, routes []*route.Route, turnCredentials *TURNCredentials, serial uint64, network *Network) *proto.SyncResponse { func toSyncResponse(config *Config, peer *Peer, turnCredentials *TURNCredentials, networkMap *NetworkMap, dnsName string) *proto.SyncResponse {
wtConfig := toWiretrusteeConfig(config, turnCredentials) wtConfig := toWiretrusteeConfig(config, turnCredentials)
pConfig := toPeerConfig(peer, network) pConfig := toPeerConfig(peer, networkMap.Network, dnsName)
remotePeers := toRemotePeerConfig(peers) remotePeers := toRemotePeerConfig(networkMap.Peers, dnsName)
routesUpdate := toProtocolRoutes(routes) routesUpdate := toProtocolRoutes(networkMap.Routes)
dnsUpdate := toProtocolDNSConfig(networkMap.DNSConfig)
return &proto.SyncResponse{ return &proto.SyncResponse{
WiretrusteeConfig: wtConfig, WiretrusteeConfig: wtConfig,
@@ -443,11 +480,12 @@ func toSyncResponse(config *Config, peer *Peer, peers []*Peer, routes []*route.R
RemotePeers: remotePeers, RemotePeers: remotePeers,
RemotePeersIsEmpty: len(remotePeers) == 0, RemotePeersIsEmpty: len(remotePeers) == 0,
NetworkMap: &proto.NetworkMap{ NetworkMap: &proto.NetworkMap{
Serial: serial, Serial: networkMap.Network.CurrentSerial(),
PeerConfig: pConfig, PeerConfig: pConfig,
RemotePeers: remotePeers, RemotePeers: remotePeers,
RemotePeersIsEmpty: len(remotePeers) == 0, RemotePeersIsEmpty: len(remotePeers) == 0,
Routes: routesUpdate, Routes: routesUpdate,
DNSConfig: dnsUpdate,
}, },
} }
} }
@@ -473,7 +511,7 @@ func (s *GRPCServer) sendInitialSync(peerKey wgtypes.Key, peer *Peer, srv proto.
} else { } else {
turnCredentials = nil turnCredentials = nil
} }
plainResp := toSyncResponse(s.config, peer, networkMap.Peers, networkMap.Routes, turnCredentials, networkMap.Network.CurrentSerial(), networkMap.Network) plainResp := toSyncResponse(s.config, peer, turnCredentials, networkMap, s.accountManager.GetDNSDomain())
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp) encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
if err != nil { if err != nil {

View File

@@ -136,6 +136,9 @@ components:
ui_version: ui_version:
description: Peer's desktop UI version description: Peer's desktop UI version
type: string type: string
dns_label:
description: Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud
type: string
required: required:
- ip - ip
- connected - connected
@@ -145,6 +148,7 @@ components:
- groups - groups
- ssh_enabled - ssh_enabled
- hostname - hostname
- dns_label
SetupKey: SetupKey:
type: object type: object
properties: properties:
@@ -444,12 +448,24 @@ components:
type: array type: array
items: items:
type: string type: string
primary:
description: Nameserver group primary status
type: boolean
domains:
description: Nameserver group domain list
type: array
items:
type: string
minLength: 1
maxLength: 255
required: required:
- name - name
- description - description
- nameservers - nameservers
- enabled - enabled
- groups - groups
- primary
- domains
NameserverGroup: NameserverGroup:
allOf: allOf:
- type: object - type: object
@@ -468,7 +484,7 @@ components:
path: path:
description: Nameserver group field to update in form /<field> description: Nameserver group field to update in form /<field>
type: string type: string
enum: [ "name","description","enabled","groups","nameservers" ] enum: [ "name", "description", "enabled", "groups", "nameservers", "primary", "domains" ]
required: required:
- path - path

View File

@@ -39,10 +39,12 @@ const (
// Defines values for NameserverGroupPatchOperationPath. // Defines values for NameserverGroupPatchOperationPath.
const ( const (
NameserverGroupPatchOperationPathDescription NameserverGroupPatchOperationPath = "description" NameserverGroupPatchOperationPathDescription NameserverGroupPatchOperationPath = "description"
NameserverGroupPatchOperationPathDomains NameserverGroupPatchOperationPath = "domains"
NameserverGroupPatchOperationPathEnabled NameserverGroupPatchOperationPath = "enabled" NameserverGroupPatchOperationPathEnabled NameserverGroupPatchOperationPath = "enabled"
NameserverGroupPatchOperationPathGroups NameserverGroupPatchOperationPath = "groups" NameserverGroupPatchOperationPathGroups NameserverGroupPatchOperationPath = "groups"
NameserverGroupPatchOperationPathName NameserverGroupPatchOperationPath = "name" NameserverGroupPatchOperationPathName NameserverGroupPatchOperationPath = "name"
NameserverGroupPatchOperationPathNameservers NameserverGroupPatchOperationPath = "nameservers" NameserverGroupPatchOperationPathNameservers NameserverGroupPatchOperationPath = "nameservers"
NameserverGroupPatchOperationPathPrimary NameserverGroupPatchOperationPath = "primary"
) )
// Defines values for PatchMinimumOp. // Defines values for PatchMinimumOp.
@@ -159,6 +161,9 @@ type NameserverGroup struct {
// Description Nameserver group description // Description Nameserver group description
Description string `json:"description"` Description string `json:"description"`
// Domains Nameserver group domain list
Domains []string `json:"domains"`
// Enabled Nameserver group status // Enabled Nameserver group status
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
@@ -173,6 +178,9 @@ type NameserverGroup struct {
// Nameservers Nameserver group // Nameservers Nameserver group
Nameservers []Nameserver `json:"nameservers"` Nameservers []Nameserver `json:"nameservers"`
// Primary Nameserver group primary status
Primary bool `json:"primary"`
} }
// NameserverGroupPatchOperation defines model for NameserverGroupPatchOperation. // NameserverGroupPatchOperation defines model for NameserverGroupPatchOperation.
@@ -198,6 +206,9 @@ type NameserverGroupRequest struct {
// Description Nameserver group description // Description Nameserver group description
Description string `json:"description"` Description string `json:"description"`
// Domains Nameserver group domain list
Domains []string `json:"domains"`
// Enabled Nameserver group status // Enabled Nameserver group status
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
@@ -209,6 +220,9 @@ type NameserverGroupRequest struct {
// Nameservers Nameserver group // Nameservers Nameserver group
Nameservers []Nameserver `json:"nameservers"` Nameservers []Nameserver `json:"nameservers"`
// Primary Nameserver group primary status
Primary bool `json:"primary"`
} }
// PatchMinimum defines model for PatchMinimum. // PatchMinimum defines model for PatchMinimum.
@@ -228,6 +242,9 @@ type Peer struct {
// Connected Peer to Management connection status // Connected Peer to Management connection status
Connected bool `json:"connected"` Connected bool `json:"connected"`
// DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud
DnsLabel string `json:"dns_label"`
// Groups Groups that the peer belongs to // Groups Groups that the peer belongs to
Groups []GroupMinimum `json:"groups"` Groups []GroupMinimum `json:"groups"`

View File

@@ -2,10 +2,9 @@ package http
import ( import (
"encoding/json" "encoding/json"
"fmt"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"google.golang.org/grpc/codes" "github.com/netbirdio/netbird/management/server/http/util"
"google.golang.org/grpc/status" "github.com/netbirdio/netbird/management/server/status"
"net/http" "net/http"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
@@ -33,7 +32,8 @@ func NewGroups(accountManager server.AccountManager, authAudience string) *Group
// GetAllGroupsHandler list for the account // GetAllGroupsHandler list for the account
func (h *Groups) GetAllGroupsHandler(w http.ResponseWriter, r *http.Request) { func (h *Groups) GetAllGroupsHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
@@ -45,52 +45,54 @@ func (h *Groups) GetAllGroupsHandler(w http.ResponseWriter, r *http.Request) {
groups = append(groups, toGroupResponse(account, g)) groups = append(groups, toGroupResponse(account, g))
} }
writeJSONObject(w, groups) util.WriteJSONObject(w, groups)
} }
// UpdateGroupHandler handles update to a group identified by a given ID // UpdateGroupHandler handles update to a group identified by a given ID
func (h *Groups) UpdateGroupHandler(w http.ResponseWriter, r *http.Request) { func (h *Groups) UpdateGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError) util.WriteError(err, w)
return return
} }
vars := mux.Vars(r) vars := mux.Vars(r)
groupID, ok := vars["id"] groupID, ok := vars["id"]
if !ok { if !ok {
http.Error(w, "group ID field is missing", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "group ID field is missing"), w)
return return
} }
if len(groupID) == 0 { if len(groupID) == 0 {
http.Error(w, "group ID can't be empty", http.StatusUnprocessableEntity) util.WriteError(status.Errorf(status.InvalidArgument, "group ID can't be empty"), w)
return return
} }
_, ok = account.Groups[groupID] _, ok = account.Groups[groupID]
if !ok { if !ok {
http.Error(w, fmt.Sprintf("couldn't find group with ID %s", groupID), http.StatusNotFound) util.WriteError(status.Errorf(status.NotFound, "couldn't find group with ID %s", groupID), w)
return return
} }
allGroup, err := account.GetGroupAll() allGroup, err := account.GetGroupAll()
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) util.WriteError(err, w)
return return
} }
if allGroup.ID == groupID { if allGroup.ID == groupID {
http.Error(w, "updating group ALL is not allowed", http.StatusMethodNotAllowed) util.WriteError(status.Errorf(status.InvalidArgument, "updating group ALL is not allowed"), w)
return return
} }
var req api.PutApiGroupsIdJSONRequestBody var req api.PutApiGroupsIdJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { err = json.NewDecoder(r.Body).Decode(&req)
http.Error(w, err.Error(), http.StatusBadRequest) if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return return
} }
if *req.Name == "" { if *req.Name == "" {
http.Error(w, "group name shouldn't be empty", http.StatusUnprocessableEntity) util.WriteError(status.Errorf(status.InvalidArgument, "group name shouldn't be empty"), w)
return return
} }
@@ -102,53 +104,55 @@ func (h *Groups) UpdateGroupHandler(w http.ResponseWriter, r *http.Request) {
if err := h.accountManager.SaveGroup(account.Id, &group); err != nil { if err := h.accountManager.SaveGroup(account.Id, &group); err != nil {
log.Errorf("failed updating group %s under account %s %v", groupID, account.Id, err) log.Errorf("failed updating group %s under account %s %v", groupID, account.Id, err)
http.Redirect(w, r, "/", http.StatusInternalServerError) util.WriteError(err, w)
return return
} }
writeJSONObject(w, toGroupResponse(account, &group)) util.WriteJSONObject(w, toGroupResponse(account, &group))
} }
// PatchGroupHandler handles patch updates to a group identified by a given ID // PatchGroupHandler handles patch updates to a group identified by a given ID
func (h *Groups) PatchGroupHandler(w http.ResponseWriter, r *http.Request) { func (h *Groups) PatchGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError) util.WriteError(err, w)
return return
} }
vars := mux.Vars(r) vars := mux.Vars(r)
groupID := vars["id"] groupID := vars["id"]
if len(groupID) == 0 { if len(groupID) == 0 {
http.Error(w, "invalid group Id", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "invalid group ID"), w)
return return
} }
_, ok := account.Groups[groupID] _, ok := account.Groups[groupID]
if !ok { if !ok {
http.Error(w, fmt.Sprintf("couldn't find group id %s", groupID), http.StatusNotFound) util.WriteError(status.Errorf(status.NotFound, "couldn't find group ID %s", groupID), w)
return return
} }
allGroup, err := account.GetGroupAll() allGroup, err := account.GetGroupAll()
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) util.WriteError(err, w)
return return
} }
if allGroup.ID == groupID { if allGroup.ID == groupID {
http.Error(w, "updating group ALL is not allowed", http.StatusMethodNotAllowed) util.WriteError(status.Errorf(status.InvalidArgument, "updating group ALL is not allowed"), w)
return return
} }
var req api.PatchApiGroupsIdJSONRequestBody var req api.PatchApiGroupsIdJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { err = json.NewDecoder(r.Body).Decode(&req)
http.Error(w, err.Error(), http.StatusBadRequest) if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return return
} }
if len(req) == 0 { if len(req) == 0 {
http.Error(w, "no patch instruction received", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "no patch instruction received"), w)
return return
} }
@@ -158,13 +162,13 @@ func (h *Groups) PatchGroupHandler(w http.ResponseWriter, r *http.Request) {
switch patch.Path { switch patch.Path {
case api.GroupPatchOperationPathName: case api.GroupPatchOperationPathName:
if patch.Op != api.GroupPatchOperationOpReplace { if patch.Op != api.GroupPatchOperationOpReplace {
http.Error(w, fmt.Sprintf("Name field only accepts replace operation, got %s", patch.Op), util.WriteError(status.Errorf(status.InvalidArgument,
http.StatusBadRequest) "name field only accepts replace operation, got %s", patch.Op), w)
return return
} }
if len(patch.Value) == 0 || patch.Value[0] == "" { if len(patch.Value) == 0 || patch.Value[0] == "" {
http.Error(w, "Group name shouldn't be empty", http.StatusUnprocessableEntity) util.WriteError(status.Errorf(status.InvalidArgument, "group name shouldn't be empty"), w)
return return
} }
@@ -193,53 +197,43 @@ func (h *Groups) PatchGroupHandler(w http.ResponseWriter, r *http.Request) {
Values: peerKeys, Values: peerKeys,
}) })
default: default:
http.Error(w, "invalid operation, \"%s\", for Peers field", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument,
"invalid operation, \"%v\", for Peers field", patch.Op), w)
return return
} }
default: default:
http.Error(w, "invalid patch path", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "invalid patch path"), w)
return return
} }
} }
group, err := h.accountManager.UpdateGroup(account.Id, groupID, operations) group, err := h.accountManager.UpdateGroup(account.Id, groupID, operations)
if err != nil { if err != nil {
errStatus, ok := status.FromError(err) util.WriteError(err, w)
if ok && errStatus.Code() == codes.Internal {
http.Error(w, errStatus.String(), http.StatusInternalServerError)
return
}
if ok && errStatus.Code() == codes.NotFound {
http.Error(w, errStatus.String(), http.StatusNotFound)
return
}
log.Errorf("failed updating group %s under account %s %v", groupID, account.Id, err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
writeJSONObject(w, toGroupResponse(account, group)) util.WriteJSONObject(w, toGroupResponse(account, group))
} }
// CreateGroupHandler handles group creation request // CreateGroupHandler handles group creation request
func (h *Groups) CreateGroupHandler(w http.ResponseWriter, r *http.Request) { func (h *Groups) CreateGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError) util.WriteError(err, w)
return return
} }
var req api.PostApiGroupsJSONRequestBody var req api.PostApiGroupsJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { err = json.NewDecoder(r.Body).Decode(&req)
http.Error(w, err.Error(), http.StatusBadRequest) if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return return
} }
if req.Name == "" { if req.Name == "" {
http.Error(w, "Group name shouldn't be empty", http.StatusUnprocessableEntity) util.WriteError(status.Errorf(status.InvalidArgument, "group name shouldn't be empty"), w)
return return
} }
@@ -249,55 +243,57 @@ func (h *Groups) CreateGroupHandler(w http.ResponseWriter, r *http.Request) {
Peers: peerIPsToKeys(account, req.Peers), Peers: peerIPsToKeys(account, req.Peers),
} }
if err := h.accountManager.SaveGroup(account.Id, &group); err != nil { err = h.accountManager.SaveGroup(account.Id, &group)
log.Errorf("failed creating group \"%s\" under account %s %v", req.Name, account.Id, err) if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError) util.WriteError(err, w)
return return
} }
writeJSONObject(w, toGroupResponse(account, &group)) util.WriteJSONObject(w, toGroupResponse(account, &group))
} }
// DeleteGroupHandler handles group deletion request // DeleteGroupHandler handles group deletion request
func (h *Groups) DeleteGroupHandler(w http.ResponseWriter, r *http.Request) { func (h *Groups) DeleteGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError) util.WriteError(err, w)
return return
} }
aID := account.Id aID := account.Id
groupID := mux.Vars(r)["id"] groupID := mux.Vars(r)["id"]
if len(groupID) == 0 { if len(groupID) == 0 {
http.Error(w, "invalid group ID", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "invalid group ID"), w)
return return
} }
allGroup, err := account.GetGroupAll() allGroup, err := account.GetGroupAll()
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) util.WriteError(err, w)
return return
} }
if allGroup.ID == groupID { if allGroup.ID == groupID {
http.Error(w, "deleting group ALL is not allowed", http.StatusMethodNotAllowed) util.WriteError(status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed"), w)
return return
} }
if err := h.accountManager.DeleteGroup(aID, groupID); err != nil { err = h.accountManager.DeleteGroup(aID, groupID)
log.Errorf("failed delete group %s under account %s %v", groupID, aID, err) if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError) util.WriteError(err, w)
return return
} }
writeJSONObject(w, "") util.WriteJSONObject(w, "")
} }
// GetGroupHandler returns a group // GetGroupHandler returns a group
func (h *Groups) GetGroupHandler(w http.ResponseWriter, r *http.Request) { func (h *Groups) GetGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError) util.WriteError(err, w)
return return
} }
@@ -305,19 +301,22 @@ func (h *Groups) GetGroupHandler(w http.ResponseWriter, r *http.Request) {
case http.MethodGet: case http.MethodGet:
groupID := mux.Vars(r)["id"] groupID := mux.Vars(r)["id"]
if len(groupID) == 0 { if len(groupID) == 0 {
http.Error(w, "invalid group ID", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "invalid group ID"), w)
return return
} }
group, err := h.accountManager.GetGroup(account.Id, groupID) group, err := h.accountManager.GetGroup(account.Id, groupID)
if err != nil { if err != nil {
http.Error(w, "group not found", http.StatusNotFound) util.WriteError(err, w)
return return
} }
writeJSONObject(w, toGroupResponse(account, group)) util.WriteJSONObject(w, toGroupResponse(account, group))
default: default:
http.Error(w, "", http.StatusNotFound) if err != nil {
util.WriteError(status.Errorf(status.NotFound, "HTTP method not found"), w)
return
}
} }
} }

View File

@@ -5,6 +5,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/status"
"io" "io"
"net" "net"
"net/http" "net/http"
@@ -21,11 +22,11 @@ import (
) )
var TestPeers = map[string]*server.Peer{ var TestPeers = map[string]*server.Peer{
"A": &server.Peer{Key: "A", IP: net.ParseIP("100.100.100.100")}, "A": {Key: "A", IP: net.ParseIP("100.100.100.100")},
"B": &server.Peer{Key: "B", IP: net.ParseIP("200.200.200.200")}, "B": {Key: "B", IP: net.ParseIP("200.200.200.200")},
} }
func initGroupTestData(groups ...*server.Group) *Groups { func initGroupTestData(user *server.User, groups ...*server.Group) *Groups {
return &Groups{ return &Groups{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
SaveGroupFunc: func(accountID string, group *server.Group) error { SaveGroupFunc: func(accountID string, group *server.Group) error {
@@ -36,7 +37,7 @@ func initGroupTestData(groups ...*server.Group) *Groups {
}, },
GetGroupFunc: func(_, groupID string) (*server.Group, error) { GetGroupFunc: func(_, groupID string) (*server.Group, error) {
if groupID != "idofthegroup" { if groupID != "idofthegroup" {
return nil, fmt.Errorf("not found") return nil, status.Errorf(status.NotFound, "not found")
} }
return &server.Group{ return &server.Group{
ID: "idofthegroup", ID: "idofthegroup",
@@ -67,15 +68,18 @@ func initGroupTestData(groups ...*server.Group) *Groups {
} }
return nil, fmt.Errorf("peer not found") return nil, fmt.Errorf("peer not found")
}, },
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) { GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return &server.Account{ return &server.Account{
Id: claims.AccountId, Id: claims.AccountId,
Domain: "hotmail.com", Domain: "hotmail.com",
Peers: TestPeers, Peers: TestPeers,
Users: map[string]*server.User{
user.Id: user,
},
Groups: map[string]*server.Group{ Groups: map[string]*server.Group{
"id-existed": {ID: "id-existed", Peers: []string{"A", "B"}}, "id-existed": {ID: "id-existed", Peers: []string{"A", "B"}},
"id-all": {ID: "id-all", Name: "All"}}, "id-all": {ID: "id-all", Name: "All"}},
}, nil }, user, nil
}, },
}, },
authAudience: "", authAudience: "",
@@ -120,7 +124,8 @@ func TestGetGroup(t *testing.T) {
Name: "Group", Name: "Group",
} }
p := initGroupTestData(group) adminUser := server.NewAdminUser("test_user")
p := initGroupTestData(adminUser, group)
for _, tc := range tt { for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
@@ -219,7 +224,7 @@ func TestWriteGroup(t *testing.T) {
requestPath: "/api/groups/id-all", requestPath: "/api/groups/id-all",
requestBody: bytes.NewBuffer( requestBody: bytes.NewBuffer(
[]byte(`{"Name":"super"}`)), []byte(`{"Name":"super"}`)),
expectedStatus: http.StatusMethodNotAllowed, expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false, expectedBody: false,
}, },
{ {
@@ -240,7 +245,7 @@ func TestWriteGroup(t *testing.T) {
requestPath: "/api/groups/id-existed", requestPath: "/api/groups/id-existed",
requestBody: bytes.NewBuffer( requestBody: bytes.NewBuffer(
[]byte(`[{"op":"insert","path":"name","value":[""]}]`)), []byte(`[{"op":"insert","path":"name","value":[""]}]`)),
expectedStatus: http.StatusBadRequest, expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false, expectedBody: false,
}, },
{ {
@@ -270,7 +275,8 @@ func TestWriteGroup(t *testing.T) {
}, },
} }
p := initGroupTestData() adminUser := server.NewAdminUser("test_user")
p := initGroupTestData(adminUser)
for _, tc := range tt { for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {

View File

@@ -4,12 +4,14 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
s "github.com/netbirdio/netbird/management/server" s "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/middleware" "github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/rs/cors" "github.com/rs/cors"
"net/http" "net/http"
) )
// APIHandler creates the Management service HTTP API handler registering all the available endpoints. // APIHandler creates the Management service HTTP API handler registering all the available endpoints.
func APIHandler(accountManager s.AccountManager, authIssuer string, authAudience string, authKeysLocation string) (http.Handler, error) { func APIHandler(accountManager s.AccountManager, authIssuer string, authAudience string, authKeysLocation string,
appMetrics telemetry.AppMetrics) (http.Handler, error) {
jwtMiddleware, err := middleware.NewJwtMiddleware( jwtMiddleware, err := middleware.NewJwtMiddleware(
authIssuer, authIssuer,
authAudience, authAudience,
@@ -21,12 +23,15 @@ func APIHandler(accountManager s.AccountManager, authIssuer string, authAudience
corsMiddleware := cors.AllowAll() corsMiddleware := cors.AllowAll()
acMiddleware := middleware.NewAccessControll( acMiddleware := middleware.NewAccessControl(
authAudience, authAudience,
accountManager.IsUserAdmin) accountManager.IsUserAdmin)
apiHandler := mux.NewRouter() rootRouter := mux.NewRouter()
apiHandler.Use(corsMiddleware.Handler, jwtMiddleware.Handler, acMiddleware.Handler) metricsMiddleware := appMetrics.HTTPMiddleware()
apiHandler := rootRouter.PathPrefix("/api").Subrouter()
apiHandler.Use(metricsMiddleware.Handler, corsMiddleware.Handler, jwtMiddleware.Handler, acMiddleware.Handler)
groupsHandler := NewGroups(accountManager, authAudience) groupsHandler := NewGroups(accountManager, authAudience)
rulesHandler := NewRules(accountManager, authAudience) rulesHandler := NewRules(accountManager, authAudience)
@@ -36,46 +41,67 @@ func APIHandler(accountManager s.AccountManager, authIssuer string, authAudience
routesHandler := NewRoutes(accountManager, authAudience) routesHandler := NewRoutes(accountManager, authAudience)
nameserversHandler := NewNameservers(accountManager, authAudience) nameserversHandler := NewNameservers(accountManager, authAudience)
apiHandler.HandleFunc("/api/peers", peersHandler.GetPeers).Methods("GET", "OPTIONS") apiHandler.HandleFunc("/peers", peersHandler.GetPeers).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/peers/{id}", peersHandler.HandlePeer). apiHandler.HandleFunc("/peers/{id}", peersHandler.HandlePeer).
Methods("GET", "PUT", "DELETE", "OPTIONS") Methods("GET", "PUT", "DELETE", "OPTIONS")
apiHandler.HandleFunc("/api/users", userHandler.GetUsers).Methods("GET", "OPTIONS") apiHandler.HandleFunc("/users", userHandler.GetUsers).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/users/{id}", userHandler.UpdateUser).Methods("PUT", "OPTIONS") apiHandler.HandleFunc("/users/{id}", userHandler.UpdateUser).Methods("PUT", "OPTIONS")
apiHandler.HandleFunc("/api/users", userHandler.CreateUserHandler).Methods("POST", "OPTIONS") apiHandler.HandleFunc("/users", userHandler.CreateUserHandler).Methods("POST", "OPTIONS")
apiHandler.HandleFunc("/api/setup-keys", keysHandler.GetAllSetupKeysHandler).Methods("GET", "OPTIONS") apiHandler.HandleFunc("/setup-keys", keysHandler.GetAllSetupKeysHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/setup-keys", keysHandler.CreateSetupKeyHandler).Methods("POST", "OPTIONS") apiHandler.HandleFunc("/setup-keys", keysHandler.CreateSetupKeyHandler).Methods("POST", "OPTIONS")
apiHandler.HandleFunc("/api/setup-keys/{id}", keysHandler.GetSetupKeyHandler).Methods("GET", "OPTIONS") apiHandler.HandleFunc("/setup-keys/{id}", keysHandler.GetSetupKeyHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/setup-keys/{id}", keysHandler.UpdateSetupKeyHandler).Methods("PUT", "OPTIONS") apiHandler.HandleFunc("/setup-keys/{id}", keysHandler.UpdateSetupKeyHandler).Methods("PUT", "OPTIONS")
apiHandler.HandleFunc("/api/rules", rulesHandler.GetAllRulesHandler).Methods("GET", "OPTIONS") apiHandler.HandleFunc("/rules", rulesHandler.GetAllRulesHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/rules", rulesHandler.CreateRuleHandler).Methods("POST", "OPTIONS") apiHandler.HandleFunc("/rules", rulesHandler.CreateRuleHandler).Methods("POST", "OPTIONS")
apiHandler.HandleFunc("/api/rules/{id}", rulesHandler.UpdateRuleHandler).Methods("PUT", "OPTIONS") apiHandler.HandleFunc("/rules/{id}", rulesHandler.UpdateRuleHandler).Methods("PUT", "OPTIONS")
apiHandler.HandleFunc("/api/rules/{id}", rulesHandler.PatchRuleHandler).Methods("PATCH", "OPTIONS") apiHandler.HandleFunc("/rules/{id}", rulesHandler.PatchRuleHandler).Methods("PATCH", "OPTIONS")
apiHandler.HandleFunc("/api/rules/{id}", rulesHandler.GetRuleHandler).Methods("GET", "OPTIONS") apiHandler.HandleFunc("/rules/{id}", rulesHandler.GetRuleHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/rules/{id}", rulesHandler.DeleteRuleHandler).Methods("DELETE", "OPTIONS") apiHandler.HandleFunc("/rules/{id}", rulesHandler.DeleteRuleHandler).Methods("DELETE", "OPTIONS")
apiHandler.HandleFunc("/api/groups", groupsHandler.GetAllGroupsHandler).Methods("GET", "OPTIONS") apiHandler.HandleFunc("/groups", groupsHandler.GetAllGroupsHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/groups", groupsHandler.CreateGroupHandler).Methods("POST", "OPTIONS") apiHandler.HandleFunc("/groups", groupsHandler.CreateGroupHandler).Methods("POST", "OPTIONS")
apiHandler.HandleFunc("/api/groups/{id}", groupsHandler.UpdateGroupHandler).Methods("PUT", "OPTIONS") apiHandler.HandleFunc("/groups/{id}", groupsHandler.UpdateGroupHandler).Methods("PUT", "OPTIONS")
apiHandler.HandleFunc("/api/groups/{id}", groupsHandler.PatchGroupHandler).Methods("PATCH", "OPTIONS") apiHandler.HandleFunc("/groups/{id}", groupsHandler.PatchGroupHandler).Methods("PATCH", "OPTIONS")
apiHandler.HandleFunc("/api/groups/{id}", groupsHandler.GetGroupHandler).Methods("GET", "OPTIONS") apiHandler.HandleFunc("/groups/{id}", groupsHandler.GetGroupHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/groups/{id}", groupsHandler.DeleteGroupHandler).Methods("DELETE", "OPTIONS") apiHandler.HandleFunc("/groups/{id}", groupsHandler.DeleteGroupHandler).Methods("DELETE", "OPTIONS")
apiHandler.HandleFunc("/api/routes", routesHandler.GetAllRoutesHandler).Methods("GET", "OPTIONS") apiHandler.HandleFunc("/routes", routesHandler.GetAllRoutesHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/routes", routesHandler.CreateRouteHandler).Methods("POST", "OPTIONS") apiHandler.HandleFunc("/routes", routesHandler.CreateRouteHandler).Methods("POST", "OPTIONS")
apiHandler.HandleFunc("/api/routes/{id}", routesHandler.UpdateRouteHandler).Methods("PUT", "OPTIONS") apiHandler.HandleFunc("/routes/{id}", routesHandler.UpdateRouteHandler).Methods("PUT", "OPTIONS")
apiHandler.HandleFunc("/api/routes/{id}", routesHandler.PatchRouteHandler).Methods("PATCH", "OPTIONS") apiHandler.HandleFunc("/routes/{id}", routesHandler.PatchRouteHandler).Methods("PATCH", "OPTIONS")
apiHandler.HandleFunc("/api/routes/{id}", routesHandler.GetRouteHandler).Methods("GET", "OPTIONS") apiHandler.HandleFunc("/routes/{id}", routesHandler.GetRouteHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/routes/{id}", routesHandler.DeleteRouteHandler).Methods("DELETE", "OPTIONS") apiHandler.HandleFunc("/routes/{id}", routesHandler.DeleteRouteHandler).Methods("DELETE", "OPTIONS")
apiHandler.HandleFunc("/api/dns/nameservers", nameserversHandler.GetAllNameserversHandler).Methods("GET", "OPTIONS") apiHandler.HandleFunc("/dns/nameservers", nameserversHandler.GetAllNameserversHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/dns/nameservers", nameserversHandler.CreateNameserverGroupHandler).Methods("POST", "OPTIONS") apiHandler.HandleFunc("/dns/nameservers", nameserversHandler.CreateNameserverGroupHandler).Methods("POST", "OPTIONS")
apiHandler.HandleFunc("/api/dns/nameservers/{id}", nameserversHandler.UpdateNameserverGroupHandler).Methods("PUT", "OPTIONS") apiHandler.HandleFunc("/dns/nameservers/{id}", nameserversHandler.UpdateNameserverGroupHandler).Methods("PUT", "OPTIONS")
apiHandler.HandleFunc("/api/dns/nameservers/{id}", nameserversHandler.PatchNameserverGroupHandler).Methods("PATCH", "OPTIONS") apiHandler.HandleFunc("/dns/nameservers/{id}", nameserversHandler.PatchNameserverGroupHandler).Methods("PATCH", "OPTIONS")
apiHandler.HandleFunc("/api/dns/nameservers/{id}", nameserversHandler.GetNameserverGroupHandler).Methods("GET", "OPTIONS") apiHandler.HandleFunc("/dns/nameservers/{id}", nameserversHandler.GetNameserverGroupHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/dns/nameservers/{id}", nameserversHandler.DeleteNameserverGroupHandler).Methods("DELETE", "OPTIONS") apiHandler.HandleFunc("/dns/nameservers/{id}", nameserversHandler.DeleteNameserverGroupHandler).Methods("DELETE", "OPTIONS")
return apiHandler, nil err = apiHandler.Walk(func(route *mux.Route, router *mux.Router, ancestors []*mux.Route) error {
methods, err := route.GetMethods()
if err != nil {
return err
}
for _, method := range methods {
template, err := route.GetPathTemplate()
if err != nil {
return err
}
err = metricsMiddleware.AddHTTPRequestResponseCounter(template, method)
if err != nil {
return err
}
}
return nil
})
if err != nil {
return nil, err
}
return rootRouter, nil
} }

View File

@@ -1,7 +1,8 @@
package middleware package middleware
import ( import (
"fmt" "github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/status"
"net/http" "net/http"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
@@ -9,38 +10,39 @@ import (
type IsUserAdminFunc func(claims jwtclaims.AuthorizationClaims) (bool, error) type IsUserAdminFunc func(claims jwtclaims.AuthorizationClaims) (bool, error)
// AccessControll middleware to restrict to make POST/PUT/DELETE requests by admin only // AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only
type AccessControll struct { type AccessControl struct {
jwtExtractor jwtclaims.ClaimsExtractor jwtExtractor jwtclaims.ClaimsExtractor
isUserAdmin IsUserAdminFunc isUserAdmin IsUserAdminFunc
audience string audience string
} }
// NewAccessControll instance constructor // NewAccessControl instance constructor
func NewAccessControll(audience string, isUserAdmin IsUserAdminFunc) *AccessControll { func NewAccessControl(audience string, isUserAdmin IsUserAdminFunc) *AccessControl {
return &AccessControll{ return &AccessControl{
isUserAdmin: isUserAdmin, isUserAdmin: isUserAdmin,
audience: audience, audience: audience,
jwtExtractor: *jwtclaims.NewClaimsExtractor(nil), jwtExtractor: *jwtclaims.NewClaimsExtractor(nil),
} }
} }
// Handler method of the middleware which forbinneds all modify requests for non admin users // Handler method of the middleware which forbids all modify requests for non admin users
func (a *AccessControll) Handler(h http.Handler) http.Handler { // It also adds
func (a *AccessControl) Handler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
jwtClaims := a.jwtExtractor.ExtractClaimsFromRequestContext(r, a.audience) jwtClaims := a.jwtExtractor.ExtractClaimsFromRequestContext(r, a.audience)
ok, err := a.isUserAdmin(jwtClaims) ok, err := a.isUserAdmin(jwtClaims)
if err != nil { if err != nil {
http.Error(w, fmt.Sprintf("error get user from JWT: %v", err), http.StatusUnauthorized) util.WriteError(status.Errorf(status.Unauthorized, "invalid JWT"), w)
return return
} }
if !ok { if !ok {
switch r.Method { switch r.Method {
case http.MethodDelete, http.MethodPost, http.MethodPatch, http.MethodPut: case http.MethodDelete, http.MethodPost, http.MethodPatch, http.MethodPut:
http.Error(w, "user is not admin", http.StatusForbidden) util.WriteError(status.Errorf(status.PermissionDenied, "only admin can perform this operation"), w)
return return
} }
} }

View File

@@ -7,12 +7,12 @@ import (
"net/http" "net/http"
) )
//Jwks is a collection of JSONWebKeys obtained from Config.HttpServerConfig.AuthKeysLocation // Jwks is a collection of JSONWebKeys obtained from Config.HttpServerConfig.AuthKeysLocation
type Jwks struct { type Jwks struct {
Keys []JSONWebKeys `json:"keys"` Keys []JSONWebKeys `json:"keys"`
} }
//JSONWebKeys is a representation of a Jason Web Key // JSONWebKeys is a representation of a Jason Web Key
type JSONWebKeys struct { type JSONWebKeys struct {
Kty string `json:"kty"` Kty string `json:"kty"`
Kid string `json:"kid"` Kid string `json:"kid"`
@@ -22,7 +22,7 @@ type JSONWebKeys struct {
X5c []string `json:"x5c"` X5c []string `json:"x5c"`
} }
//NewJwtMiddleware creates new middleware to verify the JWT token sent via Authorization header // NewJwtMiddleware creates new middleware to verify the JWT token sent via Authorization header
func NewJwtMiddleware(issuer string, audience string, keysLocation string) (*JWTMiddleware, error) { func NewJwtMiddleware(issuer string, audience string, keysLocation string) (*JWTMiddleware, error) {
keys, err := getPemKeys(keysLocation) keys, err := getPemKeys(keysLocation)
@@ -66,7 +66,6 @@ func getPemKeys(keysLocation string) (*Jwks, error) {
var jwks = &Jwks{} var jwks = &Jwks{}
err = json.NewDecoder(resp.Body).Decode(jwks) err = json.NewDecoder(resp.Body).Decode(jwks)
if err != nil { if err != nil {
return jwks, err return jwks, err
} }

View File

@@ -5,6 +5,8 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/golang-jwt/jwt" "github.com/golang-jwt/jwt"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/status"
"log" "log"
"net/http" "net/http"
"strings" "strings"
@@ -57,7 +59,7 @@ type JWTMiddleware struct {
} }
func OnError(w http.ResponseWriter, r *http.Request, err string) { func OnError(w http.ResponseWriter, r *http.Request, err string) {
http.Error(w, err, http.StatusUnauthorized) util.WriteError(status.Errorf(status.Unauthorized, ""), w)
} }
// New constructs a new Secure instance with supplied options. // New constructs a new Secure instance with supplied options.
@@ -186,7 +188,7 @@ func (m *JWTMiddleware) CheckJWTFromRequest(w http.ResponseWriter, r *http.Reque
validatedToken, err := m.ValidateAndParse(token) validatedToken, err := m.ValidateAndParse(token)
if err != nil { if err != nil {
m.Options.ErrorHandler(w, r, "The token isn't valid") m.Options.ErrorHandler(w, r, err.Error())
return err return err
} }

View File

@@ -7,7 +7,9 @@ import (
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"net/http" "net/http"
) )
@@ -30,7 +32,8 @@ func NewNameservers(accountManager server.AccountManager, authAudience string) *
// GetAllNameserversHandler returns the list of nameserver groups for the account // GetAllNameserversHandler returns the list of nameserver groups for the account
func (h *Nameservers) GetAllNameserversHandler(w http.ResponseWriter, r *http.Request) { func (h *Nameservers) GetAllNameserversHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
@@ -39,7 +42,7 @@ func (h *Nameservers) GetAllNameserversHandler(w http.ResponseWriter, r *http.Re
nsGroups, err := h.accountManager.ListNameServerGroups(account.Id) nsGroups, err := h.accountManager.ListNameServerGroups(account.Id)
if err != nil { if err != nil {
toHTTPError(err, w) util.WriteError(err, w)
return return
} }
@@ -48,64 +51,67 @@ func (h *Nameservers) GetAllNameserversHandler(w http.ResponseWriter, r *http.Re
apiNameservers = append(apiNameservers, toNameserverGroupResponse(r)) apiNameservers = append(apiNameservers, toNameserverGroupResponse(r))
} }
writeJSONObject(w, apiNameservers) util.WriteJSONObject(w, apiNameservers)
} }
// CreateNameserverGroupHandler handles nameserver group creation request // CreateNameserverGroupHandler handles nameserver group creation request
func (h *Nameservers) CreateNameserverGroupHandler(w http.ResponseWriter, r *http.Request) { func (h *Nameservers) CreateNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
log.Error(err) util.WriteError(err, w)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
var req api.PostApiDnsNameserversJSONRequestBody var req api.PostApiDnsNameserversJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { err = json.NewDecoder(r.Body).Decode(&req)
http.Error(w, err.Error(), http.StatusBadRequest) if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return return
} }
nsList, err := toServerNSList(req.Nameservers) nsList, err := toServerNSList(req.Nameservers)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "invalid NS servers format"), w)
return return
} }
nsGroup, err := h.accountManager.CreateNameServerGroup(account.Id, req.Name, req.Description, nsList, req.Groups, req.Enabled) nsGroup, err := h.accountManager.CreateNameServerGroup(account.Id, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled)
if err != nil { if err != nil {
toHTTPError(err, w) util.WriteError(err, w)
return return
} }
resp := toNameserverGroupResponse(nsGroup) resp := toNameserverGroupResponse(nsGroup)
writeJSONObject(w, &resp) util.WriteJSONObject(w, &resp)
} }
// UpdateNameserverGroupHandler handles update to a nameserver group identified by a given ID // UpdateNameserverGroupHandler handles update to a nameserver group identified by a given ID
func (h *Nameservers) UpdateNameserverGroupHandler(w http.ResponseWriter, r *http.Request) { func (h *Nameservers) UpdateNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
log.Error(err) util.WriteError(err, w)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
nsGroupID := mux.Vars(r)["id"] nsGroupID := mux.Vars(r)["id"]
if len(nsGroupID) == 0 { if len(nsGroupID) == 0 {
http.Error(w, "invalid nameserver group ID", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
return return
} }
var req api.PutApiDnsNameserversIdJSONRequestBody var req api.PutApiDnsNameserversIdJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { err = json.NewDecoder(r.Body).Decode(&req)
http.Error(w, err.Error(), http.StatusBadRequest) if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return return
} }
nsList, err := toServerNSList(req.Nameservers) nsList, err := toServerNSList(req.Nameservers)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "invalid NS servers format"), w)
return return
} }
@@ -113,6 +119,8 @@ func (h *Nameservers) UpdateNameserverGroupHandler(w http.ResponseWriter, r *htt
ID: nsGroupID, ID: nsGroupID,
Name: req.Name, Name: req.Name,
Description: req.Description, Description: req.Description,
Primary: req.Primary,
Domains: req.Domains,
NameServers: nsList, NameServers: nsList,
Groups: req.Groups, Groups: req.Groups,
Enabled: req.Enabled, Enabled: req.Enabled,
@@ -120,41 +128,42 @@ func (h *Nameservers) UpdateNameserverGroupHandler(w http.ResponseWriter, r *htt
err = h.accountManager.SaveNameServerGroup(account.Id, updatedNSGroup) err = h.accountManager.SaveNameServerGroup(account.Id, updatedNSGroup)
if err != nil { if err != nil {
toHTTPError(err, w) util.WriteError(err, w)
return return
} }
resp := toNameserverGroupResponse(updatedNSGroup) resp := toNameserverGroupResponse(updatedNSGroup)
writeJSONObject(w, &resp) util.WriteJSONObject(w, &resp)
} }
// PatchNameserverGroupHandler handles patch updates to a nameserver group identified by a given ID // PatchNameserverGroupHandler handles patch updates to a nameserver group identified by a given ID
func (h *Nameservers) PatchNameserverGroupHandler(w http.ResponseWriter, r *http.Request) { func (h *Nameservers) PatchNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
log.Error(err) util.WriteError(err, w)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
nsGroupID := mux.Vars(r)["id"] nsGroupID := mux.Vars(r)["id"]
if len(nsGroupID) == 0 { if len(nsGroupID) == 0 {
http.Error(w, "invalid nameserver group ID", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
return return
} }
var req api.PatchApiDnsNameserversIdJSONRequestBody var req api.PatchApiDnsNameserversIdJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { err = json.NewDecoder(r.Body).Decode(&req)
http.Error(w, err.Error(), http.StatusBadRequest) if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return return
} }
var operations []server.NameServerGroupUpdateOperation var operations []server.NameServerGroupUpdateOperation
for _, patch := range req { for _, patch := range req {
if patch.Op != api.NameserverGroupPatchOperationOpReplace { if patch.Op != api.NameserverGroupPatchOperationOpReplace {
http.Error(w, fmt.Sprintf("nameserver groups only accepts replace operations, got %s", patch.Op), util.WriteError(status.Errorf(status.InvalidArgument,
http.StatusBadRequest) "nameserver groups only accepts replace operations, got %s", patch.Op), w)
return return
} }
switch patch.Path { switch patch.Path {
@@ -168,6 +177,16 @@ func (h *Nameservers) PatchNameserverGroupHandler(w http.ResponseWriter, r *http
Type: server.UpdateNameServerGroupDescription, Type: server.UpdateNameServerGroupDescription,
Values: patch.Value, Values: patch.Value,
}) })
case api.NameserverGroupPatchOperationPathPrimary:
operations = append(operations, server.NameServerGroupUpdateOperation{
Type: server.UpdateNameServerGroupPrimary,
Values: patch.Value,
})
case api.NameserverGroupPatchOperationPathDomains:
operations = append(operations, server.NameServerGroupUpdateOperation{
Type: server.UpdateNameServerGroupDomains,
Values: patch.Value,
})
case api.NameserverGroupPatchOperationPathNameservers: case api.NameserverGroupPatchOperationPathNameservers:
operations = append(operations, server.NameServerGroupUpdateOperation{ operations = append(operations, server.NameServerGroupUpdateOperation{
Type: server.UpdateNameServerGroupNameServers, Type: server.UpdateNameServerGroupNameServers,
@@ -184,49 +203,50 @@ func (h *Nameservers) PatchNameserverGroupHandler(w http.ResponseWriter, r *http
Values: patch.Value, Values: patch.Value,
}) })
default: default:
http.Error(w, "invalid patch path", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "invalid patch path"), w)
return return
} }
} }
updatedNSGroup, err := h.accountManager.UpdateNameServerGroup(account.Id, nsGroupID, operations) updatedNSGroup, err := h.accountManager.UpdateNameServerGroup(account.Id, nsGroupID, operations)
if err != nil { if err != nil {
toHTTPError(err, w) util.WriteError(err, w)
return return
} }
resp := toNameserverGroupResponse(updatedNSGroup) resp := toNameserverGroupResponse(updatedNSGroup)
writeJSONObject(w, &resp) util.WriteJSONObject(w, &resp)
} }
// DeleteNameserverGroupHandler handles nameserver group deletion request // DeleteNameserverGroupHandler handles nameserver group deletion request
func (h *Nameservers) DeleteNameserverGroupHandler(w http.ResponseWriter, r *http.Request) { func (h *Nameservers) DeleteNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
log.Error(err) util.WriteError(err, w)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
nsGroupID := mux.Vars(r)["id"] nsGroupID := mux.Vars(r)["id"]
if len(nsGroupID) == 0 { if len(nsGroupID) == 0 {
http.Error(w, "invalid nameserver group ID", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
return return
} }
err = h.accountManager.DeleteNameServerGroup(account.Id, nsGroupID) err = h.accountManager.DeleteNameServerGroup(account.Id, nsGroupID)
if err != nil { if err != nil {
toHTTPError(err, w) util.WriteError(err, w)
return return
} }
writeJSONObject(w, "") util.WriteJSONObject(w, "")
} }
// GetNameserverGroupHandler handles a nameserver group Get request identified by ID // GetNameserverGroupHandler handles a nameserver group Get request identified by ID
func (h *Nameservers) GetNameserverGroupHandler(w http.ResponseWriter, r *http.Request) { func (h *Nameservers) GetNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
@@ -235,19 +255,19 @@ func (h *Nameservers) GetNameserverGroupHandler(w http.ResponseWriter, r *http.R
nsGroupID := mux.Vars(r)["id"] nsGroupID := mux.Vars(r)["id"]
if len(nsGroupID) == 0 { if len(nsGroupID) == 0 {
http.Error(w, "invalid nameserver group ID", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
return return
} }
nsGroup, err := h.accountManager.GetNameServerGroup(account.Id, nsGroupID) nsGroup, err := h.accountManager.GetNameServerGroup(account.Id, nsGroupID)
if err != nil { if err != nil {
toHTTPError(err, w) util.WriteError(err, w)
return return
} }
resp := toNameserverGroupResponse(nsGroup) resp := toNameserverGroupResponse(nsGroup)
writeJSONObject(w, &resp) util.WriteJSONObject(w, &resp)
} }
@@ -279,6 +299,8 @@ func toNameserverGroupResponse(serverNSGroup *nbdns.NameServerGroup) *api.Namese
Id: serverNSGroup.ID, Id: serverNSGroup.ID,
Name: serverNSGroup.Name, Name: serverNSGroup.Name,
Description: serverNSGroup.Description, Description: serverNSGroup.Description,
Primary: serverNSGroup.Primary,
Domains: serverNSGroup.Domains,
Groups: serverNSGroup.Groups, Groups: serverNSGroup.Groups,
Nameservers: nsList, Nameservers: nsList,
Enabled: serverNSGroup.Enabled, Enabled: serverNSGroup.Enabled,

View File

@@ -5,9 +5,8 @@ import (
"encoding/json" "encoding/json"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/status"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@@ -29,12 +28,16 @@ const (
var testingNSAccount = &server.Account{ var testingNSAccount = &server.Account{
Id: testNSGroupAccountID, Id: testNSGroupAccountID,
Domain: "hotmail.com", Domain: "hotmail.com",
Users: map[string]*server.User{
"test_user": server.NewAdminUser("test_user"),
},
} }
var baseExistingNSGroup = &nbdns.NameServerGroup{ var baseExistingNSGroup = &nbdns.NameServerGroup{
ID: existingNSGroupID, ID: existingNSGroupID,
Name: "super", Name: "super",
Description: "super", Description: "super",
Primary: true,
NameServers: []nbdns.NameServer{ NameServers: []nbdns.NameServer{
{ {
IP: netip.MustParseAddr("1.1.1.1"), IP: netip.MustParseAddr("1.1.1.1"),
@@ -58,9 +61,9 @@ func initNameserversTestData() *Nameservers {
if nsGroupID == existingNSGroupID { if nsGroupID == existingNSGroupID {
return baseExistingNSGroup.Copy(), nil return baseExistingNSGroup.Copy(), nil
} }
return nil, status.Errorf(codes.NotFound, "nameserver group with ID %s not found", nsGroupID) return nil, status.Errorf(status.NotFound, "nameserver group with ID %s not found", nsGroupID)
}, },
CreateNameServerGroupFunc: func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, enabled bool) (*nbdns.NameServerGroup, error) { CreateNameServerGroupFunc: func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool) (*nbdns.NameServerGroup, error) {
return &nbdns.NameServerGroup{ return &nbdns.NameServerGroup{
ID: existingNSGroupID, ID: existingNSGroupID,
Name: name, Name: name,
@@ -68,6 +71,8 @@ func initNameserversTestData() *Nameservers {
NameServers: nameServerList, NameServers: nameServerList,
Groups: groups, Groups: groups,
Enabled: enabled, Enabled: enabled,
Primary: primary,
Domains: domains,
}, nil }, nil
}, },
DeleteNameServerGroupFunc: func(accountID, nsGroupID string) error { DeleteNameServerGroupFunc: func(accountID, nsGroupID string) error {
@@ -77,12 +82,12 @@ func initNameserversTestData() *Nameservers {
if nsGroupToSave.ID == existingNSGroupID { if nsGroupToSave.ID == existingNSGroupID {
return nil return nil
} }
return status.Errorf(codes.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID) return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID)
}, },
UpdateNameServerGroupFunc: func(accountID, nsGroupID string, operations []server.NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) { UpdateNameServerGroupFunc: func(accountID, nsGroupID string, operations []server.NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) {
nsGroupToUpdate := baseExistingNSGroup.Copy() nsGroupToUpdate := baseExistingNSGroup.Copy()
if nsGroupID != nsGroupToUpdate.ID { if nsGroupID != nsGroupToUpdate.ID {
return nil, status.Errorf(codes.NotFound, "nameserver group ID %s no longer exists", nsGroupID) return nil, status.Errorf(status.NotFound, "nameserver group ID %s no longer exists", nsGroupID)
} }
for _, operation := range operations { for _, operation := range operations {
switch operation.Type { switch operation.Type {
@@ -104,8 +109,8 @@ func initNameserversTestData() *Nameservers {
} }
return nsGroupToUpdate, nil return nsGroupToUpdate, nil
}, },
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, error) { GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return testingNSAccount, nil return testingNSAccount, nil, nil
}, },
}, },
authAudience: "", authAudience: "",
@@ -150,7 +155,7 @@ func TestNameserversHandlers(t *testing.T) {
requestType: http.MethodPost, requestType: http.MethodPost,
requestPath: "/api/dns/nameservers", requestPath: "/api/dns/nameservers",
requestBody: bytes.NewBuffer( requestBody: bytes.NewBuffer(
[]byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1.1.1.1\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true}")), []byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1.1.1.1\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true,\"primary\":true}")),
expectedStatus: http.StatusOK, expectedStatus: http.StatusOK,
expectedBody: true, expectedBody: true,
expectedNSGroup: &api.NameserverGroup{ expectedNSGroup: &api.NameserverGroup{
@@ -166,6 +171,7 @@ func TestNameserversHandlers(t *testing.T) {
}, },
Groups: []string{"group"}, Groups: []string{"group"},
Enabled: true, Enabled: true,
Primary: true,
}, },
}, },
{ {
@@ -173,8 +179,8 @@ func TestNameserversHandlers(t *testing.T) {
requestType: http.MethodPost, requestType: http.MethodPost,
requestPath: "/api/dns/nameservers", requestPath: "/api/dns/nameservers",
requestBody: bytes.NewBuffer( requestBody: bytes.NewBuffer(
[]byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1000\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true}")), []byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1000\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true,\"primary\":true}")),
expectedStatus: http.StatusBadRequest, expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false, expectedBody: false,
}, },
{ {
@@ -182,7 +188,7 @@ func TestNameserversHandlers(t *testing.T) {
requestType: http.MethodPut, requestType: http.MethodPut,
requestPath: "/api/dns/nameservers/" + existingNSGroupID, requestPath: "/api/dns/nameservers/" + existingNSGroupID,
requestBody: bytes.NewBuffer( requestBody: bytes.NewBuffer(
[]byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1.1.1.1\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true}")), []byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1.1.1.1\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true,\"primary\":true}")),
expectedStatus: http.StatusOK, expectedStatus: http.StatusOK,
expectedBody: true, expectedBody: true,
expectedNSGroup: &api.NameserverGroup{ expectedNSGroup: &api.NameserverGroup{
@@ -198,6 +204,7 @@ func TestNameserversHandlers(t *testing.T) {
}, },
Groups: []string{"group"}, Groups: []string{"group"},
Enabled: true, Enabled: true,
Primary: true,
}, },
}, },
{ {
@@ -205,7 +212,7 @@ func TestNameserversHandlers(t *testing.T) {
requestType: http.MethodPut, requestType: http.MethodPut,
requestPath: "/api/dns/nameservers/" + notFoundNSGroupID, requestPath: "/api/dns/nameservers/" + notFoundNSGroupID,
requestBody: bytes.NewBuffer( requestBody: bytes.NewBuffer(
[]byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1.1.1.1\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true}")), []byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1.1.1.1\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true,\"primary\":true}")),
expectedStatus: http.StatusNotFound, expectedStatus: http.StatusNotFound,
expectedBody: false, expectedBody: false,
}, },
@@ -214,8 +221,8 @@ func TestNameserversHandlers(t *testing.T) {
requestType: http.MethodPut, requestType: http.MethodPut,
requestPath: "/api/dns/nameservers/" + notFoundNSGroupID, requestPath: "/api/dns/nameservers/" + notFoundNSGroupID,
requestBody: bytes.NewBuffer( requestBody: bytes.NewBuffer(
[]byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"100\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true}")), []byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"100\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true,\"primary\":true}")),
expectedStatus: http.StatusBadRequest, expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false, expectedBody: false,
}, },
{ {
@@ -232,6 +239,7 @@ func TestNameserversHandlers(t *testing.T) {
Nameservers: toNameserverGroupResponse(baseExistingNSGroup).Nameservers, Nameservers: toNameserverGroupResponse(baseExistingNSGroup).Nameservers,
Groups: baseExistingNSGroup.Groups, Groups: baseExistingNSGroup.Groups,
Enabled: baseExistingNSGroup.Enabled, Enabled: baseExistingNSGroup.Enabled,
Primary: baseExistingNSGroup.Primary,
}, },
}, },
{ {

View File

@@ -6,8 +6,9 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/status"
"net/http" "net/http"
) )
@@ -28,53 +29,53 @@ func NewPeers(accountManager server.AccountManager, authAudience string) *Peers
func (h *Peers) updatePeer(account *server.Account, peer *server.Peer, w http.ResponseWriter, r *http.Request) { func (h *Peers) updatePeer(account *server.Account, peer *server.Peer, w http.ResponseWriter, r *http.Request) {
req := &api.PutApiPeersIdJSONBody{} req := &api.PutApiPeersIdJSONBody{}
peerIp := peer.IP
err := json.NewDecoder(r.Body).Decode(&req) err := json.NewDecoder(r.Body).Decode(&req)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest) util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return return
} }
update := &server.Peer{Key: peer.Key, SSHEnabled: req.SshEnabled, Name: req.Name} update := &server.Peer{Key: peer.Key, SSHEnabled: req.SshEnabled, Name: req.Name}
peer, err = h.accountManager.UpdatePeer(account.Id, update) peer, err = h.accountManager.UpdatePeer(account.Id, update)
if err != nil { if err != nil {
log.Errorf("failed updating peer %s under account %s %v", peerIp, account.Id, err) util.WriteError(err, w)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
writeJSONObject(w, toPeerResponse(peer, account)) dnsDomain := h.accountManager.GetDNSDomain()
util.WriteJSONObject(w, toPeerResponse(peer, account, dnsDomain))
} }
func (h *Peers) deletePeer(accountId string, peer *server.Peer, w http.ResponseWriter, r *http.Request) { func (h *Peers) deletePeer(accountId string, peer *server.Peer, w http.ResponseWriter, r *http.Request) {
_, err := h.accountManager.DeletePeer(accountId, peer.Key) _, err := h.accountManager.DeletePeer(accountId, peer.Key)
if err != nil { if err != nil {
log.Errorf("failed deleteing peer %s, %v", peer.IP, err) util.WriteError(err, w)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
writeJSONObject(w, "") util.WriteJSONObject(w, "")
} }
func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) { func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
log.Error(err) util.WriteError(err, w)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
vars := mux.Vars(r) vars := mux.Vars(r)
peerId := vars["id"] //effectively peer IP address peerId := vars["id"] //effectively peer IP address
if len(peerId) == 0 { if len(peerId) == 0 {
http.Error(w, "invalid peer Id", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "invalid peer ID"), w)
return return
} }
peer, err := h.accountManager.GetPeerByIP(account.Id, peerId) peer, err := h.accountManager.GetPeerByIP(account.Id, peerId)
if err != nil { if err != nil {
http.Error(w, "peer not found", http.StatusNotFound) util.WriteError(err, w)
return return
} }
dnsDomain := h.accountManager.GetDNSDomain()
switch r.Method { switch r.Method {
case http.MethodDelete: case http.MethodDelete:
h.deletePeer(account.Id, peer, w, r) h.deletePeer(account.Id, peer, w, r)
@@ -83,11 +84,11 @@ func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) {
h.updatePeer(account, peer, w, r) h.updatePeer(account, peer, w, r)
return return
case http.MethodGet: case http.MethodGet:
writeJSONObject(w, toPeerResponse(peer, account)) util.WriteJSONObject(w, toPeerResponse(peer, account, dnsDomain))
return return
default: default:
http.Error(w, "", http.StatusNotFound) util.WriteError(status.Errorf(status.NotFound, "unknown METHOD"), w)
} }
} }
@@ -95,25 +96,33 @@ func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) {
func (h *Peers) GetPeers(w http.ResponseWriter, r *http.Request) { func (h *Peers) GetPeers(w http.ResponseWriter, r *http.Request) {
switch r.Method { switch r.Method {
case http.MethodGet: case http.MethodGet:
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
log.Error(err) util.WriteError(err, w)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
respBody := []*api.Peer{} peers, err := h.accountManager.GetPeers(account.Id, user.Id)
for _, peer := range account.Peers { if err != nil {
respBody = append(respBody, toPeerResponse(peer, account)) util.WriteError(err, w)
return
} }
writeJSONObject(w, respBody)
dnsDomain := h.accountManager.GetDNSDomain()
respBody := []*api.Peer{}
for _, peer := range peers {
respBody = append(respBody, toPeerResponse(peer, account, dnsDomain))
}
util.WriteJSONObject(w, respBody)
return return
default: default:
http.Error(w, "", http.StatusNotFound) util.WriteError(status.Errorf(status.NotFound, "unknown METHOD"), w)
} }
} }
func toPeerResponse(peer *server.Peer, account *server.Account) *api.Peer { func toPeerResponse(peer *server.Peer, account *server.Account, dnsDomain string) *api.Peer {
var groupsInfo []api.GroupMinimum var groupsInfo []api.GroupMinimum
groupsChecked := make(map[string]struct{}) groupsChecked := make(map[string]struct{})
for _, group := range account.Groups { for _, group := range account.Groups {
@@ -134,6 +143,10 @@ func toPeerResponse(peer *server.Peer, account *server.Account) *api.Peer {
} }
} }
} }
fqdn := peer.DNSLabel
if dnsDomain != "" {
fqdn = peer.DNSLabel + "." + dnsDomain
}
return &api.Peer{ return &api.Peer{
Id: peer.IP.String(), Id: peer.IP.String(),
Name: peer.Name, Name: peer.Name,
@@ -147,5 +160,6 @@ func toPeerResponse(peer *server.Peer, account *server.Account) *api.Peer {
Hostname: peer.Meta.Hostname, Hostname: peer.Meta.Hostname,
UserId: &peer.UserID, UserId: &peer.UserID,
UiVersion: &peer.Meta.UIVersion, UiVersion: &peer.Meta.UIVersion,
DnsLabel: fqdn,
} }
} }

View File

@@ -16,17 +16,24 @@ import (
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
) )
func initTestMetaData(peer ...*server.Peer) *Peers { func initTestMetaData(peers ...*server.Peer) *Peers {
return &Peers{ return &Peers{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) { GetPeersFunc: func(accountID, userID string) ([]*server.Peer, error) {
return peers, nil
},
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user")
return &server.Account{ return &server.Account{
Id: claims.AccountId, Id: claims.AccountId,
Domain: "hotmail.com", Domain: "hotmail.com",
Peers: map[string]*server.Peer{ Peers: map[string]*server.Peer{
"test_peer": peer[0], "test_peer": peers[0],
}, },
}, nil Users: map[string]*server.User{
"test_user": user,
},
}, user, nil
}, },
}, },
authAudience: "", authAudience: "",

View File

@@ -2,15 +2,13 @@ package http
import ( import (
"encoding/json" "encoding/json"
"fmt"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"net/http" "net/http"
"unicode/utf8" "unicode/utf8"
) )
@@ -33,17 +31,16 @@ func NewRoutes(accountManager server.AccountManager, authAudience string) *Route
// GetAllRoutesHandler returns the list of routes for the account // GetAllRoutesHandler returns the list of routes for the account
func (h *Routes) GetAllRoutesHandler(w http.ResponseWriter, r *http.Request) { func (h *Routes) GetAllRoutesHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
log.Error(err) util.WriteError(err, w)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
routes, err := h.accountManager.ListRoutes(account.Id) routes, err := h.accountManager.ListRoutes(account.Id, user.Id)
if err != nil { if err != nil {
log.Error(err) util.WriteError(err, w)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
apiRoutes := make([]*api.Route, 0) apiRoutes := make([]*api.Route, 0)
@@ -51,20 +48,22 @@ func (h *Routes) GetAllRoutesHandler(w http.ResponseWriter, r *http.Request) {
apiRoutes = append(apiRoutes, toRouteResponse(account, r)) apiRoutes = append(apiRoutes, toRouteResponse(account, r))
} }
writeJSONObject(w, apiRoutes) util.WriteJSONObject(w, apiRoutes)
} }
// CreateRouteHandler handles route creation request // CreateRouteHandler handles route creation request
func (h *Routes) CreateRouteHandler(w http.ResponseWriter, r *http.Request) { func (h *Routes) CreateRouteHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError) util.WriteError(err, w)
return return
} }
var req api.PostApiRoutesJSONRequestBody var req api.PostApiRoutesJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { err = json.NewDecoder(r.Body).Decode(&req)
http.Error(w, err.Error(), http.StatusBadRequest) if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return return
} }
@@ -72,8 +71,7 @@ func (h *Routes) CreateRouteHandler(w http.ResponseWriter, r *http.Request) {
if req.Peer != "" { if req.Peer != "" {
peer, err := h.accountManager.GetPeerByIP(account.Id, req.Peer) peer, err := h.accountManager.GetPeerByIP(account.Id, req.Peer)
if err != nil { if err != nil {
log.Error(err) util.WriteError(err, w)
http.Redirect(w, r, "/", http.StatusUnprocessableEntity)
return return
} }
peerKey = peer.Key peerKey = peer.Key
@@ -81,57 +79,60 @@ func (h *Routes) CreateRouteHandler(w http.ResponseWriter, r *http.Request) {
_, newPrefix, err := route.ParseNetwork(req.Network) _, newPrefix, err := route.ParseNetwork(req.Network)
if err != nil { if err != nil {
http.Error(w, fmt.Sprintf("couldn't parse update prefix %s", req.Network), http.StatusBadRequest) util.WriteError(err, w)
return return
} }
if utf8.RuneCountInString(req.NetworkId) > route.MaxNetIDChar || req.NetworkId == "" { if utf8.RuneCountInString(req.NetworkId) > route.MaxNetIDChar || req.NetworkId == "" {
http.Error(w, fmt.Sprintf("identifier should be between 1 and %d", route.MaxNetIDChar), http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d",
route.MaxNetIDChar), w)
return return
} }
newRoute, err := h.accountManager.CreateRoute(account.Id, newPrefix.String(), peerKey, req.Description, req.NetworkId, req.Masquerade, req.Metric, req.Enabled) newRoute, err := h.accountManager.CreateRoute(account.Id, newPrefix.String(), peerKey, req.Description, req.NetworkId, req.Masquerade, req.Metric, req.Enabled)
if err != nil { if err != nil {
log.Error(err) util.WriteError(err, w)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
resp := toRouteResponse(account, newRoute) resp := toRouteResponse(account, newRoute)
writeJSONObject(w, &resp) util.WriteJSONObject(w, &resp)
} }
// UpdateRouteHandler handles update to a route identified by a given ID // UpdateRouteHandler handles update to a route identified by a given ID
func (h *Routes) UpdateRouteHandler(w http.ResponseWriter, r *http.Request) { func (h *Routes) UpdateRouteHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError) util.WriteError(err, w)
return return
} }
vars := mux.Vars(r) vars := mux.Vars(r)
routeID := vars["id"] routeID := vars["id"]
if len(routeID) == 0 { if len(routeID) == 0 {
http.Error(w, "invalid route Id", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "invalid route ID"), w)
return return
} }
_, err = h.accountManager.GetRoute(account.Id, routeID) _, err = h.accountManager.GetRoute(account.Id, routeID, user.Id)
if err != nil { if err != nil {
http.Error(w, fmt.Sprintf("couldn't find route for ID %s", routeID), http.StatusNotFound) util.WriteError(err, w)
return return
} }
var req api.PutApiRoutesIdJSONRequestBody var req api.PutApiRoutesIdJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { err = json.NewDecoder(r.Body).Decode(&req)
http.Error(w, err.Error(), http.StatusBadRequest) if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return return
} }
prefixType, newPrefix, err := route.ParseNetwork(req.Network) prefixType, newPrefix, err := route.ParseNetwork(req.Network)
if err != nil { if err != nil {
http.Error(w, fmt.Sprintf("couldn't parse update prefix %s for route ID %s", req.Network, routeID), http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "couldn't parse update prefix %s for route ID %s",
req.Network, routeID), w)
return return
} }
@@ -139,15 +140,15 @@ func (h *Routes) UpdateRouteHandler(w http.ResponseWriter, r *http.Request) {
if req.Peer != "" { if req.Peer != "" {
peer, err := h.accountManager.GetPeerByIP(account.Id, req.Peer) peer, err := h.accountManager.GetPeerByIP(account.Id, req.Peer)
if err != nil { if err != nil {
log.Error(err) util.WriteError(err, w)
http.Redirect(w, r, "/", http.StatusUnprocessableEntity)
return return
} }
peerKey = peer.Key peerKey = peer.Key
} }
if utf8.RuneCountInString(req.NetworkId) > route.MaxNetIDChar || req.NetworkId == "" { if utf8.RuneCountInString(req.NetworkId) > route.MaxNetIDChar || req.NetworkId == "" {
http.Error(w, fmt.Sprintf("identifier should be between 1 and %d", route.MaxNetIDChar), http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument,
"identifier should be between 1 and %d", route.MaxNetIDChar), w)
return return
} }
@@ -165,46 +166,46 @@ func (h *Routes) UpdateRouteHandler(w http.ResponseWriter, r *http.Request) {
err = h.accountManager.SaveRoute(account.Id, newRoute) err = h.accountManager.SaveRoute(account.Id, newRoute)
if err != nil { if err != nil {
log.Errorf("failed updating route \"%s\" under account %s %v", routeID, account.Id, err) util.WriteError(err, w)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
resp := toRouteResponse(account, newRoute) resp := toRouteResponse(account, newRoute)
writeJSONObject(w, &resp) util.WriteJSONObject(w, &resp)
} }
// PatchRouteHandler handles patch updates to a route identified by a given ID // PatchRouteHandler handles patch updates to a route identified by a given ID
func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) { func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError) util.WriteError(err, w)
return return
} }
vars := mux.Vars(r) vars := mux.Vars(r)
routeID := vars["id"] routeID := vars["id"]
if len(routeID) == 0 { if len(routeID) == 0 {
http.Error(w, "invalid route ID", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "invalid route ID"), w)
return return
} }
_, err = h.accountManager.GetRoute(account.Id, routeID) _, err = h.accountManager.GetRoute(account.Id, routeID, user.Id)
if err != nil { if err != nil {
log.Error(err) util.WriteError(err, w)
http.Error(w, fmt.Sprintf("couldn't find route ID %s", routeID), http.StatusNotFound)
return return
} }
var req api.PatchApiRoutesIdJSONRequestBody var req api.PatchApiRoutesIdJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { err = json.NewDecoder(r.Body).Decode(&req)
http.Error(w, err.Error(), http.StatusBadRequest) if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return return
} }
if len(req) == 0 { if len(req) == 0 {
http.Error(w, "no patch instruction received", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "no patch instruction received"), w)
return return
} }
@@ -214,8 +215,8 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
switch patch.Path { switch patch.Path {
case api.RoutePatchOperationPathNetwork: case api.RoutePatchOperationPathNetwork:
if patch.Op != api.RoutePatchOperationOpReplace { if patch.Op != api.RoutePatchOperationOpReplace {
http.Error(w, fmt.Sprintf("Network field only accepts replace operation, got %s", patch.Op), util.WriteError(status.Errorf(status.InvalidArgument,
http.StatusBadRequest) "network field only accepts replace operation, got %s", patch.Op), w)
return return
} }
operations = append(operations, server.RouteUpdateOperation{ operations = append(operations, server.RouteUpdateOperation{
@@ -224,8 +225,8 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
}) })
case api.RoutePatchOperationPathDescription: case api.RoutePatchOperationPathDescription:
if patch.Op != api.RoutePatchOperationOpReplace { if patch.Op != api.RoutePatchOperationOpReplace {
http.Error(w, fmt.Sprintf("Description field only accepts replace operation, got %s", patch.Op), util.WriteError(status.Errorf(status.InvalidArgument,
http.StatusBadRequest) "description field only accepts replace operation, got %s", patch.Op), w)
return return
} }
operations = append(operations, server.RouteUpdateOperation{ operations = append(operations, server.RouteUpdateOperation{
@@ -234,8 +235,8 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
}) })
case api.RoutePatchOperationPathNetworkId: case api.RoutePatchOperationPathNetworkId:
if patch.Op != api.RoutePatchOperationOpReplace { if patch.Op != api.RoutePatchOperationOpReplace {
http.Error(w, fmt.Sprintf("Network Identifier field only accepts replace operation, got %s", patch.Op), util.WriteError(status.Errorf(status.InvalidArgument,
http.StatusBadRequest) "network Identifier field only accepts replace operation, got %s", patch.Op), w)
return return
} }
operations = append(operations, server.RouteUpdateOperation{ operations = append(operations, server.RouteUpdateOperation{
@@ -244,21 +245,20 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
}) })
case api.RoutePatchOperationPathPeer: case api.RoutePatchOperationPathPeer:
if patch.Op != api.RoutePatchOperationOpReplace { if patch.Op != api.RoutePatchOperationOpReplace {
http.Error(w, fmt.Sprintf("Peer field only accepts replace operation, got %s", patch.Op), util.WriteError(status.Errorf(status.InvalidArgument,
http.StatusBadRequest) "peer field only accepts replace operation, got %s", patch.Op), w)
return return
} }
if len(patch.Value) > 1 { if len(patch.Value) > 1 {
http.Error(w, fmt.Sprintf("Value field only accepts 1 value, got %d", len(patch.Value)), util.WriteError(status.Errorf(status.InvalidArgument,
http.StatusBadRequest) "value field only accepts 1 value, got %d", len(patch.Value)), w)
return return
} }
peerValue := patch.Value peerValue := patch.Value
if patch.Value[0] != "" { if patch.Value[0] != "" {
peer, err := h.accountManager.GetPeerByIP(account.Id, patch.Value[0]) peer, err := h.accountManager.GetPeerByIP(account.Id, patch.Value[0])
if err != nil { if err != nil {
log.Error(err) util.WriteError(err, w)
http.Redirect(w, r, "/", http.StatusUnprocessableEntity)
return return
} }
peerValue = []string{peer.Key} peerValue = []string{peer.Key}
@@ -269,8 +269,9 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
}) })
case api.RoutePatchOperationPathMetric: case api.RoutePatchOperationPathMetric:
if patch.Op != api.RoutePatchOperationOpReplace { if patch.Op != api.RoutePatchOperationOpReplace {
http.Error(w, fmt.Sprintf("Metric field only accepts replace operation, got %s", patch.Op), util.WriteError(status.Errorf(status.InvalidArgument,
http.StatusBadRequest) "metric field only accepts replace operation, got %s", patch.Op), w)
return return
} }
operations = append(operations, server.RouteUpdateOperation{ operations = append(operations, server.RouteUpdateOperation{
@@ -279,8 +280,8 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
}) })
case api.RoutePatchOperationPathMasquerade: case api.RoutePatchOperationPathMasquerade:
if patch.Op != api.RoutePatchOperationOpReplace { if patch.Op != api.RoutePatchOperationOpReplace {
http.Error(w, fmt.Sprintf("Masquerade field only accepts replace operation, got %s", patch.Op), util.WriteError(status.Errorf(status.InvalidArgument,
http.StatusBadRequest) "masquerade field only accepts replace operation, got %s", patch.Op), w)
return return
} }
operations = append(operations, server.RouteUpdateOperation{ operations = append(operations, server.RouteUpdateOperation{
@@ -289,8 +290,8 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
}) })
case api.RoutePatchOperationPathEnabled: case api.RoutePatchOperationPathEnabled:
if patch.Op != api.RoutePatchOperationOpReplace { if patch.Op != api.RoutePatchOperationOpReplace {
http.Error(w, fmt.Sprintf("Enabled field only accepts replace operation, got %s", patch.Op), util.WriteError(status.Errorf(status.InvalidArgument,
http.StatusBadRequest) "enabled field only accepts replace operation, got %s", patch.Op), w)
return return
} }
operations = append(operations, server.RouteUpdateOperation{ operations = append(operations, server.RouteUpdateOperation{
@@ -298,90 +299,68 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
Values: patch.Value, Values: patch.Value,
}) })
default: default:
http.Error(w, "invalid patch path", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "invalid patch path"), w)
return return
} }
} }
route, err := h.accountManager.UpdateRoute(account.Id, routeID, operations) route, err := h.accountManager.UpdateRoute(account.Id, routeID, operations)
if err != nil { if err != nil {
errStatus, ok := status.FromError(err) util.WriteError(err, w)
if ok && errStatus.Code() == codes.Internal {
http.Error(w, errStatus.String(), http.StatusInternalServerError)
return
}
if ok && errStatus.Code() == codes.NotFound {
http.Error(w, errStatus.String(), http.StatusNotFound)
return
}
if ok && errStatus.Code() == codes.InvalidArgument {
http.Error(w, errStatus.String(), http.StatusBadRequest)
return
}
log.Errorf("failed updating route %s under account %s %v", routeID, account.Id, err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
resp := toRouteResponse(account, route) resp := toRouteResponse(account, route)
writeJSONObject(w, &resp) util.WriteJSONObject(w, &resp)
} }
// DeleteRouteHandler handles route deletion request // DeleteRouteHandler handles route deletion request
func (h *Routes) DeleteRouteHandler(w http.ResponseWriter, r *http.Request) { func (h *Routes) DeleteRouteHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError) util.WriteError(err, w)
return return
} }
routeID := mux.Vars(r)["id"] routeID := mux.Vars(r)["id"]
if len(routeID) == 0 { if len(routeID) == 0 {
http.Error(w, "invalid route ID", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "invalid route ID"), w)
return return
} }
err = h.accountManager.DeleteRoute(account.Id, routeID) err = h.accountManager.DeleteRoute(account.Id, routeID)
if err != nil { if err != nil {
errStatus, ok := status.FromError(err) util.WriteError(err, w)
if ok && errStatus.Code() == codes.NotFound {
http.Error(w, fmt.Sprintf("route %s not found under account %s", routeID, account.Id), http.StatusNotFound)
return
}
log.Errorf("failed delete route %s under account %s %v", routeID, account.Id, err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
writeJSONObject(w, "") util.WriteJSONObject(w, "")
} }
// GetRouteHandler handles a route Get request identified by ID // GetRouteHandler handles a route Get request identified by ID
func (h *Routes) GetRouteHandler(w http.ResponseWriter, r *http.Request) { func (h *Routes) GetRouteHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError) util.WriteError(err, w)
return return
} }
routeID := mux.Vars(r)["id"] routeID := mux.Vars(r)["id"]
if len(routeID) == 0 { if len(routeID) == 0 {
http.Error(w, "invalid route ID", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "invalid route ID"), w)
return return
} }
foundRoute, err := h.accountManager.GetRoute(account.Id, routeID) foundRoute, err := h.accountManager.GetRoute(account.Id, routeID, user.Id)
if err != nil { if err != nil {
http.Error(w, "route not found", http.StatusNotFound) util.WriteError(status.Errorf(status.NotFound, "route not found"), w)
return return
} }
writeJSONObject(w, toRouteResponse(account, foundRoute)) util.WriteJSONObject(w, toRouteResponse(account, foundRoute))
} }
func toRouteResponse(account *server.Account, serverRoute *route.Route) *api.Route { func toRouteResponse(account *server.Account, serverRoute *route.Route) *api.Route {

View File

@@ -5,9 +5,8 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@@ -51,16 +50,19 @@ var testingAccount = &server.Account{
IP: netip.MustParseAddr(existingPeerID).AsSlice(), IP: netip.MustParseAddr(existingPeerID).AsSlice(),
}, },
}, },
Users: map[string]*server.User{
"test_user": server.NewAdminUser("test_user"),
},
} }
func initRoutesTestData() *Routes { func initRoutesTestData() *Routes {
return &Routes{ return &Routes{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
GetRouteFunc: func(_, routeID string) (*route.Route, error) { GetRouteFunc: func(_, routeID, _ string) (*route.Route, error) {
if routeID == existingRouteID { if routeID == existingRouteID {
return baseExistingRoute, nil return baseExistingRoute, nil
} }
return nil, status.Errorf(codes.NotFound, "route with ID %s not found", routeID) return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID)
}, },
CreateRouteFunc: func(accountID string, network, peer, description, netID string, masquerade bool, metric int, enabled bool) (*route.Route, error) { CreateRouteFunc: func(accountID string, network, peer, description, netID string, masquerade bool, metric int, enabled bool) (*route.Route, error) {
networkType, p, _ := route.ParseNetwork(network) networkType, p, _ := route.ParseNetwork(network)
@@ -80,13 +82,13 @@ func initRoutesTestData() *Routes {
}, },
DeleteRouteFunc: func(_ string, peerIP string) error { DeleteRouteFunc: func(_ string, peerIP string) error {
if peerIP != existingRouteID { if peerIP != existingRouteID {
return status.Errorf(codes.NotFound, "Peer with ID %s not found", peerIP) return status.Errorf(status.NotFound, "Peer with ID %s not found", peerIP)
} }
return nil return nil
}, },
GetPeerByIPFunc: func(_ string, peerIP string) (*server.Peer, error) { GetPeerByIPFunc: func(_ string, peerIP string) (*server.Peer, error) {
if peerIP != existingPeerID { if peerIP != existingPeerID {
return nil, status.Errorf(codes.NotFound, "Peer with ID %s not found", peerIP) return nil, status.Errorf(status.NotFound, "Peer with ID %s not found", peerIP)
} }
return &server.Peer{ return &server.Peer{
Key: existingPeerKey, Key: existingPeerKey,
@@ -96,7 +98,7 @@ func initRoutesTestData() *Routes {
UpdateRouteFunc: func(_ string, routeID string, operations []server.RouteUpdateOperation) (*route.Route, error) { UpdateRouteFunc: func(_ string, routeID string, operations []server.RouteUpdateOperation) (*route.Route, error) {
routeToUpdate := baseExistingRoute routeToUpdate := baseExistingRoute
if routeID != routeToUpdate.ID { if routeID != routeToUpdate.ID {
return nil, status.Errorf(codes.NotFound, "route %s no longer exists", routeID) return nil, status.Errorf(status.NotFound, "route %s no longer exists", routeID)
} }
for _, operation := range operations { for _, operation := range operations {
switch operation.Type { switch operation.Type {
@@ -120,8 +122,8 @@ func initRoutesTestData() *Routes {
} }
return routeToUpdate, nil return routeToUpdate, nil
}, },
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, error) { GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return testingAccount, nil return testingAccount, testingAccount.Users["test_user"], nil
}, },
}, },
authAudience: "", authAudience: "",
@@ -198,15 +200,15 @@ func TestRoutesHandlers(t *testing.T) {
requestType: http.MethodPost, requestType: http.MethodPost,
requestPath: "/api/routes", requestPath: "/api/routes",
requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\"}", notFoundPeerID)), requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\"}", notFoundPeerID)),
expectedStatus: http.StatusUnprocessableEntity, expectedStatus: http.StatusNotFound,
expectedBody: false, expectedBody: false,
}, },
{ {
name: "POST Not Invalid Network Identifier", name: "POST Invalid Network Identifier",
requestType: http.MethodPost, requestType: http.MethodPost,
requestPath: "/api/routes", requestPath: "/api/routes",
requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"12345678901234567890qwertyuiopqwertyuiop1\",\"Peer\":\"%s\"}", existingPeerID)), requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"12345678901234567890qwertyuiopqwertyuiop1\",\"Peer\":\"%s\"}", existingPeerID)),
expectedStatus: http.StatusBadRequest, expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false, expectedBody: false,
}, },
{ {
@@ -214,7 +216,7 @@ func TestRoutesHandlers(t *testing.T) {
requestType: http.MethodPost, requestType: http.MethodPost,
requestPath: "/api/routes", requestPath: "/api/routes",
requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/34\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\"}", existingPeerID)), requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/34\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\"}", existingPeerID)),
expectedStatus: http.StatusBadRequest, expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false, expectedBody: false,
}, },
{ {
@@ -248,7 +250,7 @@ func TestRoutesHandlers(t *testing.T) {
requestType: http.MethodPut, requestType: http.MethodPut,
requestPath: "/api/routes/" + existingRouteID, requestPath: "/api/routes/" + existingRouteID,
requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\"}", notFoundPeerID)), requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\"}", notFoundPeerID)),
expectedStatus: http.StatusUnprocessableEntity, expectedStatus: http.StatusNotFound,
expectedBody: false, expectedBody: false,
}, },
{ {
@@ -256,7 +258,7 @@ func TestRoutesHandlers(t *testing.T) {
requestType: http.MethodPut, requestType: http.MethodPut,
requestPath: "/api/routes/" + existingRouteID, requestPath: "/api/routes/" + existingRouteID,
requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"12345678901234567890qwertyuiopqwertyuiop1\",\"Peer\":\"%s\"}", existingPeerID)), requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"12345678901234567890qwertyuiopqwertyuiop1\",\"Peer\":\"%s\"}", existingPeerID)),
expectedStatus: http.StatusBadRequest, expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false, expectedBody: false,
}, },
{ {
@@ -264,7 +266,7 @@ func TestRoutesHandlers(t *testing.T) {
requestType: http.MethodPut, requestType: http.MethodPut,
requestPath: "/api/routes/" + existingRouteID, requestPath: "/api/routes/" + existingRouteID,
requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/34\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\"}", existingPeerID)), requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/34\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\"}", existingPeerID)),
expectedStatus: http.StatusBadRequest, expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false, expectedBody: false,
}, },
{ {
@@ -309,7 +311,7 @@ func TestRoutesHandlers(t *testing.T) {
requestType: http.MethodPatch, requestType: http.MethodPatch,
requestPath: "/api/routes/" + existingRouteID, requestPath: "/api/routes/" + existingRouteID,
requestBody: bytes.NewBufferString(fmt.Sprintf("[{\"op\":\"replace\",\"path\":\"peer\",\"value\":[\"%s\"]}]", notFoundPeerID)), requestBody: bytes.NewBufferString(fmt.Sprintf("[{\"op\":\"replace\",\"path\":\"peer\",\"value\":[\"%s\"]}]", notFoundPeerID)),
expectedStatus: http.StatusUnprocessableEntity, expectedStatus: http.StatusNotFound,
expectedBody: false, expectedBody: false,
}, },
{ {

View File

@@ -2,15 +2,13 @@ package http
import ( import (
"encoding/json" "encoding/json"
"fmt"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
"github.com/rs/xid" "github.com/rs/xid"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"net/http" "net/http"
) )
@@ -31,50 +29,56 @@ func NewRules(accountManager server.AccountManager, authAudience string) *Rules
// GetAllRulesHandler list for the account // GetAllRulesHandler list for the account
func (h *Rules) GetAllRulesHandler(w http.ResponseWriter, r *http.Request) { func (h *Rules) GetAllRulesHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
log.Error(err) util.WriteError(err, w)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
accountRules, err := h.accountManager.ListRules(account.Id, user.Id)
if err != nil {
util.WriteError(err, w)
return
}
rules := []*api.Rule{} rules := []*api.Rule{}
for _, r := range account.Rules { for _, r := range accountRules {
rules = append(rules, toRuleResponse(account, r)) rules = append(rules, toRuleResponse(account, r))
} }
writeJSONObject(w, rules) util.WriteJSONObject(w, rules)
} }
// UpdateRuleHandler handles update to a rule identified by a given ID // UpdateRuleHandler handles update to a rule identified by a given ID
func (h *Rules) UpdateRuleHandler(w http.ResponseWriter, r *http.Request) { func (h *Rules) UpdateRuleHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError) util.WriteError(err, w)
return return
} }
vars := mux.Vars(r) vars := mux.Vars(r)
ruleID := vars["id"] ruleID := vars["id"]
if len(ruleID) == 0 { if len(ruleID) == 0 {
http.Error(w, "invalid rule Id", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "invalid rule ID"), w)
return return
} }
_, ok := account.Rules[ruleID] _, ok := account.Rules[ruleID]
if !ok { if !ok {
http.Error(w, fmt.Sprintf("couldn't find rule id %s", ruleID), http.StatusNotFound) util.WriteError(status.Errorf(status.NotFound, "couldn't find rule id %s", ruleID), w)
return return
} }
var req api.PutApiRulesIdJSONRequestBody var req api.PutApiRulesIdJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { err = json.NewDecoder(r.Body).Decode(&req)
http.Error(w, err.Error(), http.StatusBadRequest) if err != nil {
return util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
} }
if req.Name == "" { if req.Name == "" {
http.Error(w, "Rule name shouldn't be empty", http.StatusUnprocessableEntity) util.WriteError(status.Errorf(status.InvalidArgument, "rule name shouldn't be empty"), w)
return return
} }
@@ -101,50 +105,52 @@ func (h *Rules) UpdateRuleHandler(w http.ResponseWriter, r *http.Request) {
case server.TrafficFlowBidirectString: case server.TrafficFlowBidirectString:
rule.Flow = server.TrafficFlowBidirect rule.Flow = server.TrafficFlowBidirect
default: default:
http.Error(w, "unknown flow type", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "unknown flow type"), w)
return return
} }
if err := h.accountManager.SaveRule(account.Id, &rule); err != nil { err = h.accountManager.SaveRule(account.Id, &rule)
log.Errorf("failed updating rule \"%s\" under account %s %v", ruleID, account.Id, err) if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError) util.WriteError(err, w)
return return
} }
resp := toRuleResponse(account, &rule) resp := toRuleResponse(account, &rule)
writeJSONObject(w, &resp) util.WriteJSONObject(w, &resp)
} }
// PatchRuleHandler handles patch updates to a rule identified by a given ID // PatchRuleHandler handles patch updates to a rule identified by a given ID
func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) { func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError) util.WriteError(err, w)
return return
} }
vars := mux.Vars(r) vars := mux.Vars(r)
ruleID := vars["id"] ruleID := vars["id"]
if len(ruleID) == 0 { if len(ruleID) == 0 {
http.Error(w, "invalid rule Id", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "invalid rule ID"), w)
return return
} }
_, ok := account.Rules[ruleID] _, ok := account.Rules[ruleID]
if !ok { if !ok {
http.Error(w, fmt.Sprintf("couldn't find rule id %s", ruleID), http.StatusNotFound) util.WriteError(status.Errorf(status.NotFound, "couldn't find rule ID %s", ruleID), w)
return return
} }
var req api.PatchApiRulesIdJSONRequestBody var req api.PatchApiRulesIdJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { err = json.NewDecoder(r.Body).Decode(&req)
http.Error(w, err.Error(), http.StatusBadRequest) if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return return
} }
if len(req) == 0 { if len(req) == 0 {
http.Error(w, "no patch instruction received", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "no patch instruction received"), w)
return return
} }
@@ -154,12 +160,12 @@ func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
switch patch.Path { switch patch.Path {
case api.RulePatchOperationPathName: case api.RulePatchOperationPathName:
if patch.Op != api.RulePatchOperationOpReplace { if patch.Op != api.RulePatchOperationOpReplace {
http.Error(w, fmt.Sprintf("Name field only accepts replace operation, got %s", patch.Op), util.WriteError(status.Errorf(status.InvalidArgument,
http.StatusBadRequest) "name field only accepts replace operation, got %s", patch.Op), w)
return return
} }
if len(patch.Value) == 0 || patch.Value[0] == "" { if len(patch.Value) == 0 || patch.Value[0] == "" {
http.Error(w, "Rule name shouldn't be empty", http.StatusUnprocessableEntity) util.WriteError(status.Errorf(status.InvalidArgument, "rule name shouldn't be empty"), w)
return return
} }
operations = append(operations, server.RuleUpdateOperation{ operations = append(operations, server.RuleUpdateOperation{
@@ -168,8 +174,8 @@ func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
}) })
case api.RulePatchOperationPathDescription: case api.RulePatchOperationPathDescription:
if patch.Op != api.RulePatchOperationOpReplace { if patch.Op != api.RulePatchOperationOpReplace {
http.Error(w, fmt.Sprintf("Description field only accepts replace operation, got %s", patch.Op), util.WriteError(status.Errorf(status.InvalidArgument,
http.StatusBadRequest) "description field only accepts replace operation, got %s", patch.Op), w)
return return
} }
operations = append(operations, server.RuleUpdateOperation{ operations = append(operations, server.RuleUpdateOperation{
@@ -178,8 +184,8 @@ func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
}) })
case api.RulePatchOperationPathFlow: case api.RulePatchOperationPathFlow:
if patch.Op != api.RulePatchOperationOpReplace { if patch.Op != api.RulePatchOperationOpReplace {
http.Error(w, fmt.Sprintf("Flow field only accepts replace operation, got %s", patch.Op), util.WriteError(status.Errorf(status.InvalidArgument,
http.StatusBadRequest) "flow field only accepts replace operation, got %s", patch.Op), w)
return return
} }
operations = append(operations, server.RuleUpdateOperation{ operations = append(operations, server.RuleUpdateOperation{
@@ -188,8 +194,8 @@ func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
}) })
case api.RulePatchOperationPathDisabled: case api.RulePatchOperationPathDisabled:
if patch.Op != api.RulePatchOperationOpReplace { if patch.Op != api.RulePatchOperationOpReplace {
http.Error(w, fmt.Sprintf("Disabled field only accepts replace operation, got %s", patch.Op), util.WriteError(status.Errorf(status.InvalidArgument,
http.StatusBadRequest) "disabled field only accepts replace operation, got %s", patch.Op), w)
return return
} }
operations = append(operations, server.RuleUpdateOperation{ operations = append(operations, server.RuleUpdateOperation{
@@ -214,7 +220,8 @@ func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
Values: patch.Value, Values: patch.Value,
}) })
default: default:
http.Error(w, "invalid operation, \"%s\", for Source field", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument,
"invalid operation \"%s\" on Source field", patch.Op), w)
return return
} }
case api.RulePatchOperationPathDestinations: case api.RulePatchOperationPathDestinations:
@@ -235,11 +242,12 @@ func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
Values: patch.Value, Values: patch.Value,
}) })
default: default:
http.Error(w, "invalid operation, \"%s\", for Destination field", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument,
"invalid operation \"%s\" on Destination field", patch.Op), w)
return return
} }
default: default:
http.Error(w, "invalid patch path", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "invalid patch path"), w)
return return
} }
} }
@@ -247,48 +255,33 @@ func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
rule, err := h.accountManager.UpdateRule(account.Id, ruleID, operations) rule, err := h.accountManager.UpdateRule(account.Id, ruleID, operations)
if err != nil { if err != nil {
errStatus, ok := status.FromError(err) util.WriteError(err, w)
if ok && errStatus.Code() == codes.Internal {
http.Error(w, errStatus.String(), http.StatusInternalServerError)
return
}
if ok && errStatus.Code() == codes.NotFound {
http.Error(w, errStatus.String(), http.StatusNotFound)
return
}
if ok && errStatus.Code() == codes.InvalidArgument {
http.Error(w, errStatus.String(), http.StatusBadRequest)
return
}
log.Errorf("failed updating rule %s under account %s %v", ruleID, account.Id, err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
resp := toRuleResponse(account, rule) resp := toRuleResponse(account, rule)
writeJSONObject(w, &resp) util.WriteJSONObject(w, &resp)
} }
// CreateRuleHandler handles rule creation request // CreateRuleHandler handles rule creation request
func (h *Rules) CreateRuleHandler(w http.ResponseWriter, r *http.Request) { func (h *Rules) CreateRuleHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError) util.WriteError(err, w)
return return
} }
var req api.PostApiRulesJSONRequestBody var req api.PostApiRulesJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { err = json.NewDecoder(r.Body).Decode(&req)
http.Error(w, err.Error(), http.StatusBadRequest) if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return return
} }
if req.Name == "" { if req.Name == "" {
http.Error(w, "Rule name shouldn't be empty", http.StatusUnprocessableEntity) util.WriteError(status.Errorf(status.InvalidArgument, "rule name shouldn't be empty"), w)
return return
} }
@@ -315,50 +308,52 @@ func (h *Rules) CreateRuleHandler(w http.ResponseWriter, r *http.Request) {
case server.TrafficFlowBidirectString: case server.TrafficFlowBidirectString:
rule.Flow = server.TrafficFlowBidirect rule.Flow = server.TrafficFlowBidirect
default: default:
http.Error(w, "unknown flow type", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "unknown flow type"), w)
return return
} }
if err := h.accountManager.SaveRule(account.Id, &rule); err != nil { err = h.accountManager.SaveRule(account.Id, &rule)
log.Errorf("failed creating rule \"%s\" under account %s %v", req.Name, account.Id, err) if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError) util.WriteError(err, w)
return return
} }
resp := toRuleResponse(account, &rule) resp := toRuleResponse(account, &rule)
writeJSONObject(w, &resp) util.WriteJSONObject(w, &resp)
} }
// DeleteRuleHandler handles rule deletion request // DeleteRuleHandler handles rule deletion request
func (h *Rules) DeleteRuleHandler(w http.ResponseWriter, r *http.Request) { func (h *Rules) DeleteRuleHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError) util.WriteError(err, w)
return return
} }
aID := account.Id aID := account.Id
rID := mux.Vars(r)["id"] rID := mux.Vars(r)["id"]
if len(rID) == 0 { if len(rID) == 0 {
http.Error(w, "invalid rule ID", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "invalid rule ID"), w)
return return
} }
if err := h.accountManager.DeleteRule(aID, rID); err != nil { err = h.accountManager.DeleteRule(aID, rID)
log.Errorf("failed delete rule %s under account %s %v", rID, aID, err) if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError) util.WriteError(err, w)
return return
} }
writeJSONObject(w, "") util.WriteJSONObject(w, "")
} }
// GetRuleHandler handles a group Get request identified by ID // GetRuleHandler handles a group Get request identified by ID
func (h *Rules) GetRuleHandler(w http.ResponseWriter, r *http.Request) { func (h *Rules) GetRuleHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError) util.WriteError(err, w)
return return
} }
@@ -366,19 +361,19 @@ func (h *Rules) GetRuleHandler(w http.ResponseWriter, r *http.Request) {
case http.MethodGet: case http.MethodGet:
ruleID := mux.Vars(r)["id"] ruleID := mux.Vars(r)["id"]
if len(ruleID) == 0 { if len(ruleID) == 0 {
http.Error(w, "invalid rule ID", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "invalid rule ID"), w)
return return
} }
rule, err := h.accountManager.GetRule(account.Id, ruleID) rule, err := h.accountManager.GetRule(account.Id, ruleID, user.Id)
if err != nil { if err != nil {
http.Error(w, "rule not found", http.StatusNotFound) util.WriteError(status.Errorf(status.NotFound, "rule not found"), w)
return return
} }
writeJSONObject(w, toRuleResponse(account, rule)) util.WriteJSONObject(w, toRuleResponse(account, rule))
default: default:
http.Error(w, "", http.StatusNotFound) util.WriteError(status.Errorf(status.NotFound, "method not found"), w)
} }
} }

View File

@@ -28,7 +28,7 @@ func initRulesTestData(rules ...*server.Rule) *Rules {
} }
return nil return nil
}, },
GetRuleFunc: func(_, ruleID string) (*server.Rule, error) { GetRuleFunc: func(_, ruleID, _ string) (*server.Rule, error) {
if ruleID != "idoftherule" { if ruleID != "idoftherule" {
return nil, fmt.Errorf("not found") return nil, fmt.Errorf("not found")
} }
@@ -66,7 +66,8 @@ func initRulesTestData(rules ...*server.Rule) *Rules {
} }
return &rule, nil return &rule, nil
}, },
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) { GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user")
return &server.Account{ return &server.Account{
Id: claims.AccountId, Id: claims.AccountId,
Domain: "hotmail.com", Domain: "hotmail.com",
@@ -75,7 +76,10 @@ func initRulesTestData(rules ...*server.Rule) *Rules {
"F": {ID: "F"}, "F": {ID: "F"},
"G": {ID: "G"}, "G": {ID: "G"},
}, },
}, nil Users: map[string]*server.User{
"test_user": user,
},
}, user, nil
}, },
}, },
authAudience: "", authAudience: "",
@@ -235,7 +239,7 @@ func TestRulesWriteRule(t *testing.T) {
requestPath: "/api/rules/id-existed", requestPath: "/api/rules/id-existed",
requestBody: bytes.NewBuffer( requestBody: bytes.NewBuffer(
[]byte(`[{"op":"insert","path":"name","value":[""]}]`)), []byte(`[{"op":"insert","path":"name","value":[""]}]`)),
expectedStatus: http.StatusBadRequest, expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false, expectedBody: false,
}, },
{ {

View File

@@ -2,14 +2,12 @@ package http
import ( import (
"encoding/json" "encoding/json"
"fmt"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/status"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"net/http" "net/http"
"time" "time"
) )
@@ -31,29 +29,28 @@ func NewSetupKeysHandler(accountManager server.AccountManager, authAudience stri
// CreateSetupKeyHandler is a POST requests that creates a new SetupKey // CreateSetupKeyHandler is a POST requests that creates a new SetupKey
func (h *SetupKeys) CreateSetupKeyHandler(w http.ResponseWriter, r *http.Request) { func (h *SetupKeys) CreateSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
log.Error(err) util.WriteError(err, w)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
req := &api.PostApiSetupKeysJSONRequestBody{} req := &api.PostApiSetupKeysJSONRequestBody{}
err = json.NewDecoder(r.Body).Decode(&req) err = json.NewDecoder(r.Body).Decode(&req)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest) util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return return
} }
if req.Name == "" { if req.Name == "" {
http.Error(w, "Setup key name shouldn't be empty", http.StatusUnprocessableEntity) util.WriteError(status.Errorf(status.InvalidArgument, "setup key name shouldn't be empty"), w)
return return
} }
if !(server.SetupKeyType(req.Type) == server.SetupKeyReusable || if !(server.SetupKeyType(req.Type) == server.SetupKeyReusable ||
server.SetupKeyType(req.Type) == server.SetupKeyOneOff) { server.SetupKeyType(req.Type) == server.SetupKeyOneOff) {
util.WriteError(status.Errorf(status.InvalidArgument, "unknown setup key type %s", string(req.Type)), w)
http.Error(w, "unknown setup key type "+string(req.Type), http.StatusBadRequest)
return return
} }
@@ -62,17 +59,11 @@ func (h *SetupKeys) CreateSetupKeyHandler(w http.ResponseWriter, r *http.Request
if req.AutoGroups == nil { if req.AutoGroups == nil {
req.AutoGroups = []string{} req.AutoGroups = []string{}
} }
// newExpiresIn := time.Duration(req.ExpiresIn) * time.Second
// newKey.ExpiresAt = time.Now().Add(newExpiresIn)
setupKey, err := h.accountManager.CreateSetupKey(account.Id, req.Name, server.SetupKeyType(req.Type), expiresIn, setupKey, err := h.accountManager.CreateSetupKey(account.Id, req.Name, server.SetupKeyType(req.Type), expiresIn,
req.AutoGroups) req.AutoGroups)
if err != nil { if err != nil {
errStatus, ok := status.FromError(err) util.WriteError(err, w)
if ok && errStatus.Code() == codes.NotFound {
http.Error(w, "account not found", http.StatusNotFound)
return
}
http.Error(w, "failed adding setup key", http.StatusInternalServerError)
return return
} }
@@ -81,29 +72,23 @@ func (h *SetupKeys) CreateSetupKeyHandler(w http.ResponseWriter, r *http.Request
// GetSetupKeyHandler is a GET request to get a SetupKey by ID // GetSetupKeyHandler is a GET request to get a SetupKey by ID
func (h *SetupKeys) GetSetupKeyHandler(w http.ResponseWriter, r *http.Request) { func (h *SetupKeys) GetSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
log.Error(err) util.WriteError(err, w)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
vars := mux.Vars(r) vars := mux.Vars(r)
keyID := vars["id"] keyID := vars["id"]
if len(keyID) == 0 { if len(keyID) == 0 {
http.Error(w, "invalid key Id", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "invalid key ID"), w)
return return
} }
key, err := h.accountManager.GetSetupKey(account.Id, keyID) key, err := h.accountManager.GetSetupKey(account.Id, user.Id, keyID)
if err != nil { if err != nil {
errStatus, ok := status.FromError(err) util.WriteError(err, w)
if ok && errStatus.Code() == codes.NotFound {
http.Error(w, fmt.Sprintf("setup key %s not found under account %s", keyID, account.Id), http.StatusNotFound)
return
}
log.Errorf("failed getting setup key %s under account %s %v", keyID, account.Id, err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
@@ -112,34 +97,34 @@ func (h *SetupKeys) GetSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
// UpdateSetupKeyHandler is a PUT request to update server.SetupKey // UpdateSetupKeyHandler is a PUT request to update server.SetupKey
func (h *SetupKeys) UpdateSetupKeyHandler(w http.ResponseWriter, r *http.Request) { func (h *SetupKeys) UpdateSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
log.Error(err) util.WriteError(err, w)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
vars := mux.Vars(r) vars := mux.Vars(r)
keyID := vars["id"] keyID := vars["id"]
if len(keyID) == 0 { if len(keyID) == 0 {
http.Error(w, "invalid key Id", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "invalid key ID"), w)
return return
} }
req := &api.PutApiSetupKeysIdJSONRequestBody{} req := &api.PutApiSetupKeysIdJSONRequestBody{}
err = json.NewDecoder(r.Body).Decode(&req) err = json.NewDecoder(r.Body).Decode(&req)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest) util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return return
} }
if req.Name == "" { if req.Name == "" {
http.Error(w, fmt.Sprintf("setup key name field is invalid: %s", req.Name), http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "setup key name field is invalid: %s", req.Name), w)
return return
} }
if req.AutoGroups == nil { if req.AutoGroups == nil {
http.Error(w, fmt.Sprintf("setup key AutoGroups field is invalid: %s", req.AutoGroups), http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "setup key AutoGroups field is invalid"), w)
return return
} }
@@ -150,16 +135,8 @@ func (h *SetupKeys) UpdateSetupKeyHandler(w http.ResponseWriter, r *http.Request
newKey.Id = keyID newKey.Id = keyID
newKey, err = h.accountManager.SaveSetupKey(account.Id, newKey) newKey, err = h.accountManager.SaveSetupKey(account.Id, newKey)
if err != nil { if err != nil {
if e, ok := status.FromError(err); ok { util.WriteError(err, w)
switch e.Code() {
case codes.NotFound:
http.Error(w, fmt.Sprintf("couldn't find setup key for ID %s", keyID), http.StatusNotFound)
default:
http.Error(w, "failed updating setup key", http.StatusInternalServerError)
}
}
return return
} }
writeSuccess(w, newKey) writeSuccess(w, newKey)
@@ -168,25 +145,25 @@ func (h *SetupKeys) UpdateSetupKeyHandler(w http.ResponseWriter, r *http.Request
// GetAllSetupKeysHandler is a GET request that returns a list of SetupKey // GetAllSetupKeysHandler is a GET request that returns a list of SetupKey
func (h *SetupKeys) GetAllSetupKeysHandler(w http.ResponseWriter, r *http.Request) { func (h *SetupKeys) GetAllSetupKeysHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
log.Error(err) util.WriteError(err, w)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
setupKeys, err := h.accountManager.ListSetupKeys(account.Id) setupKeys, err := h.accountManager.ListSetupKeys(account.Id, user.Id)
if err != nil { if err != nil {
log.Error(err) util.WriteError(err, w)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
apiSetupKeys := make([]*api.SetupKey, 0) apiSetupKeys := make([]*api.SetupKey, 0)
for _, key := range setupKeys { for _, key := range setupKeys {
apiSetupKeys = append(apiSetupKeys, toResponseBody(key)) apiSetupKeys = append(apiSetupKeys, toResponseBody(key))
} }
writeJSONObject(w, apiSetupKeys) util.WriteJSONObject(w, apiSetupKeys)
} }
func writeSuccess(w http.ResponseWriter, key *server.SetupKey) { func writeSuccess(w http.ResponseWriter, key *server.SetupKey) {
@@ -194,7 +171,7 @@ func writeSuccess(w http.ResponseWriter, key *server.SetupKey) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(toResponseBody(key)) err := json.NewEncoder(w).Encode(toResponseBody(key))
if err != nil { if err != nil {
http.Error(w, "failed handling request", http.StatusInternalServerError) util.WriteError(err, w)
return return
} }
} }

View File

@@ -6,9 +6,8 @@ import (
"fmt" "fmt"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/status"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@@ -28,20 +27,24 @@ const (
notFoundSetupKeyID = "notFoundSetupKeyID" notFoundSetupKeyID = "notFoundSetupKeyID"
) )
func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.SetupKey, updatedSetupKey *server.SetupKey) *SetupKeys { func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.SetupKey, updatedSetupKey *server.SetupKey,
user *server.User) *SetupKeys {
return &SetupKeys{ return &SetupKeys{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) { GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return &server.Account{ return &server.Account{
Id: testAccountID, Id: testAccountID,
Domain: "hotmail.com", Domain: "hotmail.com",
Users: map[string]*server.User{
user.Id: user,
},
SetupKeys: map[string]*server.SetupKey{ SetupKeys: map[string]*server.SetupKey{
defaultKey.Key: defaultKey, defaultKey.Key: defaultKey,
}, },
Groups: map[string]*server.Group{ Groups: map[string]*server.Group{
"group-1": {ID: "group-1", Peers: []string{"A", "B"}}, "group-1": {ID: "group-1", Peers: []string{"A", "B"}},
"id-all": {ID: "id-all", Name: "All"}}, "id-all": {ID: "id-all", Name: "All"}},
}, nil }, user, nil
}, },
CreateSetupKeyFunc: func(_ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string) (*server.SetupKey, error) { CreateSetupKeyFunc: func(_ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string) (*server.SetupKey, error) {
if keyName == newKey.Name || typ != newKey.Type { if keyName == newKey.Name || typ != newKey.Type {
@@ -49,14 +52,14 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
} }
return nil, fmt.Errorf("failed creating setup key") return nil, fmt.Errorf("failed creating setup key")
}, },
GetSetupKeyFunc: func(accountID string, keyID string) (*server.SetupKey, error) { GetSetupKeyFunc: func(accountID, userID, keyID string) (*server.SetupKey, error) {
switch keyID { switch keyID {
case defaultKey.Id: case defaultKey.Id:
return defaultKey, nil return defaultKey, nil
case newKey.Id: case newKey.Id:
return newKey, nil return newKey, nil
default: default:
return nil, status.Errorf(codes.NotFound, "key %s not found", keyID) return nil, status.Errorf(status.NotFound, "key %s not found", keyID)
} }
}, },
@@ -64,10 +67,10 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
if key.Id == updatedSetupKey.Id { if key.Id == updatedSetupKey.Id {
return updatedSetupKey, nil return updatedSetupKey, nil
} }
return nil, status.Errorf(codes.NotFound, "key %s not found", key.Id) return nil, status.Errorf(status.NotFound, "key %s not found", key.Id)
}, },
ListSetupKeysFunc: func(accountID string) ([]*server.SetupKey, error) { ListSetupKeysFunc: func(accountID, userID string) ([]*server.SetupKey, error) {
return []*server.SetupKey{defaultKey}, nil return []*server.SetupKey{defaultKey}, nil
}, },
}, },
@@ -75,7 +78,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
jwtExtractor: jwtclaims.ClaimsExtractor{ jwtExtractor: jwtclaims.ClaimsExtractor{
ExtractClaimsFromRequestContext: func(r *http.Request, authAudience string) jwtclaims.AuthorizationClaims { ExtractClaimsFromRequestContext: func(r *http.Request, authAudience string) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{ return jwtclaims.AuthorizationClaims{
UserId: "test_user", UserId: user.Id,
Domain: "hotmail.com", Domain: "hotmail.com",
AccountId: testAccountID, AccountId: testAccountID,
} }
@@ -88,6 +91,8 @@ func TestSetupKeysHandlers(t *testing.T) {
defaultSetupKey := server.GenerateDefaultSetupKey() defaultSetupKey := server.GenerateDefaultSetupKey()
defaultSetupKey.Id = existingSetupKeyID defaultSetupKey.Id = existingSetupKeyID
adminUser := server.NewAdminUser("test_user")
newSetupKey := server.GenerateSetupKey(newSetupKeyName, server.SetupKeyReusable, 0, []string{"group-1"}) newSetupKey := server.GenerateSetupKey(newSetupKeyName, server.SetupKeyReusable, 0, []string{"group-1"})
updatedDefaultSetupKey := defaultSetupKey.Copy() updatedDefaultSetupKey := defaultSetupKey.Copy()
updatedDefaultSetupKey.AutoGroups = []string{"group-1"} updatedDefaultSetupKey.AutoGroups = []string{"group-1"}
@@ -153,7 +158,7 @@ func TestSetupKeysHandlers(t *testing.T) {
}, },
} }
handler := initSetupKeysTestMetaData(defaultSetupKey, newSetupKey, updatedDefaultSetupKey) handler := initSetupKeysTestMetaData(defaultSetupKey, newSetupKey, updatedDefaultSetupKey, adminUser)
for _, tc := range tt { for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {

View File

@@ -2,12 +2,10 @@ package http
import ( import (
"encoding/json" "encoding/json"
"fmt"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/http/util"
"google.golang.org/grpc/codes" "github.com/netbirdio/netbird/management/server/status"
"google.golang.org/grpc/status"
"net/http" "net/http"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
@@ -31,33 +29,34 @@ func NewUserHandler(accountManager server.AccountManager, authAudience string) *
// UpdateUser is a PUT requests to update User data // UpdateUser is a PUT requests to update User data
func (h *UserHandler) UpdateUser(w http.ResponseWriter, r *http.Request) { func (h *UserHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPut { if r.Method != http.MethodPut {
http.Error(w, "", http.StatusBadRequest) util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
} }
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
log.Error(err) util.WriteError(err, w)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
vars := mux.Vars(r) vars := mux.Vars(r)
userID := vars["id"] userID := vars["id"]
if len(userID) == 0 { if len(userID) == 0 {
http.Error(w, "invalid user ID", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return return
} }
req := &api.PutApiUsersIdJSONRequestBody{} req := &api.PutApiUsersIdJSONRequestBody{}
err = json.NewDecoder(r.Body).Decode(&req) err = json.NewDecoder(r.Body).Decode(&req)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest) util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return return
} }
userRole := server.StrRoleToUserRole(req.Role) userRole := server.StrRoleToUserRole(req.Role)
if userRole == server.UserRoleUnknown { if userRole == server.UserRoleUnknown {
http.Error(w, "invalid user role", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "invalid user role"), w)
return return
} }
@@ -67,40 +66,36 @@ func (h *UserHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
AutoGroups: req.AutoGroups, AutoGroups: req.AutoGroups,
}) })
if err != nil { if err != nil {
if e, ok := status.FromError(err); ok { util.WriteError(err, w)
switch e.Code() {
case codes.NotFound:
http.Error(w, fmt.Sprintf("couldn't find a user for ID %s", userID), http.StatusNotFound)
default:
http.Error(w, "failed to update user", http.StatusInternalServerError)
}
}
return return
} }
writeJSONObject(w, toUserResponse(newUser)) util.WriteJSONObject(w, toUserResponse(newUser))
} }
// CreateUserHandler creates a User in the system with a status "invited" (effectively this is a user invite). // CreateUserHandler creates a User in the system with a status "invited" (effectively this is a user invite).
func (h *UserHandler) CreateUserHandler(w http.ResponseWriter, r *http.Request) { func (h *UserHandler) CreateUserHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost { if r.Method != http.MethodPost {
http.Error(w, "", http.StatusNotFound) util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
} }
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
log.Error(err) util.WriteError(err, w)
return
} }
req := &api.PostApiUsersJSONRequestBody{} req := &api.PostApiUsersJSONRequestBody{}
err = json.NewDecoder(r.Body).Decode(&req) err = json.NewDecoder(r.Body).Decode(&req)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest) util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return return
} }
if server.StrRoleToUserRole(req.Role) == server.UserRoleUnknown { if server.StrRoleToUserRole(req.Role) == server.UserRoleUnknown {
http.Error(w, "unknown user role "+req.Role, http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "unknown user role %s", req.Role), w)
return return
} }
@@ -111,36 +106,30 @@ func (h *UserHandler) CreateUserHandler(w http.ResponseWriter, r *http.Request)
AutoGroups: req.AutoGroups, AutoGroups: req.AutoGroups,
}) })
if err != nil { if err != nil {
if e, ok := server.FromError(err); ok { util.WriteError(err, w)
switch e.Type() {
case server.UserAlreadyExists:
http.Error(w, "You can't invite users with an existing NetBird account.", http.StatusPreconditionFailed)
return
default:
}
}
http.Error(w, "failed to invite", http.StatusInternalServerError)
return return
} }
writeJSONObject(w, toUserResponse(newUser)) util.WriteJSONObject(w, toUserResponse(newUser))
} }
// GetUsers returns a list of users of the account this user belongs to. // GetUsers returns a list of users of the account this user belongs to.
// It also gathers additional user data (like email and name) from the IDP manager. // It also gathers additional user data (like email and name) from the IDP manager.
func (h *UserHandler) GetUsers(w http.ResponseWriter, r *http.Request) { func (h *UserHandler) GetUsers(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet { if r.Method != http.MethodGet {
http.Error(w, "", http.StatusBadRequest) util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
} }
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
log.Error(err) util.WriteError(err, w)
return
} }
data, err := h.accountManager.GetUsersFromAccount(account.Id) data, err := h.accountManager.GetUsersFromAccount(account.Id, user.Id)
if err != nil { if err != nil {
log.Error(err) util.WriteError(err, w)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
@@ -149,7 +138,7 @@ func (h *UserHandler) GetUsers(w http.ResponseWriter, r *http.Request) {
users = append(users, toUserResponse(r)) users = append(users, toUserResponse(r))
} }
writeJSONObject(w, users) util.WriteJSONObject(w, users)
} }
func toUserResponse(user *server.UserInfo) *api.User { func toUserResponse(user *server.UserInfo) *api.User {

View File

@@ -16,7 +16,7 @@ import (
func initUsers(user ...*server.User) *UserHandler { func initUsers(user ...*server.User) *UserHandler {
return &UserHandler{ return &UserHandler{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) { GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
users := make(map[string]*server.User, 0) users := make(map[string]*server.User, 0)
for _, u := range user { for _, u := range user {
users[u.Id] = u users[u.Id] = u
@@ -25,9 +25,9 @@ func initUsers(user ...*server.User) *UserHandler {
Id: "12345", Id: "12345",
Domain: "netbird.io", Domain: "netbird.io",
Users: users, Users: users,
}, nil }, users[claims.UserId], nil
}, },
GetUsersFromAccountFunc: func(accountID string) ([]*server.UserInfo, error) { GetUsersFromAccountFunc: func(accountID, userID string) ([]*server.UserInfo, error) {
users := make([]*server.UserInfo, 0) users := make([]*server.UserInfo, 0)
for _, v := range user { for _, v := range user {
users = append(users, &server.UserInfo{ users = append(users, &server.UserInfo{
@@ -44,7 +44,7 @@ func initUsers(user ...*server.User) *UserHandler {
jwtExtractor: jwtclaims.ClaimsExtractor{ jwtExtractor: jwtclaims.ClaimsExtractor{
ExtractClaimsFromRequestContext: func(r *http.Request, authAudiance string) jwtclaims.AuthorizationClaims { ExtractClaimsFromRequestContext: func(r *http.Request, authAudiance string) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{ return jwtclaims.AuthorizationClaims{
UserId: "test_user", UserId: "1",
Domain: "hotmail.com", Domain: "hotmail.com",
AccountId: "test_id", AccountId: "test_id",
} }
@@ -66,7 +66,6 @@ func TestGetUsers(t *testing.T) {
expectedResult []*server.User expectedResult []*server.User
}{ }{
{name: "GetAllUsers", requestType: http.MethodGet, requestPath: "/api/users/", expectedStatus: http.StatusOK, expectedResult: users}, {name: "GetAllUsers", requestType: http.MethodGet, requestPath: "/api/users/", expectedStatus: http.StatusOK, expectedResult: users},
{name: "WrongRequestMethod", requestType: http.MethodPost, requestPath: "/api/users/", expectedStatus: http.StatusBadRequest},
} }
for _, tc := range tt { for _, tc := range tt {

View File

@@ -1,91 +0,0 @@
package http
import (
"encoding/json"
"errors"
"fmt"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/jwtclaims"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"net/http"
"time"
)
// writeJSONObject simply writes object to the HTTP reponse in JSON format
func writeJSONObject(w http.ResponseWriter, obj interface{}) {
w.WriteHeader(http.StatusOK)
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
err := json.NewEncoder(w).Encode(obj)
if err != nil {
http.Error(w, "failed handling request", http.StatusInternalServerError)
return
}
}
// Duration is used strictly for JSON requests/responses due to duration marshalling issues
type Duration struct {
time.Duration
}
func (d Duration) MarshalJSON() ([]byte, error) {
return json.Marshal(d.String())
}
func (d *Duration) UnmarshalJSON(b []byte) error {
var v interface{}
if err := json.Unmarshal(b, &v); err != nil {
return err
}
switch value := v.(type) {
case float64:
d.Duration = time.Duration(value)
return nil
case string:
var err error
d.Duration, err = time.ParseDuration(value)
if err != nil {
return err
}
return nil
default:
return errors.New("invalid duration")
}
}
func getJWTAccount(accountManager server.AccountManager,
jwtExtractor jwtclaims.ClaimsExtractor,
authAudience string, r *http.Request) (*server.Account, error) {
jwtClaims := jwtExtractor.ExtractClaimsFromRequestContext(r, authAudience)
account, err := accountManager.GetAccountFromToken(jwtClaims)
if err != nil {
return nil, fmt.Errorf("failed getting account of a user %s: %v", jwtClaims.UserId, err)
}
return account, nil
}
func toHTTPError(err error, w http.ResponseWriter) {
errStatus, ok := status.FromError(err)
if ok && errStatus.Code() == codes.Internal {
http.Error(w, errStatus.String(), http.StatusInternalServerError)
return
}
if ok && errStatus.Code() == codes.NotFound {
http.Error(w, errStatus.String(), http.StatusNotFound)
return
}
if ok && errStatus.Code() == codes.InvalidArgument {
http.Error(w, errStatus.String(), http.StatusBadRequest)
return
}
unhandledMSG := fmt.Sprintf("got unhandled error code, error: %s", errStatus.String())
log.Error(unhandledMSG)
http.Error(w, unhandledMSG, http.StatusInternalServerError)
}

View File

@@ -0,0 +1,105 @@
package util
import (
"encoding/json"
"errors"
"fmt"
"github.com/netbirdio/netbird/management/server/status"
log "github.com/sirupsen/logrus"
"net/http"
"time"
)
// WriteJSONObject simply writes object to the HTTP reponse in JSON format
func WriteJSONObject(w http.ResponseWriter, obj interface{}) {
w.WriteHeader(http.StatusOK)
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
err := json.NewEncoder(w).Encode(obj)
if err != nil {
WriteError(err, w)
return
}
}
// Duration is used strictly for JSON requests/responses due to duration marshalling issues
type Duration struct {
time.Duration
}
// MarshalJSON marshals the duration
func (d Duration) MarshalJSON() ([]byte, error) {
return json.Marshal(d.String())
}
// UnmarshalJSON unmarshals the duration
func (d *Duration) UnmarshalJSON(b []byte) error {
var v interface{}
if err := json.Unmarshal(b, &v); err != nil {
return err
}
switch value := v.(type) {
case float64:
d.Duration = time.Duration(value)
return nil
case string:
var err error
d.Duration, err = time.ParseDuration(value)
if err != nil {
return err
}
return nil
default:
return errors.New("invalid duration")
}
}
// WriteErrorResponse prepares and writes an error response i nJSON
func WriteErrorResponse(errMsg string, httpStatus int, w http.ResponseWriter) {
type errorResponse struct {
Message string `json:"message"`
Code int `json:"code"`
}
w.WriteHeader(httpStatus)
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
err := json.NewEncoder(w).Encode(&errorResponse{
Message: errMsg,
Code: httpStatus,
})
if err != nil {
http.Error(w, "failed handling request", http.StatusInternalServerError)
}
}
// WriteError converts an error to an JSON error response.
// If it is known internal error of type server.Error then it sets the messages from the error, a generic message otherwise
func WriteError(err error, w http.ResponseWriter) {
errStatus, ok := status.FromError(err)
httpStatus := http.StatusInternalServerError
msg := "internal server error"
if ok {
switch errStatus.Type() {
case status.UserAlreadyExists:
httpStatus = http.StatusConflict
case status.AlreadyExists:
httpStatus = http.StatusConflict
case status.PreconditionFailed:
httpStatus = http.StatusPreconditionFailed
case status.PermissionDenied:
httpStatus = http.StatusForbidden
case status.NotFound:
httpStatus = http.StatusNotFound
case status.Internal:
httpStatus = http.StatusInternalServerError
case status.InvalidArgument:
httpStatus = http.StatusUnprocessableEntity
default:
}
msg = err.Error()
} else {
unhandledMSG := fmt.Sprintf("got unhandled error code, error: %s", err.Error())
log.Error(unhandledMSG)
}
WriteErrorResponse(msg, httpStatus, w)
}

View File

@@ -6,6 +6,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/netbirdio/netbird/management/server/telemetry"
"io" "io"
"net/http" "net/http"
"net/url" "net/url"
@@ -24,6 +25,7 @@ type Auth0Manager struct {
httpClient ManagerHTTPClient httpClient ManagerHTTPClient
credentials ManagerCredentials credentials ManagerCredentials
helper ManagerHelper helper ManagerHelper
appMetrics telemetry.AppMetrics
} }
// Auth0ClientConfig auth0 manager client configurations // Auth0ClientConfig auth0 manager client configurations
@@ -51,6 +53,7 @@ type Auth0Credentials struct {
httpClient ManagerHTTPClient httpClient ManagerHTTPClient
jwtToken JWTToken jwtToken JWTToken
mux sync.Mutex mux sync.Mutex
appMetrics telemetry.AppMetrics
} }
// createUserRequest is a user create request // createUserRequest is a user create request
@@ -106,7 +109,7 @@ type auth0Profile struct {
} }
// NewAuth0Manager creates a new instance of the Auth0Manager // NewAuth0Manager creates a new instance of the Auth0Manager
func NewAuth0Manager(config Auth0ClientConfig) (*Auth0Manager, error) { func NewAuth0Manager(config Auth0ClientConfig, appMetrics telemetry.AppMetrics) (*Auth0Manager, error) {
httpTransport := http.DefaultTransport.(*http.Transport).Clone() httpTransport := http.DefaultTransport.(*http.Transport).Clone()
httpTransport.MaxIdleConns = 5 httpTransport.MaxIdleConns = 5
@@ -134,12 +137,15 @@ func NewAuth0Manager(config Auth0ClientConfig) (*Auth0Manager, error) {
clientConfig: config, clientConfig: config,
httpClient: httpClient, httpClient: httpClient,
helper: helper, helper: helper,
appMetrics: appMetrics,
} }
return &Auth0Manager{ return &Auth0Manager{
authIssuer: config.AuthIssuer, authIssuer: config.AuthIssuer,
credentials: credentials, credentials: credentials,
httpClient: httpClient, httpClient: httpClient,
helper: helper, helper: helper,
appMetrics: appMetrics,
}, nil }, nil
} }
@@ -170,6 +176,9 @@ func (c *Auth0Credentials) requestJWTToken() (*http.Response, error) {
res, err = c.httpClient.Do(req) res, err = c.httpClient.Do(req)
if err != nil { if err != nil {
if c.appMetrics != nil {
c.appMetrics.IDPMetrics().CountRequestError()
}
return res, err return res, err
} }
@@ -214,6 +223,10 @@ func (c *Auth0Credentials) Authenticate() (JWTToken, error) {
c.mux.Lock() c.mux.Lock()
defer c.mux.Unlock() defer c.mux.Unlock()
if c.appMetrics != nil {
c.appMetrics.IDPMetrics().CountAuthenticate()
}
// If jwtToken has an expires time and we have enough time to do a request return immediately // If jwtToken has an expires time and we have enough time to do a request return immediately
if c.jwtStillValid() { if c.jwtStillValid() {
return c.jwtToken, nil return c.jwtToken, nil
@@ -287,9 +300,16 @@ func (am *Auth0Manager) GetAccount(accountID string) ([]*UserData, error) {
res, err := am.httpClient.Do(req) res, err := am.httpClient.Do(req)
if err != nil { if err != nil {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
return nil, err return nil, err
} }
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountGetAccount()
}
body, err := io.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -342,9 +362,16 @@ func (am *Auth0Manager) GetUserDataByID(userID string, appMetadata AppMetadata)
res, err := am.httpClient.Do(req) res, err := am.httpClient.Do(req)
if err != nil { if err != nil {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
return nil, err return nil, err
} }
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountGetUserDataByID()
}
body, err := io.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -398,9 +425,16 @@ func (am *Auth0Manager) UpdateUserAppMetadata(userID string, appMetadata AppMeta
res, err := am.httpClient.Do(req) res, err := am.httpClient.Do(req)
if err != nil { if err != nil {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
return err return err
} }
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountUpdateUserAppMetadata()
}
defer func() { defer func() {
err = res.Body.Close() err = res.Body.Close()
if err != nil { if err != nil {
@@ -416,12 +450,13 @@ func (am *Auth0Manager) UpdateUserAppMetadata(userID string, appMetadata AppMeta
} }
func buildCreateUserRequestPayload(email string, name string, accountID string) (string, error) { func buildCreateUserRequestPayload(email string, name string, accountID string) (string, error) {
invite := true
req := &createUserRequest{ req := &createUserRequest{
Email: email, Email: email,
Name: name, Name: name,
AppMeta: AppMetadata{ AppMeta: AppMetadata{
WTAccountID: accountID, WTAccountID: accountID,
WTPendingInvite: true, WTPendingInvite: &invite,
}, },
Connection: "Username-Password-Authentication", Connection: "Username-Password-Authentication",
Password: GeneratePassword(8, 1, 1, 1), Password: GeneratePassword(8, 1, 1, 1),
@@ -502,6 +537,9 @@ func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) {
jobResp, err := am.httpClient.Do(exportJobReq) jobResp, err := am.httpClient.Do(exportJobReq)
if err != nil { if err != nil {
log.Debugf("Couldn't get job response %v", err) log.Debugf("Couldn't get job response %v", err)
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
return nil, err return nil, err
} }
@@ -512,6 +550,9 @@ func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) {
} }
}() }()
if jobResp.StatusCode != 200 { if jobResp.StatusCode != 200 {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to update the appMetadata, statusCode %d", jobResp.StatusCode) return nil, fmt.Errorf("unable to update the appMetadata, statusCode %d", jobResp.StatusCode)
} }
@@ -530,6 +571,9 @@ func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) {
} }
if exportJobResp.ID == "" { if exportJobResp.ID == "" {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("couldn't get an batch id status %d, %s, response body: %v", jobResp.StatusCode, jobResp.Status, exportJobResp) return nil, fmt.Errorf("couldn't get an batch id status %d, %s, response body: %v", jobResp.StatusCode, jobResp.Status, exportJobResp)
} }
@@ -556,12 +600,16 @@ func (am *Auth0Manager) GetUserByEmail(email string) ([]*UserData, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
reqURL := am.authIssuer + "/api/v2/users-by-email?email=" + email reqURL := am.authIssuer + "/api/v2/users-by-email?email=" + url.QueryEscape(email)
body, err := doGetReq(am.httpClient, reqURL, jwtToken.AccessToken) body, err := doGetReq(am.httpClient, reqURL, jwtToken.AccessToken)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountGetUserByEmail()
}
userResp := []*UserData{} userResp := []*UserData{}
err = am.helper.Unmarshal(body, &userResp) err = am.helper.Unmarshal(body, &userResp)
@@ -585,9 +633,16 @@ func (am *Auth0Manager) CreateUser(email string, name string, accountID string)
return nil, err return nil, err
} }
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountCreateUser()
}
resp, err := am.httpClient.Do(req) resp, err := am.httpClient.Do(req)
if err != nil { if err != nil {
log.Debugf("Couldn't get job response %v", err) log.Debugf("Couldn't get job response %v", err)
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
return nil, err return nil, err
} }
@@ -598,6 +653,9 @@ func (am *Auth0Manager) CreateUser(email string, name string, accountID string)
} }
}() }()
if !(resp.StatusCode == 200 || resp.StatusCode == 201) { if !(resp.StatusCode == 200 || resp.StatusCode == 201) {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to create user, statusCode %d", resp.StatusCode) return nil, fmt.Errorf("unable to create user, statusCode %d", resp.StatusCode)
} }
@@ -698,7 +756,7 @@ func (am *Auth0Manager) downloadProfileExport(location string) (map[string][]*Us
Email: profile.Email, Email: profile.Email,
AppMetadata: AppMetadata{ AppMetadata: AppMetadata{
WTAccountID: profile.AccountID, WTAccountID: profile.AccountID,
WTPendingInvite: profile.PendingInvite, WTPendingInvite: &profile.PendingInvite,
}, },
}) })
} }
@@ -729,13 +787,12 @@ func doGetReq(client ManagerHTTPClient, url, accessToken string) ([]byte, error)
log.Errorf("error while closing body for url %s: %v", url, err) log.Errorf("error while closing body for url %s: %v", url, err)
} }
}() }()
if res.StatusCode != 200 {
return nil, fmt.Errorf("unable to get %s, statusCode %d", url, res.StatusCode)
}
body, err := io.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if res.StatusCode != 200 {
return nil, fmt.Errorf("unable to get %s, statusCode %d", url, res.StatusCode)
}
return body, nil return body, nil
} }

View File

@@ -3,6 +3,7 @@ package idp
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"io" "io"
"net/http" "net/http"
@@ -340,7 +341,7 @@ func TestAuth0_UpdateUserAppMetadata(t *testing.T) {
updateUserAppMetadataTestCase2 := updateUserAppMetadataTest{ updateUserAppMetadataTestCase2 := updateUserAppMetadataTest{
name: "Bad Status Code", name: "Bad Status Code",
inputReqBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp), inputReqBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
expectedReqBody: fmt.Sprintf("{\"app_metadata\":{\"wt_account_id\":\"%s\",\"wt_pending_invite\":false}}", appMetadata.WTAccountID), expectedReqBody: fmt.Sprintf("{\"app_metadata\":{\"wt_account_id\":\"%s\",\"wt_pending_invite\":null}}", appMetadata.WTAccountID),
appMetadata: appMetadata, appMetadata: appMetadata,
statusCode: 400, statusCode: 400,
helper: JsonParser{}, helper: JsonParser{},
@@ -363,7 +364,7 @@ func TestAuth0_UpdateUserAppMetadata(t *testing.T) {
updateUserAppMetadataTestCase4 := updateUserAppMetadataTest{ updateUserAppMetadataTestCase4 := updateUserAppMetadataTest{
name: "Good request", name: "Good request",
inputReqBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp), inputReqBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
expectedReqBody: fmt.Sprintf("{\"app_metadata\":{\"wt_account_id\":\"%s\",\"wt_pending_invite\":false}}", appMetadata.WTAccountID), expectedReqBody: fmt.Sprintf("{\"app_metadata\":{\"wt_account_id\":\"%s\",\"wt_pending_invite\":null}}", appMetadata.WTAccountID),
appMetadata: appMetadata, appMetadata: appMetadata,
statusCode: 200, statusCode: 200,
helper: JsonParser{}, helper: JsonParser{},
@@ -371,7 +372,23 @@ func TestAuth0_UpdateUserAppMetadata(t *testing.T) {
assertErrFuncMessage: "shouldn't return error", assertErrFuncMessage: "shouldn't return error",
} }
for _, testCase := range []updateUserAppMetadataTest{updateUserAppMetadataTestCase1, updateUserAppMetadataTestCase2, updateUserAppMetadataTestCase3, updateUserAppMetadataTestCase4} { invite := true
updateUserAppMetadataTestCase5 := updateUserAppMetadataTest{
name: "Update Pending Invite",
inputReqBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
expectedReqBody: fmt.Sprintf("{\"app_metadata\":{\"wt_account_id\":\"%s\",\"wt_pending_invite\":true}}", appMetadata.WTAccountID),
appMetadata: AppMetadata{
WTAccountID: "ok",
WTPendingInvite: &invite,
},
statusCode: 200,
helper: JsonParser{},
assertErrFunc: assert.NoError,
assertErrFuncMessage: "shouldn't return error",
}
for _, testCase := range []updateUserAppMetadataTest{updateUserAppMetadataTestCase1, updateUserAppMetadataTestCase2,
updateUserAppMetadataTestCase3, updateUserAppMetadataTestCase4, updateUserAppMetadataTestCase5} {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
jwtReqClient := mockHTTPClient{ jwtReqClient := mockHTTPClient{
resBody: testCase.inputReqBody, resBody: testCase.inputReqBody,
@@ -459,7 +476,7 @@ func TestNewAuth0Manager(t *testing.T) {
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4} { for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4} {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
_, err := NewAuth0Manager(testCase.inputConfig) _, err := NewAuth0Manager(testCase.inputConfig, &telemetry.MockAppMetrics{})
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage) testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
}) })
} }

View File

@@ -2,6 +2,7 @@ package idp
import ( import (
"fmt" "fmt"
"github.com/netbirdio/netbird/management/server/telemetry"
"net/http" "net/http"
"strings" "strings"
"time" "time"
@@ -51,7 +52,7 @@ type AppMetadata struct {
// WTAccountID is a NetBird (previously Wiretrustee) account id to update in the IDP // WTAccountID is a NetBird (previously Wiretrustee) account id to update in the IDP
// maps to wt_account_id when json.marshal // maps to wt_account_id when json.marshal
WTAccountID string `json:"wt_account_id,omitempty"` WTAccountID string `json:"wt_account_id,omitempty"`
WTPendingInvite bool `json:"wt_pending_invite"` WTPendingInvite *bool `json:"wt_pending_invite"`
} }
// JWTToken a JWT object that holds information of a token // JWTToken a JWT object that holds information of a token
@@ -64,12 +65,12 @@ type JWTToken struct {
} }
// NewManager returns a new idp manager based on the configuration that it receives // NewManager returns a new idp manager based on the configuration that it receives
func NewManager(config Config) (Manager, error) { func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error) {
switch strings.ToLower(config.ManagerType) { switch strings.ToLower(config.ManagerType) {
case "none", "": case "none", "":
return nil, nil return nil, nil
case "auth0": case "auth0":
return NewAuth0Manager(config.Auth0ClientCredentials) return NewAuth0Manager(config.Auth0ClientCredentials, appMetrics)
default: default:
return nil, fmt.Errorf("invalid manager type: %s", config.ManagerType) return nil, fmt.Errorf("invalid manager type: %s", config.ManagerType)
} }

View File

@@ -398,17 +398,17 @@ func startManagement(t *testing.T, port int, config *Config) (*grpc.Server, erro
return nil, err return nil, err
} }
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
store, err := NewStore(config.Datadir) store, err := NewFileStore(config.Datadir)
if err != nil { if err != nil {
return nil, err return nil, err
} }
peersUpdateManager := NewPeersUpdateManager() peersUpdateManager := NewPeersUpdateManager()
accountManager, err := BuildManager(store, peersUpdateManager, nil) accountManager, err := BuildManager(store, peersUpdateManager, nil, "", "")
if err != nil { if err != nil {
return nil, err return nil, err
} }
turnManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig) turnManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := NewServer(config, accountManager, peersUpdateManager, turnManager) mgmtServer, err := NewServer(config, accountManager, peersUpdateManager, turnManager, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -488,17 +488,17 @@ func startServer(config *server.Config) (*grpc.Server, net.Listener) {
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
s := grpc.NewServer() s := grpc.NewServer()
store, err := server.NewStore(config.Datadir) store, err := server.NewFileStore(config.Datadir)
if err != nil { if err != nil {
log.Fatalf("failed creating a store: %s: %v", config.Datadir, err) log.Fatalf("failed creating a store: %s: %v", config.Datadir, err)
} }
peersUpdateManager := server.NewPeersUpdateManager() peersUpdateManager := server.NewPeersUpdateManager()
accountManager, err := server.BuildManager(store, peersUpdateManager, nil) accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "")
if err != nil { if err != nil {
log.Fatalf("failed creating a manager: %v", err) log.Fatalf("failed creating a manager: %v", err)
} }
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig) turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager) mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
mgmtProto.RegisterManagementServiceServer(s, mgmtServer) mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
go func() { go func() {

View File

@@ -5,11 +5,14 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/hashicorp/go-version"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"io" "io"
"net/http" "net/http"
"regexp"
"sort"
"strings" "strings"
"time" "time"
) )
@@ -17,7 +20,7 @@ import (
const ( const (
// PayloadEvent identifies an event type // PayloadEvent identifies an event type
PayloadEvent = "self-hosted stats" PayloadEvent = "self-hosted stats"
// payloadEndpoint metrics endpoint to send anonymous data // payloadEndpoint metrics defaultEndpoint to send anonymous data
payloadEndpoint = "https://metrics.netbird.io" payloadEndpoint = "https://metrics.netbird.io"
// defaultPushInterval default interval to push metrics // defaultPushInterval default interval to push metrics
defaultPushInterval = 24 * time.Hour defaultPushInterval = 24 * time.Hour
@@ -85,6 +88,7 @@ func (w *Worker) Run() {
if err != nil { if err != nil {
log.Error(err) log.Error(err)
} }
w.lastRun = time.Now()
} }
} }
} }
@@ -161,11 +165,15 @@ func (w *Worker) generateProperties() properties {
groups int groups int
routes int routes int
nameservers int nameservers int
uiClient int
version string version string
peerActiveVersions []string
osUIClients map[string]int
) )
start := time.Now() start := time.Now()
metricsProperties := make(properties) metricsProperties := make(properties)
osPeers = make(map[string]int) osPeers = make(map[string]int)
osUIClients = make(map[string]int)
uptime = time.Since(w.startupTime).Seconds() uptime = time.Since(w.startupTime).Seconds()
connections := w.connManager.GetAllConnectedPeers() connections := w.connManager.GetAllConnectedPeers()
version = system.NetbirdVersion() version = system.NetbirdVersion()
@@ -184,20 +192,33 @@ func (w *Worker) generateProperties() properties {
for _, peer := range account.Peers { for _, peer := range account.Peers {
peers++ peers++
if peer.SetupKey != "" { if peer.SetupKey == "" {
userPeers++ userPeers++
} }
osKey := strings.ToLower(fmt.Sprintf("peer_os_%s", peer.Meta.GoOS))
osCount := osPeers[osKey]
osPeers[osKey] = osCount + 1
if peer.Meta.UIVersion != "" {
uiClient++
uiOSKey := strings.ToLower(fmt.Sprintf("ui_client_os_%s", peer.Meta.GoOS))
osUICount := osUIClients[uiOSKey]
osUIClients[uiOSKey] = osUICount + 1
}
_, connected := connections[peer.Key] _, connected := connections[peer.Key]
if connected || peer.Status.LastSeen.After(w.lastRun) { if connected || peer.Status.LastSeen.After(w.lastRun) {
activePeersLastDay++ activePeersLastDay++
osActiveKey := osKey + "_active"
osActiveCount := osPeers[osActiveKey]
osPeers[osActiveKey] = osActiveCount + 1
peerActiveVersions = append(peerActiveVersions, peer.Meta.WtVersion)
} }
osKey := strings.ToLower(fmt.Sprintf("peer_os_%s", peer.Meta.GoOS))
osCount := osPeers[osKey]
osPeers[osKey] = osCount + 1
} }
} }
minActivePeerVersion, maxActivePeerVersion := getMinMaxVersion(peerActiveVersions)
metricsProperties["uptime"] = uptime metricsProperties["uptime"] = uptime
metricsProperties["accounts"] = accounts metricsProperties["accounts"] = accounts
metricsProperties["users"] = users metricsProperties["users"] = users
@@ -210,11 +231,17 @@ func (w *Worker) generateProperties() properties {
metricsProperties["routes"] = routes metricsProperties["routes"] = routes
metricsProperties["nameservers"] = nameservers metricsProperties["nameservers"] = nameservers
metricsProperties["version"] = version metricsProperties["version"] = version
metricsProperties["min_active_peer_version"] = minActivePeerVersion
metricsProperties["max_active_peer_version"] = maxActivePeerVersion
metricsProperties["ui_clients"] = uiClient
for os, count := range osPeers { for os, count := range osPeers {
metricsProperties[os] = count metricsProperties[os] = count
} }
for os, count := range osUIClients {
metricsProperties[os] = count
}
metricsProperties["metric_generation_time"] = time.Since(start).Milliseconds() metricsProperties["metric_generation_time"] = time.Since(start).Milliseconds()
return metricsProperties return metricsProperties
@@ -279,5 +306,32 @@ func createPostRequest(ctx context.Context, endpoint string, payloadStr string)
req.Header.Add("content-type", "application/json") req.Header.Add("content-type", "application/json")
return req, nil return req, nil
}
func getMinMaxVersion(inputList []string) (string, string) {
reg, err := regexp.Compile(version.SemverRegexpRaw)
if err != nil {
return "", ""
}
versions := make([]*version.Version, 0)
for _, raw := range inputList {
if raw != "" && reg.MatchString(raw) {
v, err := version.NewVersion(raw)
if err == nil {
versions = append(versions, v)
}
}
}
switch len(versions) {
case 0:
return "", ""
case 1:
v := versions[0].String()
return v, v
default:
sort.Sort(version.Collection(versions))
return versions[0].String(), versions[len(versions)-1].String()
}
} }

View File

@@ -1,13 +0,0 @@
## Migration from Store v2 to Store v2
Previously Account.Id was an Auth0 user id.
Conversion moves user id to Account.CreatedBy and generates a new Account.Id using xid.
It also adds a User with id = old Account.Id with a role Admin.
To start a conversion simply run the command below providing your current Wiretrustee Management datadir (where store.json file is located)
and a new data directory location (where a converted store.js will be stored):
```shell
./migration --oldDir /var/wiretrustee/datadir --newDir /var/wiretrustee/newdatadir/
```
Afterwards you can run the Management service providing ```/var/wiretrustee/newdatadir/ ``` as a datadir.

View File

@@ -1,56 +0,0 @@
package main
import (
"flag"
"fmt"
"github.com/netbirdio/netbird/management/server"
"github.com/rs/xid"
)
func main() {
oldDir := flag.String("oldDir", "old store directory", "/var/wiretrustee/datadir")
newDir := flag.String("newDir", "new store directory", "/var/wiretrustee/newdatadir")
flag.Parse()
oldStore, err := server.NewStore(*oldDir)
if err != nil {
panic(err)
}
newStore, err := server.NewStore(*newDir)
if err != nil {
panic(err)
}
err = Convert(oldStore, newStore)
if err != nil {
panic(err)
}
fmt.Println("successfully converted")
}
// Convert converts old store ato a new store
// Previously Account.Id was an Auth0 user id
// Conversion moved user id to Account.CreatedBy and generated a new Account.Id using xid
// It also adds a User with id = old Account.Id with a role Admin
func Convert(oldStore *server.FileStore, newStore *server.FileStore) error {
for _, account := range oldStore.Accounts {
accountCopy := account.Copy()
accountCopy.Id = xid.New().String()
accountCopy.CreatedBy = account.Id
accountCopy.Users[account.Id] = &server.User{
Id: account.Id,
Role: server.UserRoleAdmin,
}
err := newStore.SaveAccount(accountCopy)
if err != nil {
return err
}
}
return nil
}

View File

@@ -1,76 +0,0 @@
package main
import (
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/util"
"path/filepath"
"testing"
)
func TestConvertAccounts(t *testing.T) {
storeDir := t.TempDir()
err := util.CopyFileContents("../testdata/storev1.json", filepath.Join(storeDir, "store.json"))
if err != nil {
t.Fatal(err)
}
store, err := server.NewStore(storeDir)
if err != nil {
t.Fatal(err)
}
convertedStore, err := server.NewStore(filepath.Join(storeDir, "converted"))
if err != nil {
t.Fatal(err)
}
err = Convert(store, convertedStore)
if err != nil {
t.Fatal(err)
}
if len(store.Accounts) != len(convertedStore.Accounts) {
t.Errorf("expecting the same number of accounts after conversion")
}
for _, account := range store.Accounts {
convertedAccount, err := convertedStore.GetUserAccount(account.Id)
if err != nil || convertedAccount == nil {
t.Errorf("expecting Account %s to be converted", account.Id)
return
}
if convertedAccount.CreatedBy != account.Id {
t.Errorf("expecting converted Account.CreatedBy field to be equal to the old Account.Id")
return
}
if convertedAccount.Id == account.Id {
t.Errorf("expecting converted Account.Id to be different from Account.Id")
return
}
if len(convertedAccount.Users) != 1 {
t.Errorf("expecting converted Account.Users to be of size 1")
return
}
user := convertedAccount.Users[account.Id]
if user == nil {
t.Errorf("expecting to find a user in converted Account.Users")
return
}
if user.Role != server.UserRoleAdmin {
t.Errorf("expecting to find a user in converted Account.Users with a role Admin")
return
}
for peerId := range account.Peers {
convertedPeer := convertedAccount.Peers[peerId]
if convertedPeer == nil {
t.Errorf("expecting Account Peer of StoreV1 to be found in StoreV2")
return
}
}
}
}

View File

@@ -14,14 +14,13 @@ type MockAccountManager struct {
GetOrCreateAccountByUserFunc func(userId, domain string) (*server.Account, error) GetOrCreateAccountByUserFunc func(userId, domain string) (*server.Account, error)
GetAccountByUserFunc func(userId string) (*server.Account, error) GetAccountByUserFunc func(userId string) (*server.Account, error)
CreateSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType, expiresIn time.Duration, autoGroups []string) (*server.SetupKey, error) CreateSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType, expiresIn time.Duration, autoGroups []string) (*server.SetupKey, error)
GetSetupKeyFunc func(accountID string, keyID string) (*server.SetupKey, error) GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error)
GetAccountByIdFunc func(accountId string) (*server.Account, error)
GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error) GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error)
IsUserAdminFunc func(claims jwtclaims.AuthorizationClaims) (bool, error) IsUserAdminFunc func(claims jwtclaims.AuthorizationClaims) (bool, error)
AccountExistsFunc func(accountId string) (*bool, error) AccountExistsFunc func(accountId string) (*bool, error)
GetPeerFunc func(peerKey string) (*server.Peer, error) GetPeerFunc func(peerKey string) (*server.Peer, error)
GetPeersFunc func(accountID, userID string) ([]*server.Peer, error)
MarkPeerConnectedFunc func(peerKey string, connected bool) error MarkPeerConnectedFunc func(peerKey string, connected bool) error
RenamePeerFunc func(accountId string, peerKey string, newName string) (*server.Peer, error)
DeletePeerFunc func(accountId string, peerKey string) (*server.Peer, error) DeletePeerFunc func(accountId string, peerKey string) (*server.Peer, error)
GetPeerByIPFunc func(accountId string, peerIP string) (*server.Peer, error) GetPeerByIPFunc func(accountId string, peerIP string) (*server.Peer, error)
GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error) GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error)
@@ -35,42 +34,51 @@ type MockAccountManager struct {
GroupAddPeerFunc func(accountID, groupID, peerKey string) error GroupAddPeerFunc func(accountID, groupID, peerKey string) error
GroupDeletePeerFunc func(accountID, groupID, peerKey string) error GroupDeletePeerFunc func(accountID, groupID, peerKey string) error
GroupListPeersFunc func(accountID, groupID string) ([]*server.Peer, error) GroupListPeersFunc func(accountID, groupID string) ([]*server.Peer, error)
GetRuleFunc func(accountID, ruleID string) (*server.Rule, error) GetRuleFunc func(accountID, ruleID, userID string) (*server.Rule, error)
SaveRuleFunc func(accountID string, rule *server.Rule) error SaveRuleFunc func(accountID string, rule *server.Rule) error
UpdateRuleFunc func(accountID string, ruleID string, operations []server.RuleUpdateOperation) (*server.Rule, error) UpdateRuleFunc func(accountID string, ruleID string, operations []server.RuleUpdateOperation) (*server.Rule, error)
DeleteRuleFunc func(accountID, ruleID string) error DeleteRuleFunc func(accountID, ruleID string) error
ListRulesFunc func(accountID string) ([]*server.Rule, error) ListRulesFunc func(accountID, userID string) ([]*server.Rule, error)
GetUsersFromAccountFunc func(accountID string) ([]*server.UserInfo, error) GetUsersFromAccountFunc func(accountID, userID string) ([]*server.UserInfo, error)
UpdatePeerMetaFunc func(peerKey string, meta server.PeerSystemMeta) error UpdatePeerMetaFunc func(peerKey string, meta server.PeerSystemMeta) error
UpdatePeerSSHKeyFunc func(peerKey string, sshKey string) error UpdatePeerSSHKeyFunc func(peerKey string, sshKey string) error
UpdatePeerFunc func(accountID string, peer *server.Peer) (*server.Peer, error) UpdatePeerFunc func(accountID string, peer *server.Peer) (*server.Peer, error)
CreateRouteFunc func(accountID string, prefix, peer, description, netID string, masquerade bool, metric int, enabled bool) (*route.Route, error) CreateRouteFunc func(accountID string, prefix, peer, description, netID string, masquerade bool, metric int, enabled bool) (*route.Route, error)
GetRouteFunc func(accountID, routeID string) (*route.Route, error) GetRouteFunc func(accountID, routeID, userID string) (*route.Route, error)
SaveRouteFunc func(accountID string, route *route.Route) error SaveRouteFunc func(accountID string, route *route.Route) error
UpdateRouteFunc func(accountID string, routeID string, operations []server.RouteUpdateOperation) (*route.Route, error) UpdateRouteFunc func(accountID string, routeID string, operations []server.RouteUpdateOperation) (*route.Route, error)
DeleteRouteFunc func(accountID, routeID string) error DeleteRouteFunc func(accountID, routeID string) error
ListRoutesFunc func(accountID string) ([]*route.Route, error) ListRoutesFunc func(accountID, userID string) ([]*route.Route, error)
SaveSetupKeyFunc func(accountID string, key *server.SetupKey) (*server.SetupKey, error) SaveSetupKeyFunc func(accountID string, key *server.SetupKey) (*server.SetupKey, error)
ListSetupKeysFunc func(accountID string) ([]*server.SetupKey, error) ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error)
SaveUserFunc func(accountID string, user *server.User) (*server.UserInfo, error) SaveUserFunc func(accountID string, user *server.User) (*server.UserInfo, error)
GetNameServerGroupFunc func(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) GetNameServerGroupFunc func(accountID, nsGroupID string) (*nbdns.NameServerGroup, error)
CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, enabled bool) (*nbdns.NameServerGroup, error) CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool) (*nbdns.NameServerGroup, error)
SaveNameServerGroupFunc func(accountID string, nsGroupToSave *nbdns.NameServerGroup) error SaveNameServerGroupFunc func(accountID string, nsGroupToSave *nbdns.NameServerGroup) error
UpdateNameServerGroupFunc func(accountID, nsGroupID string, operations []server.NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) UpdateNameServerGroupFunc func(accountID, nsGroupID string, operations []server.NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error)
DeleteNameServerGroupFunc func(accountID, nsGroupID string) error DeleteNameServerGroupFunc func(accountID, nsGroupID string) error
ListNameServerGroupsFunc func(accountID string) ([]*nbdns.NameServerGroup, error) ListNameServerGroupsFunc func(accountID string) ([]*nbdns.NameServerGroup, error)
CreateUserFunc func(accountID string, key *server.UserInfo) (*server.UserInfo, error) CreateUserFunc func(accountID string, key *server.UserInfo) (*server.UserInfo, error)
GetAccountFromTokenFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) GetAccountFromTokenFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error)
GetDNSDomainFunc func() string
} }
// GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface // GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface
func (am *MockAccountManager) GetUsersFromAccount(accountID string) ([]*server.UserInfo, error) { func (am *MockAccountManager) GetUsersFromAccount(accountID string, userID string) ([]*server.UserInfo, error) {
if am.GetUsersFromAccountFunc != nil { if am.GetUsersFromAccountFunc != nil {
return am.GetUsersFromAccountFunc(accountID) return am.GetUsersFromAccountFunc(accountID, userID)
} }
return nil, status.Errorf(codes.Unimplemented, "method GetUsersFromAccount is not implemented") return nil, status.Errorf(codes.Unimplemented, "method GetUsersFromAccount is not implemented")
} }
// DeletePeer mock implementation of DeletePeer from server.AccountManager interface
func (am *MockAccountManager) DeletePeer(accountId string, peerKey string) (*server.Peer, error) {
if am.DeletePeerFunc != nil {
return am.DeletePeerFunc(accountId, peerKey)
}
return nil, status.Errorf(codes.Unimplemented, "method DeletePeer is not implemented")
}
// GetOrCreateAccountByUser mock implementation of GetOrCreateAccountByUser from server.AccountManager interface // GetOrCreateAccountByUser mock implementation of GetOrCreateAccountByUser from server.AccountManager interface
func (am *MockAccountManager) GetOrCreateAccountByUser( func (am *MockAccountManager) GetOrCreateAccountByUser(
userId, domain string, userId, domain string,
@@ -106,16 +114,8 @@ func (am *MockAccountManager) CreateSetupKey(
return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented") return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented")
} }
// GetAccountById mock implementation of GetAccountById from server.AccountManager interface // GetAccountByUserOrAccountID mock implementation of GetAccountByUserOrAccountID from server.AccountManager interface
func (am *MockAccountManager) GetAccountById(accountId string) (*server.Account, error) { func (am *MockAccountManager) GetAccountByUserOrAccountID(
if am.GetAccountByIdFunc != nil {
return am.GetAccountByIdFunc(accountId)
}
return nil, status.Errorf(codes.Unimplemented, "method GetAccountById is not implemented")
}
// GetAccountByUserOrAccountId mock implementation of GetAccountByUserOrAccountId from server.AccountManager interface
func (am *MockAccountManager) GetAccountByUserOrAccountId(
userId, accountId, domain string, userId, accountId, domain string,
) (*server.Account, error) { ) (*server.Account, error) {
if am.GetAccountByUserOrAccountIdFunc != nil { if am.GetAccountByUserOrAccountIdFunc != nil {
@@ -123,7 +123,7 @@ func (am *MockAccountManager) GetAccountByUserOrAccountId(
} }
return nil, status.Errorf( return nil, status.Errorf(
codes.Unimplemented, codes.Unimplemented,
"method GetAccountByUserOrAccountId is not implemented", "method GetAccountByUserOrAccountID is not implemented",
) )
} }
@@ -151,26 +151,6 @@ func (am *MockAccountManager) MarkPeerConnected(peerKey string, connected bool)
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented") return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
} }
// RenamePeer mock implementation of RenamePeer from server.AccountManager interface
func (am *MockAccountManager) RenamePeer(
accountId string,
peerKey string,
newName string,
) (*server.Peer, error) {
if am.RenamePeerFunc != nil {
return am.RenamePeerFunc(accountId, peerKey, newName)
}
return nil, status.Errorf(codes.Unimplemented, "method RenamePeer is not implemented")
}
// DeletePeer mock implementation of DeletePeer from server.AccountManager interface
func (am *MockAccountManager) DeletePeer(accountId string, peerKey string) (*server.Peer, error) {
if am.DeletePeerFunc != nil {
return am.DeletePeerFunc(accountId, peerKey)
}
return nil, status.Errorf(codes.Unimplemented, "method DeletePeer is not implemented")
}
// GetPeerByIP mock implementation of GetPeerByIP from server.AccountManager interface // GetPeerByIP mock implementation of GetPeerByIP from server.AccountManager interface
func (am *MockAccountManager) GetPeerByIP(accountId string, peerIP string) (*server.Peer, error) { func (am *MockAccountManager) GetPeerByIP(accountId string, peerIP string) (*server.Peer, error) {
if am.GetPeerByIPFunc != nil { if am.GetPeerByIPFunc != nil {
@@ -272,9 +252,9 @@ func (am *MockAccountManager) GroupListPeers(accountID, groupID string) ([]*serv
} }
// GetRule mock implementation of GetRule from server.AccountManager interface // GetRule mock implementation of GetRule from server.AccountManager interface
func (am *MockAccountManager) GetRule(accountID, ruleID string) (*server.Rule, error) { func (am *MockAccountManager) GetRule(accountID, ruleID, userID string) (*server.Rule, error) {
if am.GetRuleFunc != nil { if am.GetRuleFunc != nil {
return am.GetRuleFunc(accountID, ruleID) return am.GetRuleFunc(accountID, ruleID, userID)
} }
return nil, status.Errorf(codes.Unimplemented, "method GetRule is not implemented") return nil, status.Errorf(codes.Unimplemented, "method GetRule is not implemented")
} }
@@ -304,9 +284,9 @@ func (am *MockAccountManager) DeleteRule(accountID, ruleID string) error {
} }
// ListRules mock implementation of ListRules from server.AccountManager interface // ListRules mock implementation of ListRules from server.AccountManager interface
func (am *MockAccountManager) ListRules(accountID string) ([]*server.Rule, error) { func (am *MockAccountManager) ListRules(accountID, userID string) ([]*server.Rule, error) {
if am.ListRulesFunc != nil { if am.ListRulesFunc != nil {
return am.ListRulesFunc(accountID) return am.ListRulesFunc(accountID, userID)
} }
return nil, status.Errorf(codes.Unimplemented, "method ListRules is not implemented") return nil, status.Errorf(codes.Unimplemented, "method ListRules is not implemented")
} }
@@ -345,16 +325,16 @@ func (am *MockAccountManager) UpdatePeer(accountID string, peer *server.Peer) (*
// CreateRoute mock implementation of CreateRoute from server.AccountManager interface // CreateRoute mock implementation of CreateRoute from server.AccountManager interface
func (am *MockAccountManager) CreateRoute(accountID string, network, peer, description, netID string, masquerade bool, metric int, enabled bool) (*route.Route, error) { func (am *MockAccountManager) CreateRoute(accountID string, network, peer, description, netID string, masquerade bool, metric int, enabled bool) (*route.Route, error) {
if am.GetRouteFunc != nil { if am.CreateRouteFunc != nil {
return am.CreateRouteFunc(accountID, network, peer, description, netID, masquerade, metric, enabled) return am.CreateRouteFunc(accountID, network, peer, description, netID, masquerade, metric, enabled)
} }
return nil, status.Errorf(codes.Unimplemented, "method CreateRoute is not implemented") return nil, status.Errorf(codes.Unimplemented, "method CreateRoute is not implemented")
} }
// GetRoute mock implementation of GetRoute from server.AccountManager interface // GetRoute mock implementation of GetRoute from server.AccountManager interface
func (am *MockAccountManager) GetRoute(accountID, routeID string) (*route.Route, error) { func (am *MockAccountManager) GetRoute(accountID, routeID, userID string) (*route.Route, error) {
if am.GetRouteFunc != nil { if am.GetRouteFunc != nil {
return am.GetRouteFunc(accountID, routeID) return am.GetRouteFunc(accountID, routeID, userID)
} }
return nil, status.Errorf(codes.Unimplemented, "method GetRoute is not implemented") return nil, status.Errorf(codes.Unimplemented, "method GetRoute is not implemented")
} }
@@ -384,9 +364,9 @@ func (am *MockAccountManager) DeleteRoute(accountID, routeID string) error {
} }
// ListRoutes mock implementation of ListRoutes from server.AccountManager interface // ListRoutes mock implementation of ListRoutes from server.AccountManager interface
func (am *MockAccountManager) ListRoutes(accountID string) ([]*route.Route, error) { func (am *MockAccountManager) ListRoutes(accountID, userID string) ([]*route.Route, error) {
if am.ListRoutesFunc != nil { if am.ListRoutesFunc != nil {
return am.ListRoutesFunc(accountID) return am.ListRoutesFunc(accountID, userID)
} }
return nil, status.Errorf(codes.Unimplemented, "method ListRoutes is not implemented") return nil, status.Errorf(codes.Unimplemented, "method ListRoutes is not implemented")
} }
@@ -401,18 +381,18 @@ func (am *MockAccountManager) SaveSetupKey(accountID string, key *server.SetupKe
} }
// GetSetupKey mocks GetSetupKey of the AccountManager interface // GetSetupKey mocks GetSetupKey of the AccountManager interface
func (am *MockAccountManager) GetSetupKey(accountID, keyID string) (*server.SetupKey, error) { func (am *MockAccountManager) GetSetupKey(accountID, userID, keyID string) (*server.SetupKey, error) {
if am.GetSetupKeyFunc != nil { if am.GetSetupKeyFunc != nil {
return am.GetSetupKeyFunc(accountID, keyID) return am.GetSetupKeyFunc(accountID, userID, keyID)
} }
return nil, status.Errorf(codes.Unimplemented, "method GetSetupKey is not implemented") return nil, status.Errorf(codes.Unimplemented, "method GetSetupKey is not implemented")
} }
// ListSetupKeys mocks ListSetupKeys of the AccountManager interface // ListSetupKeys mocks ListSetupKeys of the AccountManager interface
func (am *MockAccountManager) ListSetupKeys(accountID string) ([]*server.SetupKey, error) { func (am *MockAccountManager) ListSetupKeys(accountID, userID string) ([]*server.SetupKey, error) {
if am.ListSetupKeysFunc != nil { if am.ListSetupKeysFunc != nil {
return am.ListSetupKeysFunc(accountID) return am.ListSetupKeysFunc(accountID, userID)
} }
return nil, status.Errorf(codes.Unimplemented, "method ListSetupKeys is not implemented") return nil, status.Errorf(codes.Unimplemented, "method ListSetupKeys is not implemented")
@@ -435,9 +415,9 @@ func (am *MockAccountManager) GetNameServerGroup(accountID, nsGroupID string) (*
} }
// CreateNameServerGroup mocks CreateNameServerGroup of the AccountManager interface // CreateNameServerGroup mocks CreateNameServerGroup of the AccountManager interface
func (am *MockAccountManager) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, enabled bool) (*nbdns.NameServerGroup, error) { func (am *MockAccountManager) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool) (*nbdns.NameServerGroup, error) {
if am.CreateNameServerGroupFunc != nil { if am.CreateNameServerGroupFunc != nil {
return am.CreateNameServerGroupFunc(accountID, name, description, nameServerList, groups, enabled) return am.CreateNameServerGroupFunc(accountID, name, description, nameServerList, groups, primary, domains, enabled)
} }
return nil, nil return nil, nil
} }
@@ -483,9 +463,26 @@ func (am *MockAccountManager) CreateUser(accountID string, invite *server.UserIn
} }
// GetAccountFromToken mocks GetAccountFromToken of the AccountManager interface // GetAccountFromToken mocks GetAccountFromToken of the AccountManager interface
func (am *MockAccountManager) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*server.Account, error) { func (am *MockAccountManager) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User,
error) {
if am.GetAccountFromTokenFunc != nil { if am.GetAccountFromTokenFunc != nil {
return am.GetAccountFromTokenFunc(claims) return am.GetAccountFromTokenFunc(claims)
} }
return nil, status.Errorf(codes.Unimplemented, "method GetAccountFromToken is not implemented") return nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromToken is not implemented")
}
// GetPeers mocks GetPeers of the AccountManager interface
func (am *MockAccountManager) GetPeers(accountID, userID string) ([]*server.Peer, error) {
if am.GetAccountFromTokenFunc != nil {
return am.GetPeersFunc(accountID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetPeersFunc is not implemented")
}
// GetDNSDomain mocks GetDNSDomain of the AccountManager interface
func (am *MockAccountManager) GetDNSDomain() string {
if am.GetDNSDomainFunc != nil {
return am.GetDNSDomainFunc()
}
return ""
} }

View File

@@ -1,10 +1,11 @@
package server package server
import ( import (
"github.com/miekg/dns"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/status"
"github.com/rs/xid" "github.com/rs/xid"
"google.golang.org/grpc/codes" log "github.com/sirupsen/logrus"
"google.golang.org/grpc/status"
"strconv" "strconv"
"unicode/utf8" "unicode/utf8"
) )
@@ -20,6 +21,10 @@ const (
UpdateNameServerGroupGroups UpdateNameServerGroupGroups
// UpdateNameServerGroupEnabled indicates a nameserver group status update operation // UpdateNameServerGroupEnabled indicates a nameserver group status update operation
UpdateNameServerGroupEnabled UpdateNameServerGroupEnabled
// UpdateNameServerGroupPrimary indicates a nameserver group primary status update operation
UpdateNameServerGroupPrimary
// UpdateNameServerGroupDomains indicates a nameserver group' domains update operation
UpdateNameServerGroupDomains
) )
// NameServerGroupUpdateOperationType operation type // NameServerGroupUpdateOperationType operation type
@@ -37,6 +42,10 @@ func (t NameServerGroupUpdateOperationType) String() string {
return "UpdateNameServerGroupGroups" return "UpdateNameServerGroupGroups"
case UpdateNameServerGroupEnabled: case UpdateNameServerGroupEnabled:
return "UpdateNameServerGroupEnabled" return "UpdateNameServerGroupEnabled"
case UpdateNameServerGroupPrimary:
return "UpdateNameServerGroupPrimary"
case UpdateNameServerGroupDomains:
return "UpdateNameServerGroupDomains"
default: default:
return "InvalidOperation" return "InvalidOperation"
} }
@@ -50,12 +59,13 @@ type NameServerGroupUpdateOperation struct {
// GetNameServerGroup gets a nameserver group object from account and nameserver group IDs // GetNameServerGroup gets a nameserver group object from account and nameserver group IDs
func (am *DefaultAccountManager) GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) { func (am *DefaultAccountManager) GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) {
am.mux.Lock()
defer am.mux.Unlock() unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found") return nil, err
} }
nsGroup, found := account.NameServerGroups[nsGroupID] nsGroup, found := account.NameServerGroups[nsGroupID]
@@ -63,17 +73,18 @@ func (am *DefaultAccountManager) GetNameServerGroup(accountID, nsGroupID string)
return nsGroup.Copy(), nil return nsGroup.Copy(), nil
} }
return nil, status.Errorf(codes.NotFound, "nameserver group with ID %s not found", nsGroupID) return nil, status.Errorf(status.NotFound, "nameserver group with ID %s not found", nsGroupID)
} }
// CreateNameServerGroup creates and saves a new nameserver group // CreateNameServerGroup creates and saves a new nameserver group
func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, enabled bool) (*nbdns.NameServerGroup, error) { func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool) (*nbdns.NameServerGroup, error) {
am.mux.Lock()
defer am.mux.Unlock() unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found") return nil, err
} }
newNSGroup := &nbdns.NameServerGroup{ newNSGroup := &nbdns.NameServerGroup{
@@ -83,6 +94,8 @@ func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, d
NameServers: nameServerList, NameServers: nameServerList,
Groups: groups, Groups: groups,
Enabled: enabled, Enabled: enabled,
Primary: primary,
Domains: domains,
} }
err = validateNameServerGroup(false, newNSGroup, account) err = validateNameServerGroup(false, newNSGroup, account)
@@ -102,21 +115,28 @@ func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, d
return nil, err return nil, err
} }
err = am.updateAccountPeers(account)
if err != nil {
log.Error(err)
return newNSGroup.Copy(), status.Errorf(status.Internal, "failed to update peers after create nameserver %s", name)
}
return newNSGroup.Copy(), nil return newNSGroup.Copy(), nil
} }
// SaveNameServerGroup saves nameserver group // SaveNameServerGroup saves nameserver group
func (am *DefaultAccountManager) SaveNameServerGroup(accountID string, nsGroupToSave *nbdns.NameServerGroup) error { func (am *DefaultAccountManager) SaveNameServerGroup(accountID string, nsGroupToSave *nbdns.NameServerGroup) error {
am.mux.Lock()
defer am.mux.Unlock() unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
if nsGroupToSave == nil { if nsGroupToSave == nil {
return status.Errorf(codes.InvalidArgument, "nameserver group provided is nil") return status.Errorf(status.InvalidArgument, "nameserver group provided is nil")
} }
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return status.Errorf(codes.NotFound, "account not found") return err
} }
err = validateNameServerGroup(true, nsGroupToSave, account) err = validateNameServerGroup(true, nsGroupToSave, account)
@@ -132,26 +152,33 @@ func (am *DefaultAccountManager) SaveNameServerGroup(accountID string, nsGroupTo
return err return err
} }
err = am.updateAccountPeers(account)
if err != nil {
log.Error(err)
return status.Errorf(status.Internal, "failed to update peers after update nameserver %s", nsGroupToSave.Name)
}
return nil return nil
} }
// UpdateNameServerGroup updates existing nameserver group with set of operations // UpdateNameServerGroup updates existing nameserver group with set of operations
func (am *DefaultAccountManager) UpdateNameServerGroup(accountID, nsGroupID string, operations []NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) { func (am *DefaultAccountManager) UpdateNameServerGroup(accountID, nsGroupID string, operations []NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) {
am.mux.Lock()
defer am.mux.Unlock() unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found") return nil, err
} }
if len(operations) == 0 { if len(operations) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "operations shouldn't be empty") return nil, status.Errorf(status.InvalidArgument, "operations shouldn't be empty")
} }
nsGroupToUpdate, ok := account.NameServerGroups[nsGroupID] nsGroupToUpdate, ok := account.NameServerGroups[nsGroupID]
if !ok { if !ok {
return nil, status.Errorf(codes.NotFound, "nameserver group ID %s no longer exists", nsGroupID) return nil, status.Errorf(status.NotFound, "nameserver group ID %s no longer exists", nsGroupID)
} }
newNSGroup := nsGroupToUpdate.Copy() newNSGroup := nsGroupToUpdate.Copy()
@@ -159,12 +186,12 @@ func (am *DefaultAccountManager) UpdateNameServerGroup(accountID, nsGroupID stri
for _, operation := range operations { for _, operation := range operations {
valuesCount := len(operation.Values) valuesCount := len(operation.Values)
if valuesCount < 1 { if valuesCount < 1 {
return nil, status.Errorf(codes.InvalidArgument, "operation %s contains invalid number of values, it should be at least 1", operation.Type.String()) return nil, status.Errorf(status.InvalidArgument, "operation %s contains invalid number of values, it should be at least 1", operation.Type.String())
} }
for _, value := range operation.Values { for _, value := range operation.Values {
if value == "" { if value == "" {
return nil, status.Errorf(codes.InvalidArgument, "operation %s contains invalid empty string value", operation.Type.String()) return nil, status.Errorf(status.InvalidArgument, "operation %s contains invalid empty string value", operation.Type.String())
} }
} }
switch operation.Type { switch operation.Type {
@@ -172,7 +199,7 @@ func (am *DefaultAccountManager) UpdateNameServerGroup(accountID, nsGroupID stri
newNSGroup.Description = operation.Values[0] newNSGroup.Description = operation.Values[0]
case UpdateNameServerGroupName: case UpdateNameServerGroupName:
if valuesCount > 1 { if valuesCount > 1 {
return nil, status.Errorf(codes.InvalidArgument, "failed to parse name values, expected 1 value got %d", valuesCount) return nil, status.Errorf(status.InvalidArgument, "failed to parse name values, expected 1 value got %d", valuesCount)
} }
err = validateNSGroupName(operation.Values[0], nsGroupID, account.NameServerGroups) err = validateNSGroupName(operation.Values[0], nsGroupID, account.NameServerGroups)
if err != nil { if err != nil {
@@ -202,9 +229,21 @@ func (am *DefaultAccountManager) UpdateNameServerGroup(accountID, nsGroupID stri
case UpdateNameServerGroupEnabled: case UpdateNameServerGroupEnabled:
enabled, err := strconv.ParseBool(operation.Values[0]) enabled, err := strconv.ParseBool(operation.Values[0])
if err != nil { if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "failed to parse enabled %s, not boolean", operation.Values[0]) return nil, status.Errorf(status.InvalidArgument, "failed to parse enabled %s, not boolean", operation.Values[0])
} }
newNSGroup.Enabled = enabled newNSGroup.Enabled = enabled
case UpdateNameServerGroupPrimary:
primary, err := strconv.ParseBool(operation.Values[0])
if err != nil {
return nil, status.Errorf(status.InvalidArgument, "failed to parse primary status %s, not boolean", operation.Values[0])
}
newNSGroup.Primary = primary
case UpdateNameServerGroupDomains:
err = validateDomainInput(false, operation.Values)
if err != nil {
return nil, err
}
newNSGroup.Domains = operation.Values
} }
} }
@@ -216,17 +255,24 @@ func (am *DefaultAccountManager) UpdateNameServerGroup(accountID, nsGroupID stri
return nil, err return nil, err
} }
err = am.updateAccountPeers(account)
if err != nil {
log.Error(err)
return newNSGroup.Copy(), status.Errorf(status.Internal, "failed to update peers after update nameserver %s", newNSGroup.Name)
}
return newNSGroup.Copy(), nil return newNSGroup.Copy(), nil
} }
// DeleteNameServerGroup deletes nameserver group with nsGroupID // DeleteNameServerGroup deletes nameserver group with nsGroupID
func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID string) error { func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID string) error {
am.mux.Lock()
defer am.mux.Unlock() unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return status.Errorf(codes.NotFound, "account not found") return err
} }
delete(account.NameServerGroups, nsGroupID) delete(account.NameServerGroups, nsGroupID)
@@ -237,17 +283,24 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID stri
return err return err
} }
err = am.updateAccountPeers(account)
if err != nil {
log.Error(err)
return status.Errorf(status.Internal, "failed to update peers after deleting nameserver %s", nsGroupID)
}
return nil return nil
} }
// ListNameServerGroups returns a list of nameserver groups from account // ListNameServerGroups returns a list of nameserver groups from account
func (am *DefaultAccountManager) ListNameServerGroups(accountID string) ([]*nbdns.NameServerGroup, error) { func (am *DefaultAccountManager) ListNameServerGroups(accountID string) ([]*nbdns.NameServerGroup, error) {
am.mux.Lock()
defer am.mux.Unlock() unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found") return nil, err
} }
nsGroups := make([]*nbdns.NameServerGroup, 0, len(account.NameServerGroups)) nsGroups := make([]*nbdns.NameServerGroup, 0, len(account.NameServerGroups))
@@ -264,11 +317,16 @@ func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServ
nsGroupID = nameserverGroup.ID nsGroupID = nameserverGroup.ID
_, found := account.NameServerGroups[nsGroupID] _, found := account.NameServerGroups[nsGroupID]
if !found { if !found {
return status.Errorf(codes.NotFound, "nameserver group with ID %s was not found", nsGroupID) return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupID)
} }
} }
err := validateNSGroupName(nameserverGroup.Name, nsGroupID, account.NameServerGroups) err := validateDomainInput(nameserverGroup.Primary, nameserverGroup.Domains)
if err != nil {
return err
}
err = validateNSGroupName(nameserverGroup.Name, nsGroupID, account.NameServerGroups)
if err != nil { if err != nil {
return err return err
} }
@@ -286,14 +344,32 @@ func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServ
return nil return nil
} }
func validateDomainInput(primary bool, domains []string) error {
if !primary && len(domains) == 0 {
return status.Errorf(status.InvalidArgument, "nameserver group primary status is false and domains are empty,"+
" it should be primary or have at least one domain")
}
if primary && len(domains) != 0 {
return status.Errorf(status.InvalidArgument, "nameserver group primary status is true and domains are not empty,"+
" you should set either primary or domain")
}
for _, domain := range domains {
_, valid := dns.IsDomainName(domain)
if !valid {
return status.Errorf(status.InvalidArgument, "nameserver group got an invalid domain: %s", domain)
}
}
return nil
}
func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.NameServerGroup) error { func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.NameServerGroup) error {
if utf8.RuneCountInString(name) > nbdns.MaxGroupNameChar || name == "" { if utf8.RuneCountInString(name) > nbdns.MaxGroupNameChar || name == "" {
return status.Errorf(codes.InvalidArgument, "nameserver group name should be between 1 and %d", nbdns.MaxGroupNameChar) return status.Errorf(status.InvalidArgument, "nameserver group name should be between 1 and %d", nbdns.MaxGroupNameChar)
} }
for _, nsGroup := range nsGroupMap { for _, nsGroup := range nsGroupMap {
if name == nsGroup.Name && nsGroup.ID != nsGroupID { if name == nsGroup.Name && nsGroup.ID != nsGroupID {
return status.Errorf(codes.InvalidArgument, "a nameserver group with name %s already exist", name) return status.Errorf(status.InvalidArgument, "a nameserver group with name %s already exist", name)
} }
} }
@@ -303,19 +379,19 @@ func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.Na
func validateNSList(list []nbdns.NameServer) error { func validateNSList(list []nbdns.NameServer) error {
nsListLenght := len(list) nsListLenght := len(list)
if nsListLenght == 0 || nsListLenght > 2 { if nsListLenght == 0 || nsListLenght > 2 {
return status.Errorf(codes.InvalidArgument, "the list of nameservers should be 1 or 2, got %d", len(list)) return status.Errorf(status.InvalidArgument, "the list of nameservers should be 1 or 2, got %d", len(list))
} }
return nil return nil
} }
func validateGroups(list []string, groups map[string]*Group) error { func validateGroups(list []string, groups map[string]*Group) error {
if len(list) == 0 { if len(list) == 0 {
return status.Errorf(codes.InvalidArgument, "the list of group IDs should not be empty") return status.Errorf(status.InvalidArgument, "the list of group IDs should not be empty")
} }
for _, id := range list { for _, id := range list {
if id == "" { if id == "" {
return status.Errorf(codes.InvalidArgument, "group ID should not be empty string") return status.Errorf(status.InvalidArgument, "group ID should not be empty string")
} }
found := false found := false
for groupID := range groups { for groupID := range groups {
@@ -325,7 +401,7 @@ func validateGroups(list []string, groups map[string]*Group) error {
} }
} }
if !found { if !found {
return status.Errorf(codes.InvalidArgument, "group id %s not found", id) return status.Errorf(status.InvalidArgument, "group id %s not found", id)
} }
} }

View File

@@ -14,6 +14,8 @@ const (
existingNSGroupID = "existingNSGroup" existingNSGroupID = "existingNSGroup"
nsGroupPeer1Key = "BhRPtynAAYRDy08+q4HTMsos8fs4plTP4NOSh7C1ry8=" nsGroupPeer1Key = "BhRPtynAAYRDy08+q4HTMsos8fs4plTP4NOSh7C1ry8="
nsGroupPeer2Key = "/yF0+vCfv+mRR5k0dca0TrGdO/oiNeAI58gToZm5NyI=" nsGroupPeer2Key = "/yF0+vCfv+mRR5k0dca0TrGdO/oiNeAI58gToZm5NyI="
validDomain = "example.com"
invalidDomain = "dnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdns.com"
) )
func TestCreateNameServerGroup(t *testing.T) { func TestCreateNameServerGroup(t *testing.T) {
@@ -23,6 +25,8 @@ func TestCreateNameServerGroup(t *testing.T) {
enabled bool enabled bool
groups []string groups []string
nameServers []nbdns.NameServer nameServers []nbdns.NameServer
primary bool
domains []string
} }
testCases := []struct { testCases := []struct {
@@ -33,11 +37,12 @@ func TestCreateNameServerGroup(t *testing.T) {
expectedNSGroup *nbdns.NameServerGroup expectedNSGroup *nbdns.NameServerGroup
}{ }{
{ {
name: "Create A NS Group", name: "Create A NS Group With Primary Status",
inputArgs: input{ inputArgs: input{
name: "super", name: "super",
description: "super", description: "super",
groups: []string{group1ID}, groups: []string{group1ID},
primary: true,
nameServers: []nbdns.NameServer{ nameServers: []nbdns.NameServer{
{ {
IP: netip.MustParseAddr("1.1.1.1"), IP: netip.MustParseAddr("1.1.1.1"),
@@ -57,6 +62,52 @@ func TestCreateNameServerGroup(t *testing.T) {
expectedNSGroup: &nbdns.NameServerGroup{ expectedNSGroup: &nbdns.NameServerGroup{
Name: "super", Name: "super",
Description: "super", Description: "super",
Primary: true,
Groups: []string{group1ID},
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("1.1.1.1"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
{
IP: netip.MustParseAddr("1.1.2.2"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
},
Enabled: true,
},
},
{
name: "Create A NS Group With Domains",
inputArgs: input{
name: "super",
description: "super",
groups: []string{group1ID},
primary: false,
domains: []string{validDomain},
nameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("1.1.1.1"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
{
IP: netip.MustParseAddr("1.1.2.2"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
},
enabled: true,
},
errFunc: require.NoError,
shouldCreate: true,
expectedNSGroup: &nbdns.NameServerGroup{
Name: "super",
Description: "super",
Primary: false,
Domains: []string{"example.com"},
Groups: []string{group1ID}, Groups: []string{group1ID},
NameServers: []nbdns.NameServer{ NameServers: []nbdns.NameServer{
{ {
@@ -78,6 +129,7 @@ func TestCreateNameServerGroup(t *testing.T) {
inputArgs: input{ inputArgs: input{
name: existingNSGroupName, name: existingNSGroupName,
description: "super", description: "super",
primary: true,
groups: []string{group1ID}, groups: []string{group1ID},
nameServers: []nbdns.NameServer{ nameServers: []nbdns.NameServer{
{ {
@@ -101,6 +153,7 @@ func TestCreateNameServerGroup(t *testing.T) {
inputArgs: input{ inputArgs: input{
name: "", name: "",
description: "super", description: "super",
primary: true,
groups: []string{group1ID}, groups: []string{group1ID},
nameServers: []nbdns.NameServer{ nameServers: []nbdns.NameServer{
{ {
@@ -124,6 +177,7 @@ func TestCreateNameServerGroup(t *testing.T) {
inputArgs: input{ inputArgs: input{
name: "1234567890123456789012345678901234567890extra", name: "1234567890123456789012345678901234567890extra",
description: "super", description: "super",
primary: true,
groups: []string{group1ID}, groups: []string{group1ID},
nameServers: []nbdns.NameServer{ nameServers: []nbdns.NameServer{
{ {
@@ -147,6 +201,7 @@ func TestCreateNameServerGroup(t *testing.T) {
inputArgs: input{ inputArgs: input{
name: "super", name: "super",
description: "super", description: "super",
primary: true,
groups: []string{group1ID}, groups: []string{group1ID},
nameServers: []nbdns.NameServer{}, nameServers: []nbdns.NameServer{},
enabled: true, enabled: true,
@@ -159,6 +214,7 @@ func TestCreateNameServerGroup(t *testing.T) {
inputArgs: input{ inputArgs: input{
name: "super", name: "super",
description: "super", description: "super",
primary: true,
groups: []string{group1ID}, groups: []string{group1ID},
nameServers: []nbdns.NameServer{ nameServers: []nbdns.NameServer{
{ {
@@ -187,6 +243,7 @@ func TestCreateNameServerGroup(t *testing.T) {
inputArgs: input{ inputArgs: input{
name: "super", name: "super",
description: "super", description: "super",
primary: true,
groups: []string{}, groups: []string{},
nameServers: []nbdns.NameServer{ nameServers: []nbdns.NameServer{
{ {
@@ -210,6 +267,7 @@ func TestCreateNameServerGroup(t *testing.T) {
inputArgs: input{ inputArgs: input{
name: "super", name: "super",
description: "super", description: "super",
primary: true,
groups: []string{"missingGroup"}, groups: []string{"missingGroup"},
nameServers: []nbdns.NameServer{ nameServers: []nbdns.NameServer{
{ {
@@ -233,6 +291,7 @@ func TestCreateNameServerGroup(t *testing.T) {
inputArgs: input{ inputArgs: input{
name: "super", name: "super",
description: "super", description: "super",
primary: true,
groups: []string{""}, groups: []string{""},
nameServers: []nbdns.NameServer{ nameServers: []nbdns.NameServer{
{ {
@@ -251,6 +310,53 @@ func TestCreateNameServerGroup(t *testing.T) {
errFunc: require.Error, errFunc: require.Error,
shouldCreate: false, shouldCreate: false,
}, },
{
name: "Should Not Create If No Domain Or Primary",
inputArgs: input{
name: "super",
description: "super",
groups: []string{group1ID},
nameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("1.1.1.1"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
{
IP: netip.MustParseAddr("1.1.2.2"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
},
enabled: true,
},
errFunc: require.Error,
shouldCreate: false,
},
{
name: "Should Not Create If Domain List Is Invalid",
inputArgs: input{
name: "super",
description: "super",
groups: []string{group1ID},
domains: []string{invalidDomain},
nameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("1.1.1.1"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
{
IP: netip.MustParseAddr("1.1.2.2"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
},
enabled: true,
},
errFunc: require.Error,
shouldCreate: false,
},
} }
for _, testCase := range testCases { for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
@@ -270,6 +376,8 @@ func TestCreateNameServerGroup(t *testing.T) {
testCase.inputArgs.description, testCase.inputArgs.description,
testCase.inputArgs.nameServers, testCase.inputArgs.nameServers,
testCase.inputArgs.groups, testCase.inputArgs.groups,
testCase.inputArgs.primary,
testCase.inputArgs.domains,
testCase.inputArgs.enabled, testCase.inputArgs.enabled,
) )
@@ -295,6 +403,7 @@ func TestSaveNameServerGroup(t *testing.T) {
ID: "testingNSGroup", ID: "testingNSGroup",
Name: "super", Name: "super",
Description: "super", Description: "super",
Primary: true,
NameServers: []nbdns.NameServer{ NameServers: []nbdns.NameServer{
{ {
IP: netip.MustParseAddr("1.1.1.1"), IP: netip.MustParseAddr("1.1.1.1"),
@@ -313,6 +422,10 @@ func TestSaveNameServerGroup(t *testing.T) {
validGroups := []string{group2ID} validGroups := []string{group2ID}
invalidGroups := []string{"nonExisting"} invalidGroups := []string{"nonExisting"}
disabledPrimary := false
validDomains := []string{validDomain}
invalidDomains := []string{invalidDomain}
validNameServerList := []nbdns.NameServer{ validNameServerList := []nbdns.NameServer{
{ {
IP: netip.MustParseAddr("1.1.1.1"), IP: netip.MustParseAddr("1.1.1.1"),
@@ -348,6 +461,8 @@ func TestSaveNameServerGroup(t *testing.T) {
existingNSGroup *nbdns.NameServerGroup existingNSGroup *nbdns.NameServerGroup
newID *string newID *string
newName *string newName *string
newPrimary *bool
newDomains []string
newNSList []nbdns.NameServer newNSList []nbdns.NameServer
newGroups []string newGroups []string
skipCopying bool skipCopying bool
@@ -356,16 +471,20 @@ func TestSaveNameServerGroup(t *testing.T) {
expectedNSGroup *nbdns.NameServerGroup expectedNSGroup *nbdns.NameServerGroup
}{ }{
{ {
name: "Should Update Name Server Group", name: "Should Config Name Server Group",
existingNSGroup: existingNSGroup, existingNSGroup: existingNSGroup,
newName: &validName, newName: &validName,
newGroups: validGroups, newGroups: validGroups,
newPrimary: &disabledPrimary,
newDomains: validDomains,
newNSList: validNameServerList, newNSList: validNameServerList,
errFunc: require.NoError, errFunc: require.NoError,
shouldCreate: true, shouldCreate: true,
expectedNSGroup: &nbdns.NameServerGroup{ expectedNSGroup: &nbdns.NameServerGroup{
ID: "testingNSGroup", ID: "testingNSGroup",
Name: validName, Name: validName,
Primary: false,
Domains: validDomains,
Description: "super", Description: "super",
NameServers: validNameServerList, NameServers: validNameServerList,
Groups: validGroups, Groups: validGroups,
@@ -373,68 +492,91 @@ func TestSaveNameServerGroup(t *testing.T) {
}, },
}, },
{ {
name: "Should Not Update If Name Is Small", name: "Should Not Config If Name Is Small",
existingNSGroup: existingNSGroup, existingNSGroup: existingNSGroup,
newName: &invalidNameSmall, newName: &invalidNameSmall,
errFunc: require.Error, errFunc: require.Error,
shouldCreate: false, shouldCreate: false,
}, },
{ {
name: "Should Not Update If Name Is Large", name: "Should Not Config If Name Is Large",
existingNSGroup: existingNSGroup, existingNSGroup: existingNSGroup,
newName: &invalidNameLarge, newName: &invalidNameLarge,
errFunc: require.Error, errFunc: require.Error,
shouldCreate: false, shouldCreate: false,
}, },
{ {
name: "Should Not Update If Name Exists", name: "Should Not Config If Name Exists",
existingNSGroup: existingNSGroup, existingNSGroup: existingNSGroup,
newName: &invalidNameExisting, newName: &invalidNameExisting,
errFunc: require.Error, errFunc: require.Error,
shouldCreate: false, shouldCreate: false,
}, },
{ {
name: "Should Not Update If ID Don't Exist", name: "Should Not Config If ID Don't Exist",
existingNSGroup: existingNSGroup, existingNSGroup: existingNSGroup,
newID: &invalidID, newID: &invalidID,
errFunc: require.Error, errFunc: require.Error,
shouldCreate: false, shouldCreate: false,
}, },
{ {
name: "Should Not Update If Nameserver List Is Small", name: "Should Not Config If Nameserver List Is Small",
existingNSGroup: existingNSGroup, existingNSGroup: existingNSGroup,
newNSList: []nbdns.NameServer{}, newNSList: []nbdns.NameServer{},
errFunc: require.Error, errFunc: require.Error,
shouldCreate: false, shouldCreate: false,
}, },
{ {
name: "Should Not Update If Nameserver List Is Large", name: "Should Not Config If Nameserver List Is Large",
existingNSGroup: existingNSGroup, existingNSGroup: existingNSGroup,
newNSList: invalidNameServerListLarge, newNSList: invalidNameServerListLarge,
errFunc: require.Error, errFunc: require.Error,
shouldCreate: false, shouldCreate: false,
}, },
{ {
name: "Should Not Update If Groups List Is Empty", name: "Should Not Config If Groups List Is Empty",
existingNSGroup: existingNSGroup, existingNSGroup: existingNSGroup,
newGroups: []string{}, newGroups: []string{},
errFunc: require.Error, errFunc: require.Error,
shouldCreate: false, shouldCreate: false,
}, },
{ {
name: "Should Not Update If Groups List Has Empty ID", name: "Should Not Config If Groups List Has Empty ID",
existingNSGroup: existingNSGroup, existingNSGroup: existingNSGroup,
newGroups: []string{""}, newGroups: []string{""},
errFunc: require.Error, errFunc: require.Error,
shouldCreate: false, shouldCreate: false,
}, },
{ {
name: "Should Not Update If Groups List Has Non Existing Group ID", name: "Should Not Config If Groups List Has Non Existing Group ID",
existingNSGroup: existingNSGroup, existingNSGroup: existingNSGroup,
newGroups: invalidGroups, newGroups: invalidGroups,
errFunc: require.Error, errFunc: require.Error,
shouldCreate: false, shouldCreate: false,
}, },
{
name: "Should Not Config If Domains List Is Empty",
existingNSGroup: existingNSGroup,
newPrimary: &disabledPrimary,
errFunc: require.Error,
shouldCreate: false,
},
{
name: "Should Not Config If Primary And Domains",
existingNSGroup: existingNSGroup,
newPrimary: &existingNSGroup.Primary,
newDomains: validDomains,
errFunc: require.Error,
shouldCreate: false,
},
{
name: "Should Not Config If Domains List Is Invalid",
existingNSGroup: existingNSGroup,
newPrimary: &disabledPrimary,
newDomains: invalidDomains,
errFunc: require.Error,
shouldCreate: false,
},
} }
for _, testCase := range testCases { for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
@@ -475,6 +617,14 @@ func TestSaveNameServerGroup(t *testing.T) {
if testCase.newNSList != nil { if testCase.newNSList != nil {
nsGroupToSave.NameServers = testCase.newNSList nsGroupToSave.NameServers = testCase.newNSList
} }
if testCase.newPrimary != nil {
nsGroupToSave.Primary = *testCase.newPrimary
}
if testCase.newDomains != nil {
nsGroupToSave.Domains = testCase.newDomains
}
} }
err = am.SaveNameServerGroup(account.Id, nsGroupToSave) err = am.SaveNameServerGroup(account.Id, nsGroupToSave)
@@ -485,6 +635,11 @@ func TestSaveNameServerGroup(t *testing.T) {
return return
} }
account, err = am.Store.GetAccount(account.Id)
if err != nil {
t.Fatal(err)
}
savedNSGroup, saved := account.NameServerGroups[testCase.expectedNSGroup.ID] savedNSGroup, saved := account.NameServerGroups[testCase.expectedNSGroup.ID]
require.True(t, saved) require.True(t, saved)
@@ -503,6 +658,7 @@ func TestUpdateNameServerGroup(t *testing.T) {
ID: nsGroupID, ID: nsGroupID,
Name: "super", Name: "super",
Description: "super", Description: "super",
Primary: true,
NameServers: []nbdns.NameServer{ NameServers: []nbdns.NameServer{
{ {
IP: netip.MustParseAddr("1.1.1.1"), IP: netip.MustParseAddr("1.1.1.1"),
@@ -529,7 +685,7 @@ func TestUpdateNameServerGroup(t *testing.T) {
expectedNSGroup *nbdns.NameServerGroup expectedNSGroup *nbdns.NameServerGroup
}{ }{
{ {
name: "Should Update Single Property", name: "Should Config Single Property",
existingNSGroup: existingNSGroup, existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID, nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{ operations: []NameServerGroupUpdateOperation{
@@ -544,6 +700,7 @@ func TestUpdateNameServerGroup(t *testing.T) {
ID: nsGroupID, ID: nsGroupID,
Name: "superNew", Name: "superNew",
Description: "super", Description: "super",
Primary: true,
NameServers: []nbdns.NameServer{ NameServers: []nbdns.NameServer{
{ {
IP: netip.MustParseAddr("1.1.1.1"), IP: netip.MustParseAddr("1.1.1.1"),
@@ -561,7 +718,7 @@ func TestUpdateNameServerGroup(t *testing.T) {
}, },
}, },
{ {
name: "Should Update Multiple Properties", name: "Should Config Multiple Properties",
existingNSGroup: existingNSGroup, existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID, nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{ operations: []NameServerGroupUpdateOperation{
@@ -585,6 +742,14 @@ func TestUpdateNameServerGroup(t *testing.T) {
Type: UpdateNameServerGroupEnabled, Type: UpdateNameServerGroupEnabled,
Values: []string{"false"}, Values: []string{"false"},
}, },
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupPrimary,
Values: []string{"false"},
},
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupDomains,
Values: []string{validDomain},
},
}, },
errFunc: require.NoError, errFunc: require.NoError,
shouldCreate: true, shouldCreate: true,
@@ -592,6 +757,8 @@ func TestUpdateNameServerGroup(t *testing.T) {
ID: nsGroupID, ID: nsGroupID,
Name: "superNew", Name: "superNew",
Description: "superDescription", Description: "superDescription",
Primary: false,
Domains: []string{validDomain},
NameServers: []nbdns.NameServer{ NameServers: []nbdns.NameServer{
{ {
IP: netip.MustParseAddr("127.0.0.1"), IP: netip.MustParseAddr("127.0.0.1"),
@@ -609,20 +776,20 @@ func TestUpdateNameServerGroup(t *testing.T) {
}, },
}, },
{ {
name: "Should Not Update On Invalid ID", name: "Should Not Config On Invalid ID",
existingNSGroup: existingNSGroup, existingNSGroup: existingNSGroup,
nsGroupID: "nonExistingNSGroup", nsGroupID: "nonExistingNSGroup",
errFunc: require.Error, errFunc: require.Error,
}, },
{ {
name: "Should Not Update On Empty Operations", name: "Should Not Config On Empty Operations",
existingNSGroup: existingNSGroup, existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID, nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{}, operations: []NameServerGroupUpdateOperation{},
errFunc: require.Error, errFunc: require.Error,
}, },
{ {
name: "Should Not Update On Empty Values", name: "Should Not Config On Empty Values",
existingNSGroup: existingNSGroup, existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID, nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{ operations: []NameServerGroupUpdateOperation{
@@ -633,7 +800,7 @@ func TestUpdateNameServerGroup(t *testing.T) {
errFunc: require.Error, errFunc: require.Error,
}, },
{ {
name: "Should Not Update On Empty String", name: "Should Not Config On Empty String",
existingNSGroup: existingNSGroup, existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID, nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{ operations: []NameServerGroupUpdateOperation{
@@ -645,7 +812,7 @@ func TestUpdateNameServerGroup(t *testing.T) {
errFunc: require.Error, errFunc: require.Error,
}, },
{ {
name: "Should Not Update On Invalid Name Large String", name: "Should Not Config On Invalid Name Large String",
existingNSGroup: existingNSGroup, existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID, nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{ operations: []NameServerGroupUpdateOperation{
@@ -657,7 +824,7 @@ func TestUpdateNameServerGroup(t *testing.T) {
errFunc: require.Error, errFunc: require.Error,
}, },
{ {
name: "Should Not Update On Invalid On Existing Name", name: "Should Not Config On Invalid On Existing Name",
existingNSGroup: existingNSGroup, existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID, nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{ operations: []NameServerGroupUpdateOperation{
@@ -669,7 +836,7 @@ func TestUpdateNameServerGroup(t *testing.T) {
errFunc: require.Error, errFunc: require.Error,
}, },
{ {
name: "Should Not Update On Invalid On Multiple Name Values", name: "Should Not Config On Invalid On Multiple Name Values",
existingNSGroup: existingNSGroup, existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID, nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{ operations: []NameServerGroupUpdateOperation{
@@ -681,7 +848,7 @@ func TestUpdateNameServerGroup(t *testing.T) {
errFunc: require.Error, errFunc: require.Error,
}, },
{ {
name: "Should Not Update On Invalid Boolean", name: "Should Not Config On Invalid Boolean",
existingNSGroup: existingNSGroup, existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID, nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{ operations: []NameServerGroupUpdateOperation{
@@ -693,7 +860,7 @@ func TestUpdateNameServerGroup(t *testing.T) {
errFunc: require.Error, errFunc: require.Error,
}, },
{ {
name: "Should Not Update On Invalid Nameservers Wrong Schema", name: "Should Not Config On Invalid Nameservers Wrong Schema",
existingNSGroup: existingNSGroup, existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID, nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{ operations: []NameServerGroupUpdateOperation{
@@ -705,7 +872,7 @@ func TestUpdateNameServerGroup(t *testing.T) {
errFunc: require.Error, errFunc: require.Error,
}, },
{ {
name: "Should Not Update On Invalid Nameservers Wrong IP", name: "Should Not Config On Invalid Nameservers Wrong IP",
existingNSGroup: existingNSGroup, existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID, nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{ operations: []NameServerGroupUpdateOperation{
@@ -717,7 +884,7 @@ func TestUpdateNameServerGroup(t *testing.T) {
errFunc: require.Error, errFunc: require.Error,
}, },
{ {
name: "Should Not Update On Large Number Of Nameservers", name: "Should Not Config On Large Number Of Nameservers",
existingNSGroup: existingNSGroup, existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID, nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{ operations: []NameServerGroupUpdateOperation{
@@ -729,7 +896,7 @@ func TestUpdateNameServerGroup(t *testing.T) {
errFunc: require.Error, errFunc: require.Error,
}, },
{ {
name: "Should Not Update On Invalid GroupID", name: "Should Not Config On Invalid GroupID",
existingNSGroup: existingNSGroup, existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID, nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{ operations: []NameServerGroupUpdateOperation{
@@ -740,6 +907,30 @@ func TestUpdateNameServerGroup(t *testing.T) {
}, },
errFunc: require.Error, errFunc: require.Error,
}, },
{
name: "Should Not Config On Invalid Domains",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupDomains,
Values: []string{invalidDomain},
},
},
errFunc: require.Error,
},
{
name: "Should Not Config On Invalid Primary Status",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupPrimary,
Values: []string{"yes"},
},
},
errFunc: require.Error,
},
} }
for _, testCase := range testCases { for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
@@ -865,12 +1056,12 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return BuildManager(store, NewPeersUpdateManager(), nil) return BuildManager(store, NewPeersUpdateManager(), nil, "", "")
} }
func createNSStore(t *testing.T) (Store, error) { func createNSStore(t *testing.T) (Store, error) {
dataDir := t.TempDir() dataDir := t.TempDir()
store, err := NewStore(dataDir) store, err := NewFileStore(dataDir)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -2,10 +2,10 @@ package server
import ( import (
"github.com/c-robinson/iplib" "github.com/c-robinson/iplib"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/rs/xid" "github.com/rs/xid"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"math/rand" "math/rand"
"net" "net"
"sync" "sync"
@@ -23,9 +23,10 @@ const (
) )
type NetworkMap struct { type NetworkMap struct {
Peers []*Peer Peers []*Peer
Network *Network Network *Network
Routes []*route.Route Routes []*route.Route
DNSConfig nbdns.Config
} }
type Network struct { type Network struct {
@@ -93,7 +94,7 @@ func AllocatePeerIP(ipNet net.IPNet, takenIps []net.IP) (net.IP, error) {
ips, _ := generateIPs(&ipNet, takenIPMap) ips, _ := generateIPs(&ipNet, takenIPMap)
if len(ips) == 0 { if len(ips) == 0 {
return nil, status.Errorf(codes.OutOfRange, "failed allocating new IP for the ipNet %s - network is out of IPs", ipNet.String()) return nil, status.Errorf(status.PreconditionFailed, "failed allocating new IP for the ipNet %s - network is out of IPs", ipNet.String())
} }
// pick a random IP // pick a random IP

Some files were not shown because too many files have changed in this diff Show More