Compare commits

...

30 Commits

Author SHA1 Message Date
Maycon Santos
157137e4ad Use a single way to generate network map (#550) 2022-11-08 11:38:40 +01:00
Maycon Santos
7d7e576775 Set report caller when info or higher (#555) 2022-11-08 10:56:13 +01:00
Misha Bragin
f37b43a542 Save Peer Status separately in the FileStore (#554)
Due to peer reconnects when restarting the Management service,
there are lots of SaveStore operations to update peer status.

Store.SavePeerStatus stores peer status separately and the
FileStore implementation stores it in memory.
2022-11-08 10:46:12 +01:00
Maycon Santos
7e262572a4 Move dns label generation to store (#552) 2022-11-08 10:31:34 +01:00
Misha Bragin
a768a0aa8a Always lock the store when getting an account (#551) 2022-11-07 19:09:22 +01:00
Misha Bragin
ed7ac81027 Introduce locking on the account level (#548) 2022-11-07 17:52:23 +01:00
Maycon Santos
1f845f466c Add account copy test (#549) 2022-11-07 17:37:28 +01:00
Maycon Santos
270f0e4ce8 Feature/dns protocol (#543)
Added DNS update protocol message

Added sync to clients

Update nameserver API with new fields

Added default NS groups

Added new dns-name flag for the management service append to peer DNS label
2022-11-07 15:38:21 +01:00
Misha Bragin
d0c6d88971 Simplified Store Interface (#545)
This PR simplifies Store and FileStore
by keeping just the Get and Save account methods.

The AccountManager operates mostly around
a single account, so it makes sense to fetch
the whole account object from the store.
2022-11-07 12:10:56 +01:00
Misha Bragin
4321b71984 Hide content based on user role (#541) 2022-11-05 10:24:50 +01:00
Maycon Santos
e8d82c1bd3 Feature/dns-server (#537)
Adding DNS server for client

Updated the API with new fields

Added custom zone object for peer's DNS resolution
2022-11-03 18:39:37 +01:00
Misha Bragin
6aa7a2c5e1 Hide setup key from non-admin users (#539) 2022-11-03 17:02:31 +01:00
Rui Lopes
2e0bf61e9a correctly set the windows application icon on windows (#535)
the icon format is not really supported, so this uses a png instead.

this closes https://github.com/netbirdio/netbird/issues/534.
2022-11-01 00:34:30 +01:00
Maycon Santos
126af9dffc Return gateway address if not nil (#533)
If the gateway address would be nil which is
the case on macOS, we return the preferredSrc

added tests for getExistingRIBRouteGateway function

update log message
2022-10-31 11:54:34 +01:00
Maycon Santos
4cdf2df660 Update sign pipeline version to 0.0.4 (#531)
This version has a fix for the
macOS UI client architecture
2022-10-31 11:03:42 +01:00
Maycon Santos
9a4c9aa286 Add active peers count per OS (#526)
* Add active peers count per OS

* increase iface tests timeout
2022-10-26 14:48:40 +02:00
Rui Lopes
5ed61700ff Set the application icon, settings window title and systray tooltip (#523) 2022-10-26 14:34:30 +02:00
Misha Bragin
84117a9fb7 Update WireGuard trademark note 2022-10-23 11:47:42 +02:00
Misha Bragin
92b612eba4 Update demo video link 2022-10-22 16:55:49 +02:00
Misha Bragin
aeeaa21eed Update README.md (#524) 2022-10-22 16:19:16 +02:00
Misha Bragin
d228cd0cb1 Remove release note 2022-10-22 15:10:09 +02:00
Misha Bragin
b41f36fccd Add gRPC metrics (#522) 2022-10-22 15:06:54 +02:00
Misha Bragin
d2cde4a040 Add IdP metrics (#521) 2022-10-22 13:29:39 +02:00
Misha Bragin
84879a356b Extract app metrics to a separate struct (#520) 2022-10-22 11:50:21 +02:00
Misha Bragin
ed2214f9a9 Add HTTP request/response totals to metrics (#519) 2022-10-22 10:07:13 +02:00
braginini
4388dcc20b Listen metrics on all interfaces 2022-10-21 16:50:06 +02:00
Misha Bragin
4f1f0df7d2 Add Open-telemetry support (#517)
This PR brings open-telemetry metrics to the
Management service.
The Management service exposes new HTTP endpoint
/metrics on 8081 port by default.
The port can be changed by specifying
--metrics-port PORT flag when starting the service.
2022-10-21 16:24:13 +02:00
Misha Bragin
08ddf04c5f Fix IdP tests (#516) 2022-10-19 18:36:10 +02:00
Misha Bragin
b5ee2174a8 Do not set wt_pending_invite when unnecessary (#515)
wt_pending_invite property is set for every user on IdP.
Avoid setting it when unnecessary.
2022-10-19 17:51:41 +02:00
Misha Bragin
7218a3d563 Management single account mode (#511) 2022-10-19 17:43:28 +02:00
87 changed files with 4690 additions and 1411 deletions

View File

@@ -25,7 +25,6 @@ jobs:
needs: pre
runs-on: windows-latest
steps:
- name: Checkout code
uses: actions/checkout@v2

View File

@@ -9,7 +9,7 @@ on:
pull_request:
env:
SIGN_PIPE_VER: "v0.0.3"
SIGN_PIPE_VER: "v0.0.4"
GORELEASER_VER: "v1.6.3"
jobs:

View File

@@ -1,6 +1,6 @@
<p align="center">
<strong>:hatching_chick: New release! NetBird Easy SSH</strong>.
<a href="https://github.com/netbirdio/netbird/releases/tag/v0.8.0">
<strong>:hatching_chick: New Release! User Invites.</strong>
<a href="https://github.com/netbirdio/netbird/releases">
Learn more
</a>
</p>
@@ -62,9 +62,8 @@ NetBird creates an overlay peer-to-peer network connecting machines automaticall
- \[ ] Network Activity Monitoring.
### Secure peer-to-peer VPN with SSO and MFA in minutes
<p float="left" align="middle">
<img src="docs/media/netbird-sso-mfa-demo.gif" width="800"/>
</p>
https://user-images.githubusercontent.com/700848/197345890-2e2cded5-7b7a-436f-a444-94e80dd24f46.mov
**Note**: The `main` branch may be in an *unstable or even broken state* during development.
For stable versions, see [releases](https://github.com/netbirdio/netbird/releases).
@@ -104,5 +103,5 @@ See a complete [architecture overview](https://netbird.io/docs/overview/architec
We use open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE (WebRTC)](https://github.com/pion/ice), and [Coturn](https://github.com/coturn/coturn). We very much appreciate the work these guys are doing and we'd greatly appreciate if you could support them in any way (e.g. giving a star or a contribution).
### Legal
[WireGuard](https://wireguard.com/) is a registered trademark of Jason A. Donenfeld.
_WireGuard_ and the _WireGuard_ logo are [registered trademarks](https://www.wireguard.com/trademark-policy/) of Jason A. Donenfeld.

View File

@@ -62,18 +62,18 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
t.Fatal(err)
}
s := grpc.NewServer()
store, err := mgmt.NewStore(config.Datadir)
store, err := mgmt.NewFileStore(config.Datadir)
if err != nil {
t.Fatal(err)
}
peersUpdateManager := mgmt.NewPeersUpdateManager()
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil)
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "")
if err != nil {
t.Fatal(err)
}
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager)
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil)
if err != nil {
t.Fatal(err)
}

View File

@@ -0,0 +1,56 @@
package dns
import (
"github.com/miekg/dns"
nbdns "github.com/netbirdio/netbird/dns"
log "github.com/sirupsen/logrus"
"sync"
)
type localResolver struct {
registeredMap registrationMap
records sync.Map
}
// ServeDNS handles a DNS request
func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
log.Tracef("received question: %#v\n", r.Question[0])
response := d.lookupRecord(r)
if response == nil {
log.Debugf("got empty response for question: %#v\n", r.Question[0])
return
}
replyMessage := &dns.Msg{}
replyMessage.SetReply(r)
replyMessage.Answer = append(replyMessage.Answer, response)
err := w.WriteMsg(replyMessage)
if err != nil {
log.Debugf("got an error while writing the local resolver response, error: %v", err)
}
}
func (d *localResolver) lookupRecord(r *dns.Msg) dns.RR {
record, found := d.records.Load(r.Question[0].Name)
if !found {
return nil
}
return record.(dns.RR)
}
func (d *localResolver) registerRecord(record nbdns.SimpleRecord) error {
fullRecord, err := dns.NewRR(record.String())
if err != nil {
return err
}
d.records.Store(fullRecord.Header().Name, fullRecord)
return nil
}
func (d *localResolver) deleteRecord(recordKey string) {
d.records.Delete(dns.Fqdn(recordKey))
}

View File

@@ -0,0 +1,86 @@
package dns
import (
"github.com/miekg/dns"
nbdns "github.com/netbirdio/netbird/dns"
"strings"
"testing"
)
func TestLocalResolver_ServeDNS(t *testing.T) {
recordA := nbdns.SimpleRecord{
Name: "peera.netbird.cloud.",
Type: 1,
Class: nbdns.DefaultClass,
TTL: 300,
RData: "1.2.3.4",
}
recordCNAME := nbdns.SimpleRecord{
Name: "peerb.netbird.cloud.",
Type: 5,
Class: nbdns.DefaultClass,
TTL: 300,
RData: "www.netbird.io",
}
testCases := []struct {
name string
inputRecord nbdns.SimpleRecord
inputMSG *dns.Msg
responseShouldBeNil bool
}{
{
name: "Should Resolve A Record",
inputRecord: recordA,
inputMSG: new(dns.Msg).SetQuestion(recordA.Name, dns.TypeA),
},
{
name: "Should Resolve CNAME Record",
inputRecord: recordCNAME,
inputMSG: new(dns.Msg).SetQuestion(recordCNAME.Name, dns.TypeCNAME),
},
{
name: "Should Not Write When Not Found A Record",
inputRecord: recordA,
inputMSG: new(dns.Msg).SetQuestion("not.found.com", dns.TypeA),
responseShouldBeNil: true,
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
resolver := &localResolver{
registeredMap: make(registrationMap),
}
_ = resolver.registerRecord(testCase.inputRecord)
var responseMSG *dns.Msg
responseWriter := &mockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
resolver.ServeDNS(responseWriter, testCase.inputMSG)
if responseMSG == nil {
if testCase.responseShouldBeNil {
return
}
t.Fatalf("should write a response message")
}
answerString := responseMSG.Answer[0].String()
if !strings.Contains(answerString, testCase.inputRecord.Name) {
t.Fatalf("answer doesn't contain the same domain name: \nWant: %s\nGot:%s", testCase.name, answerString)
}
if !strings.Contains(answerString, dns.Type(testCase.inputRecord.Type).String()) {
t.Fatalf("answer doesn't contain the correct type: \nWant: %s\nGot:%s", dns.Type(testCase.inputRecord.Type).String(), answerString)
}
if !strings.Contains(answerString, testCase.inputRecord.RData) {
t.Fatalf("answer doesn't contain the same address: \nWant: %s\nGot:%s", testCase.inputRecord.RData, answerString)
}
})
}
}

View File

@@ -0,0 +1,35 @@
package dns
import (
"fmt"
nbdns "github.com/netbirdio/netbird/dns"
)
// MockServer is the mock instance of a dns server
type MockServer struct {
StartFunc func()
StopFunc func()
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
}
// Start mock implementation of Start from Server interface
func (m *MockServer) Start() {
if m.StartFunc != nil {
m.StartFunc()
}
}
// Stop mock implementation of Stop from Server interface
func (m *MockServer) Stop() {
if m.StopFunc != nil {
m.StopFunc()
}
}
// UpdateDNSServer mock implementation of UpdateDNSServer from Server interface
func (m *MockServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
if m.UpdateDNSServerFunc != nil {
return m.UpdateDNSServerFunc(serial, update)
}
return fmt.Errorf("method UpdateDNSServer is not implemented")
}

View File

@@ -0,0 +1,25 @@
package dns
import (
"github.com/miekg/dns"
"net"
)
type mockResponseWriter struct {
WriteMsgFunc func(m *dns.Msg) error
}
func (rw *mockResponseWriter) WriteMsg(m *dns.Msg) error {
if rw.WriteMsgFunc != nil {
return rw.WriteMsgFunc(m)
}
return nil
}
func (rw *mockResponseWriter) LocalAddr() net.Addr { return nil }
func (rw *mockResponseWriter) RemoteAddr() net.Addr { return nil }
func (rw *mockResponseWriter) Write([]byte) (int, error) { return 0, nil }
func (rw *mockResponseWriter) Close() error { return nil }
func (rw *mockResponseWriter) TsigStatus() error { return nil }
func (rw *mockResponseWriter) TsigTimersOnly(bool) {}
func (rw *mockResponseWriter) Hijack() {}

View File

@@ -0,0 +1,277 @@
package dns
import (
"context"
"fmt"
"github.com/miekg/dns"
nbdns "github.com/netbirdio/netbird/dns"
log "github.com/sirupsen/logrus"
"sync"
"time"
)
const (
port = 5053
defaultIP = "0.0.0.0"
)
// Server is a dns server interface
type Server interface {
Start()
Stop()
UpdateDNSServer(serial uint64, update nbdns.Config) error
}
// DefaultServer dns server object
type DefaultServer struct {
ctx context.Context
stop context.CancelFunc
mux sync.Mutex
server *dns.Server
dnsMux *dns.ServeMux
dnsMuxMap registrationMap
localResolver *localResolver
updateSerial uint64
listenerIsRunning bool
}
type registrationMap map[string]struct{}
type muxUpdate struct {
domain string
handler dns.Handler
}
// NewDefaultServer returns a new dns server
func NewDefaultServer(ctx context.Context) *DefaultServer {
mux := dns.NewServeMux()
dnsServer := &dns.Server{
Addr: fmt.Sprintf("%s:%d", defaultIP, port),
Net: "udp",
Handler: mux,
UDPSize: 65535,
}
ctx, stop := context.WithCancel(ctx)
return &DefaultServer{
ctx: ctx,
stop: stop,
server: dnsServer,
dnsMux: mux,
dnsMuxMap: make(registrationMap),
localResolver: &localResolver{
registeredMap: make(registrationMap),
},
}
}
// Start runs the listener in a go routine
func (s *DefaultServer) Start() {
log.Debugf("starting dns on %s:%d", defaultIP, port)
go func() {
s.setListenerStatus(true)
defer s.setListenerStatus(false)
err := s.server.ListenAndServe()
if err != nil {
log.Errorf("dns server returned an error: %v", err)
}
}()
}
func (s *DefaultServer) setListenerStatus(running bool) {
s.listenerIsRunning = running
}
// Stop stops the server
func (s *DefaultServer) Stop() {
s.stop()
err := s.stopListener()
if err != nil {
log.Error(err)
}
}
func (s *DefaultServer) stopListener() error {
if !s.listenerIsRunning {
return nil
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err := s.server.ShutdownContext(ctx)
if err != nil {
return fmt.Errorf("stopping dns server listener returned an error: %v", err)
}
return nil
}
// UpdateDNSServer processes an update received from the management service
func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
select {
case <-s.ctx.Done():
log.Infof("not updating DNS server as context is closed")
return s.ctx.Err()
default:
if serial < s.updateSerial {
return fmt.Errorf("not applying dns update, error: "+
"network update is %d behind the last applied update", s.updateSerial-serial)
}
s.mux.Lock()
defer s.mux.Unlock()
// is the service should be disabled, we stop the listener
// and proceed with a regular update to clean up the handlers and records
if !update.ServiceEnable {
err := s.stopListener()
if err != nil {
log.Error(err)
}
} else if !s.listenerIsRunning {
s.Start()
}
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
if err != nil {
return fmt.Errorf("not applying dns update, error: %v", err)
}
upstreamMuxUpdates, err := s.buildUpstreamHandlerUpdate(update.NameServerGroups)
if err != nil {
return fmt.Errorf("not applying dns update, error: %v", err)
}
muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...)
s.updateMux(muxUpdates)
s.updateLocalResolver(localRecords)
s.updateSerial = serial
return nil
}
}
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]muxUpdate, map[string]nbdns.SimpleRecord, error) {
var muxUpdates []muxUpdate
localRecords := make(map[string]nbdns.SimpleRecord, 0)
for _, customZone := range customZones {
if len(customZone.Records) == 0 {
return nil, nil, fmt.Errorf("received an empty list of records")
}
muxUpdates = append(muxUpdates, muxUpdate{
domain: customZone.Domain,
handler: s.localResolver,
})
for _, record := range customZone.Records {
localRecords[record.Name] = record
}
}
return muxUpdates, localRecords, nil
}
func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]muxUpdate, error) {
var muxUpdates []muxUpdate
for _, nsGroup := range nameServerGroups {
if len(nsGroup.NameServers) == 0 {
return nil, fmt.Errorf("received a nameserver group with empty nameserver list")
}
handler := &upstreamResolver{
parentCTX: s.ctx,
upstreamClient: &dns.Client{},
upstreamTimeout: defaultUpstreamTimeout,
}
for _, ns := range nsGroup.NameServers {
if ns.NSType != nbdns.UDPNameServerType {
log.Warnf("skiping nameserver %s with type %s, this peer supports only %s",
ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String())
continue
}
handler.upstreamServers = append(handler.upstreamServers, getNSHostPort(ns))
}
if len(handler.upstreamServers) == 0 {
log.Errorf("received a nameserver group with an invalid nameserver list")
continue
}
if nsGroup.Primary {
muxUpdates = append(muxUpdates, muxUpdate{
domain: nbdns.RootZone,
handler: handler,
})
continue
}
if len(nsGroup.Domains) == 0 {
return nil, fmt.Errorf("received a non primary nameserver group with an empty domain list")
}
for _, domain := range nsGroup.Domains {
if domain == "" {
return nil, fmt.Errorf("received a nameserver group with an empty domain element")
}
muxUpdates = append(muxUpdates, muxUpdate{
domain: domain,
handler: handler,
})
}
}
return muxUpdates, nil
}
func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
muxUpdateMap := make(registrationMap)
for _, update := range muxUpdates {
s.registerMux(update.domain, update.handler)
muxUpdateMap[update.domain] = struct{}{}
}
for key := range s.dnsMuxMap {
_, found := muxUpdateMap[key]
if !found {
s.deregisterMux(key)
}
}
s.dnsMuxMap = muxUpdateMap
}
func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord) {
for key := range s.localResolver.registeredMap {
_, found := update[key]
if !found {
s.localResolver.deleteRecord(key)
}
}
updatedMap := make(registrationMap)
for key, record := range update {
err := s.localResolver.registerRecord(record)
if err != nil {
log.Warnf("got an error while registering the record (%s), error: %v", record.String(), err)
}
updatedMap[key] = struct{}{}
}
s.localResolver.registeredMap = updatedMap
}
func getNSHostPort(ns nbdns.NameServer) string {
return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port)
}
func (s *DefaultServer) registerMux(pattern string, handler dns.Handler) {
s.dnsMux.Handle(pattern, handler)
}
func (s *DefaultServer) deregisterMux(pattern string) {
s.dnsMux.HandleRemove(pattern)
}

View File

@@ -0,0 +1,285 @@
package dns
import (
"context"
"fmt"
nbdns "github.com/netbirdio/netbird/dns"
"net"
"net/netip"
"os"
"runtime"
"testing"
"time"
)
var zoneRecords = []nbdns.SimpleRecord{
{
Name: "peera.netbird.cloud",
Type: 1,
Class: nbdns.DefaultClass,
TTL: 300,
RData: "1.2.3.4",
},
}
func TestUpdateDNSServer(t *testing.T) {
nameServers := []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
{
IP: netip.MustParseAddr("8.8.4.4"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
}
testCases := []struct {
name string
initUpstreamMap registrationMap
initLocalMap registrationMap
initSerial uint64
inputSerial uint64
inputUpdate nbdns.Config
shouldFail bool
expectedUpstreamMap registrationMap
expectedLocalMap registrationMap
}{
{
name: "Initial Config Should Succeed",
initLocalMap: make(registrationMap),
initUpstreamMap: make(registrationMap),
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
Domains: []string{"netbird.io"},
NameServers: nameServers,
},
{
NameServers: nameServers,
Primary: true,
},
},
},
expectedUpstreamMap: registrationMap{"netbird.io": struct{}{}, "netbird.cloud": struct{}{}, nbdns.RootZone: struct{}{}},
expectedLocalMap: registrationMap{zoneRecords[0].Name: struct{}{}},
},
{
name: "New Config Should Succeed",
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
initUpstreamMap: registrationMap{zoneRecords[0].Name: struct{}{}},
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
Domains: []string{"netbird.io"},
NameServers: nameServers,
},
},
},
expectedUpstreamMap: registrationMap{"netbird.io": struct{}{}, "netbird.cloud": struct{}{}},
expectedLocalMap: registrationMap{zoneRecords[0].Name: struct{}{}},
},
{
name: "Smaller Config Serial Should Be Skipped",
initLocalMap: make(registrationMap),
initUpstreamMap: make(registrationMap),
initSerial: 2,
inputSerial: 1,
shouldFail: true,
},
{
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
initLocalMap: make(registrationMap),
initUpstreamMap: make(registrationMap),
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
NameServers: nameServers,
},
},
},
shouldFail: true,
},
{
name: "Invalid NS Group Nameservers list Should Fail",
initLocalMap: make(registrationMap),
initUpstreamMap: make(registrationMap),
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
NameServers: nameServers,
},
},
},
shouldFail: true,
},
{
name: "Invalid Custom Zone Records list Should Fail",
initLocalMap: make(registrationMap),
initUpstreamMap: make(registrationMap),
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
NameServers: nameServers,
Primary: true,
},
},
},
shouldFail: true,
},
{
name: "Empty Config Should Succeed and Clean Maps",
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
initUpstreamMap: registrationMap{zoneRecords[0].Name: struct{}{}},
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{ServiceEnable: true},
expectedUpstreamMap: make(registrationMap),
expectedLocalMap: make(registrationMap),
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
ctx := context.Background()
dnsServer := NewDefaultServer(ctx)
dnsServer.dnsMuxMap = testCase.initUpstreamMap
dnsServer.localResolver.registeredMap = testCase.initLocalMap
dnsServer.updateSerial = testCase.initSerial
dnsServer.listenerIsRunning = true
err := dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
if err != nil {
if testCase.shouldFail {
return
}
t.Fatalf("update dns server should not fail, got error: %v", err)
}
if len(dnsServer.dnsMuxMap) != len(testCase.expectedUpstreamMap) {
t.Fatalf("update upstream failed, map size is different than expected, want %d, got %d", len(testCase.expectedUpstreamMap), len(dnsServer.dnsMuxMap))
}
for key := range testCase.expectedUpstreamMap {
_, found := dnsServer.dnsMuxMap[key]
if !found {
t.Fatalf("update upstream failed, key %s was not found in the dnsMuxMap: %#v", key, dnsServer.dnsMuxMap)
}
}
if len(dnsServer.localResolver.registeredMap) != len(testCase.expectedLocalMap) {
t.Fatalf("update local failed, registered map size is different than expected, want %d, got %d", len(testCase.expectedLocalMap), len(dnsServer.localResolver.registeredMap))
}
for key := range testCase.expectedLocalMap {
_, found := dnsServer.localResolver.registeredMap[key]
if !found {
t.Fatalf("update local failed, key %s was not found in the localResolver.registeredMap: %#v", key, dnsServer.localResolver.registeredMap)
}
}
})
}
}
func TestDNSServerStartStop(t *testing.T) {
ctx := context.Background()
dnsServer := NewDefaultServer(ctx)
if runtime.GOOS == "windows" && os.Getenv("CI") == "true" {
// todo review why this test is not working only on github actions workflows
t.Skip("skipping test in Windows CI workflows.")
}
dnsServer.Start()
err := dnsServer.localResolver.registerRecord(zoneRecords[0])
if err != nil {
t.Error(err)
}
dnsServer.dnsMux.Handle("netbird.cloud", dnsServer.localResolver)
resolver := &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
d := net.Dialer{
Timeout: time.Second * 5,
}
addr := fmt.Sprintf("127.0.0.1:%d", port)
conn, err := d.DialContext(ctx, network, addr)
if err != nil {
t.Log(err)
// retry test before exit, for slower systems
return d.DialContext(ctx, network, addr)
}
return conn, nil
},
}
ips, err := resolver.LookupHost(context.Background(), zoneRecords[0].Name)
if err != nil {
t.Fatalf("failed to connect to the server, error: %v", err)
}
t.Log(ips)
if ips[0] != zoneRecords[0].RData {
t.Fatalf("got a different IP from the server: want %s, got %s", zoneRecords[0].RData, ips[0])
}
dnsServer.Stop()
ctx, cancel := context.WithTimeout(ctx, time.Second*1)
defer cancel()
_, err = resolver.LookupHost(ctx, zoneRecords[0].Name)
if err == nil {
t.Fatalf("we should encounter an error when querying a stopped server")
}
}

View File

@@ -0,0 +1,67 @@
package dns
import (
"context"
"errors"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"net"
"time"
)
const defaultUpstreamTimeout = 15 * time.Second
type upstreamResolver struct {
parentCTX context.Context
upstreamClient *dns.Client
upstreamServers []string
upstreamTimeout time.Duration
}
// ServeDNS handles a DNS request
func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
log.Tracef("received an upstream question: %#v", r.Question[0])
select {
case <-u.parentCTX.Done():
return
default:
}
for _, upstream := range u.upstreamServers {
ctx, cancel := context.WithTimeout(u.parentCTX, u.upstreamTimeout)
rm, t, err := u.upstreamClient.ExchangeContext(ctx, r, upstream)
cancel()
if err != nil {
if err == context.DeadlineExceeded || isTimeout(err) {
log.Warnf("got an error while connecting to upstream %s, error: %v", upstream, err)
continue
}
log.Errorf("got an error while querying the upstream %s, error: %v", upstream, err)
return
}
log.Tracef("took %s to query the upstream %s", t, upstream)
err = w.WriteMsg(rm)
if err != nil {
log.Errorf("got an error while writing the upstream resolver response, error: %v", err)
}
return
}
log.Errorf("all queries to the upstream nameservers failed with timeout")
}
// isTimeout returns true if the given error is a network timeout error.
//
// Copied from k8s.io/apimachinery/pkg/util/net.IsTimeout
func isTimeout(err error) bool {
var neterr net.Error
if errors.As(err, &neterr) {
return neterr != nil && neterr.Timeout()
}
return false
}

View File

@@ -0,0 +1,110 @@
package dns
import (
"context"
"github.com/miekg/dns"
"strings"
"testing"
"time"
)
func TestUpstreamResolver_ServeDNS(t *testing.T) {
testCases := []struct {
name string
inputMSG *dns.Msg
responseShouldBeNil bool
InputServers []string
timeout time.Duration
cancelCTX bool
expectedAnswer string
}{
{
name: "Should Resolve A Record",
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
InputServers: []string{"8.8.8.8:53", "8.8.4.4:53"},
timeout: defaultUpstreamTimeout,
expectedAnswer: "1.1.1.1",
},
{
name: "Should Resolve If First Upstream Times Out",
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
InputServers: []string{"8.0.0.0:53", "8.8.4.4:53"},
timeout: 2 * time.Second,
expectedAnswer: "1.1.1.1",
},
{
name: "Should Not Resolve If Can't Connect To Both Servers",
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
InputServers: []string{"8.0.0.0:53", "8.0.0.1:53"},
timeout: 200 * time.Millisecond,
responseShouldBeNil: true,
},
{
name: "Should Not Resolve If Parent Context Is Canceled",
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
InputServers: []string{"8.0.0.0:53", "8.8.4.4:53"},
cancelCTX: true,
timeout: defaultUpstreamTimeout,
responseShouldBeNil: true,
},
//{
// name: "Should Resolve CNAME Record",
// inputMSG: new(dns.Msg).SetQuestion("one.one.one.one", dns.TypeCNAME),
//},
//{
// name: "Should Not Write When Not Found A Record",
// inputMSG: new(dns.Msg).SetQuestion("not.found.com", dns.TypeA),
// responseShouldBeNil: true,
//},
}
// should resolve if first upstream times out
// should not write when both fails
// should not resolve if parent context is canceled
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO())
resolver := &upstreamResolver{
parentCTX: ctx,
upstreamClient: &dns.Client{},
upstreamServers: testCase.InputServers,
upstreamTimeout: testCase.timeout,
}
if testCase.cancelCTX {
cancel()
} else {
defer cancel()
}
var responseMSG *dns.Msg
responseWriter := &mockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
resolver.ServeDNS(responseWriter, testCase.inputMSG)
if responseMSG == nil {
if testCase.responseShouldBeNil {
return
}
t.Fatalf("should write a response message")
}
foundAnswer := false
for _, answer := range responseMSG.Answer {
if strings.Contains(answer.String(), testCase.expectedAnswer) {
foundAnswer = true
break
}
}
if !foundAnswer {
t.Errorf("couldn't find the required answer, %s, in the dns response", testCase.expectedAnswer)
}
})
}
}

View File

@@ -3,12 +3,15 @@ package internal
import (
"context"
"fmt"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/routemanager"
nbssh "github.com/netbirdio/netbird/client/ssh"
nbstatus "github.com/netbirdio/netbird/client/status"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/route"
"math/rand"
"net"
"net/netip"
"reflect"
"runtime"
"strings"
@@ -103,6 +106,8 @@ type Engine struct {
statusRecorder *nbstatus.Status
routeManager routemanager.Manager
dnsServer dns.Server
}
// Peer is an instance of the Connection Peer
@@ -130,6 +135,7 @@ func NewEngine(
networkSerial: 0,
sshServerFunc: nbssh.DefaultSSHServer,
statusRecorder: statusRecorder,
dnsServer: dns.NewDefaultServer(ctx),
}
}
@@ -190,6 +196,10 @@ func (e *Engine) Stop() error {
e.routeManager.Stop()
}
if e.dnsServer != nil {
e.dnsServer.Stop()
}
log.Infof("stopped Netbird Engine")
return nil
@@ -638,6 +648,15 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
log.Errorf("failed to update routes, err: %v", err)
}
protoDNSConfig := networkMap.GetDNSConfig()
if protoDNSConfig == nil {
protoDNSConfig = &mgmProto.DNSConfig{}
}
err = e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig))
if err != nil {
log.Errorf("failed to update dns server, err: %v", err)
}
e.networkSerial = serial
return nil
}
@@ -660,6 +679,48 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
return routes
}
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig) nbdns.Config {
dnsUpdate := nbdns.Config{
ServiceEnable: protoDNSConfig.GetServiceEnable(),
CustomZones: make([]nbdns.CustomZone, 0),
NameServerGroups: make([]*nbdns.NameServerGroup, 0),
}
for _, zone := range protoDNSConfig.GetCustomZones() {
dnsZone := nbdns.CustomZone{
Domain: zone.GetDomain(),
}
for _, record := range zone.Records {
dnsRecord := nbdns.SimpleRecord{
Name: record.GetName(),
Type: int(record.GetType()),
Class: record.GetClass(),
TTL: int(record.GetTTL()),
RData: record.GetRData(),
}
dnsZone.Records = append(dnsZone.Records, dnsRecord)
}
dnsUpdate.CustomZones = append(dnsUpdate.CustomZones, dnsZone)
}
for _, nsGroup := range protoDNSConfig.GetNameServerGroups() {
dnsNSGroup := &nbdns.NameServerGroup{
Primary: nsGroup.GetPrimary(),
Domains: nsGroup.GetDomains(),
}
for _, ns := range nsGroup.GetNameServers() {
dnsNS := nbdns.NameServer{
IP: netip.MustParseAddr(ns.GetIP()),
NSType: nbdns.NameServerType(ns.GetNSType()),
Port: int(ns.GetPort()),
}
dnsNSGroup.NameServers = append(dnsNSGroup.NameServers, dnsNS)
}
dnsUpdate.NameServerGroups = append(dnsUpdate.NameServerGroups, dnsNSGroup)
}
return dnsUpdate
}
// addNewPeers adds peers that were not know before but arrived from the Management service with the update
func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
for _, p := range peersUpdate {

View File

@@ -3,9 +3,11 @@ package internal
import (
"context"
"fmt"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/ssh"
nbstatus "github.com/netbirdio/netbird/client/status"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
"github.com/stretchr/testify/assert"
@@ -440,7 +442,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
expectedSerial uint64
}{
{
name: "Routes Update Should Be Passed To Manager",
name: "Routes Config Should Be Passed To Manager",
networkMap: &mgmtProto.NetworkMap{
Serial: 1,
PeerConfig: nil,
@@ -486,7 +488,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
expectedSerial: 1,
},
{
name: "Empty Routes Update Should Be Passed",
name: "Empty Routes Config Should Be Passed",
networkMap: &mgmtProto.NetworkMap{
Serial: 1,
PeerConfig: nil,
@@ -566,6 +568,183 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
}
}
func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
testCases := []struct {
name string
inputErr error
networkMap *mgmtProto.NetworkMap
expectedZonesLen int
expectedZones []nbdns.CustomZone
expectedNSGroupsLen int
expectedNSGroups []*nbdns.NameServerGroup
expectedSerial uint64
}{
{
name: "DNS Config Should Be Passed To DNS Server",
networkMap: &mgmtProto.NetworkMap{
Serial: 1,
PeerConfig: nil,
RemotePeersIsEmpty: false,
Routes: nil,
DNSConfig: &mgmtProto.DNSConfig{
ServiceEnable: true,
CustomZones: []*mgmtProto.CustomZone{
{
Domain: "netbird.cloud.",
Records: []*mgmtProto.SimpleRecord{
{
Name: "peer-a.netbird.cloud.",
Type: 1,
Class: nbdns.DefaultClass,
TTL: 300,
RData: "100.64.0.1",
},
},
},
},
NameServerGroups: []*mgmtProto.NameServerGroup{
{
Primary: true,
NameServers: []*mgmtProto.NameServer{
{
IP: "8.8.8.8",
NSType: 1,
Port: 53,
},
},
},
},
},
},
expectedZonesLen: 1,
expectedZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud.",
Records: []nbdns.SimpleRecord{
{
Name: "peer-a.netbird.cloud.",
Type: 1,
Class: nbdns.DefaultClass,
TTL: 300,
RData: "100.64.0.1",
},
},
},
},
expectedNSGroupsLen: 1,
expectedNSGroups: []*nbdns.NameServerGroup{
{
Primary: true,
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
NSType: 1,
Port: 53,
},
},
},
},
expectedSerial: 1,
},
{
name: "Empty DNS Config Should Be OK",
networkMap: &mgmtProto.NetworkMap{
Serial: 1,
PeerConfig: nil,
RemotePeersIsEmpty: false,
Routes: nil,
DNSConfig: nil,
},
expectedZonesLen: 0,
expectedZones: []nbdns.CustomZone{},
expectedNSGroupsLen: 0,
expectedNSGroups: []*nbdns.NameServerGroup{},
expectedSerial: 1,
},
{
name: "Error Shouldn't Break Engine",
inputErr: fmt.Errorf("mocking error"),
networkMap: &mgmtProto.NetworkMap{
Serial: 1,
PeerConfig: nil,
RemotePeersIsEmpty: false,
Routes: nil,
},
expectedZonesLen: 0,
expectedZones: []nbdns.CustomZone{},
expectedNSGroupsLen: 0,
expectedNSGroups: []*nbdns.NameServerGroup{},
expectedSerial: 1,
},
}
for n, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
// test setup
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Fatal(err)
return
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
wgIfaceName := fmt.Sprintf("utun%d", 104+n)
wgAddr := fmt.Sprintf("100.66.%d.1/24", n)
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, &EngineConfig{
WgIfaceName: wgIfaceName,
WgAddr: wgAddr,
WgPrivateKey: key,
WgPort: 33100,
}, nbstatus.NewRecorder())
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU)
assert.NoError(t, err, "shouldn't return error")
mockRouteManager := &routemanager.MockManager{
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) error {
return nil
},
}
engine.routeManager = mockRouteManager
input := struct {
inputSerial uint64
inputNSGroups []*nbdns.NameServerGroup
inputZones []nbdns.CustomZone
}{}
mockDNSServer := &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error {
input.inputSerial = serial
input.inputZones = update.CustomZones
input.inputNSGroups = update.NameServerGroups
return testCase.inputErr
},
}
engine.dnsServer = mockDNSServer
defer func() {
exitErr := engine.Stop()
if exitErr != nil {
return
}
}()
err = engine.updateNetworkMap(testCase.networkMap)
assert.NoError(t, err, "shouldn't return error")
assert.Equal(t, testCase.expectedSerial, input.inputSerial, "serial should match")
assert.Len(t, input.inputNSGroups, testCase.expectedZonesLen, "zones len should match")
assert.Equal(t, testCase.expectedZones, input.inputZones, "custom zones should match")
assert.Len(t, input.inputNSGroups, testCase.expectedNSGroupsLen, "ns groups len should match")
assert.Equal(t, testCase.expectedNSGroups, input.inputNSGroups, "ns groups should match")
})
}
}
func TestEngine_MultiplePeers(t *testing.T) {
// log.SetLevel(log.DebugLevel)
@@ -756,17 +935,17 @@ func startManagement(port int, dataDir string) (*grpc.Server, error) {
return nil, err
}
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
store, err := server.NewStore(config.Datadir)
store, err := server.NewFileStore(config.Datadir)
if err != nil {
log.Fatalf("failed creating a store: %s: %v", config.Datadir, err)
}
peersUpdateManager := server.NewPeersUpdateManager()
accountManager, err := server.BuildManager(store, peersUpdateManager, nil)
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "")
if err != nil {
return nil, err
}
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager)
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil)
if err != nil {
return nil, err
}

View File

@@ -21,7 +21,7 @@ func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error {
}
if prefixGateway != nil && !prefixGateway.Equal(gateway) {
log.Warnf("route for network %s already exist and is pointing to the gateway: %s, won't add another one", prefix, prefixGateway)
log.Warnf("skipping adding a new route for network %s because it already exists and is pointing to the non default gateway: %s", prefix, prefixGateway)
return nil
}
return addToRouteTable(prefix, addr)
@@ -45,11 +45,14 @@ func getExistingRIBRouteGateway(prefix netip.Prefix) (net.IP, error) {
if err != nil {
return nil, err
}
_, _, localGatewayAddress, err := r.Route(prefix.Addr().AsSlice())
_, gateway, preferredSrc, err := r.Route(prefix.Addr().AsSlice())
if err != nil {
log.Errorf("getting routes returned an error: %v", err)
return nil, errRouteNotFound
}
if gateway == nil {
return preferredSrc, nil
}
return localGatewayAddress, nil
return gateway, nil
}

View File

@@ -4,6 +4,7 @@ import (
"fmt"
"github.com/netbirdio/netbird/iface"
"github.com/stretchr/testify/require"
"net"
"net/netip"
"testing"
)
@@ -66,3 +67,45 @@ func TestAddRemoveRoutes(t *testing.T) {
})
}
}
func TestGetExistingRIBRouteGateway(t *testing.T) {
gateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0"))
if err != nil {
t.Fatal("shouldn't return error when fetching the gateway: ", err)
}
if gateway == nil {
t.Fatal("should return a gateway")
}
addresses, err := net.InterfaceAddrs()
if err != nil {
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
}
var testingIP string
var testingPrefix netip.Prefix
for _, address := range addresses {
if address.Network() != "ip+net" {
continue
}
prefix := netip.MustParsePrefix(address.String())
if !prefix.Addr().IsLoopback() && prefix.Addr().Is4() {
testingIP = prefix.Addr().String()
testingPrefix = prefix.Masked()
break
}
}
localIP, err := getExistingRIBRouteGateway(testingPrefix)
if err != nil {
t.Fatal("shouldn't return error: ", err)
}
if localIP == nil {
t.Fatal("should return a gateway for local network")
}
if localIP.String() == gateway.String() {
t.Fatal("local ip should not match with gateway IP")
}
if localIP.String() != testingIP {
t.Fatalf("local ip should match with testing IP: want %s got %s", testingIP, localIP.String())
}
}

View File

@@ -8,7 +8,6 @@ import (
"context"
"flag"
"fmt"
"github.com/netbirdio/netbird/client/system"
"os"
"os/exec"
"path"
@@ -18,6 +17,8 @@ import (
"syscall"
"time"
"github.com/netbirdio/netbird/client/system"
"github.com/cenkalti/backoff/v4"
_ "embed"
@@ -61,6 +62,8 @@ func main() {
flag.Parse()
a := app.New()
a.SetIcon(fyne.NewStaticResource("netbird", iconDisconnectedPNG))
client := newServiceClient(daemonAddr, a, showSettings)
if showSettings {
a.Run()
@@ -113,7 +116,7 @@ type serviceClient struct {
iLogFile *widget.Entry
iPreSharedKey *widget.Entry
// observable settings over correspondign iMngURL and iPreSharedKey values.
// observable settings over corresponding iMngURL and iPreSharedKey values.
managementURL string
preSharedKey string
adminURL string
@@ -121,7 +124,7 @@ type serviceClient struct {
// newServiceClient instance constructor
//
// This constructor olso build UI elements for settings window.
// This constructor also builds the UI elements for the settings window.
func newServiceClient(addr string, a fyne.App, showSettings bool) *serviceClient {
s := &serviceClient{
ctx: context.Background(),
@@ -149,7 +152,7 @@ func newServiceClient(addr string, a fyne.App, showSettings bool) *serviceClient
func (s *serviceClient) showUIElements() {
// add settings window UI elements.
s.wSettings = s.app.NewWindow("Settings")
s.wSettings = s.app.NewWindow("NetBird Settings")
s.iMngURL = widget.NewEntry()
s.iAdminURL = widget.NewEntry()
s.iConfigFile = widget.NewEntry()
@@ -326,11 +329,13 @@ func (s *serviceClient) updateStatus() error {
if status.Status == string(internal.StatusConnected) && !s.mUp.Disabled() {
systray.SetIcon(s.icConnected)
systray.SetTooltip("NetBird (Connected)")
s.mStatus.SetTitle("Connected")
s.mUp.Disable()
s.mDown.Enable()
} else if status.Status != string(internal.StatusConnected) && s.mUp.Disabled() {
systray.SetIcon(s.icDisconnected)
systray.SetTooltip("NetBird (Disconnected)")
s.mStatus.SetTitle("Disconnected")
s.mDown.Disable()
s.mUp.Enable()
@@ -355,6 +360,7 @@ func (s *serviceClient) updateStatus() error {
func (s *serviceClient) onTrayReady() {
systray.SetIcon(s.icDisconnected)
systray.SetTooltip("NetBird")
// setup systray menu items
s.mStatus = systray.AddMenuItem("Disconnected", "Disconnected")

View File

@@ -2,5 +2,83 @@
// to parse and normalize dns records and configuration
package dns
// DefaultDNSPort well-known port number
const DefaultDNSPort = 53
import (
"fmt"
"github.com/miekg/dns"
"golang.org/x/net/idna"
"regexp"
"strings"
)
const (
// DefaultDNSPort well-known port number
DefaultDNSPort = 53
// RootZone is a string representation of the root zone
RootZone = "."
// DefaultClass is the class supported by the system
DefaultClass = "IN"
)
const invalidHostLabel = "[^a-zA-Z0-9-]+"
// Config represents a dns configuration that is exchanged between management and peers
type Config struct {
// ServiceEnable indicates if the service should be enabled
ServiceEnable bool
// NameServerGroups contains a list of nameserver group
NameServerGroups []*NameServerGroup
// CustomZones contains a list of custom zone
CustomZones []CustomZone
}
// CustomZone represents a custom zone to be resolved by the dns server
type CustomZone struct {
// Domain is the zone's domain
Domain string
// Records custom zone records
Records []SimpleRecord
}
// SimpleRecord provides a simple DNS record specification for CNAME, A and AAAA records
type SimpleRecord struct {
// Name domain name
Name string
// Type of record, 1 for A, 5 for CNAME, 28 for AAAA. see https://pkg.go.dev/github.com/miekg/dns@v1.1.41#pkg-constants
Type int
// Class dns class, currently use the DefaultClass for all records
Class string
// TTL time-to-live for the record
TTL int
// RData is the actual value resolved in a dns query
RData string
}
// String returns a string of the simple record formatted as:
// <Name> <TTL> <Class> <Type> <RDATA>
func (s SimpleRecord) String() string {
fqdn := dns.Fqdn(s.Name)
return fmt.Sprintf("%s %d %s %s %s", fqdn, s.TTL, s.Class, dns.Type(s.Type).String(), s.RData)
}
// GetParsedDomainLabel returns a domain label with max 59 characters,
// parsed for old Hosts.txt requirements, and converted to ASCII and lowercase
func GetParsedDomainLabel(name string) (string, error) {
labels := dns.SplitDomainName(name)
if len(labels) == 0 {
return "", fmt.Errorf("got empty label list for name \"%s\"", name)
}
rawLabel := labels[0]
ascii, err := idna.Punycode.ToASCII(rawLabel)
if err != nil {
return "", fmt.Errorf("unable to convert host lavel to ASCII, error: %v", err)
}
invalidHostMatcher := regexp.MustCompile(invalidHostLabel)
validHost := strings.ToLower(invalidHostMatcher.ReplaceAllString(ascii, "-"))
if len(validHost) > 58 {
validHost = validHost[:59]
}
return validHost, nil
}

View File

@@ -9,8 +9,6 @@ import (
)
const (
// MaxGroupNameChar maximum group name size
MaxGroupNameChar = 40
// InvalidNameServerType invalid nameserver type
InvalidNameServerType NameServerType = iota
// UDPNameServerType udp nameserver type
@@ -18,6 +16,8 @@ const (
)
const (
// MaxGroupNameChar maximum group name size
MaxGroupNameChar = 40
// InvalidNameServerTypeString invalid nameserver type as string
InvalidNameServerTypeString = "invalid"
// UDPNameServerTypeString udp nameserver type as string
@@ -59,6 +59,10 @@ type NameServerGroup struct {
NameServers []NameServer
// Groups list of peer group IDs to distribute the nameservers information
Groups []string
// Primary indicates that the nameserver group is the primary resolver for any dns query
Primary bool
// Domains indicate the dns query domains to use with this nameserver group
Domains []string
// Enabled group status
Enabled bool
}
@@ -128,6 +132,8 @@ func (g *NameServerGroup) Copy() *NameServerGroup {
NameServers: g.NameServers,
Groups: g.Groups,
Enabled: g.Enabled,
Primary: g.Primary,
Domains: g.Domains,
}
}
@@ -136,8 +142,10 @@ func (g *NameServerGroup) IsEqual(other *NameServerGroup) bool {
return other.ID == g.ID &&
other.Name == g.Name &&
other.Description == g.Description &&
other.Primary == g.Primary &&
compareNameServerList(g.NameServers, other.NameServers) &&
compareGroupsList(g.Groups, other.Groups)
compareGroupsList(g.Groups, other.Groups) &&
compareGroupsList(g.Domains, other.Domains)
}
func compareNameServerList(list, other []NameServer) bool {

Binary file not shown.

Before

Width:  |  Height:  |  Size: 887 KiB

23
go.mod
View File

@@ -18,12 +18,12 @@ require (
github.com/spf13/pflag v1.0.5
github.com/vishvananda/netlink v1.1.0
golang.org/x/crypto v0.0.0-20220513210258-46612604a0f9
golang.org/x/sys v0.0.0-20220622161953-175b2fd9d664
golang.org/x/sys v0.0.0-20220919091848-fb04ddd9f9c8
golang.zx2c4.com/wireguard v0.0.0-20211209221555-9c9e7e272434
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211215182854-7a385b3431de
golang.zx2c4.com/wireguard/windows v0.5.1
google.golang.org/grpc v1.43.0
google.golang.org/protobuf v1.28.0
google.golang.org/protobuf v1.28.1
gopkg.in/natefinch/lumberjack.v2 v2.0.0
)
@@ -38,10 +38,15 @@ require (
github.com/google/nftables v0.0.0-20220808154552-2eca00135732
github.com/libp2p/go-netroute v0.2.0
github.com/magiconair/properties v1.8.5
github.com/miekg/dns v1.1.41
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/prometheus/client_golang v1.13.0
github.com/rs/xid v1.3.0
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
github.com/stretchr/testify v1.8.0
go.opentelemetry.io/otel/exporters/prometheus v0.33.0
go.opentelemetry.io/otel/metric v0.33.0
go.opentelemetry.io/otel/sdk/metric v0.33.0
golang.org/x/net v0.0.0-20220630215102-69896b714898
golang.org/x/term v0.0.0-20220526004731-065cf7ba2467
)
@@ -65,11 +70,13 @@ require (
github.com/getlantern/ops v0.0.0-20190325191751-d70cb0d6f85f // indirect
github.com/go-gl/gl v0.0.0-20210813123233-e4099ee2221f // indirect
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20211024062804-40e447a793be // indirect
github.com/go-logr/logr v1.2.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-redis/redis/v8 v8.11.5 // indirect
github.com/go-stack/stack v1.8.0 // indirect
github.com/godbus/dbus/v5 v5.0.4 // indirect
github.com/goki/freetype v0.0.0-20181231101311-fa8a33aabaff // indirect
github.com/google/go-cmp v0.5.7 // indirect
github.com/google/go-cmp v0.5.9 // indirect
github.com/google/gopacket v1.1.19 // indirect
github.com/inconshreveable/mousetrap v1.0.0 // indirect
github.com/josharian/native v0.0.0-20200817173448-b6b71def0850 // indirect
@@ -89,20 +96,22 @@ require (
github.com/pion/turn/v2 v2.0.8 // indirect
github.com/pion/udp v0.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_golang v1.12.2 // indirect
github.com/prometheus/client_model v0.2.0 // indirect
github.com/prometheus/common v0.33.0 // indirect
github.com/prometheus/procfs v0.7.3 // indirect
github.com/prometheus/common v0.37.0 // indirect
github.com/prometheus/procfs v0.8.0 // indirect
github.com/rogpeppe/go-internal v1.8.0 // indirect
github.com/spf13/cast v1.5.0 // indirect
github.com/srwiley/oksvg v0.0.0-20200311192757-870daf9aa564 // indirect
github.com/srwiley/rasterx v0.0.0-20200120212402-85cb7272f5e9 // indirect
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df // indirect
github.com/yuin/goldmark v1.4.1 // indirect
go.opentelemetry.io/otel v1.11.1 // indirect
go.opentelemetry.io/otel/sdk v1.11.1 // indirect
go.opentelemetry.io/otel/trace v1.11.1 // indirect
golang.org/x/exp v0.0.0-20220518171630-0b5c67f07fdf // indirect
golang.org/x/image v0.0.0-20200430140353-33d19683fad8 // indirect
golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3 // indirect
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect
golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f // indirect
golang.org/x/text v0.3.8-0.20211105212822-18b340fc7af2 // indirect
golang.org/x/tools v0.1.10 // indirect
golang.org/x/xerrors v0.0.0-20220411194840-2f41105eb62f // indirect

44
go.sum
View File

@@ -202,6 +202,11 @@ github.com/go-logfmt/logfmt v0.5.1/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KE
github.com/go-logr/logr v0.1.0/go.mod h1:ixOQHD9gLJUVQQ2ZOR7zLEifBX6tGkNJF4QyIY7sIas=
github.com/go-logr/logr v0.2.0/go.mod h1:z6/tIYblkpsD+a4lm/fGIIU9mZ+XfAiaFtq7xTgseGU=
github.com/go-logr/logr v1.2.0/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0=
github.com/go-logr/logr v1.2.3/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/go-openapi/jsonpointer v0.0.0-20160704185906-46af16f9f7b1/go.mod h1:+35s3my2LFTysnkMfxsJBAMHj/DoqoB9knIWoYG/Vk0=
github.com/go-openapi/jsonpointer v0.19.3/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg=
@@ -278,8 +283,8 @@ github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.7 h1:81/ik6ipDQS2aGcBfIN5dHDB36BwrStyeAQquSYCV4o=
github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v0.0.0-20161122191042-44d81051d367/go.mod h1:HP5RmnzzSNb993RKQDq4+1A4ia9nllfqcQFTQJedwGI=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/gofuzz v1.1.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
@@ -450,6 +455,7 @@ github.com/mdlayher/socket v0.0.0-20211102153432-57e3fa563ecb h1:2dC7L10LmTqlyMV
github.com/mdlayher/socket v0.0.0-20211102153432-57e3fa563ecb/go.mod h1:nFZ1EtZYK8Gi/k6QNu7z7CgO20i/4ExeQswwWuPmG/g=
github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg=
github.com/miekg/dns v1.1.26/go.mod h1:bPDLeHnStXmXAq1m/Ch/hvfNHr14JKNPMBo3VZKjuso=
github.com/miekg/dns v1.1.41 h1:WMszZWJG0XmzbK9FEmzH2TVcqYzFesusSIB41b8KHxY=
github.com/miekg/dns v1.1.41/go.mod h1:p6aan82bvRIyn+zDIv9xYNUpwa73JcSh9BKwknJysuI=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc=
@@ -542,8 +548,8 @@ github.com/prometheus/client_golang v1.4.0/go.mod h1:e9GMxYsXl05ICDXkRhurwBS4Q3O
github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M=
github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0=
github.com/prometheus/client_golang v1.12.1/go.mod h1:3Z9XVyYiZYEO+YQWt3RD2R3jrbd179Rt297l4aS6nDY=
github.com/prometheus/client_golang v1.12.2 h1:51L9cDoUHVrXx4zWYlcLQIZ+d+VXHgqnYKkIuq4g/34=
github.com/prometheus/client_golang v1.12.2/go.mod h1:3Z9XVyYiZYEO+YQWt3RD2R3jrbd179Rt297l4aS6nDY=
github.com/prometheus/client_golang v1.13.0 h1:b71QUfeo5M8gq2+evJdTPfZhYMAU0uKPkyPJ7TPsloU=
github.com/prometheus/client_golang v1.13.0/go.mod h1:vTeo+zgvILHsnnj/39Ou/1fPN5nJFOEMgftOUOmlvYQ=
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
@@ -554,15 +560,16 @@ github.com/prometheus/common v0.9.1/go.mod h1:yhUN8i9wzaXS3w1O07YhxHEBxD+W35wd8b
github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo=
github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc=
github.com/prometheus/common v0.32.1/go.mod h1:vu+V0TpY+O6vW9J44gczi3Ap/oXXR10b+M/gUGO4Hls=
github.com/prometheus/common v0.33.0 h1:rHgav/0a6+uYgGdNt3jwz8FNSesO/Hsang3O0T9A5SE=
github.com/prometheus/common v0.33.0/go.mod h1:gB3sOl7P0TvJabZpLY5uQMpUqRCPPCyRLCZYc7JZTNE=
github.com/prometheus/common v0.37.0 h1:ccBbHCgIiT9uSoFY0vX8H3zsNR5eLt17/RQLUvn8pXE=
github.com/prometheus/common v0.37.0/go.mod h1:phzohg0JFMnBEFGxTDbfu3QyL5GI8gTQJFhYO5B3mfA=
github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA=
github.com/prometheus/procfs v0.0.8/go.mod h1:7Qr8sr6344vo1JqZ6HhLceV9o3AJ1Ff+GxbHq6oeK9A=
github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU=
github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
github.com/prometheus/procfs v0.7.3 h1:4jVXhlkAyzOScmCkXBTOLRLTz8EeU+eyjrwB/EPq0VU=
github.com/prometheus/procfs v0.7.3/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
github.com/prometheus/procfs v0.8.0 h1:ODq8ZFEaYeCaZOJlZZdJA2AbQR98dSHSM1KW/You5mo=
github.com/prometheus/procfs v0.8.0/go.mod h1:z7EfXMXOkbkqb9IINtpCn86r/to3BnA0uaxHdg830/4=
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
@@ -648,6 +655,18 @@ go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk=
go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E=
go.opentelemetry.io/otel v1.11.1 h1:4WLLAmcfkmDk2ukNXJyq3/kiz/3UzCaYq6PskJsaou4=
go.opentelemetry.io/otel v1.11.1/go.mod h1:1nNhXBbWSD0nsL38H6btgnFN2k4i0sNLHNNMZMSbUGE=
go.opentelemetry.io/otel/exporters/prometheus v0.33.0 h1:xXhPj7SLKWU5/Zd4Hxmd+X1C4jdmvc0Xy+kvjFx2z60=
go.opentelemetry.io/otel/exporters/prometheus v0.33.0/go.mod h1:ZSmYfKdYWEdSDBB4njLBIwTf4AU2JNsH3n2quVQDebI=
go.opentelemetry.io/otel/metric v0.33.0 h1:xQAyl7uGEYvrLAiV/09iTJlp1pZnQ9Wl793qbVvED1E=
go.opentelemetry.io/otel/metric v0.33.0/go.mod h1:QlTYc+EnYNq/M2mNk1qDDMRLpqCOj2f/r5c7Fd5FYaI=
go.opentelemetry.io/otel/sdk v1.11.1 h1:F7KmQgoHljhUuJyA+9BiU+EkJfyX5nVVF4wyzWZpKxs=
go.opentelemetry.io/otel/sdk v1.11.1/go.mod h1:/l3FE4SupHJ12TduVjUkZtlfFqDCQJlOlithYrdktys=
go.opentelemetry.io/otel/sdk/metric v0.33.0 h1:oTqyWfksgKoJmbrs2q7O7ahkJzt+Ipekihf8vhpa9qo=
go.opentelemetry.io/otel/sdk/metric v0.33.0/go.mod h1:xdypMeA21JBOvjjzDUtD0kzIcHO/SPez+a8HOzJPGp0=
go.opentelemetry.io/otel/trace v1.11.1 h1:ofxdnzsNrGBYXbP7t7zpUK281+go5rF7dvdIZXF8gdQ=
go.opentelemetry.io/otel/trace v1.11.1/go.mod h1:f/Q9G7vzk5u91PhbmKbg1Qn0rzH1LJ4vbPHFGkTPtOk=
go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI=
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU=
@@ -809,8 +828,9 @@ golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f h1:Ax0t5p6N38Ga0dThY21weqDEyz2oklo4IvDkpigvkD8=
golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20170830134202-bb24a47a89ea/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@@ -915,8 +935,8 @@ golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220608164250-635b8c9b7f68/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220622161953-175b2fd9d664 h1:wEZYwx+kK+KlZ0hpvP2Ls1Xr4+RWnlzGFwPP0aiDjIU=
golang.org/x/sys v0.0.0-20220622161953-175b2fd9d664/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220919091848-fb04ddd9f9c8 h1:h+EGohizhe9XlX18rfpa8k8RAc5XyaeamM+0VHRd4lc=
golang.org/x/sys v0.0.0-20220919091848-fb04ddd9f9c8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.0.0-20220526004731-065cf7ba2467 h1:CBpWXWQpIRjzmkkA+M7q9Fqnwd2mZr3AFqexg8YTfoM=
@@ -1162,8 +1182,8 @@ google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlba
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw=
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w=
google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

View File

@@ -89,7 +89,6 @@ func getIfaceAddrs(ifaceName string) ([]net.Addr, error) {
return addrs, nil
}
//
func Test_CreateInterface(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+1)
wgIP := "10.99.99.1/32"
@@ -369,8 +368,8 @@ func Test_ConnectPeers(t *testing.T) {
if err != nil {
t.Fatal(err)
}
timeout := 10 * time.Second
// todo: investigate why in some tests execution we need 30s
timeout := 30 * time.Second
timeoutChannel := time.After(timeout)
for {
select {

View File

@@ -214,8 +214,9 @@ func isBuiltinModule(name string) (bool, error) {
}
// /proc/modules
// name | memory size | reference count | references | state: <Live|Loading|Unloading>
// macvlan 28672 1 macvtap, Live 0x0000000000000000
//
// name | memory size | reference count | references | state: <Live|Loading|Unloading>
// macvlan 28672 1 macvtap, Live 0x0000000000000000
func moduleStatus(name string) (status, error) {
state := unknown
f, err := os.Open("/proc/modules")

View File

@@ -10,6 +10,8 @@ NETBIRD_MGMT_API_ENDPOINT=https://$NETBIRD_DOMAIN:$NETBIRD_MGMT_API_PORT
NETBIRD_MGMT_API_CERT_FILE="/etc/letsencrypt/live/$NETBIRD_DOMAIN/fullchain.pem"
# Management Certficate key file path.
NETBIRD_MGMT_API_CERT_KEY_FILE="/etc/letsencrypt/live/$NETBIRD_DOMAIN/privkey.pem"
# By default Management single account mode is enabled and domain set to $NETBIRD_DOMAIN, you may want to set this to your user's email domain
NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN=$NETBIRD_DOMAIN
# Turn credentials

View File

@@ -48,7 +48,7 @@ services:
# # port and command for Let's Encrypt validation without dashboard container
# - 443:443
# command: ["--letsencrypt-domain", "$NETBIRD_DOMAIN", "--log-file", "console"]
command: ["--port", "443", "--log-file", "console", "--disable-anonymous-metrics=$NETBIRD_DISABLE_ANONYMOUS_METRICS"]
command: ["--port", "443", "--log-file", "console", "--disable-anonymous-metrics=$NETBIRD_DISABLE_ANONYMOUS_METRICS", "--single-account-mode-domain=$NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN"]
# Coturn
coturn:
image: coturn/coturn

View File

@@ -49,18 +49,18 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
t.Fatal(err)
}
s := grpc.NewServer()
store, err := mgmt.NewStore(config.Datadir)
store, err := mgmt.NewFileStore(config.Datadir)
if err != nil {
t.Fatal(err)
}
peersUpdateManager := mgmt.NewPeersUpdateManager()
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil)
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "")
if err != nil {
t.Fatal(err)
}
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager)
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil)
if err != nil {
t.Fatal(err)
}

View File

@@ -0,0 +1,18 @@
package cmd
const (
defaultMgmtDataDir = "/var/lib/netbird/"
defaultMgmtConfigDir = "/etc/netbird"
defaultLogDir = "/var/log/netbird"
oldDefaultMgmtDataDir = "/var/lib/wiretrustee/"
oldDefaultMgmtConfigDir = "/etc/wiretrustee"
oldDefaultLogDir = "/var/log/wiretrustee"
defaultMgmtConfig = defaultMgmtConfigDir + "/management.json"
defaultLogFile = defaultLogDir + "/management.log"
oldDefaultMgmtConfig = oldDefaultMgmtConfigDir + "/management.json"
oldDefaultLogFile = oldDefaultLogDir + "/management.log"
defaultSingleAccModeDomain = "netbird.selfhosted"
)

View File

@@ -8,8 +8,10 @@ import (
"flag"
"fmt"
"github.com/google/uuid"
"github.com/miekg/dns"
httpapi "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/metrics"
"github.com/netbirdio/netbird/management/server/telemetry"
"golang.org/x/crypto/acme/autocert"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
@@ -41,11 +43,13 @@ import (
const ManagementLegacyPort = 33073
var (
mgmtPort int
mgmtLetsencryptDomain string
certFile string
certKey string
config *server.Config
mgmtPort int
mgmtMetricsPort int
mgmtLetsencryptDomain string
mgmtSingleAccModeDomain string
certFile string
certKey string
config *server.Config
kaep = keepalive.EnforcementPolicy{
MinTime: 15 * time.Second,
@@ -86,6 +90,11 @@ var (
}
}
_, valid := dns.IsDomainName(dnsDomain)
if !valid || len(dnsDomain) > 192 {
return fmt.Errorf("failed parsing the provided dns-domain. Valid status: %t, Lenght: %d", valid, len(dnsDomain))
}
return nil
},
RunE: func(cmd *cobra.Command, args []string) error {
@@ -107,21 +116,33 @@ var (
}
}
store, err := server.NewStore(config.Datadir)
store, err := server.NewFileStore(config.Datadir)
if err != nil {
return fmt.Errorf("failed creating Store: %s: %v", config.Datadir, err)
}
peersUpdateManager := server.NewPeersUpdateManager()
appMetrics, err := telemetry.NewDefaultAppMetrics(cmd.Context())
if err != nil {
return err
}
err = appMetrics.Expose(mgmtMetricsPort, "/metrics")
if err != nil {
return err
}
var idpManager idp.Manager
if config.IdpManagerConfig != nil {
idpManager, err = idp.NewManager(*config.IdpManagerConfig)
idpManager, err = idp.NewManager(*config.IdpManagerConfig, appMetrics)
if err != nil {
return fmt.Errorf("failed retrieving a new idp manager with err: %v", err)
}
}
accountManager, err := server.BuildManager(store, peersUpdateManager, idpManager)
if disableSingleAccMode {
mgmtSingleAccModeDomain = ""
}
accountManager, err := server.BuildManager(store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain, dnsDomain)
if err != nil {
return fmt.Errorf("failed to build default manager: %v", err)
}
@@ -151,14 +172,14 @@ var (
tlsEnabled = true
}
httpAPIHandler, err := httpapi.APIHandler(accountManager,
config.HttpConfig.AuthIssuer, config.HttpConfig.AuthAudience, config.HttpConfig.AuthKeysLocation)
httpAPIHandler, err := httpapi.APIHandler(accountManager, config.HttpConfig.AuthIssuer,
config.HttpConfig.AuthAudience, config.HttpConfig.AuthKeysLocation, appMetrics)
if err != nil {
return fmt.Errorf("failed creating HTTP API handler: %v", err)
}
gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
srv, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager)
srv, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, appMetrics)
if err != nil {
return fmt.Errorf("failed creating gRPC API handler: %v", err)
}
@@ -170,8 +191,6 @@ var (
return err
}
fmt.Println("metrics ", disableMetrics)
if !disableMetrics {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
@@ -225,11 +244,13 @@ var (
SetupCloseHandler()
<-stopCh
_ = appMetrics.Close()
_ = listener.Close()
if certManager != nil {
_ = certManager.Listener().Close()
}
gRPCAPIHandler.Stop()
_ = store.Close()
log.Infof("stopped Management Service")
return nil
@@ -427,7 +448,7 @@ func loadTLSConfig(certFile string, certKey string) (*tls.Config, error) {
return nil, err
}
// Create the credentials and return it
// NewDefaultAppMetrics the credentials and return it
config := &tls.Config{
Certificates: []tls.Certificate{serverCert},
ClientAuth: tls.NoClientCert,

View File

@@ -13,21 +13,13 @@ const (
)
var (
defaultMgmtConfigDir string
defaultMgmtDataDir string
defaultMgmtConfig string
defaultLogDir string
defaultLogFile string
oldDefaultMgmtConfigDir string
oldDefaultMgmtDataDir string
oldDefaultMgmtConfig string
oldDefaultLogDir string
oldDefaultLogFile string
mgmtDataDir string
mgmtConfig string
logLevel string
logFile string
disableMetrics bool
dnsDomain string
mgmtDataDir string
mgmtConfig string
logLevel string
logFile string
disableMetrics bool
disableSingleAccMode bool
rootCmd = &cobra.Command{
Use: "netbird-mgmt",
@@ -46,28 +38,17 @@ func Execute() error {
func init() {
stopCh = make(chan int)
defaultMgmtDataDir = "/var/lib/netbird/"
defaultMgmtConfigDir = "/etc/netbird"
defaultLogDir = "/var/log/netbird"
oldDefaultMgmtDataDir = "/var/lib/wiretrustee/"
oldDefaultMgmtConfigDir = "/etc/wiretrustee"
oldDefaultLogDir = "/var/log/wiretrustee"
defaultMgmtConfig = defaultMgmtConfigDir + "/management.json"
defaultLogFile = defaultLogDir + "/management.log"
oldDefaultMgmtConfig = oldDefaultMgmtConfigDir + "/management.json"
oldDefaultLogFile = oldDefaultLogDir + "/management.log"
mgmtCmd.Flags().IntVar(&mgmtPort, "port", 80, "server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise")
mgmtCmd.Flags().IntVar(&mgmtMetricsPort, "metrics-port", 8081, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics")
mgmtCmd.Flags().StringVar(&mgmtDataDir, "datadir", defaultMgmtDataDir, "server data directory location")
mgmtCmd.Flags().StringVar(&mgmtConfig, "config", defaultMgmtConfig, "Netbird config file location. Config params specified via command line (e.g. datadir) have a precedence over configuration from this file")
mgmtCmd.Flags().StringVar(&mgmtLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
mgmtCmd.Flags().StringVar(&mgmtSingleAccModeDomain, "single-account-mode-domain", defaultSingleAccModeDomain, "Enables single account mode. This means that all the users will be under the same account grouped by the specified domain. If the installation has more than one account, the property is ineffective. Enabled by default with the default domain "+defaultSingleAccModeDomain)
mgmtCmd.Flags().BoolVar(&disableSingleAccMode, "disable-single-account-mode", false, "If set to true, disables single account mode. The --single-account-mode-domain property will be ignored and every new user will have a separate NetBird account.")
mgmtCmd.Flags().StringVar(&certFile, "cert-file", "", "Location of your SSL certificate. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect")
mgmtCmd.Flags().StringVar(&certKey, "cert-key", "", "Location of your SSL certificate private key. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect")
mgmtCmd.Flags().BoolVar(&disableMetrics, "disable-anonymous-metrics", false, "disables push of anonymous usage metrics to NetBird")
mgmtCmd.Flags().StringVar(&dnsDomain, "dns-domain", defaultSingleAccModeDomain, fmt.Sprintf("Domain used for peer resolution. This is appended to the peer's name, e.g. pi-server. %s. Max lenght is 192 characters to allow appending to a peer name with up to 63 characters.", defaultSingleAccModeDomain))
rootCmd.MarkFlagRequired("config") //nolint
rootCmd.PersistentFlags().StringVar(&logLevel, "log-level", "info", "")

View File

@@ -1,15 +1,15 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.26.0
// protoc v3.12.4
// protoc v3.21.9
// source: management.proto
package proto
import (
timestamp "github.com/golang/protobuf/ptypes/timestamp"
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
timestamppb "google.golang.org/protobuf/types/known/timestamppb"
reflect "reflect"
sync "sync"
)
@@ -611,7 +611,7 @@ type ServerKeyResponse struct {
// Server's Wireguard public key
Key string `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"`
// Key expiration timestamp after which the key should be fetched again by the client
ExpiresAt *timestamp.Timestamp `protobuf:"bytes,2,opt,name=expiresAt,proto3" json:"expiresAt,omitempty"`
ExpiresAt *timestamppb.Timestamp `protobuf:"bytes,2,opt,name=expiresAt,proto3" json:"expiresAt,omitempty"`
// Version of the Wiretrustee Management Service protocol
Version int32 `protobuf:"varint,3,opt,name=version,proto3" json:"version,omitempty"`
}
@@ -655,7 +655,7 @@ func (x *ServerKeyResponse) GetKey() string {
return ""
}
func (x *ServerKeyResponse) GetExpiresAt() *timestamp.Timestamp {
func (x *ServerKeyResponse) GetExpiresAt() *timestamppb.Timestamp {
if x != nil {
return x.ExpiresAt
}
@@ -982,6 +982,8 @@ type NetworkMap struct {
RemotePeersIsEmpty bool `protobuf:"varint,4,opt,name=remotePeersIsEmpty,proto3" json:"remotePeersIsEmpty,omitempty"`
// List of routes to be applied
Routes []*Route `protobuf:"bytes,5,rep,name=Routes,proto3" json:"Routes,omitempty"`
// DNS config to be applied
DNSConfig *DNSConfig `protobuf:"bytes,6,opt,name=DNSConfig,proto3" json:"DNSConfig,omitempty"`
}
func (x *NetworkMap) Reset() {
@@ -1051,6 +1053,13 @@ func (x *NetworkMap) GetRoutes() []*Route {
return nil
}
func (x *NetworkMap) GetDNSConfig() *DNSConfig {
if x != nil {
return x.DNSConfig
}
return nil
}
// RemotePeerConfig represents a configuration of a remote peer.
// The properties are used to configure Wireguard Peers sections
type RemotePeerConfig struct {
@@ -1467,6 +1476,334 @@ func (x *Route) GetNetID() string {
return ""
}
// DNSConfig represents a dns.Update
type DNSConfig struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
ServiceEnable bool `protobuf:"varint,1,opt,name=ServiceEnable,proto3" json:"ServiceEnable,omitempty"`
NameServerGroups []*NameServerGroup `protobuf:"bytes,2,rep,name=NameServerGroups,proto3" json:"NameServerGroups,omitempty"`
CustomZones []*CustomZone `protobuf:"bytes,3,rep,name=CustomZones,proto3" json:"CustomZones,omitempty"`
}
func (x *DNSConfig) Reset() {
*x = DNSConfig{}
if protoimpl.UnsafeEnabled {
mi := &file_management_proto_msgTypes[20]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *DNSConfig) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*DNSConfig) ProtoMessage() {}
func (x *DNSConfig) ProtoReflect() protoreflect.Message {
mi := &file_management_proto_msgTypes[20]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use DNSConfig.ProtoReflect.Descriptor instead.
func (*DNSConfig) Descriptor() ([]byte, []int) {
return file_management_proto_rawDescGZIP(), []int{20}
}
func (x *DNSConfig) GetServiceEnable() bool {
if x != nil {
return x.ServiceEnable
}
return false
}
func (x *DNSConfig) GetNameServerGroups() []*NameServerGroup {
if x != nil {
return x.NameServerGroups
}
return nil
}
func (x *DNSConfig) GetCustomZones() []*CustomZone {
if x != nil {
return x.CustomZones
}
return nil
}
// CustomZone represents a dns.CustomZone
type CustomZone struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Domain string `protobuf:"bytes,1,opt,name=Domain,proto3" json:"Domain,omitempty"`
Records []*SimpleRecord `protobuf:"bytes,2,rep,name=Records,proto3" json:"Records,omitempty"`
}
func (x *CustomZone) Reset() {
*x = CustomZone{}
if protoimpl.UnsafeEnabled {
mi := &file_management_proto_msgTypes[21]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *CustomZone) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*CustomZone) ProtoMessage() {}
func (x *CustomZone) ProtoReflect() protoreflect.Message {
mi := &file_management_proto_msgTypes[21]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use CustomZone.ProtoReflect.Descriptor instead.
func (*CustomZone) Descriptor() ([]byte, []int) {
return file_management_proto_rawDescGZIP(), []int{21}
}
func (x *CustomZone) GetDomain() string {
if x != nil {
return x.Domain
}
return ""
}
func (x *CustomZone) GetRecords() []*SimpleRecord {
if x != nil {
return x.Records
}
return nil
}
// SimpleRecord represents a dns.SimpleRecord
type SimpleRecord struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Name string `protobuf:"bytes,1,opt,name=Name,proto3" json:"Name,omitempty"`
Type int64 `protobuf:"varint,2,opt,name=Type,proto3" json:"Type,omitempty"`
Class string `protobuf:"bytes,3,opt,name=Class,proto3" json:"Class,omitempty"`
TTL int64 `protobuf:"varint,4,opt,name=TTL,proto3" json:"TTL,omitempty"`
RData string `protobuf:"bytes,5,opt,name=RData,proto3" json:"RData,omitempty"`
}
func (x *SimpleRecord) Reset() {
*x = SimpleRecord{}
if protoimpl.UnsafeEnabled {
mi := &file_management_proto_msgTypes[22]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *SimpleRecord) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*SimpleRecord) ProtoMessage() {}
func (x *SimpleRecord) ProtoReflect() protoreflect.Message {
mi := &file_management_proto_msgTypes[22]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use SimpleRecord.ProtoReflect.Descriptor instead.
func (*SimpleRecord) Descriptor() ([]byte, []int) {
return file_management_proto_rawDescGZIP(), []int{22}
}
func (x *SimpleRecord) GetName() string {
if x != nil {
return x.Name
}
return ""
}
func (x *SimpleRecord) GetType() int64 {
if x != nil {
return x.Type
}
return 0
}
func (x *SimpleRecord) GetClass() string {
if x != nil {
return x.Class
}
return ""
}
func (x *SimpleRecord) GetTTL() int64 {
if x != nil {
return x.TTL
}
return 0
}
func (x *SimpleRecord) GetRData() string {
if x != nil {
return x.RData
}
return ""
}
// NameServerGroup represents a dns.NameServerGroup
type NameServerGroup struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
NameServers []*NameServer `protobuf:"bytes,1,rep,name=NameServers,proto3" json:"NameServers,omitempty"`
Primary bool `protobuf:"varint,2,opt,name=Primary,proto3" json:"Primary,omitempty"`
Domains []string `protobuf:"bytes,3,rep,name=Domains,proto3" json:"Domains,omitempty"`
}
func (x *NameServerGroup) Reset() {
*x = NameServerGroup{}
if protoimpl.UnsafeEnabled {
mi := &file_management_proto_msgTypes[23]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *NameServerGroup) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*NameServerGroup) ProtoMessage() {}
func (x *NameServerGroup) ProtoReflect() protoreflect.Message {
mi := &file_management_proto_msgTypes[23]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use NameServerGroup.ProtoReflect.Descriptor instead.
func (*NameServerGroup) Descriptor() ([]byte, []int) {
return file_management_proto_rawDescGZIP(), []int{23}
}
func (x *NameServerGroup) GetNameServers() []*NameServer {
if x != nil {
return x.NameServers
}
return nil
}
func (x *NameServerGroup) GetPrimary() bool {
if x != nil {
return x.Primary
}
return false
}
func (x *NameServerGroup) GetDomains() []string {
if x != nil {
return x.Domains
}
return nil
}
// NameServer represents a dns.NameServer
type NameServer struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
IP string `protobuf:"bytes,1,opt,name=IP,proto3" json:"IP,omitempty"`
NSType int64 `protobuf:"varint,2,opt,name=NSType,proto3" json:"NSType,omitempty"`
Port int64 `protobuf:"varint,3,opt,name=Port,proto3" json:"Port,omitempty"`
}
func (x *NameServer) Reset() {
*x = NameServer{}
if protoimpl.UnsafeEnabled {
mi := &file_management_proto_msgTypes[24]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *NameServer) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*NameServer) ProtoMessage() {}
func (x *NameServer) ProtoReflect() protoreflect.Message {
mi := &file_management_proto_msgTypes[24]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use NameServer.ProtoReflect.Descriptor instead.
func (*NameServer) Descriptor() ([]byte, []int) {
return file_management_proto_rawDescGZIP(), []int{24}
}
func (x *NameServer) GetIP() string {
if x != nil {
return x.IP
}
return ""
}
func (x *NameServer) GetNSType() int64 {
if x != nil {
return x.NSType
}
return 0
}
func (x *NameServer) GetPort() int64 {
if x != nil {
return x.Port
}
return 0
}
var File_management_proto protoreflect.FileDescriptor
var file_management_proto_rawDesc = []byte{
@@ -1583,7 +1920,7 @@ var file_management_proto_rawDesc = []byte{
0x0a, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28,
0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53,
0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e,
0x66, 0x69, 0x67, 0x22, 0xf7, 0x01, 0x0a, 0x0a, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d,
0x66, 0x69, 0x67, 0x22, 0xac, 0x02, 0x0a, 0x0a, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d,
0x61, 0x70, 0x12, 0x16, 0x0a, 0x06, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x18, 0x01, 0x20, 0x01,
0x28, 0x04, 0x52, 0x06, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x12, 0x36, 0x0a, 0x0a, 0x70, 0x65,
0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16,
@@ -1598,85 +1935,125 @@ var file_management_proto_rawDesc = []byte{
0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70,
0x74, 0x79, 0x12, 0x29, 0x0a, 0x06, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x05, 0x20, 0x03,
0x28, 0x0b, 0x32, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e,
0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x06, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x22, 0x83, 0x01,
0x0a, 0x10, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66,
0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x01,
0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x1e,
0x0a, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03,
0x28, 0x09, 0x52, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, 0x73, 0x12, 0x33,
0x0a, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28,
0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53,
0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e,
0x66, 0x69, 0x67, 0x22, 0x49, 0x0a, 0x09, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67,
0x12, 0x1e, 0x0a, 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01,
0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64,
0x12, 0x1c, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20,
0x01, 0x28, 0x0c, 0x52, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x22, 0x20,
0x0a, 0x1e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a,
0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
0x22, 0xbf, 0x01, 0x0a, 0x17, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f,
0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x48, 0x0a, 0x08,
0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x2c,
0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x44, 0x65, 0x76, 0x69,
0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46,
0x6c, 0x6f, 0x77, 0x2e, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x52, 0x08, 0x50, 0x72,
0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x42, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64,
0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a,
0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76,
0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76,
0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0x16, 0x0a, 0x08, 0x70, 0x72,
0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x0a, 0x0a, 0x06, 0x48, 0x4f, 0x53, 0x54, 0x45, 0x44,
0x10, 0x00, 0x22, 0xda, 0x01, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43,
0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49,
0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49,
0x44, 0x12, 0x22, 0x0a, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65,
0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53,
0x65, 0x63, 0x72, 0x65, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18,
0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x1a, 0x0a,
0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52,
0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x2e, 0x0a, 0x12, 0x44, 0x65, 0x76,
0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18,
0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74,
0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x24, 0x0a, 0x0d, 0x54, 0x6f, 0x6b,
0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09,
0x52, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x22,
0xb5, 0x01, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x44, 0x18,
0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x4e, 0x65, 0x74,
0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x4e, 0x65, 0x74, 0x77,
0x6f, 0x72, 0x6b, 0x12, 0x20, 0x0a, 0x0b, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79,
0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72,
0x6b, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x65, 0x65, 0x72, 0x18, 0x04, 0x20,
0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x65, 0x65, 0x72, 0x12, 0x16, 0x0a, 0x06, 0x4d, 0x65, 0x74,
0x72, 0x69, 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4d, 0x65, 0x74, 0x72, 0x69,
0x63, 0x12, 0x1e, 0x0a, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, 0x61, 0x64, 0x65, 0x18,
0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, 0x61, 0x64,
0x65, 0x12, 0x14, 0x0a, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09,
0x52, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x32, 0xf7, 0x02, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61,
0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a,
0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d,
0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73,
0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e,
0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61,
0x67, 0x65, 0x22, 0x00, 0x12, 0x46, 0x0a, 0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d,
0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70,
0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e,
0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65,
0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, 0x0c,
0x47, 0x65, 0x74, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d,
0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a,
0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x72,
0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00,
0x12, 0x33, 0x0a, 0x09, 0x69, 0x73, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e,
0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79,
0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d,
0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x5a, 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69,
0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46,
0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74,
0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x06, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x33, 0x0a,
0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b,
0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x44, 0x4e,
0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66,
0x69, 0x67, 0x22, 0x83, 0x01, 0x0a, 0x10, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65,
0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62,
0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62,
0x4b, 0x65, 0x79, 0x12, 0x1e, 0x0a, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70,
0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64,
0x49, 0x70, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67,
0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d,
0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x73,
0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0x49, 0x0a, 0x09, 0x53, 0x53, 0x48, 0x43,
0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1e, 0x0a, 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e, 0x61, 0x62,
0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e,
0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x1c, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, 0x62, 0x4b,
0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, 0x62,
0x4b, 0x65, 0x79, 0x22, 0x20, 0x0a, 0x1e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74,
0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x52, 0x65,
0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0xbf, 0x01, 0x0a, 0x17, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65,
0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f,
0x77, 0x12, 0x48, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x18, 0x01, 0x20,
0x01, 0x28, 0x0e, 0x32, 0x2c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74,
0x2e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61,
0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x2e, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65,
0x72, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x42, 0x0a, 0x0e, 0x50,
0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20,
0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74,
0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52,
0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22,
0x16, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x0a, 0x0a, 0x06, 0x48,
0x4f, 0x53, 0x54, 0x45, 0x44, 0x10, 0x00, 0x22, 0xda, 0x01, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76,
0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x43, 0x6c,
0x69, 0x65, 0x6e, 0x74, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x43, 0x6c,
0x69, 0x65, 0x6e, 0x74, 0x49, 0x44, 0x12, 0x22, 0x0a, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74,
0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x43, 0x6c,
0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f,
0x6d, 0x61, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61,
0x69, 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x18, 0x04,
0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x2e,
0x0a, 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70,
0x6f, 0x69, 0x6e, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x44, 0x65, 0x76, 0x69,
0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x24,
0x0a, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18,
0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70,
0x6f, 0x69, 0x6e, 0x74, 0x22, 0xb5, 0x01, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e,
0x0a, 0x02, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18,
0x0a, 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52,
0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x20, 0x0a, 0x0b, 0x4e, 0x65, 0x74, 0x77,
0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x4e,
0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x65,
0x65, 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x65, 0x65, 0x72, 0x12, 0x16,
0x0a, 0x06, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06,
0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x12, 0x1e, 0x0a, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65,
0x72, 0x61, 0x64, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x4d, 0x61, 0x73, 0x71,
0x75, 0x65, 0x72, 0x61, 0x64, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x18,
0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x22, 0xb4, 0x01, 0x0a,
0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x24, 0x0a, 0x0d, 0x53, 0x65,
0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28,
0x08, 0x52, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65,
0x12, 0x47, 0x0a, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72,
0x6f, 0x75, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e,
0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76,
0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x52, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72,
0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x12, 0x38, 0x0a, 0x0b, 0x43, 0x75, 0x73,
0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16,
0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x75, 0x73, 0x74,
0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x52, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f,
0x6e, 0x65, 0x73, 0x22, 0x58, 0x0a, 0x0a, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e,
0x65, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28,
0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x32, 0x0a, 0x07, 0x52, 0x65, 0x63,
0x6f, 0x72, 0x64, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e,
0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65,
0x63, 0x6f, 0x72, 0x64, 0x52, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x22, 0x74, 0x0a,
0x0c, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x12, 0x12, 0x0a,
0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d,
0x65, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52,
0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x18, 0x03,
0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x54,
0x54, 0x4c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x03, 0x54, 0x54, 0x4c, 0x12, 0x14, 0x0a,
0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x52, 0x44,
0x61, 0x74, 0x61, 0x22, 0x7f, 0x0a, 0x0f, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65,
0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x38, 0x0a, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65,
0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61,
0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72,
0x76, 0x65, 0x72, 0x52, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73,
0x12, 0x18, 0x0a, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28,
0x08, 0x52, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f,
0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d,
0x61, 0x69, 0x6e, 0x73, 0x22, 0x48, 0x0a, 0x0a, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76,
0x65, 0x72, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02,
0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01,
0x28, 0x03, 0x52, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f,
0x72, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x32, 0xf7,
0x02, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72,
0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e,
0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79,
0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61,
0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74,
0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x46, 0x0a, 0x04, 0x53,
0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74,
0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67,
0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45,
0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22,
0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f,
0x74, 0x6f, 0x33,
0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, 0x0c, 0x47, 0x65, 0x74, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72,
0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74,
0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d,
0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73,
0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x09, 0x69, 0x73, 0x48, 0x65, 0x61,
0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e,
0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65,
0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x5a, 0x0a, 0x1a,
0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69,
0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e,
0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65,
0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67,
0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d,
0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f,
0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
@@ -1692,7 +2069,7 @@ func file_management_proto_rawDescGZIP() []byte {
}
var file_management_proto_enumTypes = make([]protoimpl.EnumInfo, 2)
var file_management_proto_msgTypes = make([]protoimpl.MessageInfo, 20)
var file_management_proto_msgTypes = make([]protoimpl.MessageInfo, 25)
var file_management_proto_goTypes = []interface{}{
(HostConfig_Protocol)(0), // 0: management.HostConfig.Protocol
(DeviceAuthorizationFlowProvider)(0), // 1: management.DeviceAuthorizationFlow.provider
@@ -1716,7 +2093,12 @@ var file_management_proto_goTypes = []interface{}{
(*DeviceAuthorizationFlow)(nil), // 19: management.DeviceAuthorizationFlow
(*ProviderConfig)(nil), // 20: management.ProviderConfig
(*Route)(nil), // 21: management.Route
(*timestamp.Timestamp)(nil), // 22: google.protobuf.Timestamp
(*DNSConfig)(nil), // 22: management.DNSConfig
(*CustomZone)(nil), // 23: management.CustomZone
(*SimpleRecord)(nil), // 24: management.SimpleRecord
(*NameServerGroup)(nil), // 25: management.NameServerGroup
(*NameServer)(nil), // 26: management.NameServer
(*timestamppb.Timestamp)(nil), // 27: google.protobuf.Timestamp
}
var file_management_proto_depIdxs = []int32{
11, // 0: management.SyncResponse.wiretrusteeConfig:type_name -> management.WiretrusteeConfig
@@ -1727,7 +2109,7 @@ var file_management_proto_depIdxs = []int32{
6, // 5: management.LoginRequest.peerKeys:type_name -> management.PeerKeys
11, // 6: management.LoginResponse.wiretrusteeConfig:type_name -> management.WiretrusteeConfig
14, // 7: management.LoginResponse.peerConfig:type_name -> management.PeerConfig
22, // 8: management.ServerKeyResponse.expiresAt:type_name -> google.protobuf.Timestamp
27, // 8: management.ServerKeyResponse.expiresAt:type_name -> google.protobuf.Timestamp
12, // 9: management.WiretrusteeConfig.stuns:type_name -> management.HostConfig
13, // 10: management.WiretrusteeConfig.turns:type_name -> management.ProtectedHostConfig
12, // 11: management.WiretrusteeConfig.signal:type_name -> management.HostConfig
@@ -1737,24 +2119,29 @@ var file_management_proto_depIdxs = []int32{
14, // 15: management.NetworkMap.peerConfig:type_name -> management.PeerConfig
16, // 16: management.NetworkMap.remotePeers:type_name -> management.RemotePeerConfig
21, // 17: management.NetworkMap.Routes:type_name -> management.Route
17, // 18: management.RemotePeerConfig.sshConfig:type_name -> management.SSHConfig
1, // 19: management.DeviceAuthorizationFlow.Provider:type_name -> management.DeviceAuthorizationFlow.provider
20, // 20: management.DeviceAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig
2, // 21: management.ManagementService.Login:input_type -> management.EncryptedMessage
2, // 22: management.ManagementService.Sync:input_type -> management.EncryptedMessage
10, // 23: management.ManagementService.GetServerKey:input_type -> management.Empty
10, // 24: management.ManagementService.isHealthy:input_type -> management.Empty
2, // 25: management.ManagementService.GetDeviceAuthorizationFlow:input_type -> management.EncryptedMessage
2, // 26: management.ManagementService.Login:output_type -> management.EncryptedMessage
2, // 27: management.ManagementService.Sync:output_type -> management.EncryptedMessage
9, // 28: management.ManagementService.GetServerKey:output_type -> management.ServerKeyResponse
10, // 29: management.ManagementService.isHealthy:output_type -> management.Empty
2, // 30: management.ManagementService.GetDeviceAuthorizationFlow:output_type -> management.EncryptedMessage
26, // [26:31] is the sub-list for method output_type
21, // [21:26] is the sub-list for method input_type
21, // [21:21] is the sub-list for extension type_name
21, // [21:21] is the sub-list for extension extendee
0, // [0:21] is the sub-list for field type_name
22, // 18: management.NetworkMap.DNSConfig:type_name -> management.DNSConfig
17, // 19: management.RemotePeerConfig.sshConfig:type_name -> management.SSHConfig
1, // 20: management.DeviceAuthorizationFlow.Provider:type_name -> management.DeviceAuthorizationFlow.provider
20, // 21: management.DeviceAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig
25, // 22: management.DNSConfig.NameServerGroups:type_name -> management.NameServerGroup
23, // 23: management.DNSConfig.CustomZones:type_name -> management.CustomZone
24, // 24: management.CustomZone.Records:type_name -> management.SimpleRecord
26, // 25: management.NameServerGroup.NameServers:type_name -> management.NameServer
2, // 26: management.ManagementService.Login:input_type -> management.EncryptedMessage
2, // 27: management.ManagementService.Sync:input_type -> management.EncryptedMessage
10, // 28: management.ManagementService.GetServerKey:input_type -> management.Empty
10, // 29: management.ManagementService.isHealthy:input_type -> management.Empty
2, // 30: management.ManagementService.GetDeviceAuthorizationFlow:input_type -> management.EncryptedMessage
2, // 31: management.ManagementService.Login:output_type -> management.EncryptedMessage
2, // 32: management.ManagementService.Sync:output_type -> management.EncryptedMessage
9, // 33: management.ManagementService.GetServerKey:output_type -> management.ServerKeyResponse
10, // 34: management.ManagementService.isHealthy:output_type -> management.Empty
2, // 35: management.ManagementService.GetDeviceAuthorizationFlow:output_type -> management.EncryptedMessage
31, // [31:36] is the sub-list for method output_type
26, // [26:31] is the sub-list for method input_type
26, // [26:26] is the sub-list for extension type_name
26, // [26:26] is the sub-list for extension extendee
0, // [0:26] is the sub-list for field type_name
}
func init() { file_management_proto_init() }
@@ -2003,6 +2390,66 @@ func file_management_proto_init() {
return nil
}
}
file_management_proto_msgTypes[20].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*DNSConfig); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_management_proto_msgTypes[21].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*CustomZone); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_management_proto_msgTypes[22].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*SimpleRecord); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_management_proto_msgTypes[23].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*NameServerGroup); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_management_proto_msgTypes[24].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*NameServer); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
@@ -2010,7 +2457,7 @@ func file_management_proto_init() {
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_management_proto_rawDesc,
NumEnums: 2,
NumMessages: 20,
NumMessages: 25,
NumExtensions: 0,
NumServices: 1,
},

View File

@@ -178,6 +178,9 @@ message NetworkMap {
// List of routes to be applied
repeated Route Routes = 5;
// DNS config to be applied
DNSConfig DNSConfig = 6;
}
// RemotePeerConfig represents a configuration of a remote peer.
@@ -246,4 +249,40 @@ message Route {
int64 Metric = 5;
bool Masquerade = 6;
string NetID = 7;
}
// DNSConfig represents a dns.Update
message DNSConfig {
bool ServiceEnable = 1;
repeated NameServerGroup NameServerGroups = 2;
repeated CustomZone CustomZones = 3;
}
// CustomZone represents a dns.CustomZone
message CustomZone {
string Domain = 1;
repeated SimpleRecord Records = 2;
}
// SimpleRecord represents a dns.SimpleRecord
message SimpleRecord {
string Name = 1;
int64 Type = 2;
string Class = 3;
int64 TTL = 4;
string RData = 5;
}
// NameServerGroup represents a dns.NameServerGroup
message NameServerGroup {
repeated NameServer NameServers = 1;
bool Primary = 2;
repeated string Domains = 3;
}
// NameServer represents a dns.NameServer
message NameServer {
string IP = 1;
int64 NSType = 2;
int64 Port = 3;
}

View File

@@ -15,6 +15,7 @@ import (
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"math/rand"
"net/netip"
"reflect"
"regexp"
"strings"
@@ -37,7 +38,6 @@ func cacheEntryExpiration() time.Duration {
type AccountManager interface {
GetOrCreateAccountByUser(userId, domain string) (*Account, error)
GetAccountByUser(userId string) (*Account, error)
CreateSetupKey(
accountId string,
keyName string,
@@ -47,17 +47,16 @@ type AccountManager interface {
) (*SetupKey, error)
SaveSetupKey(accountID string, key *SetupKey) (*SetupKey, error)
CreateUser(accountID string, key *UserInfo) (*UserInfo, error)
ListSetupKeys(accountID string) ([]*SetupKey, error)
ListSetupKeys(accountID, userID string) ([]*SetupKey, error)
SaveUser(accountID string, key *User) (*UserInfo, error)
GetSetupKey(accountID, keyID string) (*SetupKey, error)
GetAccountById(accountId string) (*Account, error)
GetAccountByUserOrAccountId(userId, accountId, domain string) (*Account, error)
GetSetupKey(accountID, userID, keyID string) (*SetupKey, error)
GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error)
GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, error)
IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error)
AccountExists(accountId string) (*bool, error)
GetPeer(peerKey string) (*Peer, error)
GetPeers(accountID, userID string) ([]*Peer, error)
MarkPeerConnected(peerKey string, connected bool) error
RenamePeer(accountId string, peerKey string, newName string) (*Peer, error)
DeletePeer(accountId string, peerKey string) (*Peer, error)
GetPeerByIP(accountId string, peerIP string) (*Peer, error)
UpdatePeer(accountID string, peer *Peer) (*Peer, error)
@@ -66,7 +65,7 @@ type AccountManager interface {
AddPeer(setupKey string, userId string, peer *Peer) (*Peer, error)
UpdatePeerMeta(peerKey string, meta PeerSystemMeta) error
UpdatePeerSSHKey(peerKey string, sshKey string) error
GetUsersFromAccount(accountId string) ([]*UserInfo, error)
GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error)
GetGroup(accountId, groupID string) (*Group, error)
SaveGroup(accountId string, group *Group) error
UpdateGroup(accountID string, groupID string, operations []GroupUpdateOperation) (*Group, error)
@@ -75,19 +74,19 @@ type AccountManager interface {
GroupAddPeer(accountId, groupID, peerKey string) error
GroupDeletePeer(accountId, groupID, peerKey string) error
GroupListPeers(accountId, groupID string) ([]*Peer, error)
GetRule(accountId, ruleID string) (*Rule, error)
GetRule(accountID, ruleID, userID string) (*Rule, error)
SaveRule(accountID string, rule *Rule) error
UpdateRule(accountID string, ruleID string, operations []RuleUpdateOperation) (*Rule, error)
DeleteRule(accountId, ruleID string) error
ListRules(accountId string) ([]*Rule, error)
GetRoute(accountID, routeID string) (*route.Route, error)
ListRules(accountID, userID string) ([]*Rule, error)
GetRoute(accountID, routeID, userID string) (*route.Route, error)
CreateRoute(accountID string, prefix, peer, description, netID string, masquerade bool, metric int, enabled bool) (*route.Route, error)
SaveRoute(accountID string, route *route.Route) error
UpdateRoute(accountID string, routeID string, operations []RouteUpdateOperation) (*route.Route, error)
DeleteRoute(accountID, routeID string) error
ListRoutes(accountID string) ([]*route.Route, error)
ListRoutes(accountID, userID string) ([]*route.Route, error)
GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error)
CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, enabled bool) (*nbdns.NameServerGroup, error)
CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool) (*nbdns.NameServerGroup, error)
SaveNameServerGroup(accountID string, nsGroupToSave *nbdns.NameServerGroup) error
UpdateNameServerGroup(accountID, nsGroupID string, operations []NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error)
DeleteNameServerGroup(accountID, nsGroupID string) error
@@ -96,8 +95,6 @@ type AccountManager interface {
type DefaultAccountManager struct {
Store Store
// mux to synchronise account operations (e.g. generating Peer IP address inside the Network)
mux sync.Mutex
// cacheMux and cacheLoading helps to make sure that only a single cache reload runs at a time per accountID
cacheMux sync.Mutex
// cacheLoading keeps the accountIDs that are currently reloading. The accountID has to be removed once cache has been reloaded
@@ -106,6 +103,15 @@ type DefaultAccountManager struct {
idpManager idp.Manager
cacheManager cache.CacheInterface[[]*idp.UserData]
ctx context.Context
// singleAccountMode indicates whether the instance has a single account.
// If true, then every new user will end up under the same account.
// This value will be set to false if management service has more than one account.
singleAccountMode bool
// singleAccountModeDomain is a domain to use in singleAccountMode setup
singleAccountModeDomain string
// dnsDomain is used for peer resolution. This is appended to the peer's name
dnsDomain string
}
// Account represents a unique account of the system
@@ -135,6 +141,153 @@ type UserInfo struct {
Status string `json:"-"`
}
// GetPeersRoutes returns all active routes of provided peers
func (a *Account) GetPeersRoutes(givenPeers []*Peer) []*route.Route {
//TODO Peer.ID migration: we will need to replace search by Peer.ID here
routes := make([]*route.Route, 0)
for _, peer := range givenPeers {
peerRoutes := a.GetPeerRoutes(peer.Key)
activeRoutes := make([]*route.Route, 0)
for _, pr := range peerRoutes {
if pr.Enabled {
activeRoutes = append(activeRoutes, pr)
}
}
if len(activeRoutes) > 0 {
routes = append(routes, activeRoutes...)
}
}
return routes
}
// GetPeerRoutes returns a list of routes of a given peer
func (a *Account) GetPeerRoutes(peerPubKey string) []*route.Route {
//TODO Peer.ID migration: we will need to replace search by Peer.ID here
var routes []*route.Route
for _, r := range a.Routes {
if r.Peer == peerPubKey {
routes = append(routes, r)
continue
}
}
return routes
}
// GetRoutesByPrefix return list of routes by account and route prefix
func (a *Account) GetRoutesByPrefix(prefix netip.Prefix) []*route.Route {
var routes []*route.Route
for _, r := range a.Routes {
if r.Network.String() == prefix.String() {
routes = append(routes, r)
}
}
return routes
}
// GetPeerRules returns a list of source or destination rules of a given peer.
func (a *Account) GetPeerRules(peerPubKey string) (srcRules []*Rule, dstRules []*Rule) {
// Rules are group based so there is no direct access to peers.
// First, find all groups that the given peer belongs to
peerGroups := make(map[string]struct{})
for s, group := range a.Groups {
for _, peer := range group.Peers {
if peerPubKey == peer {
peerGroups[s] = struct{}{}
break
}
}
}
// Second, find all rules that have discovered source and destination groups
srcRulesMap := make(map[string]*Rule)
dstRulesMap := make(map[string]*Rule)
for _, rule := range a.Rules {
for _, g := range rule.Source {
if _, ok := peerGroups[g]; ok && srcRulesMap[rule.ID] == nil {
srcRules = append(srcRules, rule)
srcRulesMap[rule.ID] = rule
}
}
for _, g := range rule.Destination {
if _, ok := peerGroups[g]; ok && dstRulesMap[rule.ID] == nil {
dstRules = append(dstRules, rule)
dstRulesMap[rule.ID] = rule
}
}
}
return srcRules, dstRules
}
// GetPeers returns a list of all Account peers
func (a *Account) GetPeers() []*Peer {
var peers []*Peer
for _, peer := range a.Peers {
peers = append(peers, peer)
}
return peers
}
// UpdatePeer saves new or replaces existing peer
func (a *Account) UpdatePeer(update *Peer) {
//TODO Peer.ID migration: we will need to replace search by Peer.ID here
a.Peers[update.Key] = update
}
// DeletePeer deletes peer from the account cleaning up all the references
func (a *Account) DeletePeer(peerPubKey string) {
// TODO Peer.ID migration: we will need to replace search by Peer.ID here
// delete peer from groups
for _, g := range a.Groups {
for i, pk := range g.Peers {
if pk == peerPubKey {
g.Peers = append(g.Peers[:i], g.Peers[i+1:]...)
break
}
}
}
delete(a.Peers, peerPubKey)
a.Network.IncSerial()
}
// FindPeerByPubKey looks for a Peer by provided WireGuard public key in the Account or returns error if it wasn't found.
// It will return an object copy of the peer.
func (a *Account) FindPeerByPubKey(peerPubKey string) (*Peer, error) {
for _, peer := range a.Peers {
if peer.Key == peerPubKey {
return peer.Copy(), nil
}
}
return nil, status.Errorf(codes.NotFound, "peer with the public key %s not found", peerPubKey)
}
// FindUser looks for a given user in the Account or returns error if user wasn't found.
func (a *Account) FindUser(userID string) (*User, error) {
user := a.Users[userID]
if user == nil {
return nil, Errorf(UserNotFound, "user %s not found", userID)
}
return user, nil
}
// getPeerDNSLabels return the account's peers' dns labels
func (a *Account) getPeerDNSLabels() lookupMap {
existingLabels := make(lookupMap)
for _, peer := range a.Peers {
if peer.DNSLabel != "" {
existingLabels[peer.DNSLabel] = struct{}{}
}
}
return existingLabels
}
func (a *Account) Copy() *Account {
peers := map[string]*Peer{}
for id, peer := range a.Peers {
@@ -172,16 +325,19 @@ func (a *Account) Copy() *Account {
}
return &Account{
Id: a.Id,
CreatedBy: a.CreatedBy,
SetupKeys: setupKeys,
Network: a.Network.Copy(),
Peers: peers,
Users: users,
Groups: groups,
Rules: rules,
Routes: routes,
NameServerGroups: nsGroups,
Id: a.Id,
CreatedBy: a.CreatedBy,
Domain: a.Domain,
DomainCategory: a.DomainCategory,
IsDomainPrimaryAccount: a.IsDomainPrimaryAccount,
SetupKeys: setupKeys,
Network: a.Network.Copy(),
Peers: peers,
Users: users,
Groups: groups,
Rules: rules,
Routes: routes,
NameServerGroups: nsGroups,
}
}
@@ -195,36 +351,51 @@ func (a *Account) GetGroupAll() (*Group, error) {
}
// BuildManager creates a new DefaultAccountManager with a provided Store
func BuildManager(
store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager,
) (*DefaultAccountManager, error) {
func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager,
singleAccountModeDomain string, dnsDomain string) (*DefaultAccountManager, error) {
am := &DefaultAccountManager{
Store: store,
mux: sync.Mutex{},
peersUpdateManager: peersUpdateManager,
idpManager: idpManager,
ctx: context.Background(),
cacheMux: sync.Mutex{},
cacheLoading: map[string]chan struct{}{},
dnsDomain: dnsDomain,
}
allAccounts := store.GetAllAccounts()
// enable single account mode only if configured by user and number of existing accounts is not grater than 1
am.singleAccountMode = singleAccountModeDomain != "" && len(allAccounts) <= 1
if am.singleAccountMode {
am.singleAccountModeDomain = singleAccountModeDomain
log.Infof("single account mode enabled, accounts number %d", len(allAccounts))
} else {
log.Infof("single account mode disabled, accounts number %d", len(allAccounts))
}
// if account has not default group
// if account doesn't have a default group
// we create 'all' group and add all peers into it
// also we create default rule with source as destination
for _, account := range store.GetAllAccounts() {
for _, account := range allAccounts {
shouldSave := false
_, err := account.GetGroupAll()
if err != nil {
addAllGroup(account)
if err := store.SaveAccount(account); err != nil {
shouldSave = true
}
if shouldSave {
err = store.SaveAccount(account)
if err != nil {
return nil, err
}
}
}
gocacheClient := gocache.New(CacheExpirationMax, 30*time.Minute)
gocacheStore := cacheStore.NewGoCache(gocacheClient)
goCacheClient := gocache.New(CacheExpirationMax, 30*time.Minute)
goCacheStore := cacheStore.NewGoCache(goCacheClient)
am.cacheManager = cache.NewLoadable[[]*idp.UserData](am.loadAccount, cache.New[[]*idp.UserData](gocacheStore))
am.cacheManager = cache.NewLoadable[[]*idp.UserData](am.loadAccount, cache.New[[]*idp.UserData](goCacheStore))
if !isNil(am.idpManager) {
go func() {
@@ -278,32 +449,17 @@ func (am *DefaultAccountManager) warmupIDPCache() error {
return nil
}
// GetAccountById returns an existing account using its ID or error (NotFound) if doesn't exist
func (am *DefaultAccountManager) GetAccountById(accountId string) (*Account, error) {
am.mux.Lock()
defer am.mux.Unlock()
account, err := am.Store.GetAccount(accountId)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
}
return account, nil
}
// GetAccountByUserOrAccountId look for an account by user or account Id, if no account is provided and
// user id doesn't have an account associated with it, one account is created
func (am *DefaultAccountManager) GetAccountByUserOrAccountId(
userId, accountId, domain string,
) (*Account, error) {
if accountId != "" {
return am.GetAccountById(accountId)
} else if userId != "" {
account, err := am.GetOrCreateAccountByUser(userId, domain)
// GetAccountByUserOrAccountID looks for an account by user or accountID, if no account is provided and
// userID doesn't have an account associated with it, one account is created
func (am *DefaultAccountManager) GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error) {
if accountID != "" {
return am.Store.GetAccount(accountID)
} else if userID != "" {
account, err := am.GetOrCreateAccountByUser(userID, domain)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found using user id: %s", userId)
return nil, status.Errorf(codes.NotFound, "account not found using user id: %s", userID)
}
err = am.addAccountIDToIDPAppMeta(userId, account)
err = am.addAccountIDToIDPAppMeta(userID, account)
if err != nil {
return nil, err
}
@@ -585,7 +741,7 @@ func (am *DefaultAccountManager) redeemInvite(account *Account, userID string) e
return status.Errorf(codes.NotFound, "user %s not found in the IdP", userID)
}
if user.AppMetadata.WTPendingInvite {
if user.AppMetadata.WTPendingInvite != nil && *user.AppMetadata.WTPendingInvite {
log.Infof("redeeming invite for user %s account %s", userID, account.Id)
// User has already logged in, meaning that IdP should have set wt_pending_invite to false.
// Our job is to just reload cache.
@@ -604,6 +760,15 @@ func (am *DefaultAccountManager) redeemInvite(account *Account, userID string) e
// GetAccountFromToken returns an account associated with this token
func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, error) {
if am.singleAccountMode && am.singleAccountModeDomain != "" {
// This section is mostly related to self-hosted installations.
// We override incoming domain claims to group users under a single account.
claims.Domain = am.singleAccountModeDomain
claims.DomainCategory = PrivateCategory
log.Infof("overriding JWT Domain and DomainCategory claims since single account mode is enabled")
}
account, err := am.getAccountWithAuthorizationClaims(claims)
if err != nil {
return nil, err
@@ -634,15 +799,13 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat
// Existing user + Existing account + Existing Indexed Domain -> Nothing changes
//
// Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain)
func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(
claims jwtclaims.AuthorizationClaims,
) (*Account, error) {
func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtclaims.AuthorizationClaims) (*Account, error) {
// if Account ID is part of the claims
// it means that we've already classified the domain and user has an account
if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) {
return am.GetAccountByUserOrAccountId(claims.UserId, claims.AccountId, claims.Domain)
return am.GetAccountByUserOrAccountID(claims.UserId, claims.AccountId, claims.Domain)
} else if claims.AccountId != "" {
accountFromID, err := am.GetAccountById(claims.AccountId)
accountFromID, err := am.Store.GetAccount(claims.AccountId)
if err != nil {
return nil, err
}
@@ -654,8 +817,8 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(
}
}
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireGlobalLock()
defer unlock()
// We checked if the domain has a primary account already
domainAccount, err := am.Store.GetAccountByPrivateDomain(claims.Domain)
@@ -664,7 +827,7 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(
return nil, err
}
account, err := am.Store.GetUserAccount(claims.UserId)
account, err := am.Store.GetAccountByUser(claims.UserId)
if err == nil {
err = am.handleExistingUserAccount(account, domainAccount, claims)
if err != nil {
@@ -685,12 +848,13 @@ func isDomainValid(domain string) bool {
}
// AccountExists checks whether account exists (returns true) or not (returns false)
func (am *DefaultAccountManager) AccountExists(accountId string) (*bool, error) {
am.mux.Lock()
defer am.mux.Unlock()
func (am *DefaultAccountManager) AccountExists(accountID string) (*bool, error) {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
var res bool
_, err := am.Store.GetAccount(accountId)
_, err := am.Store.GetAccount(accountID)
if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
res = false

View File

@@ -1,7 +1,11 @@
package server
import (
"fmt"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/route"
"net"
"reflect"
"sync"
"testing"
@@ -117,7 +121,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
t.Fatalf("expected to create an account for a user %s", userId)
}
account, err = manager.GetAccountByUser(userId)
account, err = manager.Store.GetAccountByUser(userId)
if err != nil {
t.Errorf("expected to get existing account after creation, no account was found for a user %s", userId)
}
@@ -298,7 +302,7 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
initAccount, err := manager.GetAccountByUserOrAccountId(testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain)
initAccount, err := manager.GetAccountByUserOrAccountID(testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain)
require.NoError(t, err, "create init user failed")
if testCase.inputUpdateAttrs {
@@ -341,7 +345,7 @@ func TestAccountManager_PrivateAccount(t *testing.T) {
t.Fatalf("expected to create an account for a user %s", userId)
}
account, err = manager.GetAccountByUser(userId)
account, err = manager.Store.GetAccountByUser(userId)
if err != nil {
t.Errorf("expected to get existing account after creation, no account was found for a user %s", userId)
}
@@ -397,7 +401,7 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
userId := "test_user"
account, err := manager.GetAccountByUserOrAccountId(userId, "", "")
account, err := manager.GetAccountByUserOrAccountID(userId, "", "")
if err != nil {
t.Fatal(err)
}
@@ -407,12 +411,12 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
accountId := account.Id
_, err = manager.GetAccountByUserOrAccountId("", accountId, "")
_, err = manager.GetAccountByUserOrAccountID("", accountId, "")
if err != nil {
t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", accountId)
}
_, err = manager.GetAccountByUserOrAccountId("", "", "")
_, err = manager.GetAccountByUserOrAccountID("", "", "")
if err == nil {
t.Errorf("expected an error when user and account IDs are empty")
}
@@ -466,7 +470,7 @@ func TestAccountManager_GetAccount(t *testing.T) {
}
// AddAccount has been already tested so we can assume it is correct and compare results
getAccount, err := manager.GetAccountById(expectedId)
getAccount, err := manager.Store.GetAccount(account.Id)
if err != nil {
t.Fatal(err)
return
@@ -536,7 +540,7 @@ func TestAccountManager_AddPeer(t *testing.T) {
return
}
account, err = manager.GetAccountById(account.Id)
account, err = manager.Store.GetAccount(account.Id)
if err != nil {
t.Fatal(err)
return
@@ -598,7 +602,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
return
}
account, err = manager.GetAccountById(account.Id)
account, err = manager.Store.GetAccount(account.Id)
if err != nil {
t.Fatal(err)
return
@@ -676,7 +680,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
peer2 := getPeer()
peer3 := getPeer()
account, err = manager.GetAccountById(account.Id)
account, err = manager.Store.GetAccount(account.Id)
if err != nil {
t.Fatal(err)
return
@@ -844,7 +848,7 @@ func TestAccountManager_DeletePeer(t *testing.T) {
return
}
account, err = manager.GetAccountById(account.Id)
account, err = manager.Store.GetAccount(account.Id)
if err != nil {
t.Fatal(err)
return
@@ -874,7 +878,7 @@ func TestGetUsersFromAccount(t *testing.T) {
account.Users[user.Id] = user
}
userInfos, err := manager.GetUsersFromAccount(accountId)
userInfos, err := manager.GetUsersFromAccount(accountId, "1")
if err != nil {
t.Fatal(err)
}
@@ -957,17 +961,257 @@ func TestAccountManager_UpdatePeerMeta(t *testing.T) {
assert.Equal(t, newMeta, p.Meta)
}
func TestAccount_GetPeerRules(t *testing.T) {
groups := map[string]*Group{
"group_1": {
ID: "group_1",
Name: "group_1",
Peers: []string{"peer-1", "peer-2"},
},
"group_2": {
ID: "group_2",
Name: "group_2",
Peers: []string{"peer-2", "peer-3"},
},
"group_3": {
ID: "group_3",
Name: "group_3",
Peers: []string{"peer-4"},
},
"group_4": {
ID: "group_4",
Name: "group_4",
Peers: []string{"peer-1"},
},
"group_5": {
ID: "group_5",
Name: "group_5",
Peers: []string{"peer-1"},
},
}
rules := map[string]*Rule{
"rule-1": {
ID: "rule-1",
Name: "rule-1",
Description: "rule-1",
Disabled: false,
Source: []string{"group_1", "group_5"},
Destination: []string{"group_2"},
Flow: 0,
},
"rule-2": {
ID: "rule-2",
Name: "rule-2",
Description: "rule-2",
Disabled: false,
Source: []string{"group_1"},
Destination: []string{"group_1"},
Flow: 0,
},
"rule-3": {
ID: "rule-3",
Name: "rule-3",
Description: "rule-3",
Disabled: false,
Source: []string{"group_3"},
Destination: []string{"group_3"},
Flow: 0,
},
}
account := &Account{
Groups: groups,
Rules: rules,
}
srcRules, dstRules := account.GetPeerRules("peer-1")
assert.Equal(t, 2, len(srcRules))
assert.Equal(t, 1, len(dstRules))
}
func TestFileStore_GetRoutesByPrefix(t *testing.T) {
_, prefix, err := route.ParseNetwork("192.168.64.0/24")
if err != nil {
t.Fatal(err)
}
account := &Account{
Routes: map[string]*route.Route{
"route-1": {
ID: "route-1",
Network: prefix,
NetID: "network-1",
Description: "network-1",
Peer: "peer-1",
NetworkType: 0,
Masquerade: false,
Metric: 999,
Enabled: true,
},
"route-2": {
ID: "route-2",
Network: prefix,
NetID: "network-1",
Description: "network-1",
Peer: "peer-2",
NetworkType: 0,
Masquerade: false,
Metric: 999,
Enabled: true,
},
},
}
routes := account.GetRoutesByPrefix(prefix)
assert.Len(t, routes, 2)
routeIDs := make(map[string]struct{}, 2)
for _, r := range routes {
routeIDs[r.ID] = struct{}{}
}
assert.Contains(t, routeIDs, "route-1")
assert.Contains(t, routeIDs, "route-2")
}
func TestAccount_GetPeersRoutes(t *testing.T) {
_, prefix, err := route.ParseNetwork("192.168.64.0/24")
if err != nil {
t.Fatal(err)
}
account := &Account{
Peers: map[string]*Peer{
"peer-1": {Key: "peer-1"}, "peer-2": {Key: "peer-2"}, "peer-3": {Key: "peer-1"},
},
Routes: map[string]*route.Route{
"route-1": {
ID: "route-1",
Network: prefix,
NetID: "network-1",
Description: "network-1",
Peer: "peer-1",
NetworkType: 0,
Masquerade: false,
Metric: 999,
Enabled: true,
},
"route-2": {
ID: "route-2",
Network: prefix,
NetID: "network-1",
Description: "network-1",
Peer: "peer-2",
NetworkType: 0,
Masquerade: false,
Metric: 999,
Enabled: true,
},
},
}
routes := account.GetPeersRoutes([]*Peer{{Key: "peer-1"}, {Key: "peer-2"}, {Key: "non-existing-peer"}})
assert.Len(t, routes, 2)
routeIDs := make(map[string]struct{}, 2)
for _, r := range routes {
routeIDs[r.ID] = struct{}{}
}
assert.Contains(t, routeIDs, "route-1")
assert.Contains(t, routeIDs, "route-2")
}
func TestAccount_Copy(t *testing.T) {
account := &Account{
Id: "account1",
CreatedBy: "tester",
Domain: "test.com",
DomainCategory: "public",
IsDomainPrimaryAccount: true,
SetupKeys: map[string]*SetupKey{
"setup1": {
Id: "setup1",
AutoGroups: []string{"group1"},
},
},
Network: &Network{
Id: "net1",
},
Peers: map[string]*Peer{
"peer1": {
Key: "key1",
},
},
Users: map[string]*User{
"user1": {
Id: "user1",
Role: UserRoleAdmin,
AutoGroups: []string{"group1"},
},
},
Groups: map[string]*Group{
"group1": {
ID: "group1",
},
},
Rules: map[string]*Rule{
"rule1": {
ID: "rule1",
},
},
Routes: map[string]*route.Route{
"route1": {
ID: "route1",
},
},
NameServerGroups: map[string]*nbdns.NameServerGroup{
"nsGroup1": {
ID: "nsGroup1",
},
},
}
err := hasNilField(account)
if err != nil {
t.Fatal(err)
}
accountCopy := account.Copy()
assert.Equal(t, account, accountCopy, "account copy returned a different value than expected")
}
// hasNilField validates pointers, maps and slices if they are nil
func hasNilField(x interface{}) error {
rv := reflect.ValueOf(x)
rv = rv.Elem()
for i := 0; i < rv.NumField(); i++ {
if f := rv.Field(i); f.IsValid() {
k := f.Kind()
switch k {
case reflect.Ptr:
if f.IsNil() {
return fmt.Errorf("field %s is nil", f.String())
}
case reflect.Map, reflect.Slice:
if f.Len() == 0 || f.IsNil() {
return fmt.Errorf("field %s is nil", f.String())
}
}
}
}
return nil
}
func createManager(t *testing.T) (*DefaultAccountManager, error) {
store, err := createStore(t)
if err != nil {
return nil, err
}
return BuildManager(store, NewPeersUpdateManager(), nil)
return BuildManager(store, NewPeersUpdateManager(), nil, "", "")
}
func createStore(t *testing.T) (Store, error) {
dataDir := t.TempDir()
store, err := NewStore(dataDir)
store, err := NewFileStore(dataDir)
if err != nil {
return nil, err
}

152
management/server/dns.go Normal file
View File

@@ -0,0 +1,152 @@
package server
import (
"fmt"
"github.com/miekg/dns"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/proto"
log "github.com/sirupsen/logrus"
"strconv"
)
type lookupMap map[string]struct{}
const defaultTTL = 300
func toProtocolDNSConfig(update nbdns.Config) *proto.DNSConfig {
protoUpdate := &proto.DNSConfig{ServiceEnable: update.ServiceEnable}
for _, zone := range update.CustomZones {
protoZone := &proto.CustomZone{Domain: zone.Domain}
for _, record := range zone.Records {
protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{
Name: record.Name,
Type: int64(record.Type),
Class: record.Class,
TTL: int64(record.TTL),
RData: record.RData,
})
}
protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone)
}
for _, nsGroup := range update.NameServerGroups {
protoGroup := &proto.NameServerGroup{
Primary: nsGroup.Primary,
Domains: nsGroup.Domains,
}
for _, ns := range nsGroup.NameServers {
protoNS := &proto.NameServer{
IP: ns.IP.String(),
Port: int64(ns.Port),
NSType: int64(ns.NSType),
}
protoGroup.NameServers = append(protoGroup.NameServers, protoNS)
}
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup)
}
return protoUpdate
}
func getPeersCustomZone(account *Account, dnsDomain string) nbdns.CustomZone {
if dnsDomain == "" {
log.Errorf("no dns domain is set, returning empty zone")
return nbdns.CustomZone{}
}
customZone := nbdns.CustomZone{
Domain: dns.Fqdn(dnsDomain),
}
for _, peer := range account.Peers {
if peer.DNSLabel == "" {
log.Errorf("found a peer with empty dns label. It was probably caused by a invalid character in its name. Peer Name: %s", peer.Name)
continue
}
customZone.Records = append(customZone.Records, nbdns.SimpleRecord{
Name: dns.Fqdn(peer.DNSLabel + "." + dnsDomain),
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: defaultTTL,
RData: peer.IP.String(),
})
}
return customZone
}
func getPeerNSGroups(account *Account, peerID string) []*nbdns.NameServerGroup {
groupList := make(lookupMap)
for groupID, group := range account.Groups {
for _, id := range group.Peers {
if id == peerID {
groupList[groupID] = struct{}{}
break
}
}
}
var peerNSGroups []*nbdns.NameServerGroup
for _, nsGroup := range account.NameServerGroups {
if !nsGroup.Enabled {
continue
}
for _, gID := range nsGroup.Groups {
_, found := groupList[gID]
if found {
peerNSGroups = append(peerNSGroups, nsGroup.Copy())
break
}
}
}
return peerNSGroups
}
func addPeerLabelsToAccount(account *Account, peerLabels lookupMap) {
for _, peer := range account.Peers {
label, err := getPeerHostLabel(peer.Name, peerLabels)
if err != nil {
log.Errorf("got an error while generating a peer host label. Peer name %s, error: %v. Trying with the peer's meta hostname", peer.Name, err)
label, err = getPeerHostLabel(peer.Meta.Hostname, peerLabels)
if err != nil {
log.Errorf("got another error while generating a peer host label with hostname. Peer hostname %s, error: %v. Skiping", peer.Meta.Hostname, err)
continue
}
}
peer.DNSLabel = label
peerLabels[label] = struct{}{}
}
}
func getPeerHostLabel(name string, peerLabels lookupMap) (string, error) {
label, err := nbdns.GetParsedDomainLabel(name)
if err != nil {
return "", err
}
uniqueLabel := getUniqueHostLabel(label, peerLabels)
if uniqueLabel == "" {
return "", fmt.Errorf("couldn't find a unique valid lavel for %s, parsed label %s", name, label)
}
return uniqueLabel, nil
}
// getUniqueHostLabel look for a unique host label, and if doesn't find add a suffix up to 999
func getUniqueHostLabel(name string, peerLabels lookupMap) string {
_, found := peerLabels[name]
if !found {
return name
}
for i := 1; i < 1000; i++ {
nameWithSuffix := name + "-" + strconv.Itoa(i)
_, found = peerLabels[nameWithSuffix]
if !found {
return nameWithSuffix
}
}
return ""
}

View File

@@ -11,6 +11,12 @@ const (
AccountNotFound ErrorType = iota
// PreconditionFailed indicates that some pre-condition for the operation hasn't been fulfilled
PreconditionFailed ErrorType = iota
// UserNotFound indicates that user wasn't found in the system (or under a given Account)
UserNotFound ErrorType = iota
// PermissionDenied indicates that user has no permissions to view data
PermissionDenied ErrorType = iota
)
// ErrorType is a type of the Error

View File

@@ -1,13 +1,12 @@
package server
import (
"fmt"
"github.com/netbirdio/netbird/route"
"net/netip"
log "github.com/sirupsen/logrus"
"os"
"path/filepath"
"strings"
"sync"
"time"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
@@ -21,29 +20,29 @@ const storeFileName = "store.json"
// FileStore represents an account storage backed by a file persisted to disk
type FileStore struct {
Accounts map[string]*Account
SetupKeyId2AccountId map[string]string `json:"-"`
PeerKeyId2AccountId map[string]string `json:"-"`
UserId2AccountId map[string]string `json:"-"`
PrivateDomain2AccountId map[string]string `json:"-"`
PeerKeyId2SrcRulesId map[string]map[string]struct{} `json:"-"`
PeerKeyId2DstRulesId map[string]map[string]struct{} `json:"-"`
PeerKeyID2RouteIDs map[string]map[string]struct{} `json:"-"`
AccountPrefix2RouteIDs map[string]map[string][]string `json:"-"`
SetupKeyID2AccountID map[string]string `json:"-"`
PeerKeyID2AccountID map[string]string `json:"-"`
UserID2AccountID map[string]string `json:"-"`
PrivateDomain2AccountID map[string]string `json:"-"`
InstallationID string
// mutex to synchronise Store read/write operations
mux sync.Mutex `json:"-"`
storeFile string `json:"-"`
// sync.Mutex indexed by accountID
accountLocks sync.Map `json:"-"`
globalAccountLock sync.Mutex `json:"-"`
}
type StoredAccount struct{}
// NewStore restores a store from the file located in the datadir
func NewStore(dataDir string) (*FileStore, error) {
// NewFileStore restores a store from the file located in the datadir
func NewFileStore(dataDir string) (*FileStore, error) {
return restore(filepath.Join(dataDir, storeFileName))
}
// restore restores the state of the store from the file.
// restore the state of the store from the file.
// Creates a new empty store file if doesn't exist
func restore(file string) (*FileStore, error) {
if _, err := os.Stat(file); os.IsNotExist(err) {
@@ -51,14 +50,11 @@ func restore(file string) (*FileStore, error) {
s := &FileStore{
Accounts: make(map[string]*Account),
mux: sync.Mutex{},
SetupKeyId2AccountId: make(map[string]string),
PeerKeyId2AccountId: make(map[string]string),
UserId2AccountId: make(map[string]string),
PrivateDomain2AccountId: make(map[string]string),
PeerKeyId2SrcRulesId: make(map[string]map[string]struct{}),
PeerKeyID2RouteIDs: make(map[string]map[string]struct{}),
PeerKeyId2DstRulesId: make(map[string]map[string]struct{}),
AccountPrefix2RouteIDs: make(map[string]map[string][]string),
globalAccountLock: sync.Mutex{},
SetupKeyID2AccountID: make(map[string]string),
PeerKeyID2AccountID: make(map[string]string),
UserID2AccountID: make(map[string]string),
PrivateDomain2AccountID: make(map[string]string),
storeFile: file,
}
@@ -77,287 +73,113 @@ func restore(file string) (*FileStore, error) {
store := read.(*FileStore)
store.storeFile = file
store.SetupKeyId2AccountId = make(map[string]string)
store.PeerKeyId2AccountId = make(map[string]string)
store.UserId2AccountId = make(map[string]string)
store.PrivateDomain2AccountId = make(map[string]string)
store.PeerKeyId2SrcRulesId = make(map[string]map[string]struct{})
store.PeerKeyId2DstRulesId = make(map[string]map[string]struct{})
store.PeerKeyID2RouteIDs = make(map[string]map[string]struct{})
store.AccountPrefix2RouteIDs = make(map[string]map[string][]string)
store.SetupKeyID2AccountID = make(map[string]string)
store.PeerKeyID2AccountID = make(map[string]string)
store.UserID2AccountID = make(map[string]string)
store.PrivateDomain2AccountID = make(map[string]string)
for accountId, account := range store.Accounts {
for accountID, account := range store.Accounts {
for setupKeyId := range account.SetupKeys {
store.SetupKeyId2AccountId[strings.ToUpper(setupKeyId)] = accountId
}
for _, rule := range account.Rules {
for _, groupID := range rule.Source {
if group, ok := account.Groups[groupID]; ok {
for _, peerID := range group.Peers {
rules := store.PeerKeyId2SrcRulesId[peerID]
if rules == nil {
rules = map[string]struct{}{}
store.PeerKeyId2SrcRulesId[peerID] = rules
}
rules[rule.ID] = struct{}{}
}
}
}
for _, groupID := range rule.Destination {
if group, ok := account.Groups[groupID]; ok {
for _, peerID := range group.Peers {
rules := store.PeerKeyId2DstRulesId[peerID]
if rules == nil {
rules = map[string]struct{}{}
store.PeerKeyId2DstRulesId[peerID] = rules
}
rules[rule.ID] = struct{}{}
}
}
}
store.SetupKeyID2AccountID[strings.ToUpper(setupKeyId)] = accountID
}
for _, peer := range account.Peers {
store.PeerKeyId2AccountId[peer.Key] = accountId
store.PeerKeyID2AccountID[peer.Key] = accountID
// reset all peers to status = Disconnected
if peer.Status != nil && peer.Status.Connected {
peer.Status.Connected = false
}
}
for _, user := range account.Users {
store.UserId2AccountId[user.Id] = accountId
store.UserID2AccountID[user.Id] = accountID
}
for _, user := range account.Users {
store.UserId2AccountId[user.Id] = accountId
}
for _, route := range account.Routes {
if route.Peer == "" {
continue
}
if store.PeerKeyID2RouteIDs[route.Peer] == nil {
store.PeerKeyID2RouteIDs[route.Peer] = make(map[string]struct{})
}
store.PeerKeyID2RouteIDs[route.Peer][route.ID] = struct{}{}
if store.AccountPrefix2RouteIDs[account.Id] == nil {
store.AccountPrefix2RouteIDs[account.Id] = make(map[string][]string)
}
if _, ok := store.AccountPrefix2RouteIDs[account.Id][route.Network.String()]; !ok {
store.AccountPrefix2RouteIDs[account.Id][route.Network.String()] = make([]string, 0)
}
store.AccountPrefix2RouteIDs[account.Id][route.Network.String()] = append(
store.AccountPrefix2RouteIDs[account.Id][route.Network.String()],
route.ID,
)
store.UserID2AccountID[user.Id] = accountID
}
if account.Domain != "" && account.DomainCategory == PrivateCategory &&
account.IsDomainPrimaryAccount {
store.PrivateDomain2AccountId[account.Domain] = accountId
store.PrivateDomain2AccountID[account.Domain] = accountID
}
// for data migration. Can be removed once most base will be with labels
existingLabels := account.getPeerDNSLabels()
if len(existingLabels) != len(account.Peers) {
addPeerLabelsToAccount(account, existingLabels)
}
}
// we need this persist to apply changes we made to account.Peers (we set them to Disconnected)
err = store.persist(store.storeFile)
if err != nil {
return nil, err
}
return store, nil
}
// persist persists account data to a file
// persist account data to a file
// It is recommended to call it with locking FileStore.mux
func (s *FileStore) persist(file string) error {
return util.WriteJson(file, s)
}
// SavePeer saves updated peer
func (s *FileStore) SavePeer(accountId string, peer *Peer) error {
s.mux.Lock()
defer s.mux.Unlock()
// AcquireGlobalLock acquires global lock across all the accounts and returns a function that releases the lock
func (s *FileStore) AcquireGlobalLock() (unlock func()) {
log.Debugf("acquiring global lock")
start := time.Now()
s.globalAccountLock.Lock()
account, err := s.GetAccount(accountId)
if err != nil {
return err
unlock = func() {
s.globalAccountLock.Unlock()
log.Debugf("released global lock in %v", time.Since(start))
}
// if it is new peer, add it to default 'All' group
allGroup, err := account.GetGroupAll()
if err != nil {
return err
}
ind := -1
for i, pid := range allGroup.Peers {
if pid == peer.Key {
ind = i
break
}
}
if ind < 0 {
allGroup.Peers = append(allGroup.Peers, peer.Key)
}
account.Peers[peer.Key] = peer
return s.persist(s.storeFile)
return unlock
}
// DeletePeer deletes peer from the Store
func (s *FileStore) DeletePeer(accountId string, peerKey string) (*Peer, error) {
s.mux.Lock()
defer s.mux.Unlock()
// AcquireAccountLock acquires account lock and returns a function that releases the lock
func (s *FileStore) AcquireAccountLock(accountID string) (unlock func()) {
log.Debugf("acquiring lock for account %s", accountID)
start := time.Now()
value, _ := s.accountLocks.LoadOrStore(accountID, &sync.Mutex{})
mtx := value.(*sync.Mutex)
mtx.Lock()
account, err := s.GetAccount(accountId)
if err != nil {
return nil, err
unlock = func() {
mtx.Unlock()
log.Debugf("released lock for account %s in %v", accountID, time.Since(start))
}
peer := account.Peers[peerKey]
if peer == nil {
return nil, status.Errorf(codes.NotFound, "peer not found")
}
peerRoutes := s.PeerKeyID2RouteIDs[peerKey]
delete(account.Peers, peerKey)
delete(s.PeerKeyId2AccountId, peerKey)
delete(s.PeerKeyId2DstRulesId, peerKey)
delete(s.PeerKeyId2SrcRulesId, peerKey)
delete(s.PeerKeyID2RouteIDs, peerKey)
// cleanup groups
for _, g := range account.Groups {
var peers []string
for _, p := range g.Peers {
if p != peerKey {
peers = append(peers, p)
}
}
g.Peers = peers
}
for routeID := range peerRoutes {
account.Routes[routeID].Enabled = false
account.Routes[routeID].Peer = ""
}
err = s.persist(s.storeFile)
if err != nil {
return nil, err
}
return peer, nil
return unlock
}
// GetPeer returns a peer from a Store
func (s *FileStore) GetPeer(peerKey string) (*Peer, error) {
s.mux.Lock()
defer s.mux.Unlock()
accountId, accountIdFound := s.PeerKeyId2AccountId[peerKey]
if !accountIdFound {
return nil, status.Errorf(codes.NotFound, "peer not found")
}
account, err := s.GetAccount(accountId)
if err != nil {
return nil, err
}
if peer, ok := account.Peers[peerKey]; ok {
return peer, nil
}
return nil, status.Errorf(codes.NotFound, "peer not found")
}
// SaveAccount updates an existing account or adds a new one
func (s *FileStore) SaveAccount(account *Account) error {
s.mux.Lock()
defer s.mux.Unlock()
accountCopy := account.Copy()
// todo will override, handle existing keys
s.Accounts[account.Id] = account
s.Accounts[accountCopy.Id] = accountCopy
// todo check that account.Id and keyId are not exist already
// because if keyId exists for other accounts this can be bad
for keyId := range account.SetupKeys {
s.SetupKeyId2AccountId[strings.ToUpper(keyId)] = account.Id
for keyID := range accountCopy.SetupKeys {
s.SetupKeyID2AccountID[strings.ToUpper(keyID)] = accountCopy.Id
}
// enforce peer to account index and delete peer to route indexes for rebuild
for _, peer := range account.Peers {
s.PeerKeyId2AccountId[peer.Key] = account.Id
delete(s.PeerKeyID2RouteIDs, peer.Key)
for _, peer := range accountCopy.Peers {
s.PeerKeyID2AccountID[peer.Key] = accountCopy.Id
}
delete(s.AccountPrefix2RouteIDs, account.Id)
// remove all peers related to account from rules indexes
cleanIDs := make([]string, 0)
for key := range s.PeerKeyId2SrcRulesId {
if accountID, ok := s.PeerKeyId2AccountId[key]; ok && accountID == account.Id {
cleanIDs = append(cleanIDs, key)
}
}
for _, key := range cleanIDs {
delete(s.PeerKeyId2SrcRulesId, key)
}
cleanIDs = cleanIDs[:0]
for key := range s.PeerKeyId2DstRulesId {
if accountID, ok := s.PeerKeyId2AccountId[key]; ok && accountID == account.Id {
cleanIDs = append(cleanIDs, key)
}
}
for _, key := range cleanIDs {
delete(s.PeerKeyId2DstRulesId, key)
for _, user := range accountCopy.Users {
s.UserID2AccountID[user.Id] = accountCopy.Id
}
// rebuild rule indexes
for _, rule := range account.Rules {
for _, gid := range rule.Source {
g, ok := account.Groups[gid]
if !ok {
break
}
for _, pid := range g.Peers {
rules := s.PeerKeyId2SrcRulesId[pid]
if rules == nil {
rules = map[string]struct{}{}
s.PeerKeyId2SrcRulesId[pid] = rules
}
rules[rule.ID] = struct{}{}
}
}
for _, gid := range rule.Destination {
g, ok := account.Groups[gid]
if !ok {
break
}
for _, pid := range g.Peers {
rules := s.PeerKeyId2DstRulesId[pid]
if rules == nil {
rules = map[string]struct{}{}
s.PeerKeyId2DstRulesId[pid] = rules
}
rules[rule.ID] = struct{}{}
}
}
}
for _, route := range account.Routes {
if route.Peer == "" {
continue
}
if s.PeerKeyID2RouteIDs[route.Peer] == nil {
s.PeerKeyID2RouteIDs[route.Peer] = make(map[string]struct{})
}
s.PeerKeyID2RouteIDs[route.Peer][route.ID] = struct{}{}
if s.AccountPrefix2RouteIDs[account.Id] == nil {
s.AccountPrefix2RouteIDs[account.Id] = make(map[string][]string)
}
if _, ok := s.AccountPrefix2RouteIDs[account.Id][route.Network.String()]; !ok {
s.AccountPrefix2RouteIDs[account.Id][route.Network.String()] = make([]string, 0)
}
s.AccountPrefix2RouteIDs[account.Id][route.Network.String()] = append(
s.AccountPrefix2RouteIDs[account.Id][route.Network.String()],
route.ID,
)
}
for _, user := range account.Users {
s.UserId2AccountId[user.Id] = account.Id
}
if account.DomainCategory == PrivateCategory && account.IsDomainPrimaryAccount {
s.PrivateDomain2AccountId[account.Domain] = account.Id
if accountCopy.DomainCategory == PrivateCategory && accountCopy.IsDomainPrimaryAccount {
s.PrivateDomain2AccountID[accountCopy.Domain] = accountCopy.Id
}
return s.persist(s.storeFile)
@@ -365,53 +187,41 @@ func (s *FileStore) SaveAccount(account *Account) error {
// GetAccountByPrivateDomain returns account by private domain
func (s *FileStore) GetAccountByPrivateDomain(domain string) (*Account, error) {
accountId, accountIdFound := s.PrivateDomain2AccountId[strings.ToLower(domain)]
if !accountIdFound {
s.mux.Lock()
defer s.mux.Unlock()
accountID, accountIDFound := s.PrivateDomain2AccountID[strings.ToLower(domain)]
if !accountIDFound {
return nil, status.Errorf(
codes.NotFound,
"provided domain is not registered or is not private",
"account not found: provided domain is not registered or is not private",
)
}
account, err := s.GetAccount(accountId)
account, err := s.getAccount(accountID)
if err != nil {
return nil, err
}
return account, nil
return account.Copy(), nil
}
// GetAccountBySetupKey returns account by setup key id
func (s *FileStore) GetAccountBySetupKey(setupKey string) (*Account, error) {
accountId, accountIdFound := s.SetupKeyId2AccountId[strings.ToUpper(setupKey)]
if !accountIdFound {
return nil, status.Errorf(codes.NotFound, "provided setup key doesn't exists")
}
account, err := s.GetAccount(accountId)
if err != nil {
return nil, err
}
return account, nil
}
// GetAccountPeers returns account peers
func (s *FileStore) GetAccountPeers(accountId string) ([]*Peer, error) {
s.mux.Lock()
defer s.mux.Unlock()
account, err := s.GetAccount(accountId)
accountID, accountIDFound := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)]
if !accountIDFound {
return nil, status.Errorf(codes.NotFound, "account not found: provided setup key doesn't exists")
}
account, err := s.getAccount(accountID)
if err != nil {
return nil, err
}
var peers []*Peer
for _, peer := range account.Peers {
peers = append(peers, peer)
}
return peers, nil
return account.Copy(), nil
}
// GetAllAccounts returns all accounts
@@ -425,9 +235,9 @@ func (s *FileStore) GetAllAccounts() (all []*Account) {
return all
}
// GetAccount returns an account for id
func (s *FileStore) GetAccount(accountId string) (*Account, error) {
account, accountFound := s.Accounts[accountId]
// getAccount returns a reference to the Account. Should not return a copy.
func (s *FileStore) getAccount(accountID string) (*Account, error) {
account, accountFound := s.Accounts[accountID]
if !accountFound {
return nil, status.Errorf(codes.NotFound, "account not found")
}
@@ -435,139 +245,53 @@ func (s *FileStore) GetAccount(accountId string) (*Account, error) {
return account, nil
}
// GetUserAccount returns a user account
func (s *FileStore) GetUserAccount(userId string) (*Account, error) {
// GetAccount returns an account for ID
func (s *FileStore) GetAccount(accountID string) (*Account, error) {
s.mux.Lock()
defer s.mux.Unlock()
accountId, accountIdFound := s.UserId2AccountId[userId]
if !accountIdFound {
account, err := s.getAccount(accountID)
if err != nil {
return nil, err
}
return account.Copy(), nil
}
// GetAccountByUser returns a user account
func (s *FileStore) GetAccountByUser(userID string) (*Account, error) {
s.mux.Lock()
defer s.mux.Unlock()
accountID, accountIDFound := s.UserID2AccountID[userID]
if !accountIDFound {
return nil, status.Errorf(codes.NotFound, "account not found")
}
return s.GetAccount(accountId)
account, err := s.getAccount(accountID)
if err != nil {
return nil, err
}
return account.Copy(), nil
}
func (s *FileStore) getPeerAccount(peerKey string) (*Account, error) {
accountId, accountIdFound := s.PeerKeyId2AccountId[peerKey]
if !accountIdFound {
// GetAccountByPeerPubKey returns an account for a given peer WireGuard public key
func (s *FileStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) {
s.mux.Lock()
defer s.mux.Unlock()
accountID, accountIDFound := s.PeerKeyID2AccountID[peerKey]
if !accountIDFound {
return nil, status.Errorf(codes.NotFound, "Provided peer key doesn't exists %s", peerKey)
}
return s.GetAccount(accountId)
}
// GetPeerAccount returns user account if exists
func (s *FileStore) GetPeerAccount(peerKey string) (*Account, error) {
s.mux.Lock()
defer s.mux.Unlock()
return s.getPeerAccount(peerKey)
}
// GetPeerSrcRules return list of source rules for peer
func (s *FileStore) GetPeerSrcRules(accountId, peerKey string) ([]*Rule, error) {
s.mux.Lock()
defer s.mux.Unlock()
account, err := s.GetAccount(accountId)
account, err := s.getAccount(accountID)
if err != nil {
return nil, err
}
ruleIDs, ok := s.PeerKeyId2SrcRulesId[peerKey]
if !ok {
return nil, fmt.Errorf("no rules for peer: %v", ruleIDs)
}
rules := []*Rule{}
for id := range ruleIDs {
rule, ok := account.Rules[id]
if ok {
rules = append(rules, rule)
}
}
return rules, nil
}
// GetPeerDstRules return list of destination rules for peer
func (s *FileStore) GetPeerDstRules(accountId, peerKey string) ([]*Rule, error) {
s.mux.Lock()
defer s.mux.Unlock()
account, err := s.GetAccount(accountId)
if err != nil {
return nil, err
}
ruleIDs, ok := s.PeerKeyId2DstRulesId[peerKey]
if !ok {
return nil, fmt.Errorf("no rules for peer: %v", ruleIDs)
}
rules := []*Rule{}
for id := range ruleIDs {
rule, ok := account.Rules[id]
if ok {
rules = append(rules, rule)
}
}
return rules, nil
}
// GetPeerRoutes return list of routes for peer
func (s *FileStore) GetPeerRoutes(peerKey string) ([]*route.Route, error) {
s.mux.Lock()
defer s.mux.Unlock()
account, err := s.getPeerAccount(peerKey)
if err != nil {
return nil, err
}
var routes []*route.Route
routeIDs, ok := s.PeerKeyID2RouteIDs[peerKey]
if !ok {
return routes, nil
}
for id := range routeIDs {
route, found := account.Routes[id]
if found {
routes = append(routes, route)
}
}
return routes, nil
}
// GetRoutesByPrefix return list of routes by account and route prefix
func (s *FileStore) GetRoutesByPrefix(accountID string, prefix netip.Prefix) ([]*route.Route, error) {
s.mux.Lock()
defer s.mux.Unlock()
account, err := s.GetAccount(accountID)
if err != nil {
return nil, err
}
routeIDs, ok := s.AccountPrefix2RouteIDs[accountID][prefix.String()]
if !ok {
return nil, status.Errorf(codes.NotFound, "no routes for prefix: %v", prefix.String())
}
var routes []*route.Route
for _, id := range routeIDs {
route, found := account.Routes[id]
if found {
routes = append(routes, route)
}
}
return routes, nil
return account.Copy(), nil
}
// GetInstallationID returns the installation ID from the store
@@ -576,11 +300,42 @@ func (s *FileStore) GetInstallationID() string {
}
// SaveInstallationID saves the installation ID
func (s *FileStore) SaveInstallationID(id string) error {
func (s *FileStore) SaveInstallationID(ID string) error {
s.mux.Lock()
defer s.mux.Unlock()
s.InstallationID = id
s.InstallationID = ID
return s.persist(s.storeFile)
}
// SavePeerStatus stores the PeerStatus in memory. It doesn't attempt to persist data to speed up things.
// PeerStatus will be saved eventually when some other changes occur.
func (s *FileStore) SavePeerStatus(accountID, peerKey string, peerStatus PeerStatus) error {
s.mux.Lock()
defer s.mux.Unlock()
account, err := s.getAccount(accountID)
if err != nil {
return err
}
peer := account.Peers[peerKey]
if peer == nil {
return status.Errorf(codes.NotFound, "peer %s not found", peerKey)
}
peer.Status = &peerStatus
return nil
}
// Close the FileStore persisting data to disk
func (s *FileStore) Close() error {
s.mux.Lock()
defer s.mux.Unlock()
log.Infof("closing FileStore")
return s.persist(s.storeFile)
}

View File

@@ -2,6 +2,7 @@ package server
import (
"github.com/netbirdio/netbird/util"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"net"
"path/filepath"
@@ -9,6 +10,10 @@ import (
"time"
)
type accounts struct {
Accounts map[string]*Account
}
func TestNewStore(t *testing.T) {
store := newStore(t)
@@ -16,16 +21,16 @@ func TestNewStore(t *testing.T) {
t.Errorf("expected to create a new empty Accounts map when creating a new FileStore")
}
if store.SetupKeyId2AccountId == nil || len(store.SetupKeyId2AccountId) != 0 {
t.Errorf("expected to create a new empty SetupKeyId2AccountId map when creating a new FileStore")
if store.SetupKeyID2AccountID == nil || len(store.SetupKeyID2AccountID) != 0 {
t.Errorf("expected to create a new empty SetupKeyID2AccountID map when creating a new FileStore")
}
if store.PeerKeyId2AccountId == nil || len(store.PeerKeyId2AccountId) != 0 {
t.Errorf("expected to create a new empty PeerKeyId2AccountId map when creating a new FileStore")
if store.PeerKeyID2AccountID == nil || len(store.PeerKeyID2AccountID) != 0 {
t.Errorf("expected to create a new empty PeerKeyID2AccountID map when creating a new FileStore")
}
if store.UserId2AccountId == nil || len(store.UserId2AccountId) != 0 {
t.Errorf("expected to create a new empty UserId2AccountId map when creating a new FileStore")
if store.UserID2AccountID == nil || len(store.UserID2AccountID) != 0 {
t.Errorf("expected to create a new empty UserID2AccountID map when creating a new FileStore")
}
}
@@ -55,16 +60,16 @@ func TestSaveAccount(t *testing.T) {
t.Errorf("expecting Account to be stored after SaveAccount()")
}
if store.PeerKeyId2AccountId["peerkey"] == "" {
t.Errorf("expecting PeerKeyId2AccountId index updated after SaveAccount()")
if store.PeerKeyID2AccountID["peerkey"] == "" {
t.Errorf("expecting PeerKeyID2AccountID index updated after SaveAccount()")
}
if store.UserId2AccountId["testuser"] == "" {
t.Errorf("expecting UserId2AccountId index updated after SaveAccount()")
if store.UserID2AccountID["testuser"] == "" {
t.Errorf("expecting UserID2AccountID index updated after SaveAccount()")
}
if store.SetupKeyId2AccountId[setupKey.Key] == "" {
t.Errorf("expecting SetupKeyId2AccountId index updated after SaveAccount()")
if store.SetupKeyID2AccountID[setupKey.Key] == "" {
t.Errorf("expecting SetupKeyID2AccountID index updated after SaveAccount()")
}
}
@@ -88,7 +93,7 @@ func TestStore(t *testing.T) {
return
}
restored, err := NewStore(store.storeFile)
restored, err := NewFileStore(store.storeFile)
if err != nil {
return
}
@@ -124,7 +129,7 @@ func TestRestore(t *testing.T) {
t.Fatal(err)
}
store, err := NewStore(storeDir)
store, err := NewFileStore(storeDir)
if err != nil {
return
}
@@ -141,11 +146,11 @@ func TestRestore(t *testing.T) {
require.NotNil(t, account.SetupKeys["A2C8E62B-38F5-4553-B31E-DD66C696CEBB"], "failed to restore a FileStore file - missing Account SetupKey A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
require.Len(t, store.UserId2AccountId, 2, "failed to restore a FileStore wrong UserId2AccountId mapping length")
require.Len(t, store.UserID2AccountID, 2, "failed to restore a FileStore wrong UserID2AccountID mapping length")
require.Len(t, store.SetupKeyId2AccountId, 1, "failed to restore a FileStore wrong SetupKeyId2AccountId mapping length")
require.Len(t, store.SetupKeyID2AccountID, 1, "failed to restore a FileStore wrong SetupKeyID2AccountID mapping length")
require.Len(t, store.PrivateDomain2AccountId, 1, "failed to restore a FileStore wrong PrivateDomain2AccountId mapping length")
require.Len(t, store.PrivateDomain2AccountID, 1, "failed to restore a FileStore wrong PrivateDomain2AccountID mapping length")
}
func TestGetAccountByPrivateDomain(t *testing.T) {
@@ -156,7 +161,7 @@ func TestGetAccountByPrivateDomain(t *testing.T) {
t.Fatal(err)
}
store, err := NewStore(storeDir)
store, err := NewFileStore(storeDir)
if err != nil {
return
}
@@ -171,8 +176,101 @@ func TestGetAccountByPrivateDomain(t *testing.T) {
require.Error(t, err, "should return error on domain lookup")
}
func TestFileStore_GetAccount(t *testing.T) {
storeDir := t.TempDir()
storeFile := filepath.Join(storeDir, "store.json")
err := util.CopyFileContents("testdata/store.json", storeFile)
if err != nil {
t.Fatal(err)
}
accounts := &accounts{}
_, err = util.ReadJson(storeFile, accounts)
if err != nil {
t.Fatal(err)
}
store, err := NewFileStore(storeDir)
if err != nil {
t.Fatal(err)
}
expected := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"]
if expected == nil {
t.Fatalf("expected account doesn't exist")
}
account, err := store.GetAccount(expected.Id)
if err != nil {
t.Fatal(err)
}
assert.Equal(t, expected.IsDomainPrimaryAccount, account.IsDomainPrimaryAccount)
assert.Equal(t, expected.DomainCategory, account.DomainCategory)
assert.Equal(t, expected.Domain, account.Domain)
assert.Equal(t, expected.CreatedBy, account.CreatedBy)
assert.Equal(t, expected.Network.Id, account.Network.Id)
assert.Len(t, account.Peers, len(expected.Peers))
assert.Len(t, account.Users, len(expected.Users))
assert.Len(t, account.SetupKeys, len(expected.SetupKeys))
assert.Len(t, account.Rules, len(expected.Rules))
assert.Len(t, account.Routes, len(expected.Routes))
assert.Len(t, account.NameServerGroups, len(expected.NameServerGroups))
}
func TestFileStore_SavePeerStatus(t *testing.T) {
storeDir := t.TempDir()
err := util.CopyFileContents("testdata/store.json", filepath.Join(storeDir, "store.json"))
if err != nil {
t.Fatal(err)
}
store, err := NewFileStore(storeDir)
if err != nil {
return
}
account, err := store.getAccount("bf1c8084-ba50-4ce7-9439-34653001fc3b")
if err != nil {
t.Fatal(err)
}
// save status of non-existing peer
newStatus := PeerStatus{Connected: true, LastSeen: time.Now()}
err = store.SavePeerStatus(account.Id, "non-existing-peer", newStatus)
assert.Error(t, err)
// save new status of existing peer
account.Peers["testpeer"] = &Peer{
Key: "peerkey",
SetupKey: "peerkeysetupkey",
IP: net.IP{127, 0, 0, 1},
Meta: PeerSystemMeta{},
Name: "peer name",
Status: &PeerStatus{Connected: false, LastSeen: time.Now()},
}
err = store.SaveAccount(account)
if err != nil {
t.Fatal(err)
}
err = store.SavePeerStatus(account.Id, "testpeer", newStatus)
if err != nil {
t.Fatal(err)
}
account, err = store.getAccount(account.Id)
if err != nil {
t.Fatal(err)
}
actual := account.Peers["testpeer"].Status
assert.Equal(t, newStatus, *actual)
}
func newStore(t *testing.T) *FileStore {
store, err := NewStore(t.TempDir())
store, err := NewFileStore(t.TempDir())
if err != nil {
t.Errorf("failed creating a new store")
}

View File

@@ -47,8 +47,9 @@ func (g *Group) Copy() *Group {
// GetGroup object of the peers
func (am *DefaultAccountManager) GetGroup(accountID, groupID string) (*Group, error) {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
@@ -65,8 +66,9 @@ func (am *DefaultAccountManager) GetGroup(accountID, groupID string) (*Group, er
// SaveGroup object of the peers
func (am *DefaultAccountManager) SaveGroup(accountID string, group *Group) error {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
@@ -86,8 +88,9 @@ func (am *DefaultAccountManager) SaveGroup(accountID string, group *Group) error
// UpdateGroup updates a group using a list of operations
func (am *DefaultAccountManager) UpdateGroup(accountID string,
groupID string, operations []GroupUpdateOperation) (*Group, error) {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
@@ -135,8 +138,9 @@ func (am *DefaultAccountManager) UpdateGroup(accountID string,
// DeleteGroup object of the peers
func (am *DefaultAccountManager) DeleteGroup(accountID, groupID string) error {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
@@ -155,8 +159,9 @@ func (am *DefaultAccountManager) DeleteGroup(accountID, groupID string) error {
// ListGroups objects of the peers
func (am *DefaultAccountManager) ListGroups(accountID string) ([]*Group, error) {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
@@ -173,8 +178,9 @@ func (am *DefaultAccountManager) ListGroups(accountID string) ([]*Group, error)
// GroupAddPeer appends peer to the group
func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerKey string) error {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
@@ -207,8 +213,9 @@ func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerKey string
// GroupDeletePeer removes peer from the group
func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerKey string) error {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
@@ -235,8 +242,9 @@ func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerKey str
// GroupListPeers returns list of the peers from the group
func (am *DefaultAccountManager) GroupListPeers(accountID, groupID string) ([]*Peer, error) {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {

View File

@@ -3,7 +3,7 @@ package server
import (
"context"
"fmt"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/management/server/telemetry"
gPeer "google.golang.org/grpc/peer"
"strings"
"time"
@@ -30,10 +30,12 @@ type GRPCServer struct {
config *Config
turnCredentialsManager TURNCredentialsManager
jwtMiddleware *middleware.JWTMiddleware
appMetrics telemetry.AppMetrics
}
// NewServer creates a new Management server
func NewServer(config *Config, accountManager AccountManager, peersUpdateManager *PeersUpdateManager, turnCredentialsManager TURNCredentialsManager) (*GRPCServer, error) {
func NewServer(config *Config, accountManager AccountManager, peersUpdateManager *PeersUpdateManager,
turnCredentialsManager TURNCredentialsManager, appMetrics telemetry.AppMetrics) (*GRPCServer, error) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
return nil, err
@@ -53,6 +55,16 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager
log.Debug("unable to use http config to create new jwt middleware")
}
if appMetrics != nil {
// update gauge based on number of connected peers which is equal to open gRPC streams
err = appMetrics.GRPCMetrics().RegisterConnectedStreams(func() int64 {
return int64(len(peersUpdateManager.peerChannels))
})
if err != nil {
return nil, err
}
}
return &GRPCServer{
wgKey: key,
// peerKey -> event channel
@@ -61,11 +73,15 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager
config: config,
turnCredentialsManager: turnCredentialsManager,
jwtMiddleware: jwtMiddleware,
appMetrics: appMetrics,
}, nil
}
func (s *GRPCServer) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.ServerKeyResponse, error) {
// todo introduce something more meaningful with the key expiration/rotation
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountGetKeyRequest()
}
now := time.Now().Add(24 * time.Hour)
secs := int64(now.Second())
nanos := int32(now.Nanosecond())
@@ -80,6 +96,9 @@ func (s *GRPCServer) GetServerKey(ctx context.Context, req *proto.Empty) (*proto
// Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and
// notifies the connected peer of any updates (e.g. new peers under the same account)
func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error {
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountSyncRequest()
}
p, ok := gRPCPeer.FromContext(srv.Context())
if ok {
log.Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, p.Addr.String())
@@ -233,17 +252,14 @@ func (s *GRPCServer) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest)
if err != nil {
return nil, status.Errorf(codes.Internal, "unable to fetch network map after registering peer, error: %v", err)
}
// notify other peers of our registration
for _, remotePeer := range networkMap.Peers {
// exclude notified peer and add ourselves
peersToSend := []*Peer{peer}
for _, p := range networkMap.Peers {
if remotePeer.Key != p.Key {
peersToSend = append(peersToSend, p)
}
remotePeerNetworkMap, err := s.accountManager.GetNetworkMap(remotePeer.Key)
if err != nil {
return nil, status.Errorf(codes.Internal, "unable to fetch network map after registering peer, error: %v", err)
}
update := toSyncResponse(s.config, remotePeer, peersToSend, networkMap.Routes, nil, networkMap.Network.CurrentSerial(), networkMap.Network)
update := toSyncResponse(s.config, remotePeer, nil, remotePeerNetworkMap)
err = s.peersUpdateManager.SendUpdate(remotePeer.Key, &UpdateMessage{Update: update})
if err != nil {
// todo rethink if we should keep this return
@@ -259,6 +275,9 @@ func (s *GRPCServer) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest)
// In case it isn't, the endpoint checks whether setup key is provided within the request and tries to register a peer.
// In case of the successful registration login is also successful
func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountLoginRequest()
}
p, ok := gRPCPeer.FromContext(ctx)
if ok {
log.Debugf("Login request from peer [%s] [%s]", req.WgPubKey, p.Addr.String())
@@ -370,6 +389,9 @@ func ToResponseProto(configProto Protocol) proto.HostConfig_Protocol {
}
func toWiretrusteeConfig(config *Config, turnCredentials *TURNCredentials) *proto.WiretrusteeConfig {
if config == nil {
return nil
}
var stuns []*proto.HostConfig
for _, stun := range config.Stuns {
stuns = append(stuns, &proto.HostConfig{
@@ -428,14 +450,16 @@ func toRemotePeerConfig(peers []*Peer) []*proto.RemotePeerConfig {
return remotePeers
}
func toSyncResponse(config *Config, peer *Peer, peers []*Peer, routes []*route.Route, turnCredentials *TURNCredentials, serial uint64, network *Network) *proto.SyncResponse {
func toSyncResponse(config *Config, peer *Peer, turnCredentials *TURNCredentials, networkMap *NetworkMap) *proto.SyncResponse {
wtConfig := toWiretrusteeConfig(config, turnCredentials)
pConfig := toPeerConfig(peer, network)
pConfig := toPeerConfig(peer, networkMap.Network)
remotePeers := toRemotePeerConfig(peers)
remotePeers := toRemotePeerConfig(networkMap.Peers)
routesUpdate := toProtocolRoutes(routes)
routesUpdate := toProtocolRoutes(networkMap.Routes)
dnsUpdate := toProtocolDNSConfig(networkMap.DNSConfig)
return &proto.SyncResponse{
WiretrusteeConfig: wtConfig,
@@ -443,11 +467,12 @@ func toSyncResponse(config *Config, peer *Peer, peers []*Peer, routes []*route.R
RemotePeers: remotePeers,
RemotePeersIsEmpty: len(remotePeers) == 0,
NetworkMap: &proto.NetworkMap{
Serial: serial,
Serial: networkMap.Network.CurrentSerial(),
PeerConfig: pConfig,
RemotePeers: remotePeers,
RemotePeersIsEmpty: len(remotePeers) == 0,
Routes: routesUpdate,
DNSConfig: dnsUpdate,
},
}
}
@@ -473,7 +498,7 @@ func (s *GRPCServer) sendInitialSync(peerKey wgtypes.Key, peer *Peer, srv proto.
} else {
turnCredentials = nil
}
plainResp := toSyncResponse(s.config, peer, networkMap.Peers, networkMap.Routes, turnCredentials, networkMap.Network.CurrentSerial(), networkMap.Network)
plainResp := toSyncResponse(s.config, peer, turnCredentials, networkMap)
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
if err != nil {

View File

@@ -136,6 +136,9 @@ components:
ui_version:
description: Peer's desktop UI version
type: string
dns_label:
description: Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud
type: string
required:
- ip
- connected
@@ -145,6 +148,7 @@ components:
- groups
- ssh_enabled
- hostname
- dns_label
SetupKey:
type: object
properties:
@@ -444,12 +448,24 @@ components:
type: array
items:
type: string
primary:
description: Nameserver group primary status
type: boolean
domains:
description: Nameserver group domain list
type: array
items:
type: string
minLength: 1
maxLength: 255
required:
- name
- description
- nameservers
- enabled
- groups
- primary
- domains
NameserverGroup:
allOf:
- type: object
@@ -468,7 +484,7 @@ components:
path:
description: Nameserver group field to update in form /<field>
type: string
enum: [ "name","description","enabled","groups","nameservers" ]
enum: [ "name", "description", "enabled", "groups", "nameservers", "primary", "domains" ]
required:
- path

View File

@@ -39,10 +39,12 @@ const (
// Defines values for NameserverGroupPatchOperationPath.
const (
NameserverGroupPatchOperationPathDescription NameserverGroupPatchOperationPath = "description"
NameserverGroupPatchOperationPathDomains NameserverGroupPatchOperationPath = "domains"
NameserverGroupPatchOperationPathEnabled NameserverGroupPatchOperationPath = "enabled"
NameserverGroupPatchOperationPathGroups NameserverGroupPatchOperationPath = "groups"
NameserverGroupPatchOperationPathName NameserverGroupPatchOperationPath = "name"
NameserverGroupPatchOperationPathNameservers NameserverGroupPatchOperationPath = "nameservers"
NameserverGroupPatchOperationPathPrimary NameserverGroupPatchOperationPath = "primary"
)
// Defines values for PatchMinimumOp.
@@ -159,6 +161,9 @@ type NameserverGroup struct {
// Description Nameserver group description
Description string `json:"description"`
// Domains Nameserver group domain list
Domains []string `json:"domains"`
// Enabled Nameserver group status
Enabled bool `json:"enabled"`
@@ -173,6 +178,9 @@ type NameserverGroup struct {
// Nameservers Nameserver group
Nameservers []Nameserver `json:"nameservers"`
// Primary Nameserver group primary status
Primary bool `json:"primary"`
}
// NameserverGroupPatchOperation defines model for NameserverGroupPatchOperation.
@@ -198,6 +206,9 @@ type NameserverGroupRequest struct {
// Description Nameserver group description
Description string `json:"description"`
// Domains Nameserver group domain list
Domains []string `json:"domains"`
// Enabled Nameserver group status
Enabled bool `json:"enabled"`
@@ -209,6 +220,9 @@ type NameserverGroupRequest struct {
// Nameservers Nameserver group
Nameservers []Nameserver `json:"nameservers"`
// Primary Nameserver group primary status
Primary bool `json:"primary"`
}
// PatchMinimum defines model for PatchMinimum.
@@ -228,6 +242,9 @@ type Peer struct {
// Connected Peer to Management connection status
Connected bool `json:"connected"`
// DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud
DnsLabel string `json:"dns_label"`
// Groups Groups that the peer belongs to
Groups []GroupMinimum `json:"groups"`

View File

@@ -33,7 +33,7 @@ func NewGroups(accountManager server.AccountManager, authAudience string) *Group
// GetAllGroupsHandler list for the account
func (h *Groups) GetAllGroupsHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
@@ -50,7 +50,7 @@ func (h *Groups) GetAllGroupsHandler(w http.ResponseWriter, r *http.Request) {
// UpdateGroupHandler handles update to a group identified by a given ID
func (h *Groups) UpdateGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@@ -111,7 +111,7 @@ func (h *Groups) UpdateGroupHandler(w http.ResponseWriter, r *http.Request) {
// PatchGroupHandler handles patch updates to a group identified by a given ID
func (h *Groups) PatchGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@@ -226,7 +226,7 @@ func (h *Groups) PatchGroupHandler(w http.ResponseWriter, r *http.Request) {
// CreateGroupHandler handles group creation request
func (h *Groups) CreateGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@@ -260,7 +260,7 @@ func (h *Groups) CreateGroupHandler(w http.ResponseWriter, r *http.Request) {
// DeleteGroupHandler handles group deletion request
func (h *Groups) DeleteGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@@ -295,7 +295,7 @@ func (h *Groups) DeleteGroupHandler(w http.ResponseWriter, r *http.Request) {
// GetGroupHandler returns a group
func (h *Groups) GetGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return

View File

@@ -21,11 +21,11 @@ import (
)
var TestPeers = map[string]*server.Peer{
"A": &server.Peer{Key: "A", IP: net.ParseIP("100.100.100.100")},
"B": &server.Peer{Key: "B", IP: net.ParseIP("200.200.200.200")},
"A": {Key: "A", IP: net.ParseIP("100.100.100.100")},
"B": {Key: "B", IP: net.ParseIP("200.200.200.200")},
}
func initGroupTestData(groups ...*server.Group) *Groups {
func initGroupTestData(user *server.User, groups ...*server.Group) *Groups {
return &Groups{
accountManager: &mock_server.MockAccountManager{
SaveGroupFunc: func(accountID string, group *server.Group) error {
@@ -72,6 +72,9 @@ func initGroupTestData(groups ...*server.Group) *Groups {
Id: claims.AccountId,
Domain: "hotmail.com",
Peers: TestPeers,
Users: map[string]*server.User{
user.Id: user,
},
Groups: map[string]*server.Group{
"id-existed": {ID: "id-existed", Peers: []string{"A", "B"}},
"id-all": {ID: "id-all", Name: "All"}},
@@ -120,7 +123,8 @@ func TestGetGroup(t *testing.T) {
Name: "Group",
}
p := initGroupTestData(group)
adminUser := server.NewAdminUser("test_user")
p := initGroupTestData(adminUser, group)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
@@ -270,7 +274,8 @@ func TestWriteGroup(t *testing.T) {
},
}
p := initGroupTestData()
adminUser := server.NewAdminUser("test_user")
p := initGroupTestData(adminUser)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {

View File

@@ -4,12 +4,14 @@ import (
"github.com/gorilla/mux"
s "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/rs/cors"
"net/http"
)
// APIHandler creates the Management service HTTP API handler registering all the available endpoints.
func APIHandler(accountManager s.AccountManager, authIssuer string, authAudience string, authKeysLocation string) (http.Handler, error) {
func APIHandler(accountManager s.AccountManager, authIssuer string, authAudience string, authKeysLocation string,
appMetrics telemetry.AppMetrics) (http.Handler, error) {
jwtMiddleware, err := middleware.NewJwtMiddleware(
authIssuer,
authAudience,
@@ -21,12 +23,15 @@ func APIHandler(accountManager s.AccountManager, authIssuer string, authAudience
corsMiddleware := cors.AllowAll()
acMiddleware := middleware.NewAccessControll(
acMiddleware := middleware.NewAccessControl(
authAudience,
accountManager.IsUserAdmin)
apiHandler := mux.NewRouter()
apiHandler.Use(corsMiddleware.Handler, jwtMiddleware.Handler, acMiddleware.Handler)
rootRouter := mux.NewRouter()
metricsMiddleware := appMetrics.HTTPMiddleware()
apiHandler := rootRouter.PathPrefix("/api").Subrouter()
apiHandler.Use(metricsMiddleware.Handler, corsMiddleware.Handler, jwtMiddleware.Handler, acMiddleware.Handler)
groupsHandler := NewGroups(accountManager, authAudience)
rulesHandler := NewRules(accountManager, authAudience)
@@ -36,46 +41,67 @@ func APIHandler(accountManager s.AccountManager, authIssuer string, authAudience
routesHandler := NewRoutes(accountManager, authAudience)
nameserversHandler := NewNameservers(accountManager, authAudience)
apiHandler.HandleFunc("/api/peers", peersHandler.GetPeers).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/peers/{id}", peersHandler.HandlePeer).
apiHandler.HandleFunc("/peers", peersHandler.GetPeers).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/peers/{id}", peersHandler.HandlePeer).
Methods("GET", "PUT", "DELETE", "OPTIONS")
apiHandler.HandleFunc("/api/users", userHandler.GetUsers).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/users/{id}", userHandler.UpdateUser).Methods("PUT", "OPTIONS")
apiHandler.HandleFunc("/api/users", userHandler.CreateUserHandler).Methods("POST", "OPTIONS")
apiHandler.HandleFunc("/users", userHandler.GetUsers).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/users/{id}", userHandler.UpdateUser).Methods("PUT", "OPTIONS")
apiHandler.HandleFunc("/users", userHandler.CreateUserHandler).Methods("POST", "OPTIONS")
apiHandler.HandleFunc("/api/setup-keys", keysHandler.GetAllSetupKeysHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/setup-keys", keysHandler.CreateSetupKeyHandler).Methods("POST", "OPTIONS")
apiHandler.HandleFunc("/api/setup-keys/{id}", keysHandler.GetSetupKeyHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/setup-keys/{id}", keysHandler.UpdateSetupKeyHandler).Methods("PUT", "OPTIONS")
apiHandler.HandleFunc("/setup-keys", keysHandler.GetAllSetupKeysHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/setup-keys", keysHandler.CreateSetupKeyHandler).Methods("POST", "OPTIONS")
apiHandler.HandleFunc("/setup-keys/{id}", keysHandler.GetSetupKeyHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/setup-keys/{id}", keysHandler.UpdateSetupKeyHandler).Methods("PUT", "OPTIONS")
apiHandler.HandleFunc("/api/rules", rulesHandler.GetAllRulesHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/rules", rulesHandler.CreateRuleHandler).Methods("POST", "OPTIONS")
apiHandler.HandleFunc("/api/rules/{id}", rulesHandler.UpdateRuleHandler).Methods("PUT", "OPTIONS")
apiHandler.HandleFunc("/api/rules/{id}", rulesHandler.PatchRuleHandler).Methods("PATCH", "OPTIONS")
apiHandler.HandleFunc("/api/rules/{id}", rulesHandler.GetRuleHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/rules/{id}", rulesHandler.DeleteRuleHandler).Methods("DELETE", "OPTIONS")
apiHandler.HandleFunc("/rules", rulesHandler.GetAllRulesHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/rules", rulesHandler.CreateRuleHandler).Methods("POST", "OPTIONS")
apiHandler.HandleFunc("/rules/{id}", rulesHandler.UpdateRuleHandler).Methods("PUT", "OPTIONS")
apiHandler.HandleFunc("/rules/{id}", rulesHandler.PatchRuleHandler).Methods("PATCH", "OPTIONS")
apiHandler.HandleFunc("/rules/{id}", rulesHandler.GetRuleHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/rules/{id}", rulesHandler.DeleteRuleHandler).Methods("DELETE", "OPTIONS")
apiHandler.HandleFunc("/api/groups", groupsHandler.GetAllGroupsHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/groups", groupsHandler.CreateGroupHandler).Methods("POST", "OPTIONS")
apiHandler.HandleFunc("/api/groups/{id}", groupsHandler.UpdateGroupHandler).Methods("PUT", "OPTIONS")
apiHandler.HandleFunc("/api/groups/{id}", groupsHandler.PatchGroupHandler).Methods("PATCH", "OPTIONS")
apiHandler.HandleFunc("/api/groups/{id}", groupsHandler.GetGroupHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/groups/{id}", groupsHandler.DeleteGroupHandler).Methods("DELETE", "OPTIONS")
apiHandler.HandleFunc("/groups", groupsHandler.GetAllGroupsHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/groups", groupsHandler.CreateGroupHandler).Methods("POST", "OPTIONS")
apiHandler.HandleFunc("/groups/{id}", groupsHandler.UpdateGroupHandler).Methods("PUT", "OPTIONS")
apiHandler.HandleFunc("/groups/{id}", groupsHandler.PatchGroupHandler).Methods("PATCH", "OPTIONS")
apiHandler.HandleFunc("/groups/{id}", groupsHandler.GetGroupHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/groups/{id}", groupsHandler.DeleteGroupHandler).Methods("DELETE", "OPTIONS")
apiHandler.HandleFunc("/api/routes", routesHandler.GetAllRoutesHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/routes", routesHandler.CreateRouteHandler).Methods("POST", "OPTIONS")
apiHandler.HandleFunc("/api/routes/{id}", routesHandler.UpdateRouteHandler).Methods("PUT", "OPTIONS")
apiHandler.HandleFunc("/api/routes/{id}", routesHandler.PatchRouteHandler).Methods("PATCH", "OPTIONS")
apiHandler.HandleFunc("/api/routes/{id}", routesHandler.GetRouteHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/routes/{id}", routesHandler.DeleteRouteHandler).Methods("DELETE", "OPTIONS")
apiHandler.HandleFunc("/routes", routesHandler.GetAllRoutesHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/routes", routesHandler.CreateRouteHandler).Methods("POST", "OPTIONS")
apiHandler.HandleFunc("/routes/{id}", routesHandler.UpdateRouteHandler).Methods("PUT", "OPTIONS")
apiHandler.HandleFunc("/routes/{id}", routesHandler.PatchRouteHandler).Methods("PATCH", "OPTIONS")
apiHandler.HandleFunc("/routes/{id}", routesHandler.GetRouteHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/routes/{id}", routesHandler.DeleteRouteHandler).Methods("DELETE", "OPTIONS")
apiHandler.HandleFunc("/api/dns/nameservers", nameserversHandler.GetAllNameserversHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/dns/nameservers", nameserversHandler.CreateNameserverGroupHandler).Methods("POST", "OPTIONS")
apiHandler.HandleFunc("/api/dns/nameservers/{id}", nameserversHandler.UpdateNameserverGroupHandler).Methods("PUT", "OPTIONS")
apiHandler.HandleFunc("/api/dns/nameservers/{id}", nameserversHandler.PatchNameserverGroupHandler).Methods("PATCH", "OPTIONS")
apiHandler.HandleFunc("/api/dns/nameservers/{id}", nameserversHandler.GetNameserverGroupHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/dns/nameservers/{id}", nameserversHandler.DeleteNameserverGroupHandler).Methods("DELETE", "OPTIONS")
apiHandler.HandleFunc("/dns/nameservers", nameserversHandler.GetAllNameserversHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/dns/nameservers", nameserversHandler.CreateNameserverGroupHandler).Methods("POST", "OPTIONS")
apiHandler.HandleFunc("/dns/nameservers/{id}", nameserversHandler.UpdateNameserverGroupHandler).Methods("PUT", "OPTIONS")
apiHandler.HandleFunc("/dns/nameservers/{id}", nameserversHandler.PatchNameserverGroupHandler).Methods("PATCH", "OPTIONS")
apiHandler.HandleFunc("/dns/nameservers/{id}", nameserversHandler.GetNameserverGroupHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/dns/nameservers/{id}", nameserversHandler.DeleteNameserverGroupHandler).Methods("DELETE", "OPTIONS")
return apiHandler, nil
err = apiHandler.Walk(func(route *mux.Route, router *mux.Router, ancestors []*mux.Route) error {
methods, err := route.GetMethods()
if err != nil {
return err
}
for _, method := range methods {
template, err := route.GetPathTemplate()
if err != nil {
return err
}
err = metricsMiddleware.AddHTTPRequestResponseCounter(template, method)
if err != nil {
return err
}
}
return nil
})
if err != nil {
return nil, err
}
return rootRouter, nil
}

View File

@@ -9,24 +9,25 @@ import (
type IsUserAdminFunc func(claims jwtclaims.AuthorizationClaims) (bool, error)
// AccessControll middleware to restrict to make POST/PUT/DELETE requests by admin only
type AccessControll struct {
// AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only
type AccessControl struct {
jwtExtractor jwtclaims.ClaimsExtractor
isUserAdmin IsUserAdminFunc
audience string
}
// NewAccessControll instance constructor
func NewAccessControll(audience string, isUserAdmin IsUserAdminFunc) *AccessControll {
return &AccessControll{
// NewAccessControl instance constructor
func NewAccessControl(audience string, isUserAdmin IsUserAdminFunc) *AccessControl {
return &AccessControl{
isUserAdmin: isUserAdmin,
audience: audience,
jwtExtractor: *jwtclaims.NewClaimsExtractor(nil),
}
}
// Handler method of the middleware which forbinneds all modify requests for non admin users
func (a *AccessControll) Handler(h http.Handler) http.Handler {
// Handler method of the middleware which forbids all modify requests for non admin users
// It also adds
func (a *AccessControl) Handler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
jwtClaims := a.jwtExtractor.ExtractClaimsFromRequestContext(r, a.audience)
@@ -34,7 +35,6 @@ func (a *AccessControll) Handler(h http.Handler) http.Handler {
if err != nil {
http.Error(w, fmt.Sprintf("error get user from JWT: %v", err), http.StatusUnauthorized)
return
}
if !ok {

View File

@@ -186,7 +186,7 @@ func (m *JWTMiddleware) CheckJWTFromRequest(w http.ResponseWriter, r *http.Reque
validatedToken, err := m.ValidateAndParse(token)
if err != nil {
m.Options.ErrorHandler(w, r, "The token isn't valid")
m.Options.ErrorHandler(w, r, err.Error())
return err
}

View File

@@ -30,7 +30,7 @@ func NewNameservers(accountManager server.AccountManager, authAudience string) *
// GetAllNameserversHandler returns the list of nameserver groups for the account
func (h *Nameservers) GetAllNameserversHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
@@ -53,7 +53,7 @@ func (h *Nameservers) GetAllNameserversHandler(w http.ResponseWriter, r *http.Re
// CreateNameserverGroupHandler handles nameserver group creation request
func (h *Nameservers) CreateNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
@@ -71,7 +71,7 @@ func (h *Nameservers) CreateNameserverGroupHandler(w http.ResponseWriter, r *htt
return
}
nsGroup, err := h.accountManager.CreateNameServerGroup(account.Id, req.Name, req.Description, nsList, req.Groups, req.Enabled)
nsGroup, err := h.accountManager.CreateNameServerGroup(account.Id, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled)
if err != nil {
toHTTPError(err, w)
return
@@ -84,7 +84,7 @@ func (h *Nameservers) CreateNameserverGroupHandler(w http.ResponseWriter, r *htt
// UpdateNameserverGroupHandler handles update to a nameserver group identified by a given ID
func (h *Nameservers) UpdateNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
@@ -113,6 +113,8 @@ func (h *Nameservers) UpdateNameserverGroupHandler(w http.ResponseWriter, r *htt
ID: nsGroupID,
Name: req.Name,
Description: req.Description,
Primary: req.Primary,
Domains: req.Domains,
NameServers: nsList,
Groups: req.Groups,
Enabled: req.Enabled,
@@ -131,7 +133,7 @@ func (h *Nameservers) UpdateNameserverGroupHandler(w http.ResponseWriter, r *htt
// PatchNameserverGroupHandler handles patch updates to a nameserver group identified by a given ID
func (h *Nameservers) PatchNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
@@ -168,6 +170,16 @@ func (h *Nameservers) PatchNameserverGroupHandler(w http.ResponseWriter, r *http
Type: server.UpdateNameServerGroupDescription,
Values: patch.Value,
})
case api.NameserverGroupPatchOperationPathPrimary:
operations = append(operations, server.NameServerGroupUpdateOperation{
Type: server.UpdateNameServerGroupPrimary,
Values: patch.Value,
})
case api.NameserverGroupPatchOperationPathDomains:
operations = append(operations, server.NameServerGroupUpdateOperation{
Type: server.UpdateNameServerGroupDomains,
Values: patch.Value,
})
case api.NameserverGroupPatchOperationPathNameservers:
operations = append(operations, server.NameServerGroupUpdateOperation{
Type: server.UpdateNameServerGroupNameServers,
@@ -202,7 +214,7 @@ func (h *Nameservers) PatchNameserverGroupHandler(w http.ResponseWriter, r *http
// DeleteNameserverGroupHandler handles nameserver group deletion request
func (h *Nameservers) DeleteNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
@@ -226,7 +238,7 @@ func (h *Nameservers) DeleteNameserverGroupHandler(w http.ResponseWriter, r *htt
// GetNameserverGroupHandler handles a nameserver group Get request identified by ID
func (h *Nameservers) GetNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
@@ -279,6 +291,8 @@ func toNameserverGroupResponse(serverNSGroup *nbdns.NameServerGroup) *api.Namese
Id: serverNSGroup.ID,
Name: serverNSGroup.Name,
Description: serverNSGroup.Description,
Primary: serverNSGroup.Primary,
Domains: serverNSGroup.Domains,
Groups: serverNSGroup.Groups,
Nameservers: nsList,
Enabled: serverNSGroup.Enabled,

View File

@@ -29,12 +29,16 @@ const (
var testingNSAccount = &server.Account{
Id: testNSGroupAccountID,
Domain: "hotmail.com",
Users: map[string]*server.User{
"test_user": server.NewAdminUser("test_user"),
},
}
var baseExistingNSGroup = &nbdns.NameServerGroup{
ID: existingNSGroupID,
Name: "super",
Description: "super",
Primary: true,
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("1.1.1.1"),
@@ -60,7 +64,7 @@ func initNameserversTestData() *Nameservers {
}
return nil, status.Errorf(codes.NotFound, "nameserver group with ID %s not found", nsGroupID)
},
CreateNameServerGroupFunc: func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, enabled bool) (*nbdns.NameServerGroup, error) {
CreateNameServerGroupFunc: func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool) (*nbdns.NameServerGroup, error) {
return &nbdns.NameServerGroup{
ID: existingNSGroupID,
Name: name,
@@ -68,6 +72,8 @@ func initNameserversTestData() *Nameservers {
NameServers: nameServerList,
Groups: groups,
Enabled: enabled,
Primary: primary,
Domains: domains,
}, nil
},
DeleteNameServerGroupFunc: func(accountID, nsGroupID string) error {
@@ -150,7 +156,7 @@ func TestNameserversHandlers(t *testing.T) {
requestType: http.MethodPost,
requestPath: "/api/dns/nameservers",
requestBody: bytes.NewBuffer(
[]byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1.1.1.1\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true}")),
[]byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1.1.1.1\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true,\"primary\":true}")),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedNSGroup: &api.NameserverGroup{
@@ -166,6 +172,7 @@ func TestNameserversHandlers(t *testing.T) {
},
Groups: []string{"group"},
Enabled: true,
Primary: true,
},
},
{
@@ -173,7 +180,7 @@ func TestNameserversHandlers(t *testing.T) {
requestType: http.MethodPost,
requestPath: "/api/dns/nameservers",
requestBody: bytes.NewBuffer(
[]byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1000\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true}")),
[]byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1000\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true,\"primary\":true}")),
expectedStatus: http.StatusBadRequest,
expectedBody: false,
},
@@ -182,7 +189,7 @@ func TestNameserversHandlers(t *testing.T) {
requestType: http.MethodPut,
requestPath: "/api/dns/nameservers/" + existingNSGroupID,
requestBody: bytes.NewBuffer(
[]byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1.1.1.1\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true}")),
[]byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1.1.1.1\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true,\"primary\":true}")),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedNSGroup: &api.NameserverGroup{
@@ -198,6 +205,7 @@ func TestNameserversHandlers(t *testing.T) {
},
Groups: []string{"group"},
Enabled: true,
Primary: true,
},
},
{
@@ -205,7 +213,7 @@ func TestNameserversHandlers(t *testing.T) {
requestType: http.MethodPut,
requestPath: "/api/dns/nameservers/" + notFoundNSGroupID,
requestBody: bytes.NewBuffer(
[]byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1.1.1.1\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true}")),
[]byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1.1.1.1\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true,\"primary\":true}")),
expectedStatus: http.StatusNotFound,
expectedBody: false,
},
@@ -214,7 +222,7 @@ func TestNameserversHandlers(t *testing.T) {
requestType: http.MethodPut,
requestPath: "/api/dns/nameservers/" + notFoundNSGroupID,
requestBody: bytes.NewBuffer(
[]byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"100\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true}")),
[]byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"100\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true,\"primary\":true}")),
expectedStatus: http.StatusBadRequest,
expectedBody: false,
},
@@ -232,6 +240,7 @@ func TestNameserversHandlers(t *testing.T) {
Nameservers: toNameserverGroupResponse(baseExistingNSGroup).Nameservers,
Groups: baseExistingNSGroup.Groups,
Enabled: baseExistingNSGroup.Enabled,
Primary: baseExistingNSGroup.Primary,
},
},
{

View File

@@ -56,7 +56,7 @@ func (h *Peers) deletePeer(accountId string, peer *server.Peer, w http.ResponseW
}
func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
@@ -95,15 +95,20 @@ func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) {
func (h *Peers) GetPeers(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
peers, err := h.accountManager.GetPeers(account.Id, user.Id)
if err != nil {
return
}
respBody := []*api.Peer{}
for _, peer := range account.Peers {
for _, peer := range peers {
respBody = append(respBody, toPeerResponse(peer, account))
}
writeJSONObject(w, respBody)
@@ -147,5 +152,6 @@ func toPeerResponse(peer *server.Peer, account *server.Account) *api.Peer {
Hostname: peer.Meta.Hostname,
UserId: &peer.UserID,
UiVersion: &peer.Meta.UIVersion,
DnsLabel: peer.DNSLabel,
}
}

View File

@@ -16,15 +16,21 @@ import (
"github.com/netbirdio/netbird/management/server/mock_server"
)
func initTestMetaData(peer ...*server.Peer) *Peers {
func initTestMetaData(peers ...*server.Peer) *Peers {
return &Peers{
accountManager: &mock_server.MockAccountManager{
GetPeersFunc: func(accountID, userID string) ([]*server.Peer, error) {
return peers, nil
},
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
return &server.Account{
Id: claims.AccountId,
Domain: "hotmail.com",
Peers: map[string]*server.Peer{
"test_peer": peer[0],
"test_peer": peers[0],
},
Users: map[string]*server.User{
"test_user": server.NewAdminUser("test_user"),
},
}, nil
},

View File

@@ -33,16 +33,24 @@ func NewRoutes(accountManager server.AccountManager, authAudience string) *Route
// GetAllRoutesHandler returns the list of routes for the account
func (h *Routes) GetAllRoutesHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
routes, err := h.accountManager.ListRoutes(account.Id)
routes, err := h.accountManager.ListRoutes(account.Id, user.Id)
if err != nil {
log.Error(err)
if e, ok := server.FromError(err); ok {
switch e.Type() {
case server.PermissionDenied:
http.Error(w, e.Error(), http.StatusForbidden)
return
default:
}
}
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
@@ -56,7 +64,7 @@ func (h *Routes) GetAllRoutesHandler(w http.ResponseWriter, r *http.Request) {
// CreateRouteHandler handles route creation request
func (h *Routes) CreateRouteHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@@ -104,7 +112,7 @@ func (h *Routes) CreateRouteHandler(w http.ResponseWriter, r *http.Request) {
// UpdateRouteHandler handles update to a route identified by a given ID
func (h *Routes) UpdateRouteHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@@ -117,7 +125,7 @@ func (h *Routes) UpdateRouteHandler(w http.ResponseWriter, r *http.Request) {
return
}
_, err = h.accountManager.GetRoute(account.Id, routeID)
_, err = h.accountManager.GetRoute(account.Id, routeID, user.Id)
if err != nil {
http.Error(w, fmt.Sprintf("couldn't find route for ID %s", routeID), http.StatusNotFound)
return
@@ -177,7 +185,7 @@ func (h *Routes) UpdateRouteHandler(w http.ResponseWriter, r *http.Request) {
// PatchRouteHandler handles patch updates to a route identified by a given ID
func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@@ -190,7 +198,7 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
return
}
_, err = h.accountManager.GetRoute(account.Id, routeID)
_, err = h.accountManager.GetRoute(account.Id, routeID, user.Id)
if err != nil {
log.Error(err)
http.Error(w, fmt.Sprintf("couldn't find route ID %s", routeID), http.StatusNotFound)
@@ -334,7 +342,7 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
// DeleteRouteHandler handles route deletion request
func (h *Routes) DeleteRouteHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@@ -363,7 +371,7 @@ func (h *Routes) DeleteRouteHandler(w http.ResponseWriter, r *http.Request) {
// GetRouteHandler handles a route Get request identified by ID
func (h *Routes) GetRouteHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@@ -375,7 +383,7 @@ func (h *Routes) GetRouteHandler(w http.ResponseWriter, r *http.Request) {
return
}
foundRoute, err := h.accountManager.GetRoute(account.Id, routeID)
foundRoute, err := h.accountManager.GetRoute(account.Id, routeID, user.Id)
if err != nil {
http.Error(w, "route not found", http.StatusNotFound)
return

View File

@@ -51,12 +51,15 @@ var testingAccount = &server.Account{
IP: netip.MustParseAddr(existingPeerID).AsSlice(),
},
},
Users: map[string]*server.User{
"test_user": server.NewAdminUser("test_user"),
},
}
func initRoutesTestData() *Routes {
return &Routes{
accountManager: &mock_server.MockAccountManager{
GetRouteFunc: func(_, routeID string) (*route.Route, error) {
GetRouteFunc: func(_, routeID, _ string) (*route.Route, error) {
if routeID == existingRouteID {
return baseExistingRoute, nil
}

View File

@@ -31,15 +31,29 @@ func NewRules(accountManager server.AccountManager, authAudience string) *Rules
// GetAllRulesHandler list for the account
func (h *Rules) GetAllRulesHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
accountRules, err := h.accountManager.ListRules(account.Id, user.Id)
if err != nil {
log.Error(err)
if e, ok := server.FromError(err); ok {
switch e.Type() {
case server.PermissionDenied:
http.Error(w, e.Error(), http.StatusForbidden)
return
default:
}
}
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
rules := []*api.Rule{}
for _, r := range account.Rules {
for _, r := range accountRules {
rules = append(rules, toRuleResponse(account, r))
}
@@ -48,7 +62,7 @@ func (h *Rules) GetAllRulesHandler(w http.ResponseWriter, r *http.Request) {
// UpdateRuleHandler handles update to a rule identified by a given ID
func (h *Rules) UpdateRuleHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@@ -118,7 +132,7 @@ func (h *Rules) UpdateRuleHandler(w http.ResponseWriter, r *http.Request) {
// PatchRuleHandler handles patch updates to a rule identified by a given ID
func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@@ -275,7 +289,7 @@ func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
// CreateRuleHandler handles rule creation request
func (h *Rules) CreateRuleHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@@ -332,7 +346,7 @@ func (h *Rules) CreateRuleHandler(w http.ResponseWriter, r *http.Request) {
// DeleteRuleHandler handles rule deletion request
func (h *Rules) DeleteRuleHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@@ -356,7 +370,7 @@ func (h *Rules) DeleteRuleHandler(w http.ResponseWriter, r *http.Request) {
// GetRuleHandler handles a group Get request identified by ID
func (h *Rules) GetRuleHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@@ -370,7 +384,7 @@ func (h *Rules) GetRuleHandler(w http.ResponseWriter, r *http.Request) {
return
}
rule, err := h.accountManager.GetRule(account.Id, ruleID)
rule, err := h.accountManager.GetRule(account.Id, ruleID, user.Id)
if err != nil {
http.Error(w, "rule not found", http.StatusNotFound)
return

View File

@@ -28,7 +28,7 @@ func initRulesTestData(rules ...*server.Rule) *Rules {
}
return nil
},
GetRuleFunc: func(_, ruleID string) (*server.Rule, error) {
GetRuleFunc: func(_, ruleID, _ string) (*server.Rule, error) {
if ruleID != "idoftherule" {
return nil, fmt.Errorf("not found")
}
@@ -75,6 +75,9 @@ func initRulesTestData(rules ...*server.Rule) *Rules {
"F": {ID: "F"},
"G": {ID: "G"},
},
Users: map[string]*server.User{
"test_user": server.NewAdminUser("test_user"),
},
}, nil
},
},

View File

@@ -31,7 +31,7 @@ func NewSetupKeysHandler(accountManager server.AccountManager, authAudience stri
// CreateSetupKeyHandler is a POST requests that creates a new SetupKey
func (h *SetupKeys) CreateSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
@@ -81,7 +81,7 @@ func (h *SetupKeys) CreateSetupKeyHandler(w http.ResponseWriter, r *http.Request
// GetSetupKeyHandler is a GET request to get a SetupKey by ID
func (h *SetupKeys) GetSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
@@ -95,7 +95,7 @@ func (h *SetupKeys) GetSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
return
}
key, err := h.accountManager.GetSetupKey(account.Id, keyID)
key, err := h.accountManager.GetSetupKey(account.Id, user.Id, keyID)
if err != nil {
errStatus, ok := status.FromError(err)
if ok && errStatus.Code() == codes.NotFound {
@@ -112,7 +112,7 @@ func (h *SetupKeys) GetSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
// UpdateSetupKeyHandler is a PUT request to update server.SetupKey
func (h *SetupKeys) UpdateSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
@@ -168,14 +168,14 @@ func (h *SetupKeys) UpdateSetupKeyHandler(w http.ResponseWriter, r *http.Request
// GetAllSetupKeysHandler is a GET request that returns a list of SetupKey
func (h *SetupKeys) GetAllSetupKeysHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
setupKeys, err := h.accountManager.ListSetupKeys(account.Id)
setupKeys, err := h.accountManager.ListSetupKeys(account.Id, user.Id)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)

View File

@@ -28,13 +28,17 @@ const (
notFoundSetupKeyID = "notFoundSetupKeyID"
)
func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.SetupKey, updatedSetupKey *server.SetupKey) *SetupKeys {
func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.SetupKey, updatedSetupKey *server.SetupKey,
user *server.User) *SetupKeys {
return &SetupKeys{
accountManager: &mock_server.MockAccountManager{
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
return &server.Account{
Id: testAccountID,
Domain: "hotmail.com",
Users: map[string]*server.User{
user.Id: user,
},
SetupKeys: map[string]*server.SetupKey{
defaultKey.Key: defaultKey,
},
@@ -49,7 +53,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
}
return nil, fmt.Errorf("failed creating setup key")
},
GetSetupKeyFunc: func(accountID string, keyID string) (*server.SetupKey, error) {
GetSetupKeyFunc: func(accountID, userID, keyID string) (*server.SetupKey, error) {
switch keyID {
case defaultKey.Id:
return defaultKey, nil
@@ -67,7 +71,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
return nil, status.Errorf(codes.NotFound, "key %s not found", key.Id)
},
ListSetupKeysFunc: func(accountID string) ([]*server.SetupKey, error) {
ListSetupKeysFunc: func(accountID, userID string) ([]*server.SetupKey, error) {
return []*server.SetupKey{defaultKey}, nil
},
},
@@ -75,7 +79,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
jwtExtractor: jwtclaims.ClaimsExtractor{
ExtractClaimsFromRequestContext: func(r *http.Request, authAudience string) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{
UserId: "test_user",
UserId: user.Id,
Domain: "hotmail.com",
AccountId: testAccountID,
}
@@ -88,6 +92,8 @@ func TestSetupKeysHandlers(t *testing.T) {
defaultSetupKey := server.GenerateDefaultSetupKey()
defaultSetupKey.Id = existingSetupKeyID
adminUser := server.NewAdminUser("test_user")
newSetupKey := server.GenerateSetupKey(newSetupKeyName, server.SetupKeyReusable, 0, []string{"group-1"})
updatedDefaultSetupKey := defaultSetupKey.Copy()
updatedDefaultSetupKey.AutoGroups = []string{"group-1"}
@@ -153,7 +159,7 @@ func TestSetupKeysHandlers(t *testing.T) {
},
}
handler := initSetupKeysTestMetaData(defaultSetupKey, newSetupKey, updatedDefaultSetupKey)
handler := initSetupKeysTestMetaData(defaultSetupKey, newSetupKey, updatedDefaultSetupKey, adminUser)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {

View File

@@ -34,7 +34,7 @@ func (h *UserHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
http.Error(w, "", http.StatusBadRequest)
}
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
@@ -87,7 +87,7 @@ func (h *UserHandler) CreateUserHandler(w http.ResponseWriter, r *http.Request)
http.Error(w, "", http.StatusNotFound)
}
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
}
@@ -132,12 +132,13 @@ func (h *UserHandler) GetUsers(w http.ResponseWriter, r *http.Request) {
http.Error(w, "", http.StatusBadRequest)
}
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
data, err := h.accountManager.GetUsersFromAccount(account.Id)
data, err := h.accountManager.GetUsersFromAccount(account.Id, user.Id)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)

View File

@@ -27,7 +27,7 @@ func initUsers(user ...*server.User) *UserHandler {
Users: users,
}, nil
},
GetUsersFromAccountFunc: func(accountID string) ([]*server.UserInfo, error) {
GetUsersFromAccountFunc: func(accountID, userID string) ([]*server.UserInfo, error) {
users := make([]*server.UserInfo, 0)
for _, v := range user {
users = append(users, &server.UserInfo{
@@ -44,7 +44,7 @@ func initUsers(user ...*server.User) *UserHandler {
jwtExtractor: jwtclaims.ClaimsExtractor{
ExtractClaimsFromRequestContext: func(r *http.Request, authAudiance string) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{
UserId: "test_user",
UserId: "1",
Domain: "hotmail.com",
AccountId: "test_id",
}

View File

@@ -56,16 +56,22 @@ func (d *Duration) UnmarshalJSON(b []byte) error {
func getJWTAccount(accountManager server.AccountManager,
jwtExtractor jwtclaims.ClaimsExtractor,
authAudience string, r *http.Request) (*server.Account, error) {
authAudience string, r *http.Request) (*server.Account, *server.User, error) {
jwtClaims := jwtExtractor.ExtractClaimsFromRequestContext(r, authAudience)
claims := jwtExtractor.ExtractClaimsFromRequestContext(r, authAudience)
account, err := accountManager.GetAccountFromToken(jwtClaims)
account, err := accountManager.GetAccountFromToken(claims)
if err != nil {
return nil, fmt.Errorf("failed getting account of a user %s: %v", jwtClaims.UserId, err)
return nil, nil, fmt.Errorf("failed getting account of a user %s: %v", claims.UserId, err)
}
return account, nil
user := account.Users[claims.UserId]
if user == nil {
// this is not really possible because we got an account by user ID
return nil, nil, fmt.Errorf("user %s not found", claims.UserId)
}
return account, user, nil
}
func toHTTPError(err error, w http.ResponseWriter) {

View File

@@ -6,6 +6,7 @@ import (
"context"
"encoding/json"
"fmt"
"github.com/netbirdio/netbird/management/server/telemetry"
"io"
"net/http"
"net/url"
@@ -24,6 +25,7 @@ type Auth0Manager struct {
httpClient ManagerHTTPClient
credentials ManagerCredentials
helper ManagerHelper
appMetrics telemetry.AppMetrics
}
// Auth0ClientConfig auth0 manager client configurations
@@ -51,6 +53,7 @@ type Auth0Credentials struct {
httpClient ManagerHTTPClient
jwtToken JWTToken
mux sync.Mutex
appMetrics telemetry.AppMetrics
}
// createUserRequest is a user create request
@@ -106,7 +109,7 @@ type auth0Profile struct {
}
// NewAuth0Manager creates a new instance of the Auth0Manager
func NewAuth0Manager(config Auth0ClientConfig) (*Auth0Manager, error) {
func NewAuth0Manager(config Auth0ClientConfig, appMetrics telemetry.AppMetrics) (*Auth0Manager, error) {
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
httpTransport.MaxIdleConns = 5
@@ -134,12 +137,15 @@ func NewAuth0Manager(config Auth0ClientConfig) (*Auth0Manager, error) {
clientConfig: config,
httpClient: httpClient,
helper: helper,
appMetrics: appMetrics,
}
return &Auth0Manager{
authIssuer: config.AuthIssuer,
credentials: credentials,
httpClient: httpClient,
helper: helper,
appMetrics: appMetrics,
}, nil
}
@@ -170,6 +176,9 @@ func (c *Auth0Credentials) requestJWTToken() (*http.Response, error) {
res, err = c.httpClient.Do(req)
if err != nil {
if c.appMetrics != nil {
c.appMetrics.IDPMetrics().CountRequestError()
}
return res, err
}
@@ -214,6 +223,10 @@ func (c *Auth0Credentials) Authenticate() (JWTToken, error) {
c.mux.Lock()
defer c.mux.Unlock()
if c.appMetrics != nil {
c.appMetrics.IDPMetrics().CountAuthenticate()
}
// If jwtToken has an expires time and we have enough time to do a request return immediately
if c.jwtStillValid() {
return c.jwtToken, nil
@@ -287,9 +300,16 @@ func (am *Auth0Manager) GetAccount(accountID string) ([]*UserData, error) {
res, err := am.httpClient.Do(req)
if err != nil {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
return nil, err
}
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountGetAccount()
}
body, err := io.ReadAll(res.Body)
if err != nil {
return nil, err
@@ -342,9 +362,16 @@ func (am *Auth0Manager) GetUserDataByID(userID string, appMetadata AppMetadata)
res, err := am.httpClient.Do(req)
if err != nil {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
return nil, err
}
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountGetUserDataByID()
}
body, err := io.ReadAll(res.Body)
if err != nil {
return nil, err
@@ -398,9 +425,16 @@ func (am *Auth0Manager) UpdateUserAppMetadata(userID string, appMetadata AppMeta
res, err := am.httpClient.Do(req)
if err != nil {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
return err
}
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountUpdateUserAppMetadata()
}
defer func() {
err = res.Body.Close()
if err != nil {
@@ -416,12 +450,13 @@ func (am *Auth0Manager) UpdateUserAppMetadata(userID string, appMetadata AppMeta
}
func buildCreateUserRequestPayload(email string, name string, accountID string) (string, error) {
invite := true
req := &createUserRequest{
Email: email,
Name: name,
AppMeta: AppMetadata{
WTAccountID: accountID,
WTPendingInvite: true,
WTPendingInvite: &invite,
},
Connection: "Username-Password-Authentication",
Password: GeneratePassword(8, 1, 1, 1),
@@ -502,6 +537,9 @@ func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) {
jobResp, err := am.httpClient.Do(exportJobReq)
if err != nil {
log.Debugf("Couldn't get job response %v", err)
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
return nil, err
}
@@ -512,6 +550,9 @@ func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) {
}
}()
if jobResp.StatusCode != 200 {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to update the appMetadata, statusCode %d", jobResp.StatusCode)
}
@@ -530,6 +571,9 @@ func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) {
}
if exportJobResp.ID == "" {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("couldn't get an batch id status %d, %s, response body: %v", jobResp.StatusCode, jobResp.Status, exportJobResp)
}
@@ -556,12 +600,16 @@ func (am *Auth0Manager) GetUserByEmail(email string) ([]*UserData, error) {
if err != nil {
return nil, err
}
reqURL := am.authIssuer + "/api/v2/users-by-email?email=" + email
reqURL := am.authIssuer + "/api/v2/users-by-email?email=" + url.QueryEscape(email)
body, err := doGetReq(am.httpClient, reqURL, jwtToken.AccessToken)
if err != nil {
return nil, err
}
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountGetUserByEmail()
}
userResp := []*UserData{}
err = am.helper.Unmarshal(body, &userResp)
@@ -585,9 +633,16 @@ func (am *Auth0Manager) CreateUser(email string, name string, accountID string)
return nil, err
}
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountCreateUser()
}
resp, err := am.httpClient.Do(req)
if err != nil {
log.Debugf("Couldn't get job response %v", err)
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
return nil, err
}
@@ -598,6 +653,9 @@ func (am *Auth0Manager) CreateUser(email string, name string, accountID string)
}
}()
if !(resp.StatusCode == 200 || resp.StatusCode == 201) {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to create user, statusCode %d", resp.StatusCode)
}
@@ -698,7 +756,7 @@ func (am *Auth0Manager) downloadProfileExport(location string) (map[string][]*Us
Email: profile.Email,
AppMetadata: AppMetadata{
WTAccountID: profile.AccountID,
WTPendingInvite: profile.PendingInvite,
WTPendingInvite: &profile.PendingInvite,
},
})
}
@@ -729,13 +787,12 @@ func doGetReq(client ManagerHTTPClient, url, accessToken string) ([]byte, error)
log.Errorf("error while closing body for url %s: %v", url, err)
}
}()
if res.StatusCode != 200 {
return nil, fmt.Errorf("unable to get %s, statusCode %d", url, res.StatusCode)
}
body, err := io.ReadAll(res.Body)
if err != nil {
return nil, err
}
if res.StatusCode != 200 {
return nil, fmt.Errorf("unable to get %s, statusCode %d", url, res.StatusCode)
}
return body, nil
}

View File

@@ -3,6 +3,7 @@ package idp
import (
"encoding/json"
"fmt"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/stretchr/testify/require"
"io"
"net/http"
@@ -340,7 +341,7 @@ func TestAuth0_UpdateUserAppMetadata(t *testing.T) {
updateUserAppMetadataTestCase2 := updateUserAppMetadataTest{
name: "Bad Status Code",
inputReqBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
expectedReqBody: fmt.Sprintf("{\"app_metadata\":{\"wt_account_id\":\"%s\",\"wt_pending_invite\":false}}", appMetadata.WTAccountID),
expectedReqBody: fmt.Sprintf("{\"app_metadata\":{\"wt_account_id\":\"%s\",\"wt_pending_invite\":null}}", appMetadata.WTAccountID),
appMetadata: appMetadata,
statusCode: 400,
helper: JsonParser{},
@@ -363,7 +364,7 @@ func TestAuth0_UpdateUserAppMetadata(t *testing.T) {
updateUserAppMetadataTestCase4 := updateUserAppMetadataTest{
name: "Good request",
inputReqBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
expectedReqBody: fmt.Sprintf("{\"app_metadata\":{\"wt_account_id\":\"%s\",\"wt_pending_invite\":false}}", appMetadata.WTAccountID),
expectedReqBody: fmt.Sprintf("{\"app_metadata\":{\"wt_account_id\":\"%s\",\"wt_pending_invite\":null}}", appMetadata.WTAccountID),
appMetadata: appMetadata,
statusCode: 200,
helper: JsonParser{},
@@ -371,7 +372,23 @@ func TestAuth0_UpdateUserAppMetadata(t *testing.T) {
assertErrFuncMessage: "shouldn't return error",
}
for _, testCase := range []updateUserAppMetadataTest{updateUserAppMetadataTestCase1, updateUserAppMetadataTestCase2, updateUserAppMetadataTestCase3, updateUserAppMetadataTestCase4} {
invite := true
updateUserAppMetadataTestCase5 := updateUserAppMetadataTest{
name: "Update Pending Invite",
inputReqBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
expectedReqBody: fmt.Sprintf("{\"app_metadata\":{\"wt_account_id\":\"%s\",\"wt_pending_invite\":true}}", appMetadata.WTAccountID),
appMetadata: AppMetadata{
WTAccountID: "ok",
WTPendingInvite: &invite,
},
statusCode: 200,
helper: JsonParser{},
assertErrFunc: assert.NoError,
assertErrFuncMessage: "shouldn't return error",
}
for _, testCase := range []updateUserAppMetadataTest{updateUserAppMetadataTestCase1, updateUserAppMetadataTestCase2,
updateUserAppMetadataTestCase3, updateUserAppMetadataTestCase4, updateUserAppMetadataTestCase5} {
t.Run(testCase.name, func(t *testing.T) {
jwtReqClient := mockHTTPClient{
resBody: testCase.inputReqBody,
@@ -459,7 +476,7 @@ func TestNewAuth0Manager(t *testing.T) {
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4} {
t.Run(testCase.name, func(t *testing.T) {
_, err := NewAuth0Manager(testCase.inputConfig)
_, err := NewAuth0Manager(testCase.inputConfig, &telemetry.MockAppMetrics{})
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
})
}

View File

@@ -2,6 +2,7 @@ package idp
import (
"fmt"
"github.com/netbirdio/netbird/management/server/telemetry"
"net/http"
"strings"
"time"
@@ -51,7 +52,7 @@ type AppMetadata struct {
// WTAccountID is a NetBird (previously Wiretrustee) account id to update in the IDP
// maps to wt_account_id when json.marshal
WTAccountID string `json:"wt_account_id,omitempty"`
WTPendingInvite bool `json:"wt_pending_invite"`
WTPendingInvite *bool `json:"wt_pending_invite"`
}
// JWTToken a JWT object that holds information of a token
@@ -64,12 +65,12 @@ type JWTToken struct {
}
// NewManager returns a new idp manager based on the configuration that it receives
func NewManager(config Config) (Manager, error) {
func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error) {
switch strings.ToLower(config.ManagerType) {
case "none", "":
return nil, nil
case "auth0":
return NewAuth0Manager(config.Auth0ClientCredentials)
return NewAuth0Manager(config.Auth0ClientCredentials, appMetrics)
default:
return nil, fmt.Errorf("invalid manager type: %s", config.ManagerType)
}

View File

@@ -398,17 +398,17 @@ func startManagement(t *testing.T, port int, config *Config) (*grpc.Server, erro
return nil, err
}
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
store, err := NewStore(config.Datadir)
store, err := NewFileStore(config.Datadir)
if err != nil {
return nil, err
}
peersUpdateManager := NewPeersUpdateManager()
accountManager, err := BuildManager(store, peersUpdateManager, nil)
accountManager, err := BuildManager(store, peersUpdateManager, nil, "", "")
if err != nil {
return nil, err
}
turnManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := NewServer(config, accountManager, peersUpdateManager, turnManager)
mgmtServer, err := NewServer(config, accountManager, peersUpdateManager, turnManager, nil)
if err != nil {
return nil, err
}

View File

@@ -488,17 +488,17 @@ func startServer(config *server.Config) (*grpc.Server, net.Listener) {
Expect(err).NotTo(HaveOccurred())
s := grpc.NewServer()
store, err := server.NewStore(config.Datadir)
store, err := server.NewFileStore(config.Datadir)
if err != nil {
log.Fatalf("failed creating a store: %s: %v", config.Datadir, err)
}
peersUpdateManager := server.NewPeersUpdateManager()
accountManager, err := server.BuildManager(store, peersUpdateManager, nil)
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "")
if err != nil {
log.Fatalf("failed creating a manager: %v", err)
}
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager)
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil)
Expect(err).NotTo(HaveOccurred())
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
go func() {

View File

@@ -17,7 +17,7 @@ import (
const (
// PayloadEvent identifies an event type
PayloadEvent = "self-hosted stats"
// payloadEndpoint metrics endpoint to send anonymous data
// payloadEndpoint metrics defaultEndpoint to send anonymous data
payloadEndpoint = "https://metrics.netbird.io"
// defaultPushInterval default interval to push metrics
defaultPushInterval = 24 * time.Hour
@@ -188,13 +188,17 @@ func (w *Worker) generateProperties() properties {
userPeers++
}
_, connected := connections[peer.Key]
if connected || peer.Status.LastSeen.After(w.lastRun) {
activePeersLastDay++
}
osKey := strings.ToLower(fmt.Sprintf("peer_os_%s", peer.Meta.GoOS))
osCount := osPeers[osKey]
osPeers[osKey] = osCount + 1
_, connected := connections[peer.Key]
if connected || peer.Status.LastSeen.After(w.lastRun) {
activePeersLastDay++
osActiveKey := osKey + "_active"
osActiveCount := osPeers[osActiveKey]
osPeers[osActiveKey] = osActiveCount + 1
}
}
}
@@ -279,5 +283,4 @@ func createPostRequest(ctx context.Context, endpoint string, payloadStr string)
req.Header.Add("content-type", "application/json")
return req, nil
}

View File

@@ -1,13 +0,0 @@
## Migration from Store v2 to Store v2
Previously Account.Id was an Auth0 user id.
Conversion moves user id to Account.CreatedBy and generates a new Account.Id using xid.
It also adds a User with id = old Account.Id with a role Admin.
To start a conversion simply run the command below providing your current Wiretrustee Management datadir (where store.json file is located)
and a new data directory location (where a converted store.js will be stored):
```shell
./migration --oldDir /var/wiretrustee/datadir --newDir /var/wiretrustee/newdatadir/
```
Afterwards you can run the Management service providing ```/var/wiretrustee/newdatadir/ ``` as a datadir.

View File

@@ -1,56 +0,0 @@
package main
import (
"flag"
"fmt"
"github.com/netbirdio/netbird/management/server"
"github.com/rs/xid"
)
func main() {
oldDir := flag.String("oldDir", "old store directory", "/var/wiretrustee/datadir")
newDir := flag.String("newDir", "new store directory", "/var/wiretrustee/newdatadir")
flag.Parse()
oldStore, err := server.NewStore(*oldDir)
if err != nil {
panic(err)
}
newStore, err := server.NewStore(*newDir)
if err != nil {
panic(err)
}
err = Convert(oldStore, newStore)
if err != nil {
panic(err)
}
fmt.Println("successfully converted")
}
// Convert converts old store ato a new store
// Previously Account.Id was an Auth0 user id
// Conversion moved user id to Account.CreatedBy and generated a new Account.Id using xid
// It also adds a User with id = old Account.Id with a role Admin
func Convert(oldStore *server.FileStore, newStore *server.FileStore) error {
for _, account := range oldStore.Accounts {
accountCopy := account.Copy()
accountCopy.Id = xid.New().String()
accountCopy.CreatedBy = account.Id
accountCopy.Users[account.Id] = &server.User{
Id: account.Id,
Role: server.UserRoleAdmin,
}
err := newStore.SaveAccount(accountCopy)
if err != nil {
return err
}
}
return nil
}

View File

@@ -1,76 +0,0 @@
package main
import (
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/util"
"path/filepath"
"testing"
)
func TestConvertAccounts(t *testing.T) {
storeDir := t.TempDir()
err := util.CopyFileContents("../testdata/storev1.json", filepath.Join(storeDir, "store.json"))
if err != nil {
t.Fatal(err)
}
store, err := server.NewStore(storeDir)
if err != nil {
t.Fatal(err)
}
convertedStore, err := server.NewStore(filepath.Join(storeDir, "converted"))
if err != nil {
t.Fatal(err)
}
err = Convert(store, convertedStore)
if err != nil {
t.Fatal(err)
}
if len(store.Accounts) != len(convertedStore.Accounts) {
t.Errorf("expecting the same number of accounts after conversion")
}
for _, account := range store.Accounts {
convertedAccount, err := convertedStore.GetUserAccount(account.Id)
if err != nil || convertedAccount == nil {
t.Errorf("expecting Account %s to be converted", account.Id)
return
}
if convertedAccount.CreatedBy != account.Id {
t.Errorf("expecting converted Account.CreatedBy field to be equal to the old Account.Id")
return
}
if convertedAccount.Id == account.Id {
t.Errorf("expecting converted Account.Id to be different from Account.Id")
return
}
if len(convertedAccount.Users) != 1 {
t.Errorf("expecting converted Account.Users to be of size 1")
return
}
user := convertedAccount.Users[account.Id]
if user == nil {
t.Errorf("expecting to find a user in converted Account.Users")
return
}
if user.Role != server.UserRoleAdmin {
t.Errorf("expecting to find a user in converted Account.Users with a role Admin")
return
}
for peerId := range account.Peers {
convertedPeer := convertedAccount.Peers[peerId]
if convertedPeer == nil {
t.Errorf("expecting Account Peer of StoreV1 to be found in StoreV2")
return
}
}
}
}

View File

@@ -14,14 +14,13 @@ type MockAccountManager struct {
GetOrCreateAccountByUserFunc func(userId, domain string) (*server.Account, error)
GetAccountByUserFunc func(userId string) (*server.Account, error)
CreateSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType, expiresIn time.Duration, autoGroups []string) (*server.SetupKey, error)
GetSetupKeyFunc func(accountID string, keyID string) (*server.SetupKey, error)
GetAccountByIdFunc func(accountId string) (*server.Account, error)
GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error)
GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error)
IsUserAdminFunc func(claims jwtclaims.AuthorizationClaims) (bool, error)
AccountExistsFunc func(accountId string) (*bool, error)
GetPeerFunc func(peerKey string) (*server.Peer, error)
GetPeersFunc func(accountID, userID string) ([]*server.Peer, error)
MarkPeerConnectedFunc func(peerKey string, connected bool) error
RenamePeerFunc func(accountId string, peerKey string, newName string) (*server.Peer, error)
DeletePeerFunc func(accountId string, peerKey string) (*server.Peer, error)
GetPeerByIPFunc func(accountId string, peerIP string) (*server.Peer, error)
GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error)
@@ -35,26 +34,26 @@ type MockAccountManager struct {
GroupAddPeerFunc func(accountID, groupID, peerKey string) error
GroupDeletePeerFunc func(accountID, groupID, peerKey string) error
GroupListPeersFunc func(accountID, groupID string) ([]*server.Peer, error)
GetRuleFunc func(accountID, ruleID string) (*server.Rule, error)
GetRuleFunc func(accountID, ruleID, userID string) (*server.Rule, error)
SaveRuleFunc func(accountID string, rule *server.Rule) error
UpdateRuleFunc func(accountID string, ruleID string, operations []server.RuleUpdateOperation) (*server.Rule, error)
DeleteRuleFunc func(accountID, ruleID string) error
ListRulesFunc func(accountID string) ([]*server.Rule, error)
GetUsersFromAccountFunc func(accountID string) ([]*server.UserInfo, error)
ListRulesFunc func(accountID, userID string) ([]*server.Rule, error)
GetUsersFromAccountFunc func(accountID, userID string) ([]*server.UserInfo, error)
UpdatePeerMetaFunc func(peerKey string, meta server.PeerSystemMeta) error
UpdatePeerSSHKeyFunc func(peerKey string, sshKey string) error
UpdatePeerFunc func(accountID string, peer *server.Peer) (*server.Peer, error)
CreateRouteFunc func(accountID string, prefix, peer, description, netID string, masquerade bool, metric int, enabled bool) (*route.Route, error)
GetRouteFunc func(accountID, routeID string) (*route.Route, error)
GetRouteFunc func(accountID, routeID, userID string) (*route.Route, error)
SaveRouteFunc func(accountID string, route *route.Route) error
UpdateRouteFunc func(accountID string, routeID string, operations []server.RouteUpdateOperation) (*route.Route, error)
DeleteRouteFunc func(accountID, routeID string) error
ListRoutesFunc func(accountID string) ([]*route.Route, error)
ListRoutesFunc func(accountID, userID string) ([]*route.Route, error)
SaveSetupKeyFunc func(accountID string, key *server.SetupKey) (*server.SetupKey, error)
ListSetupKeysFunc func(accountID string) ([]*server.SetupKey, error)
ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error)
SaveUserFunc func(accountID string, user *server.User) (*server.UserInfo, error)
GetNameServerGroupFunc func(accountID, nsGroupID string) (*nbdns.NameServerGroup, error)
CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, enabled bool) (*nbdns.NameServerGroup, error)
CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool) (*nbdns.NameServerGroup, error)
SaveNameServerGroupFunc func(accountID string, nsGroupToSave *nbdns.NameServerGroup) error
UpdateNameServerGroupFunc func(accountID, nsGroupID string, operations []server.NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error)
DeleteNameServerGroupFunc func(accountID, nsGroupID string) error
@@ -64,13 +63,21 @@ type MockAccountManager struct {
}
// GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface
func (am *MockAccountManager) GetUsersFromAccount(accountID string) ([]*server.UserInfo, error) {
func (am *MockAccountManager) GetUsersFromAccount(accountID string, userID string) ([]*server.UserInfo, error) {
if am.GetUsersFromAccountFunc != nil {
return am.GetUsersFromAccountFunc(accountID)
return am.GetUsersFromAccountFunc(accountID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetUsersFromAccount is not implemented")
}
// DeletePeer mock implementation of DeletePeer from server.AccountManager interface
func (am *MockAccountManager) DeletePeer(accountId string, peerKey string) (*server.Peer, error) {
if am.DeletePeerFunc != nil {
return am.DeletePeerFunc(accountId, peerKey)
}
return nil, status.Errorf(codes.Unimplemented, "method DeletePeer is not implemented")
}
// GetOrCreateAccountByUser mock implementation of GetOrCreateAccountByUser from server.AccountManager interface
func (am *MockAccountManager) GetOrCreateAccountByUser(
userId, domain string,
@@ -106,16 +113,8 @@ func (am *MockAccountManager) CreateSetupKey(
return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented")
}
// GetAccountById mock implementation of GetAccountById from server.AccountManager interface
func (am *MockAccountManager) GetAccountById(accountId string) (*server.Account, error) {
if am.GetAccountByIdFunc != nil {
return am.GetAccountByIdFunc(accountId)
}
return nil, status.Errorf(codes.Unimplemented, "method GetAccountById is not implemented")
}
// GetAccountByUserOrAccountId mock implementation of GetAccountByUserOrAccountId from server.AccountManager interface
func (am *MockAccountManager) GetAccountByUserOrAccountId(
func (am *MockAccountManager) GetAccountByUserOrAccountID(
userId, accountId, domain string,
) (*server.Account, error) {
if am.GetAccountByUserOrAccountIdFunc != nil {
@@ -123,7 +122,7 @@ func (am *MockAccountManager) GetAccountByUserOrAccountId(
}
return nil, status.Errorf(
codes.Unimplemented,
"method GetAccountByUserOrAccountId is not implemented",
"method GetAccountByUserOrAccountID is not implemented",
)
}
@@ -151,26 +150,6 @@ func (am *MockAccountManager) MarkPeerConnected(peerKey string, connected bool)
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
}
// RenamePeer mock implementation of RenamePeer from server.AccountManager interface
func (am *MockAccountManager) RenamePeer(
accountId string,
peerKey string,
newName string,
) (*server.Peer, error) {
if am.RenamePeerFunc != nil {
return am.RenamePeerFunc(accountId, peerKey, newName)
}
return nil, status.Errorf(codes.Unimplemented, "method RenamePeer is not implemented")
}
// DeletePeer mock implementation of DeletePeer from server.AccountManager interface
func (am *MockAccountManager) DeletePeer(accountId string, peerKey string) (*server.Peer, error) {
if am.DeletePeerFunc != nil {
return am.DeletePeerFunc(accountId, peerKey)
}
return nil, status.Errorf(codes.Unimplemented, "method DeletePeer is not implemented")
}
// GetPeerByIP mock implementation of GetPeerByIP from server.AccountManager interface
func (am *MockAccountManager) GetPeerByIP(accountId string, peerIP string) (*server.Peer, error) {
if am.GetPeerByIPFunc != nil {
@@ -272,9 +251,9 @@ func (am *MockAccountManager) GroupListPeers(accountID, groupID string) ([]*serv
}
// GetRule mock implementation of GetRule from server.AccountManager interface
func (am *MockAccountManager) GetRule(accountID, ruleID string) (*server.Rule, error) {
func (am *MockAccountManager) GetRule(accountID, ruleID, userID string) (*server.Rule, error) {
if am.GetRuleFunc != nil {
return am.GetRuleFunc(accountID, ruleID)
return am.GetRuleFunc(accountID, ruleID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetRule is not implemented")
}
@@ -304,9 +283,9 @@ func (am *MockAccountManager) DeleteRule(accountID, ruleID string) error {
}
// ListRules mock implementation of ListRules from server.AccountManager interface
func (am *MockAccountManager) ListRules(accountID string) ([]*server.Rule, error) {
func (am *MockAccountManager) ListRules(accountID, userID string) ([]*server.Rule, error) {
if am.ListRulesFunc != nil {
return am.ListRulesFunc(accountID)
return am.ListRulesFunc(accountID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method ListRules is not implemented")
}
@@ -345,16 +324,16 @@ func (am *MockAccountManager) UpdatePeer(accountID string, peer *server.Peer) (*
// CreateRoute mock implementation of CreateRoute from server.AccountManager interface
func (am *MockAccountManager) CreateRoute(accountID string, network, peer, description, netID string, masquerade bool, metric int, enabled bool) (*route.Route, error) {
if am.GetRouteFunc != nil {
if am.CreateRouteFunc != nil {
return am.CreateRouteFunc(accountID, network, peer, description, netID, masquerade, metric, enabled)
}
return nil, status.Errorf(codes.Unimplemented, "method CreateRoute is not implemented")
}
// GetRoute mock implementation of GetRoute from server.AccountManager interface
func (am *MockAccountManager) GetRoute(accountID, routeID string) (*route.Route, error) {
func (am *MockAccountManager) GetRoute(accountID, routeID, userID string) (*route.Route, error) {
if am.GetRouteFunc != nil {
return am.GetRouteFunc(accountID, routeID)
return am.GetRouteFunc(accountID, routeID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetRoute is not implemented")
}
@@ -384,9 +363,9 @@ func (am *MockAccountManager) DeleteRoute(accountID, routeID string) error {
}
// ListRoutes mock implementation of ListRoutes from server.AccountManager interface
func (am *MockAccountManager) ListRoutes(accountID string) ([]*route.Route, error) {
func (am *MockAccountManager) ListRoutes(accountID, userID string) ([]*route.Route, error) {
if am.ListRoutesFunc != nil {
return am.ListRoutesFunc(accountID)
return am.ListRoutesFunc(accountID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method ListRoutes is not implemented")
}
@@ -401,18 +380,18 @@ func (am *MockAccountManager) SaveSetupKey(accountID string, key *server.SetupKe
}
// GetSetupKey mocks GetSetupKey of the AccountManager interface
func (am *MockAccountManager) GetSetupKey(accountID, keyID string) (*server.SetupKey, error) {
func (am *MockAccountManager) GetSetupKey(accountID, userID, keyID string) (*server.SetupKey, error) {
if am.GetSetupKeyFunc != nil {
return am.GetSetupKeyFunc(accountID, keyID)
return am.GetSetupKeyFunc(accountID, userID, keyID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetSetupKey is not implemented")
}
// ListSetupKeys mocks ListSetupKeys of the AccountManager interface
func (am *MockAccountManager) ListSetupKeys(accountID string) ([]*server.SetupKey, error) {
func (am *MockAccountManager) ListSetupKeys(accountID, userID string) ([]*server.SetupKey, error) {
if am.ListSetupKeysFunc != nil {
return am.ListSetupKeysFunc(accountID)
return am.ListSetupKeysFunc(accountID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method ListSetupKeys is not implemented")
@@ -435,9 +414,9 @@ func (am *MockAccountManager) GetNameServerGroup(accountID, nsGroupID string) (*
}
// CreateNameServerGroup mocks CreateNameServerGroup of the AccountManager interface
func (am *MockAccountManager) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, enabled bool) (*nbdns.NameServerGroup, error) {
func (am *MockAccountManager) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool) (*nbdns.NameServerGroup, error) {
if am.CreateNameServerGroupFunc != nil {
return am.CreateNameServerGroupFunc(accountID, name, description, nameServerList, groups, enabled)
return am.CreateNameServerGroupFunc(accountID, name, description, nameServerList, groups, primary, domains, enabled)
}
return nil, nil
}
@@ -489,3 +468,11 @@ func (am *MockAccountManager) GetAccountFromToken(claims jwtclaims.Authorization
}
return nil, status.Errorf(codes.Unimplemented, "method GetAccountFromToken is not implemented")
}
// GetPeers mocks GetPeers of the AccountManager interface
func (am *MockAccountManager) GetPeers(accountID, userID string) ([]*server.Peer, error) {
if am.GetAccountFromTokenFunc != nil {
return am.GetPeersFunc(accountID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetPeersFunc is not implemented")
}

View File

@@ -1,8 +1,10 @@
package server
import (
"github.com/miekg/dns"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"strconv"
@@ -20,6 +22,10 @@ const (
UpdateNameServerGroupGroups
// UpdateNameServerGroupEnabled indicates a nameserver group status update operation
UpdateNameServerGroupEnabled
// UpdateNameServerGroupPrimary indicates a nameserver group primary status update operation
UpdateNameServerGroupPrimary
// UpdateNameServerGroupDomains indicates a nameserver group' domains update operation
UpdateNameServerGroupDomains
)
// NameServerGroupUpdateOperationType operation type
@@ -37,6 +43,10 @@ func (t NameServerGroupUpdateOperationType) String() string {
return "UpdateNameServerGroupGroups"
case UpdateNameServerGroupEnabled:
return "UpdateNameServerGroupEnabled"
case UpdateNameServerGroupPrimary:
return "UpdateNameServerGroupPrimary"
case UpdateNameServerGroupDomains:
return "UpdateNameServerGroupDomains"
default:
return "InvalidOperation"
}
@@ -50,8 +60,9 @@ type NameServerGroupUpdateOperation struct {
// GetNameServerGroup gets a nameserver group object from account and nameserver group IDs
func (am *DefaultAccountManager) GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
@@ -67,9 +78,10 @@ func (am *DefaultAccountManager) GetNameServerGroup(accountID, nsGroupID string)
}
// CreateNameServerGroup creates and saves a new nameserver group
func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, enabled bool) (*nbdns.NameServerGroup, error) {
am.mux.Lock()
defer am.mux.Unlock()
func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool) (*nbdns.NameServerGroup, error) {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
@@ -83,6 +95,8 @@ func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, d
NameServers: nameServerList,
Groups: groups,
Enabled: enabled,
Primary: primary,
Domains: domains,
}
err = validateNameServerGroup(false, newNSGroup, account)
@@ -102,13 +116,20 @@ func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, d
return nil, err
}
err = am.updateAccountPeers(account)
if err != nil {
log.Error(err)
return newNSGroup.Copy(), status.Errorf(codes.Unavailable, "failed to update peers after create nameserver %s", name)
}
return newNSGroup.Copy(), nil
}
// SaveNameServerGroup saves nameserver group
func (am *DefaultAccountManager) SaveNameServerGroup(accountID string, nsGroupToSave *nbdns.NameServerGroup) error {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
if nsGroupToSave == nil {
return status.Errorf(codes.InvalidArgument, "nameserver group provided is nil")
@@ -132,13 +153,20 @@ func (am *DefaultAccountManager) SaveNameServerGroup(accountID string, nsGroupTo
return err
}
err = am.updateAccountPeers(account)
if err != nil {
log.Error(err)
return status.Errorf(codes.Unavailable, "failed to update peers after update nameserver %s", nsGroupToSave.Name)
}
return nil
}
// UpdateNameServerGroup updates existing nameserver group with set of operations
func (am *DefaultAccountManager) UpdateNameServerGroup(accountID, nsGroupID string, operations []NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
@@ -205,6 +233,18 @@ func (am *DefaultAccountManager) UpdateNameServerGroup(accountID, nsGroupID stri
return nil, status.Errorf(codes.InvalidArgument, "failed to parse enabled %s, not boolean", operation.Values[0])
}
newNSGroup.Enabled = enabled
case UpdateNameServerGroupPrimary:
primary, err := strconv.ParseBool(operation.Values[0])
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "failed to parse primary status %s, not boolean", operation.Values[0])
}
newNSGroup.Primary = primary
case UpdateNameServerGroupDomains:
err = validateDomainInput(false, operation.Values)
if err != nil {
return nil, err
}
newNSGroup.Domains = operation.Values
}
}
@@ -216,13 +256,20 @@ func (am *DefaultAccountManager) UpdateNameServerGroup(accountID, nsGroupID stri
return nil, err
}
err = am.updateAccountPeers(account)
if err != nil {
log.Error(err)
return newNSGroup.Copy(), status.Errorf(codes.Unavailable, "failed to update peers after update nameserver %s", newNSGroup.Name)
}
return newNSGroup.Copy(), nil
}
// DeleteNameServerGroup deletes nameserver group with nsGroupID
func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID string) error {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
@@ -237,13 +284,20 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID stri
return err
}
err = am.updateAccountPeers(account)
if err != nil {
log.Error(err)
return status.Errorf(codes.Unavailable, "failed to update peers after deleting nameserver %s", nsGroupID)
}
return nil
}
// ListNameServerGroups returns a list of nameserver groups from account
func (am *DefaultAccountManager) ListNameServerGroups(accountID string) ([]*nbdns.NameServerGroup, error) {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
@@ -268,7 +322,12 @@ func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServ
}
}
err := validateNSGroupName(nameserverGroup.Name, nsGroupID, account.NameServerGroups)
err := validateDomainInput(nameserverGroup.Primary, nameserverGroup.Domains)
if err != nil {
return err
}
err = validateNSGroupName(nameserverGroup.Name, nsGroupID, account.NameServerGroups)
if err != nil {
return err
}
@@ -286,6 +345,24 @@ func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServ
return nil
}
func validateDomainInput(primary bool, domains []string) error {
if !primary && len(domains) == 0 {
return status.Errorf(codes.InvalidArgument, "nameserver group primary status is false and domains are empty,"+
" it should be primary or have at least one domain")
}
if primary && len(domains) != 0 {
return status.Errorf(codes.InvalidArgument, "nameserver group primary status is true and domains are not empty,"+
" you should set either primary or domain")
}
for _, domain := range domains {
_, valid := dns.IsDomainName(domain)
if !valid {
return status.Errorf(codes.InvalidArgument, "nameserver group got an invalid domain: %s", domain)
}
}
return nil
}
func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.NameServerGroup) error {
if utf8.RuneCountInString(name) > nbdns.MaxGroupNameChar || name == "" {
return status.Errorf(codes.InvalidArgument, "nameserver group name should be between 1 and %d", nbdns.MaxGroupNameChar)

View File

@@ -14,6 +14,8 @@ const (
existingNSGroupID = "existingNSGroup"
nsGroupPeer1Key = "BhRPtynAAYRDy08+q4HTMsos8fs4plTP4NOSh7C1ry8="
nsGroupPeer2Key = "/yF0+vCfv+mRR5k0dca0TrGdO/oiNeAI58gToZm5NyI="
validDomain = "example.com"
invalidDomain = "dnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdns.com"
)
func TestCreateNameServerGroup(t *testing.T) {
@@ -23,6 +25,8 @@ func TestCreateNameServerGroup(t *testing.T) {
enabled bool
groups []string
nameServers []nbdns.NameServer
primary bool
domains []string
}
testCases := []struct {
@@ -33,11 +37,12 @@ func TestCreateNameServerGroup(t *testing.T) {
expectedNSGroup *nbdns.NameServerGroup
}{
{
name: "Create A NS Group",
name: "Create A NS Group With Primary Status",
inputArgs: input{
name: "super",
description: "super",
groups: []string{group1ID},
primary: true,
nameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("1.1.1.1"),
@@ -57,6 +62,52 @@ func TestCreateNameServerGroup(t *testing.T) {
expectedNSGroup: &nbdns.NameServerGroup{
Name: "super",
Description: "super",
Primary: true,
Groups: []string{group1ID},
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("1.1.1.1"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
{
IP: netip.MustParseAddr("1.1.2.2"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
},
Enabled: true,
},
},
{
name: "Create A NS Group With Domains",
inputArgs: input{
name: "super",
description: "super",
groups: []string{group1ID},
primary: false,
domains: []string{validDomain},
nameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("1.1.1.1"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
{
IP: netip.MustParseAddr("1.1.2.2"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
},
enabled: true,
},
errFunc: require.NoError,
shouldCreate: true,
expectedNSGroup: &nbdns.NameServerGroup{
Name: "super",
Description: "super",
Primary: false,
Domains: []string{"example.com"},
Groups: []string{group1ID},
NameServers: []nbdns.NameServer{
{
@@ -78,6 +129,7 @@ func TestCreateNameServerGroup(t *testing.T) {
inputArgs: input{
name: existingNSGroupName,
description: "super",
primary: true,
groups: []string{group1ID},
nameServers: []nbdns.NameServer{
{
@@ -101,6 +153,7 @@ func TestCreateNameServerGroup(t *testing.T) {
inputArgs: input{
name: "",
description: "super",
primary: true,
groups: []string{group1ID},
nameServers: []nbdns.NameServer{
{
@@ -124,6 +177,7 @@ func TestCreateNameServerGroup(t *testing.T) {
inputArgs: input{
name: "1234567890123456789012345678901234567890extra",
description: "super",
primary: true,
groups: []string{group1ID},
nameServers: []nbdns.NameServer{
{
@@ -147,6 +201,7 @@ func TestCreateNameServerGroup(t *testing.T) {
inputArgs: input{
name: "super",
description: "super",
primary: true,
groups: []string{group1ID},
nameServers: []nbdns.NameServer{},
enabled: true,
@@ -159,6 +214,7 @@ func TestCreateNameServerGroup(t *testing.T) {
inputArgs: input{
name: "super",
description: "super",
primary: true,
groups: []string{group1ID},
nameServers: []nbdns.NameServer{
{
@@ -187,6 +243,7 @@ func TestCreateNameServerGroup(t *testing.T) {
inputArgs: input{
name: "super",
description: "super",
primary: true,
groups: []string{},
nameServers: []nbdns.NameServer{
{
@@ -210,6 +267,7 @@ func TestCreateNameServerGroup(t *testing.T) {
inputArgs: input{
name: "super",
description: "super",
primary: true,
groups: []string{"missingGroup"},
nameServers: []nbdns.NameServer{
{
@@ -233,6 +291,7 @@ func TestCreateNameServerGroup(t *testing.T) {
inputArgs: input{
name: "super",
description: "super",
primary: true,
groups: []string{""},
nameServers: []nbdns.NameServer{
{
@@ -251,6 +310,53 @@ func TestCreateNameServerGroup(t *testing.T) {
errFunc: require.Error,
shouldCreate: false,
},
{
name: "Should Not Create If No Domain Or Primary",
inputArgs: input{
name: "super",
description: "super",
groups: []string{group1ID},
nameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("1.1.1.1"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
{
IP: netip.MustParseAddr("1.1.2.2"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
},
enabled: true,
},
errFunc: require.Error,
shouldCreate: false,
},
{
name: "Should Not Create If Domain List Is Invalid",
inputArgs: input{
name: "super",
description: "super",
groups: []string{group1ID},
domains: []string{invalidDomain},
nameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("1.1.1.1"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
{
IP: netip.MustParseAddr("1.1.2.2"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
},
enabled: true,
},
errFunc: require.Error,
shouldCreate: false,
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
@@ -270,6 +376,8 @@ func TestCreateNameServerGroup(t *testing.T) {
testCase.inputArgs.description,
testCase.inputArgs.nameServers,
testCase.inputArgs.groups,
testCase.inputArgs.primary,
testCase.inputArgs.domains,
testCase.inputArgs.enabled,
)
@@ -295,6 +403,7 @@ func TestSaveNameServerGroup(t *testing.T) {
ID: "testingNSGroup",
Name: "super",
Description: "super",
Primary: true,
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("1.1.1.1"),
@@ -313,6 +422,10 @@ func TestSaveNameServerGroup(t *testing.T) {
validGroups := []string{group2ID}
invalidGroups := []string{"nonExisting"}
disabledPrimary := false
validDomains := []string{validDomain}
invalidDomains := []string{invalidDomain}
validNameServerList := []nbdns.NameServer{
{
IP: netip.MustParseAddr("1.1.1.1"),
@@ -348,6 +461,8 @@ func TestSaveNameServerGroup(t *testing.T) {
existingNSGroup *nbdns.NameServerGroup
newID *string
newName *string
newPrimary *bool
newDomains []string
newNSList []nbdns.NameServer
newGroups []string
skipCopying bool
@@ -356,16 +471,20 @@ func TestSaveNameServerGroup(t *testing.T) {
expectedNSGroup *nbdns.NameServerGroup
}{
{
name: "Should Update Name Server Group",
name: "Should Config Name Server Group",
existingNSGroup: existingNSGroup,
newName: &validName,
newGroups: validGroups,
newPrimary: &disabledPrimary,
newDomains: validDomains,
newNSList: validNameServerList,
errFunc: require.NoError,
shouldCreate: true,
expectedNSGroup: &nbdns.NameServerGroup{
ID: "testingNSGroup",
Name: validName,
Primary: false,
Domains: validDomains,
Description: "super",
NameServers: validNameServerList,
Groups: validGroups,
@@ -373,68 +492,91 @@ func TestSaveNameServerGroup(t *testing.T) {
},
},
{
name: "Should Not Update If Name Is Small",
name: "Should Not Config If Name Is Small",
existingNSGroup: existingNSGroup,
newName: &invalidNameSmall,
errFunc: require.Error,
shouldCreate: false,
},
{
name: "Should Not Update If Name Is Large",
name: "Should Not Config If Name Is Large",
existingNSGroup: existingNSGroup,
newName: &invalidNameLarge,
errFunc: require.Error,
shouldCreate: false,
},
{
name: "Should Not Update If Name Exists",
name: "Should Not Config If Name Exists",
existingNSGroup: existingNSGroup,
newName: &invalidNameExisting,
errFunc: require.Error,
shouldCreate: false,
},
{
name: "Should Not Update If ID Don't Exist",
name: "Should Not Config If ID Don't Exist",
existingNSGroup: existingNSGroup,
newID: &invalidID,
errFunc: require.Error,
shouldCreate: false,
},
{
name: "Should Not Update If Nameserver List Is Small",
name: "Should Not Config If Nameserver List Is Small",
existingNSGroup: existingNSGroup,
newNSList: []nbdns.NameServer{},
errFunc: require.Error,
shouldCreate: false,
},
{
name: "Should Not Update If Nameserver List Is Large",
name: "Should Not Config If Nameserver List Is Large",
existingNSGroup: existingNSGroup,
newNSList: invalidNameServerListLarge,
errFunc: require.Error,
shouldCreate: false,
},
{
name: "Should Not Update If Groups List Is Empty",
name: "Should Not Config If Groups List Is Empty",
existingNSGroup: existingNSGroup,
newGroups: []string{},
errFunc: require.Error,
shouldCreate: false,
},
{
name: "Should Not Update If Groups List Has Empty ID",
name: "Should Not Config If Groups List Has Empty ID",
existingNSGroup: existingNSGroup,
newGroups: []string{""},
errFunc: require.Error,
shouldCreate: false,
},
{
name: "Should Not Update If Groups List Has Non Existing Group ID",
name: "Should Not Config If Groups List Has Non Existing Group ID",
existingNSGroup: existingNSGroup,
newGroups: invalidGroups,
errFunc: require.Error,
shouldCreate: false,
},
{
name: "Should Not Config If Domains List Is Empty",
existingNSGroup: existingNSGroup,
newPrimary: &disabledPrimary,
errFunc: require.Error,
shouldCreate: false,
},
{
name: "Should Not Config If Primary And Domains",
existingNSGroup: existingNSGroup,
newPrimary: &existingNSGroup.Primary,
newDomains: validDomains,
errFunc: require.Error,
shouldCreate: false,
},
{
name: "Should Not Config If Domains List Is Invalid",
existingNSGroup: existingNSGroup,
newPrimary: &disabledPrimary,
newDomains: invalidDomains,
errFunc: require.Error,
shouldCreate: false,
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
@@ -475,6 +617,14 @@ func TestSaveNameServerGroup(t *testing.T) {
if testCase.newNSList != nil {
nsGroupToSave.NameServers = testCase.newNSList
}
if testCase.newPrimary != nil {
nsGroupToSave.Primary = *testCase.newPrimary
}
if testCase.newDomains != nil {
nsGroupToSave.Domains = testCase.newDomains
}
}
err = am.SaveNameServerGroup(account.Id, nsGroupToSave)
@@ -485,6 +635,11 @@ func TestSaveNameServerGroup(t *testing.T) {
return
}
account, err = am.Store.GetAccount(account.Id)
if err != nil {
t.Fatal(err)
}
savedNSGroup, saved := account.NameServerGroups[testCase.expectedNSGroup.ID]
require.True(t, saved)
@@ -503,6 +658,7 @@ func TestUpdateNameServerGroup(t *testing.T) {
ID: nsGroupID,
Name: "super",
Description: "super",
Primary: true,
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("1.1.1.1"),
@@ -529,7 +685,7 @@ func TestUpdateNameServerGroup(t *testing.T) {
expectedNSGroup *nbdns.NameServerGroup
}{
{
name: "Should Update Single Property",
name: "Should Config Single Property",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
@@ -544,6 +700,7 @@ func TestUpdateNameServerGroup(t *testing.T) {
ID: nsGroupID,
Name: "superNew",
Description: "super",
Primary: true,
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("1.1.1.1"),
@@ -561,7 +718,7 @@ func TestUpdateNameServerGroup(t *testing.T) {
},
},
{
name: "Should Update Multiple Properties",
name: "Should Config Multiple Properties",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
@@ -585,6 +742,14 @@ func TestUpdateNameServerGroup(t *testing.T) {
Type: UpdateNameServerGroupEnabled,
Values: []string{"false"},
},
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupPrimary,
Values: []string{"false"},
},
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupDomains,
Values: []string{validDomain},
},
},
errFunc: require.NoError,
shouldCreate: true,
@@ -592,6 +757,8 @@ func TestUpdateNameServerGroup(t *testing.T) {
ID: nsGroupID,
Name: "superNew",
Description: "superDescription",
Primary: false,
Domains: []string{validDomain},
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("127.0.0.1"),
@@ -609,20 +776,20 @@ func TestUpdateNameServerGroup(t *testing.T) {
},
},
{
name: "Should Not Update On Invalid ID",
name: "Should Not Config On Invalid ID",
existingNSGroup: existingNSGroup,
nsGroupID: "nonExistingNSGroup",
errFunc: require.Error,
},
{
name: "Should Not Update On Empty Operations",
name: "Should Not Config On Empty Operations",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{},
errFunc: require.Error,
},
{
name: "Should Not Update On Empty Values",
name: "Should Not Config On Empty Values",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
@@ -633,7 +800,7 @@ func TestUpdateNameServerGroup(t *testing.T) {
errFunc: require.Error,
},
{
name: "Should Not Update On Empty String",
name: "Should Not Config On Empty String",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
@@ -645,7 +812,7 @@ func TestUpdateNameServerGroup(t *testing.T) {
errFunc: require.Error,
},
{
name: "Should Not Update On Invalid Name Large String",
name: "Should Not Config On Invalid Name Large String",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
@@ -657,7 +824,7 @@ func TestUpdateNameServerGroup(t *testing.T) {
errFunc: require.Error,
},
{
name: "Should Not Update On Invalid On Existing Name",
name: "Should Not Config On Invalid On Existing Name",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
@@ -669,7 +836,7 @@ func TestUpdateNameServerGroup(t *testing.T) {
errFunc: require.Error,
},
{
name: "Should Not Update On Invalid On Multiple Name Values",
name: "Should Not Config On Invalid On Multiple Name Values",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
@@ -681,7 +848,7 @@ func TestUpdateNameServerGroup(t *testing.T) {
errFunc: require.Error,
},
{
name: "Should Not Update On Invalid Boolean",
name: "Should Not Config On Invalid Boolean",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
@@ -693,7 +860,7 @@ func TestUpdateNameServerGroup(t *testing.T) {
errFunc: require.Error,
},
{
name: "Should Not Update On Invalid Nameservers Wrong Schema",
name: "Should Not Config On Invalid Nameservers Wrong Schema",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
@@ -705,7 +872,7 @@ func TestUpdateNameServerGroup(t *testing.T) {
errFunc: require.Error,
},
{
name: "Should Not Update On Invalid Nameservers Wrong IP",
name: "Should Not Config On Invalid Nameservers Wrong IP",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
@@ -717,7 +884,7 @@ func TestUpdateNameServerGroup(t *testing.T) {
errFunc: require.Error,
},
{
name: "Should Not Update On Large Number Of Nameservers",
name: "Should Not Config On Large Number Of Nameservers",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
@@ -729,7 +896,7 @@ func TestUpdateNameServerGroup(t *testing.T) {
errFunc: require.Error,
},
{
name: "Should Not Update On Invalid GroupID",
name: "Should Not Config On Invalid GroupID",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
@@ -740,6 +907,30 @@ func TestUpdateNameServerGroup(t *testing.T) {
},
errFunc: require.Error,
},
{
name: "Should Not Config On Invalid Domains",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupDomains,
Values: []string{invalidDomain},
},
},
errFunc: require.Error,
},
{
name: "Should Not Config On Invalid Primary Status",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupPrimary,
Values: []string{"yes"},
},
},
errFunc: require.Error,
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
@@ -865,12 +1056,12 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
if err != nil {
return nil, err
}
return BuildManager(store, NewPeersUpdateManager(), nil)
return BuildManager(store, NewPeersUpdateManager(), nil, "", "")
}
func createNSStore(t *testing.T) (Store, error) {
dataDir := t.TempDir()
store, err := NewStore(dataDir)
store, err := NewFileStore(dataDir)
if err != nil {
return nil, err
}

View File

@@ -2,6 +2,7 @@ package server
import (
"github.com/c-robinson/iplib"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/route"
"github.com/rs/xid"
"google.golang.org/grpc/codes"
@@ -23,9 +24,10 @@ const (
)
type NetworkMap struct {
Peers []*Peer
Network *Network
Routes []*route.Route
Peers []*Peer
Network *Network
Routes []*route.Route
DNSConfig nbdns.Config
}
type Network struct {

View File

@@ -1,6 +1,7 @@
package server
import (
nbdns "github.com/netbirdio/netbird/dns"
"net"
"strings"
"time"
@@ -43,7 +44,11 @@ type Peer struct {
// Meta is a Peer system meta data
Meta PeerSystemMeta
// Name is peer's name (machine name)
Name string
Name string
// DNSLabel is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's
// domain to the peer label. e.g. peer-dns-label.netbird.cloud
DNSLabel string
// Status peer's management connection status
Status *PeerStatus
// The user ID that registered the peer
UserID string
@@ -65,41 +70,84 @@ func (p *Peer) Copy() *Peer {
UserID: p.UserID,
SSHKey: p.SSHKey,
SSHEnabled: p.SSHEnabled,
DNSLabel: p.DNSLabel,
}
}
// GetPeer returns a peer from a Store
func (am *DefaultAccountManager) GetPeer(peerKey string) (*Peer, error) {
am.mux.Lock()
defer am.mux.Unlock()
// Copy PeerStatus
func (p *PeerStatus) Copy() *PeerStatus {
return &PeerStatus{
LastSeen: p.LastSeen,
Connected: p.Connected,
}
}
peer, err := am.Store.GetPeer(peerKey)
// GetPeer looks up peer by its public WireGuard key
func (am *DefaultAccountManager) GetPeer(peerPubKey string) (*Peer, error) {
account, err := am.Store.GetAccountByPeerPubKey(peerPubKey)
if err != nil {
return nil, err
}
return peer, nil
return account.FindPeerByPubKey(peerPubKey)
}
// GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if
// the current user is not an admin.
func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*Peer, error) {
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, err
}
user, err := account.FindUser(userID)
if err != nil {
return nil, err
}
peers := make([]*Peer, 0, len(account.Peers))
for _, peer := range account.Peers {
if !user.IsAdmin() && user.Id != peer.UserID {
// only display peers that belong to the current user if the current user is not an admin
continue
}
peers = append(peers, peer.Copy())
}
return peers, nil
}
// MarkPeerConnected marks peer as connected (true) or disconnected (false)
func (am *DefaultAccountManager) MarkPeerConnected(peerKey string, connected bool) error {
am.mux.Lock()
defer am.mux.Unlock()
func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected bool) error {
peer, err := am.Store.GetPeer(peerKey)
account, err := am.Store.GetAccountByPeerPubKey(peerPubKey)
if err != nil {
return err
}
account, err := am.Store.GetPeerAccount(peerKey)
unlock := am.Store.AcquireAccountLock(account.Id)
defer unlock()
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
account, err = am.Store.GetAccount(account.Id)
if err != nil {
return err
}
peerCopy := peer.Copy()
peerCopy.Status.LastSeen = time.Now()
peerCopy.Status.Connected = connected
err = am.Store.SavePeer(account.Id, peerCopy)
peer, err := account.FindPeerByPubKey(peerPubKey)
if err != nil {
return err
}
newStatus := peer.Status.Copy()
newStatus.LastSeen = time.Now()
newStatus.Connected = connected
peer.Status = newStatus
account.UpdatePeer(peer)
err = am.Store.SavePeerStatus(account.Id, peerPubKey, *newStatus)
if err != nil {
return err
}
@@ -108,26 +156,38 @@ func (am *DefaultAccountManager) MarkPeerConnected(peerKey string, connected boo
// UpdatePeer updates peer. Only Peer.Name and Peer.SSHEnabled can be updated.
func (am *DefaultAccountManager) UpdatePeer(accountID string, update *Peer) (*Peer, error) {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
}
peer, err := am.Store.GetPeer(update.Key)
//TODO Peer.ID migration: we will need to replace search by ID here
peer, err := account.FindPeerByPubKey(update.Key)
if err != nil {
return nil, err
}
peerCopy := peer.Copy()
if peer.Name != "" {
peerCopy.Name = update.Name
peer.Name = update.Name
}
peerCopy.SSHEnabled = update.SSHEnabled
peer.SSHEnabled = update.SSHEnabled
err = am.Store.SavePeer(accountID, peerCopy)
existingLabels := account.getPeerDNSLabels()
newLabel, err := getPeerHostLabel(peer.Name, existingLabels)
if err != nil {
return nil, err
}
peer.DNSLabel = newLabel
account.UpdatePeer(peer)
err = am.Store.SaveAccount(account)
if err != nil {
return nil, err
}
@@ -137,66 +197,33 @@ func (am *DefaultAccountManager) UpdatePeer(accountID string, update *Peer) (*Pe
return nil, err
}
return peerCopy, nil
return peer, nil
}
// RenamePeer changes peer's name
func (am *DefaultAccountManager) RenamePeer(
accountId string,
peerKey string,
newName string,
) (*Peer, error) {
am.mux.Lock()
defer am.mux.Unlock()
// DeletePeer removes peer from the account by its IP
func (am *DefaultAccountManager) DeletePeer(accountID string, peerPubKey string) (*Peer, error) {
peer, err := am.Store.GetPeer(peerKey)
if err != nil {
return nil, err
}
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
peerCopy := peer.Copy()
peerCopy.Name = newName
err = am.Store.SavePeer(accountId, peerCopy)
if err != nil {
return nil, err
}
return peerCopy, nil
}
// DeletePeer removes peer from the account by it's IP
func (am *DefaultAccountManager) DeletePeer(accountId string, peerKey string) (*Peer, error) {
am.mux.Lock()
defer am.mux.Unlock()
account, err := am.Store.GetAccount(accountId)
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
}
// delete peer from groups
for _, g := range account.Groups {
for i, pk := range g.Peers {
if pk == peerKey {
g.Peers = append(g.Peers[:i], g.Peers[i+1:]...)
break
}
}
}
peer, err := am.Store.DeletePeer(accountId, peerKey)
peer, err := account.FindPeerByPubKey(peerPubKey)
if err != nil {
return nil, err
}
account.Network.IncSerial()
account.DeletePeer(peerPubKey)
err = am.Store.SaveAccount(account)
if err != nil {
return nil, err
}
err = am.peersUpdateManager.SendUpdate(peerKey,
err = am.peersUpdateManager.SendUpdate(peerPubKey,
&UpdateMessage{
Update: &proto.SyncResponse{
// fill those field for backward compatibility
@@ -214,20 +241,22 @@ func (am *DefaultAccountManager) DeletePeer(accountId string, peerKey string) (*
return nil, err
}
// TODO Peer.ID migration: we will need to replace search by Peer.ID here
if err := am.updateAccountPeers(account); err != nil {
return nil, err
}
am.peersUpdateManager.CloseChannel(peerKey)
am.peersUpdateManager.CloseChannel(peerPubKey)
return peer, nil
}
// GetPeerByIP returns peer by it's IP
func (am *DefaultAccountManager) GetPeerByIP(accountId string, peerIP string) (*Peer, error) {
am.mux.Lock()
defer am.mux.Unlock()
// GetPeerByIP returns peer by its IP
func (am *DefaultAccountManager) GetPeerByIP(accountID string, peerIP string) (*Peer, error) {
account, err := am.Store.GetAccount(accountId)
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
}
@@ -242,33 +271,42 @@ func (am *DefaultAccountManager) GetPeerByIP(accountId string, peerIP string) (*
}
// GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result)
func (am *DefaultAccountManager) GetNetworkMap(peerKey string) (*NetworkMap, error) {
am.mux.Lock()
defer am.mux.Unlock()
func (am *DefaultAccountManager) GetNetworkMap(peerPubKey string) (*NetworkMap, error) {
account, err := am.Store.GetPeerAccount(peerKey)
account, err := am.Store.GetAccountByPeerPubKey(peerPubKey)
if err != nil {
return nil, status.Errorf(codes.Internal, "Invalid peer key %s", peerKey)
return nil, status.Errorf(codes.Internal, "Invalid peer key %s", peerPubKey)
}
aclPeers := am.getPeersByACL(account, peerKey)
routesUpdate := am.getPeersRoutes(append(aclPeers, account.Peers[peerKey]))
aclPeers := am.getPeersByACL(account, peerPubKey)
routesUpdate := account.GetPeersRoutes(append(aclPeers, account.Peers[peerPubKey]))
var zones []nbdns.CustomZone
peersCustomZone := getPeersCustomZone(account, am.dnsDomain)
if peersCustomZone.Domain != "" {
zones = append(zones, peersCustomZone)
}
dnsUpdate := nbdns.Config{
ServiceEnable: true,
CustomZones: zones,
NameServerGroups: getPeerNSGroups(account, peerPubKey),
}
return &NetworkMap{
Peers: aclPeers,
Network: account.Network.Copy(),
Routes: routesUpdate,
Peers: aclPeers,
Network: account.Network.Copy(),
Routes: routesUpdate,
DNSConfig: dnsUpdate,
}, err
}
// GetPeerNetwork returns the Network for a given peer
func (am *DefaultAccountManager) GetPeerNetwork(peerKey string) (*Network, error) {
am.mux.Lock()
defer am.mux.Unlock()
func (am *DefaultAccountManager) GetPeerNetwork(peerPubKey string) (*Network, error) {
account, err := am.Store.GetPeerAccount(peerKey)
account, err := am.Store.GetAccountByPeerPubKey(peerPubKey)
if err != nil {
return nil, status.Errorf(codes.Internal, "Invalid peer key %s", peerKey)
return nil, status.Errorf(codes.Internal, "invalid peer key %s", peerPubKey)
}
return account.Network.Copy(), err
@@ -281,13 +319,7 @@ func (am *DefaultAccountManager) GetPeerNetwork(peerKey string) (*Network, error
// to it. We also add the User ID to the peer metadata to identify registrant.
// Each new Peer will be assigned a new next net.IP from the Account.Network and Account.Network.LastIP will be updated (IP's are not reused).
// The peer property is just a placeholder for the Peer properties to pass further
func (am *DefaultAccountManager) AddPeer(
setupKey string,
userID string,
peer *Peer,
) (*Peer, error) {
am.mux.Lock()
defer am.mux.Unlock()
func (am *DefaultAccountManager) AddPeer(setupKey string, userID string, peer *Peer) (*Peer, error) {
upperKey := strings.ToUpper(setupKey)
@@ -326,7 +358,7 @@ func (am *DefaultAccountManager) AddPeer(
groupsToAdd = sk.AutoGroups
} else if len(userID) != 0 {
account, err = am.Store.GetUserAccount(userID)
account, err = am.Store.GetAccountByUser(userID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "unable to register peer, unknown user with ID: %s", userID)
}
@@ -342,11 +374,31 @@ func (am *DefaultAccountManager) AddPeer(
return nil, status.Errorf(codes.InvalidArgument, "no setup key or user id provided")
}
var takenIps []net.IP
for _, peer := range account.Peers {
takenIps = append(takenIps, peer.IP)
unlock := am.Store.AcquireAccountLock(account.Id)
defer unlock()
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
account, err = am.Store.GetAccount(account.Id)
if err != nil {
return nil, err
}
var takenIps []net.IP
existingLabels := make(lookupMap)
for _, existingPeer := range account.Peers {
takenIps = append(takenIps, existingPeer.IP)
if existingPeer.DNSLabel != "" {
existingLabels[existingPeer.DNSLabel] = struct{}{}
}
}
newLabel, err := getPeerHostLabel(peer.Name, existingLabels)
if err != nil {
return nil, err
}
peer.DNSLabel = newLabel
network := account.Network
nextIp, err := AllocatePeerIP(network.Net, takenIps)
if err != nil {
@@ -359,6 +411,7 @@ func (am *DefaultAccountManager) AddPeer(
IP: nextIp,
Meta: peer.Meta,
Name: peer.Name,
DNSLabel: newLabel,
UserID: userID,
Status: &PeerStatus{Connected: false, LastSeen: time.Now()},
SSHEnabled: false,
@@ -395,34 +448,41 @@ func (am *DefaultAccountManager) AddPeer(
}
// UpdatePeerSSHKey updates peer's public SSH key
func (am *DefaultAccountManager) UpdatePeerSSHKey(peerKey string, sshKey string) error {
am.mux.Lock()
defer am.mux.Unlock()
func (am *DefaultAccountManager) UpdatePeerSSHKey(peerPubKey string, sshKey string) error {
if sshKey == "" {
log.Debugf("empty SSH key provided for peer %s, skipping update", peerKey)
log.Debugf("empty SSH key provided for peer %s, skipping update", peerPubKey)
return nil
}
peer, err := am.Store.GetPeer(peerKey)
account, err := am.Store.GetAccountByPeerPubKey(peerPubKey)
if err != nil {
return err
}
unlock := am.Store.AcquireAccountLock(account.Id)
defer unlock()
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
account, err = am.Store.GetAccount(account.Id)
if err != nil {
return err
}
peer, err := account.FindPeerByPubKey(peerPubKey)
if err != nil {
return err
}
if peer.SSHKey == sshKey {
log.Debugf("same SSH key provided for peer %s, skipping update", peerKey)
log.Debugf("same SSH key provided for peer %s, skipping update", peerPubKey)
return nil
}
account, err := am.Store.GetPeerAccount(peerKey)
if err != nil {
return err
}
peer.SSHKey = sshKey
account.UpdatePeer(peer)
peerCopy := peer.Copy()
peerCopy.SSHKey = sshKey
err = am.Store.SavePeer(account.Id, peerCopy)
err = am.Store.SaveAccount(account)
if err != nil {
return err
}
@@ -432,29 +492,30 @@ func (am *DefaultAccountManager) UpdatePeerSSHKey(peerKey string, sshKey string)
}
// UpdatePeerMeta updates peer's system metadata
func (am *DefaultAccountManager) UpdatePeerMeta(peerKey string, meta PeerSystemMeta) error {
am.mux.Lock()
defer am.mux.Unlock()
func (am *DefaultAccountManager) UpdatePeerMeta(peerPubKey string, meta PeerSystemMeta) error {
peer, err := am.Store.GetPeer(peerKey)
account, err := am.Store.GetAccountByPeerPubKey(peerPubKey)
if err != nil {
return err
}
account, err := am.Store.GetPeerAccount(peerKey)
unlock := am.Store.AcquireAccountLock(account.Id)
defer unlock()
peer, err := account.FindPeerByPubKey(peerPubKey)
if err != nil {
return err
}
peerCopy := peer.Copy()
// Avoid overwriting UIVersion if the update was triggered sole by the CLI client
if meta.UIVersion == "" {
meta.UIVersion = peerCopy.Meta.UIVersion
meta.UIVersion = peer.Meta.UIVersion
}
peerCopy.Meta = meta
peer.Meta = meta
account.UpdatePeer(peer)
err = am.Store.SavePeer(account.Id, peerCopy)
err = am.Store.SaveAccount(account)
if err != nil {
return err
}
@@ -462,17 +523,9 @@ func (am *DefaultAccountManager) UpdatePeerMeta(peerKey string, meta PeerSystemM
}
// getPeersByACL returns all peers that given peer has access to.
func (am *DefaultAccountManager) getPeersByACL(account *Account, peerKey string) []*Peer {
func (am *DefaultAccountManager) getPeersByACL(account *Account, peerPubKey string) []*Peer {
var peers []*Peer
srcRules, err := am.Store.GetPeerSrcRules(account.Id, peerKey)
if err != nil {
srcRules = []*Rule{}
}
dstRules, err := am.Store.GetPeerDstRules(account.Id, peerKey)
if err != nil {
dstRules = []*Rule{}
}
srcRules, dstRules := account.GetPeerRules(peerPubKey)
groups := map[string]*Group{}
for _, r := range srcRules {
@@ -515,7 +568,7 @@ func (am *DefaultAccountManager) getPeersByACL(account *Account, peerKey string)
continue
}
// exclude original peer
if _, ok := peersSet[peer.Key]; peer.Key != peerKey && !ok {
if _, ok := peersSet[peer.Key]; peer.Key != peerPubKey && !ok {
peersSet[peer.Key] = struct{}{}
peers = append(peers, peer.Copy())
}
@@ -528,36 +581,18 @@ func (am *DefaultAccountManager) getPeersByACL(account *Account, peerKey string)
// updateAccountPeers updates all peers that belong to an account.
// Should be called when changes have to be synced to peers.
func (am *DefaultAccountManager) updateAccountPeers(account *Account) error {
// notify other peers of the change
peers, err := am.Store.GetAccountPeers(account.Id)
if err != nil {
return err
}
network := account.Network.Copy()
peers := account.GetPeers()
for _, peer := range peers {
aclPeers := am.getPeersByACL(account, peer.Key)
peersUpdate := toRemotePeerConfig(aclPeers)
routesUpdate := toProtocolRoutes(am.getPeersRoutes(append(aclPeers, peer)))
err = am.peersUpdateManager.SendUpdate(peer.Key,
&UpdateMessage{
Update: &proto.SyncResponse{
// fill deprecated fields for backward compatibility
RemotePeers: peersUpdate,
RemotePeersIsEmpty: len(peersUpdate) == 0,
// new field
NetworkMap: &proto.NetworkMap{
Serial: account.Network.CurrentSerial(),
RemotePeers: peersUpdate,
RemotePeersIsEmpty: len(peersUpdate) == 0,
PeerConfig: toPeerConfig(peer, network),
Routes: routesUpdate,
},
},
})
remotePeerNetworkMap, err := am.GetNetworkMap(peer.Key)
if err != nil {
return err
return status.Errorf(codes.Internal, "unable to fetch network map for peer %s, error: %v", peer.Key, err)
}
update := toSyncResponse(nil, peer, nil, remotePeerNetworkMap)
err = am.peersUpdateManager.SendUpdate(peer.Key, &UpdateMessage{Update: update})
if err != nil {
return status.Errorf(codes.Internal, "unable to send update for peer %s, error: %v", peer.Key, err)
}
}

View File

@@ -134,7 +134,7 @@ func TestAccountManager_GetNetworkMapWithRule(t *testing.T) {
return
}
rules, err := manager.ListRules(account.Id)
rules, err := manager.ListRules(account.Id, userId)
if err != nil {
t.Errorf("expecting to get a list of rules, got failure %v", err)
return

View File

@@ -60,15 +60,24 @@ type RouteUpdateOperation struct {
}
// GetRoute gets a route object from account and route IDs
func (am *DefaultAccountManager) GetRoute(accountID, routeID string) (*route.Route, error) {
am.mux.Lock()
defer am.mux.Unlock()
func (am *DefaultAccountManager) GetRoute(accountID, routeID, userID string) (*route.Route, error) {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
}
user, err := account.FindUser(userID)
if err != nil {
return nil, err
}
if !user.IsAdmin() {
return nil, Errorf(PermissionDenied, "Only administrators can view Network Routes")
}
wantedRoute, found := account.Routes[routeID]
if found {
return wantedRoute, nil
@@ -84,7 +93,12 @@ func (am *DefaultAccountManager) checkPrefixPeerExists(accountID, peer string, p
return nil
}
routesWithPrefix, err := am.Store.GetRoutesByPrefix(accountID, prefix)
account, err := am.Store.GetAccount(accountID)
if err != nil {
return err
}
routesWithPrefix := account.GetRoutesByPrefix(prefix)
if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
@@ -102,8 +116,8 @@ func (am *DefaultAccountManager) checkPrefixPeerExists(accountID, peer string, p
// CreateRoute creates and saves a new route
func (am *DefaultAccountManager) CreateRoute(accountID string, network, peer, description, netID string, masquerade bool, metric int, enabled bool) (*route.Route, error) {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
@@ -166,8 +180,8 @@ func (am *DefaultAccountManager) CreateRoute(accountID string, network, peer, de
// SaveRoute saves route
func (am *DefaultAccountManager) SaveRoute(accountID string, routeToSave *route.Route) error {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
if routeToSave == nil {
return status.Errorf(codes.InvalidArgument, "route provided is nil")
@@ -209,8 +223,8 @@ func (am *DefaultAccountManager) SaveRoute(accountID string, routeToSave *route.
// UpdateRoute updates existing route with set of operations
func (am *DefaultAccountManager) UpdateRoute(accountID, routeID string, operations []RouteUpdateOperation) (*route.Route, error) {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
@@ -306,8 +320,8 @@ func (am *DefaultAccountManager) UpdateRoute(accountID, routeID string, operatio
// DeleteRoute deletes route with routeID
func (am *DefaultAccountManager) DeleteRoute(accountID, routeID string) error {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
@@ -325,15 +339,24 @@ func (am *DefaultAccountManager) DeleteRoute(accountID, routeID string) error {
}
// ListRoutes returns a list of routes from account
func (am *DefaultAccountManager) ListRoutes(accountID string) ([]*route.Route, error) {
am.mux.Lock()
defer am.mux.Unlock()
func (am *DefaultAccountManager) ListRoutes(accountID, userID string) ([]*route.Route, error) {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
}
user, err := account.FindUser(userID)
if err != nil {
return nil, err
}
if !user.IsAdmin() {
return nil, Errorf(PermissionDenied, "Only administrators can view Network Routes")
}
routes := make([]*route.Route, 0, len(account.Routes))
for _, item := range account.Routes {
routes = append(routes, item)
@@ -354,30 +377,6 @@ func toProtocolRoute(route *route.Route) *proto.Route {
}
}
func (am *DefaultAccountManager) getPeersRoutes(peers []*Peer) []*route.Route {
routes := make([]*route.Route, 0)
for _, peer := range peers {
peerRoutes, err := am.Store.GetPeerRoutes(peer.Key)
if err != nil {
errorStatus, ok := status.FromError(err)
if !ok && errorStatus.Code() != codes.NotFound {
log.Error(err)
}
continue
}
activeRoutes := make([]*route.Route, 0)
for _, pr := range peerRoutes {
if pr.Enabled {
activeRoutes = append(activeRoutes, pr)
}
}
if len(activeRoutes) > 0 {
routes = append(routes, activeRoutes...)
}
}
return routes
}
func toProtocolRoutes(routes []*route.Route) []*proto.Route {
protoRoutes := make([]*proto.Route, 0)
for _, r := range routes {

View File

@@ -380,6 +380,11 @@ func TestSaveRoute(t *testing.T) {
return
}
account, err = am.Store.GetAccount(account.Id)
if err != nil {
t.Fatal(err)
}
savedRoute, saved := account.Routes[testCase.expectedRoute.ID]
require.True(t, saved)
@@ -740,7 +745,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
err = am.SaveGroup(account.Id, newGroup)
require.NoError(t, err)
rules, err := am.ListRules(account.Id)
rules, err := am.ListRules(account.Id, "testingUser")
require.NoError(t, err)
defaultRule := rules[0]
@@ -778,12 +783,12 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) {
if err != nil {
return nil, err
}
return BuildManager(store, NewPeersUpdateManager(), nil)
return BuildManager(store, NewPeersUpdateManager(), nil, "", "")
}
func createRouterStore(t *testing.T) (Store, error) {
dataDir := t.TempDir()
store, err := NewStore(dataDir)
store, err := NewFileStore(dataDir)
if err != nil {
return nil, err
}
@@ -840,5 +845,5 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er
return nil, err
}
return account, nil
return am.Store.GetAccount(account.Id)
}

View File

@@ -89,15 +89,24 @@ func (r *Rule) Copy() *Rule {
}
// GetRule of ACL from the store
func (am *DefaultAccountManager) GetRule(accountID, ruleID string) (*Rule, error) {
am.mux.Lock()
defer am.mux.Unlock()
func (am *DefaultAccountManager) GetRule(accountID, ruleID, userID string) (*Rule, error) {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
}
user, err := account.FindUser(userID)
if err != nil {
return nil, err
}
if !user.IsAdmin() {
return nil, Errorf(PermissionDenied, "only admins are allowed to view rules")
}
rule, ok := account.Rules[ruleID]
if ok {
return rule, nil
@@ -108,8 +117,8 @@ func (am *DefaultAccountManager) GetRule(accountID, ruleID string) (*Rule, error
// SaveRule of ACL in the store
func (am *DefaultAccountManager) SaveRule(accountID string, rule *Rule) error {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
@@ -129,8 +138,8 @@ func (am *DefaultAccountManager) SaveRule(accountID string, rule *Rule) error {
// UpdateRule updates a rule using a list of operations
func (am *DefaultAccountManager) UpdateRule(accountID string, ruleID string,
operations []RuleUpdateOperation) (*Rule, error) {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
@@ -203,8 +212,8 @@ func (am *DefaultAccountManager) UpdateRule(accountID string, ruleID string,
// DeleteRule of ACL from the store
func (am *DefaultAccountManager) DeleteRule(accountID, ruleID string) error {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
@@ -222,15 +231,24 @@ func (am *DefaultAccountManager) DeleteRule(accountID, ruleID string) error {
}
// ListRules of ACL from the store
func (am *DefaultAccountManager) ListRules(accountID string) ([]*Rule, error) {
am.mux.Lock()
defer am.mux.Unlock()
func (am *DefaultAccountManager) ListRules(accountID, userID string) ([]*Rule, error) {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
}
user, err := account.FindUser(userID)
if err != nil {
return nil, err
}
if !user.IsAdmin() {
return nil, Errorf(PermissionDenied, "Only Administrators can view Access Rules")
}
rules := make([]*Rule, 0, len(account.Rules))
for _, item := range account.Rules {
rules = append(rules, item)

View File

@@ -9,6 +9,7 @@ import (
"strconv"
"strings"
"time"
"unicode/utf8"
)
const (
@@ -100,6 +101,15 @@ func (key *SetupKey) Copy() *SetupKey {
}
}
// HiddenCopy returns a copy of the key with a Key value hidden with "*" and a 5 character prefix.
// E.g., "831F6*******************************"
func (key *SetupKey) HiddenCopy() *SetupKey {
k := key.Copy()
prefix := k.Key[0:5]
k.Key = prefix + strings.Repeat("*", utf8.RuneCountInString(key.Key)-len(prefix))
return k
}
// IncrementUsage makes a copy of a key, increments the UsedTimes by 1 and sets LastUsed to now
func (key *SetupKey) IncrementUsage() *SetupKey {
c := key.Copy()
@@ -163,8 +173,8 @@ func Hash(s string) uint32 {
// and adds it to the specified account. A list of autoGroups IDs can be empty.
func (am *DefaultAccountManager) CreateSetupKey(accountID string, keyName string, keyType SetupKeyType,
expiresIn time.Duration, autoGroups []string) (*SetupKey, error) {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
keyDuration := DefaultSetupKeyDuration
if expiresIn != 0 {
@@ -198,8 +208,8 @@ func (am *DefaultAccountManager) CreateSetupKey(accountID string, keyName string
// (e.g. the key itself, creation date, ID, etc).
// These properties are overwritten: Name, AutoGroups, Revoked. The rest is copied from the existing key.
func (am *DefaultAccountManager) SaveSetupKey(accountID string, keyToSave *SetupKey) (*SetupKey, error) {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
if keyToSave == nil {
return nil, status.Errorf(codes.InvalidArgument, "provided setup key to update is nil")
@@ -238,32 +248,48 @@ func (am *DefaultAccountManager) SaveSetupKey(accountID string, keyToSave *Setup
}
// ListSetupKeys returns a list of all setup keys of the account
func (am *DefaultAccountManager) ListSetupKeys(accountID string) ([]*SetupKey, error) {
am.mux.Lock()
defer am.mux.Unlock()
func (am *DefaultAccountManager) ListSetupKeys(accountID, userID string) ([]*SetupKey, error) {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
}
user, err := account.FindUser(userID)
if err != nil {
return nil, err
}
keys := make([]*SetupKey, 0, len(account.SetupKeys))
for _, key := range account.SetupKeys {
keys = append(keys, key.Copy())
var k *SetupKey
if !user.IsAdmin() {
k = key.HiddenCopy()
} else {
k = key.Copy()
}
keys = append(keys, k)
}
return keys, nil
}
// GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found.
func (am *DefaultAccountManager) GetSetupKey(accountID, keyID string) (*SetupKey, error) {
am.mux.Lock()
defer am.mux.Unlock()
func (am *DefaultAccountManager) GetSetupKey(accountID, userID, keyID string) (*SetupKey, error) {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
}
user, err := account.FindUser(userID)
if err != nil {
return nil, err
}
var foundKey *SetupKey
for _, key := range account.SetupKeys {
if key.Id == keyID {
@@ -280,5 +306,9 @@ func (am *DefaultAccountManager) GetSetupKey(accountID, keyID string) (*SetupKey
foundKey.UpdatedAt = foundKey.CreatedAt
}
if !user.IsAdmin() {
foundKey = foundKey.HiddenCopy()
}
return foundKey, nil
}

View File

@@ -1,26 +1,20 @@
package server
import (
"github.com/netbirdio/netbird/route"
"net/netip"
)
type Store interface {
GetPeer(peerKey string) (*Peer, error)
DeletePeer(accountId string, peerKey string) (*Peer, error)
SavePeer(accountId string, peer *Peer) error
GetAllAccounts() []*Account
GetAccount(accountId string) (*Account, error)
GetUserAccount(userId string) (*Account, error)
GetAccountPeers(accountId string) ([]*Peer, error)
GetPeerAccount(peerKey string) (*Account, error)
GetPeerSrcRules(accountId, peerKey string) ([]*Rule, error)
GetPeerDstRules(accountId, peerKey string) ([]*Rule, error)
GetAccountBySetupKey(setupKey string) (*Account, error)
GetAccount(accountID string) (*Account, error)
GetAccountByUser(userID string) (*Account, error)
GetAccountByPeerPubKey(peerKey string) (*Account, error)
GetAccountBySetupKey(setupKey string) (*Account, error) //todo use key hash later
GetAccountByPrivateDomain(domain string) (*Account, error)
SaveAccount(account *Account) error
GetPeerRoutes(peerKey string) ([]*route.Route, error)
GetRoutesByPrefix(accountID string, prefix netip.Prefix) ([]*route.Route, error)
GetInstallationID() string
SaveInstallationID(id string) error
SaveInstallationID(ID string) error
// AcquireAccountLock should attempt to acquire account lock and return a function that releases the lock
AcquireAccountLock(accountID string) func()
// AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock
AcquireGlobalLock() func()
SavePeerStatus(accountID, peerKey string, status PeerStatus) error
// Close should close the store persisting all unsaved data.
Close() error
}

View File

@@ -0,0 +1,181 @@
package telemetry
import (
"context"
"fmt"
"github.com/gorilla/mux"
prometheus2 "github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/exporters/prometheus"
metric2 "go.opentelemetry.io/otel/metric"
"go.opentelemetry.io/otel/sdk/metric"
"net"
"net/http"
"reflect"
)
const defaultEndpoint = "/metrics"
// MockAppMetrics mocks the AppMetrics interface
type MockAppMetrics struct {
GetMeterFunc func() metric2.Meter
CloseFunc func() error
ExposeFunc func(port int, endpoint string) error
IDPMetricsFunc func() *IDPMetrics
HTTPMiddlewareFunc func() *HTTPMiddleware
GRPCMetricsFunc func() *GRPCMetrics
}
// GetMeter mocks the GetMeter function of the AppMetrics interface
func (mock *MockAppMetrics) GetMeter() metric2.Meter {
if mock.GetMeterFunc != nil {
return mock.GetMeterFunc()
}
return nil
}
// Close mocks the Close function of the AppMetrics interface
func (mock *MockAppMetrics) Close() error {
if mock.CloseFunc != nil {
return mock.CloseFunc()
}
return fmt.Errorf("unimplemented")
}
// Expose mocks the Expose function of the AppMetrics interface
func (mock *MockAppMetrics) Expose(port int, endpoint string) error {
if mock.ExposeFunc != nil {
return mock.ExposeFunc(port, endpoint)
}
return fmt.Errorf("unimplemented")
}
// IDPMetrics mocks the IDPMetrics function of the IDPMetrics interface
func (mock *MockAppMetrics) IDPMetrics() *IDPMetrics {
if mock.IDPMetricsFunc != nil {
return mock.IDPMetricsFunc()
}
return nil
}
// HTTPMiddleware mocks the HTTPMiddleware function of the IDPMetrics interface
func (mock *MockAppMetrics) HTTPMiddleware() *HTTPMiddleware {
if mock.HTTPMiddlewareFunc != nil {
return mock.HTTPMiddlewareFunc()
}
return nil
}
// GRPCMetrics mocks the GRPCMetrics function of the IDPMetrics interface
func (mock *MockAppMetrics) GRPCMetrics() *GRPCMetrics {
if mock.GRPCMetricsFunc != nil {
return mock.GRPCMetricsFunc()
}
return nil
}
// AppMetrics is metrics interface
type AppMetrics interface {
GetMeter() metric2.Meter
Close() error
Expose(port int, endpoint string) error
IDPMetrics() *IDPMetrics
HTTPMiddleware() *HTTPMiddleware
GRPCMetrics() *GRPCMetrics
}
// defaultAppMetrics are core application metrics based on OpenTelemetry https://opentelemetry.io/
type defaultAppMetrics struct {
// Meter can be used by different application parts to create counters and measure things
Meter metric2.Meter
listener net.Listener
ctx context.Context
idpMetrics *IDPMetrics
httpMiddleware *HTTPMiddleware
grpcMetrics *GRPCMetrics
}
// IDPMetrics returns metrics for the idp package
func (appMetrics *defaultAppMetrics) IDPMetrics() *IDPMetrics {
return appMetrics.idpMetrics
}
// HTTPMiddleware returns metrics for the http api package
func (appMetrics *defaultAppMetrics) HTTPMiddleware() *HTTPMiddleware {
return appMetrics.httpMiddleware
}
// GRPCMetrics returns metrics for the gRPC api
func (appMetrics *defaultAppMetrics) GRPCMetrics() *GRPCMetrics {
return appMetrics.grpcMetrics
}
// Close stop application metrics HTTP handler and closes listener.
func (appMetrics *defaultAppMetrics) Close() error {
if appMetrics.listener == nil {
return nil
}
return appMetrics.listener.Close()
}
// Expose metrics on a given port and endpoint. If endpoint is empty a defaultEndpoint one will be used.
// Exposes metrics in the Prometheus format https://prometheus.io/
func (appMetrics *defaultAppMetrics) Expose(port int, endpoint string) error {
if endpoint == "" {
endpoint = defaultEndpoint
}
rootRouter := mux.NewRouter()
rootRouter.Handle(endpoint, promhttp.HandlerFor(
prometheus2.DefaultGatherer,
promhttp.HandlerOpts{EnableOpenMetrics: true}))
listener, err := net.Listen("tcp4", fmt.Sprintf(":%d", port))
if err != nil {
return err
}
appMetrics.listener = listener
go func() {
err := http.Serve(listener, rootRouter)
if err != nil {
return
}
}()
log.Infof("enabled application metrics and exposing on http://%s", listener.Addr().String())
return nil
}
// GetMeter returns metrics meter that can be used to add various counters
func (appMetrics *defaultAppMetrics) GetMeter() metric2.Meter {
return appMetrics.Meter
}
// NewDefaultAppMetrics and expose them via defaultEndpoint on a given HTTP port
func NewDefaultAppMetrics(ctx context.Context) (AppMetrics, error) {
exporter, err := prometheus.New()
if err != nil {
return nil, err
}
provider := metric.NewMeterProvider(metric.WithReader(exporter))
pkg := reflect.TypeOf(defaultEndpoint).PkgPath()
meter := provider.Meter(pkg)
idpMetrics, err := NewIDPMetrics(ctx, meter)
if err != nil {
return nil, err
}
middleware, err := NewMetricsMiddleware(ctx, meter)
if err != nil {
return nil, err
}
grpcMetrics, err := NewGRPCMetrics(ctx, meter)
if err != nil {
return nil, err
}
return &defaultAppMetrics{Meter: meter, ctx: ctx, idpMetrics: idpMetrics, httpMiddleware: middleware,
grpcMetrics: grpcMetrics}, nil
}

View File

@@ -0,0 +1,76 @@
package telemetry
import (
"context"
"go.opentelemetry.io/otel/metric"
"go.opentelemetry.io/otel/metric/instrument"
"go.opentelemetry.io/otel/metric/instrument/asyncint64"
"go.opentelemetry.io/otel/metric/instrument/syncint64"
)
// GRPCMetrics are gRPC server metrics
type GRPCMetrics struct {
meter metric.Meter
syncRequestsCounter syncint64.Counter
loginRequestsCounter syncint64.Counter
getKeyRequestsCounter syncint64.Counter
activeStreamsGauge asyncint64.Gauge
ctx context.Context
}
// NewGRPCMetrics creates new GRPCMetrics struct and registers common metrics of the gRPC server
func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, error) {
syncRequestsCounter, err := meter.SyncInt64().Counter("management.grpc.sync.request.counter", instrument.WithUnit("1"))
if err != nil {
return nil, err
}
loginRequestsCounter, err := meter.SyncInt64().Counter("management.grpc.login.request.counter", instrument.WithUnit("1"))
if err != nil {
return nil, err
}
getKeyRequestsCounter, err := meter.SyncInt64().Counter("management.grpc.key.request.counter", instrument.WithUnit("1"))
if err != nil {
return nil, err
}
activeStreamsGauge, err := meter.AsyncInt64().Gauge("management.grpc.connected.streams", instrument.WithUnit("1"))
if err != nil {
return nil, err
}
return &GRPCMetrics{
meter: meter,
syncRequestsCounter: syncRequestsCounter,
loginRequestsCounter: loginRequestsCounter,
getKeyRequestsCounter: getKeyRequestsCounter,
activeStreamsGauge: activeStreamsGauge,
ctx: ctx,
}, err
}
// CountSyncRequest counts the number of gRPC sync requests coming to the gRPC API
func (grpcMetrics *GRPCMetrics) CountSyncRequest() {
grpcMetrics.syncRequestsCounter.Add(grpcMetrics.ctx, 1)
}
// CountGetKeyRequest counts the number of gRPC get server key requests coming to the gRPC API
func (grpcMetrics *GRPCMetrics) CountGetKeyRequest() {
grpcMetrics.getKeyRequestsCounter.Add(grpcMetrics.ctx, 1)
}
// CountLoginRequest counts the number of gRPC login requests coming to the gRPC API
func (grpcMetrics *GRPCMetrics) CountLoginRequest() {
grpcMetrics.loginRequestsCounter.Add(grpcMetrics.ctx, 1)
}
// RegisterConnectedStreams registers a function that collects number of active streams and feeds it to the metrics gauge.
func (grpcMetrics *GRPCMetrics) RegisterConnectedStreams(producer func() int64) error {
return grpcMetrics.meter.RegisterCallback(
[]instrument.Asynchronous{
grpcMetrics.activeStreamsGauge,
},
func(ctx context.Context) {
grpcMetrics.activeStreamsGauge.Observe(ctx, producer())
},
)
}

View File

@@ -0,0 +1,176 @@
package telemetry
import (
"context"
"fmt"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/metric"
"go.opentelemetry.io/otel/metric/instrument"
"go.opentelemetry.io/otel/metric/instrument/syncint64"
"hash/fnv"
"net/http"
"strings"
)
const (
httpRequestCounterPrefix = "management.http.request.counter"
httpResponseCounterPrefix = "management.http.response.counter"
)
// WrappedResponseWriter is a wrapper for http.ResponseWriter that allows the
// written HTTP status code to be captured for metrics reporting or logging purposes.
type WrappedResponseWriter struct {
http.ResponseWriter
status int
wroteHeader bool
}
// WrapResponseWriter wraps original http.ResponseWriter
func WrapResponseWriter(w http.ResponseWriter) *WrappedResponseWriter {
return &WrappedResponseWriter{ResponseWriter: w}
}
// Status returns response status
func (rw *WrappedResponseWriter) Status() int {
return rw.status
}
// WriteHeader wraps http.ResponseWriter.WriteHeader method
func (rw *WrappedResponseWriter) WriteHeader(code int) {
if rw.wroteHeader {
return
}
rw.status = code
rw.ResponseWriter.WriteHeader(code)
rw.wroteHeader = true
}
// HTTPMiddleware handler used to collect metrics of every request/response coming to the API.
// Also adds request tracing (logging).
type HTTPMiddleware struct {
meter metric.Meter
ctx context.Context
// defaultEndpoint & method
httpRequestCounters map[string]syncint64.Counter
// defaultEndpoint & method & status code
httpResponseCounters map[string]syncint64.Counter
// all HTTP requests
totalHTTPRequestsCounter syncint64.Counter
// all HTTP responses
totalHTTPResponseCounter syncint64.Counter
// all HTTP responses by status code
totalHTTPResponseCodeCounters map[int]syncint64.Counter
}
// AddHTTPRequestResponseCounter adds a new meter for an HTTP defaultEndpoint and Method (GET, POST, etc)
// Creates one request counter and multiple response counters (one per http response status code).
func (m *HTTPMiddleware) AddHTTPRequestResponseCounter(endpoint string, method string) error {
meterKey := getRequestCounterKey(endpoint, method)
httpReqCounter, err := m.meter.SyncInt64().Counter(meterKey, instrument.WithUnit("1"))
if err != nil {
return err
}
m.httpRequestCounters[meterKey] = httpReqCounter
respCodes := []int{200, 204, 400, 401, 403, 404, 500, 502, 503}
for _, code := range respCodes {
meterKey = getResponseCounterKey(endpoint, method, code)
httpRespCounter, err := m.meter.SyncInt64().Counter(meterKey, instrument.WithUnit("1"))
if err != nil {
return err
}
m.httpResponseCounters[meterKey] = httpRespCounter
meterKey = fmt.Sprintf("%s_%d_total", httpResponseCounterPrefix, code)
totalHTTPResponseCodeCounter, err := m.meter.SyncInt64().Counter(meterKey, instrument.WithUnit("1"))
if err != nil {
return err
}
m.totalHTTPResponseCodeCounters[code] = totalHTTPResponseCodeCounter
}
return nil
}
// NewMetricsMiddleware creates a new HTTPMiddleware
func NewMetricsMiddleware(ctx context.Context, meter metric.Meter) (*HTTPMiddleware, error) {
totalHTTPRequestsCounter, err := meter.SyncInt64().Counter(
fmt.Sprintf("%s_total", httpRequestCounterPrefix),
instrument.WithUnit("1"))
if err != nil {
return nil, err
}
totalHTTPResponseCounter, err := meter.SyncInt64().Counter(
fmt.Sprintf("%s_total", httpResponseCounterPrefix),
instrument.WithUnit("1"))
if err != nil {
return nil, err
}
return &HTTPMiddleware{
ctx: ctx,
httpRequestCounters: map[string]syncint64.Counter{},
httpResponseCounters: map[string]syncint64.Counter{},
totalHTTPResponseCodeCounters: map[int]syncint64.Counter{},
meter: meter,
totalHTTPRequestsCounter: totalHTTPRequestsCounter,
totalHTTPResponseCounter: totalHTTPResponseCounter,
},
nil
}
func getRequestCounterKey(endpoint, method string) string {
return fmt.Sprintf("%s%s_%s", httpRequestCounterPrefix,
strings.ReplaceAll(endpoint, "/", "_"), method)
}
func getResponseCounterKey(endpoint, method string, status int) string {
return fmt.Sprintf("%s%s_%s_%d", httpResponseCounterPrefix,
strings.ReplaceAll(endpoint, "/", "_"), method, status)
}
// Handler logs every request and response and adds the, to metrics.
func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler {
fn := func(rw http.ResponseWriter, r *http.Request) {
traceID := hash(fmt.Sprintf("%v", r))
log.Tracef("HTTP request %v: %v %v", traceID, r.Method, r.URL)
metricKey := getRequestCounterKey(r.URL.Path, r.Method)
if c, ok := m.httpRequestCounters[metricKey]; ok {
c.Add(m.ctx, 1)
}
m.totalHTTPRequestsCounter.Add(m.ctx, 1)
w := WrapResponseWriter(rw)
h.ServeHTTP(w, r)
if w.Status() > 399 {
log.Errorf("HTTP response %v: %v %v status %v", traceID, r.Method, r.URL, w.Status())
} else {
log.Tracef("HTTP response %v: %v %v status %v", traceID, r.Method, r.URL, w.Status())
}
metricKey = getResponseCounterKey(r.URL.Path, r.Method, w.Status())
if c, ok := m.httpResponseCounters[metricKey]; ok {
c.Add(m.ctx, 1)
}
m.totalHTTPResponseCounter.Add(m.ctx, 1)
if c, ok := m.totalHTTPResponseCodeCounters[w.Status()]; ok {
c.Add(m.ctx, 1)
}
}
return http.HandlerFunc(fn)
}
func hash(s string) uint32 {
h := fnv.New32a()
_, err := h.Write([]byte(s))
if err != nil {
panic(err)
}
return h.Sum32()
}

View File

@@ -0,0 +1,119 @@
package telemetry
import (
"context"
"go.opentelemetry.io/otel/metric"
"go.opentelemetry.io/otel/metric/instrument"
"go.opentelemetry.io/otel/metric/instrument/syncint64"
)
// IDPMetrics is common IdP metrics
type IDPMetrics struct {
metaUpdateCounter syncint64.Counter
getUserByEmailCounter syncint64.Counter
getAllAccountsCounter syncint64.Counter
createUserCounter syncint64.Counter
getAccountCounter syncint64.Counter
getUserByIDCounter syncint64.Counter
authenticateRequestCounter syncint64.Counter
requestErrorCounter syncint64.Counter
requestStatusErrorCounter syncint64.Counter
ctx context.Context
}
// NewIDPMetrics creates new IDPMetrics struct and registers common
func NewIDPMetrics(ctx context.Context, meter metric.Meter) (*IDPMetrics, error) {
metaUpdateCounter, err := meter.SyncInt64().Counter("management.idp.update.user.meta.counter", instrument.WithUnit("1"))
if err != nil {
return nil, err
}
getUserByEmailCounter, err := meter.SyncInt64().Counter("management.idp.get.user.by.email.counter", instrument.WithUnit("1"))
if err != nil {
return nil, err
}
getAllAccountsCounter, err := meter.SyncInt64().Counter("management.idp.get.accounts.counter", instrument.WithUnit("1"))
if err != nil {
return nil, err
}
createUserCounter, err := meter.SyncInt64().Counter("management.idp.create.user.counter", instrument.WithUnit("1"))
if err != nil {
return nil, err
}
getAccountCounter, err := meter.SyncInt64().Counter("management.idp.get.account.counter", instrument.WithUnit("1"))
if err != nil {
return nil, err
}
getUserByIDCounter, err := meter.SyncInt64().Counter("management.idp.get.user.by.id.counter", instrument.WithUnit("1"))
if err != nil {
return nil, err
}
authenticateRequestCounter, err := meter.SyncInt64().Counter("management.idp.authenticate.request.counter", instrument.WithUnit("1"))
if err != nil {
return nil, err
}
requestErrorCounter, err := meter.SyncInt64().Counter("management.idp.request.error.counter", instrument.WithUnit("1"))
if err != nil {
return nil, err
}
requestStatusErrorCounter, err := meter.SyncInt64().Counter("management.idp.request.status.error.counter", instrument.WithUnit("1"))
if err != nil {
return nil, err
}
return &IDPMetrics{
metaUpdateCounter: metaUpdateCounter,
getUserByEmailCounter: getUserByEmailCounter,
getAllAccountsCounter: getAllAccountsCounter,
createUserCounter: createUserCounter,
getAccountCounter: getAccountCounter,
getUserByIDCounter: getUserByIDCounter,
authenticateRequestCounter: authenticateRequestCounter,
requestErrorCounter: requestErrorCounter,
requestStatusErrorCounter: requestStatusErrorCounter,
ctx: ctx}, nil
}
// CountUpdateUserAppMetadata ...
func (idpMetrics *IDPMetrics) CountUpdateUserAppMetadata() {
idpMetrics.metaUpdateCounter.Add(idpMetrics.ctx, 1)
}
// CountGetUserByEmail ...
func (idpMetrics *IDPMetrics) CountGetUserByEmail() {
idpMetrics.getUserByEmailCounter.Add(idpMetrics.ctx, 1)
}
// CountCreateUser ...
func (idpMetrics *IDPMetrics) CountCreateUser() {
idpMetrics.createUserCounter.Add(idpMetrics.ctx, 1)
}
// CountGetAllAccounts ...
func (idpMetrics *IDPMetrics) CountGetAllAccounts() {
idpMetrics.getAllAccountsCounter.Add(idpMetrics.ctx, 1)
}
// CountGetAccount ...
func (idpMetrics *IDPMetrics) CountGetAccount() {
idpMetrics.getAccountCounter.Add(idpMetrics.ctx, 1)
}
// CountGetUserDataByID ...
func (idpMetrics *IDPMetrics) CountGetUserDataByID() {
idpMetrics.getUserByIDCounter.Add(idpMetrics.ctx, 1)
}
// CountAuthenticate ...
func (idpMetrics *IDPMetrics) CountAuthenticate() {
idpMetrics.authenticateRequestCounter.Add(idpMetrics.ctx, 1)
}
// CountRequestError counts number of error that happened when doing http request (httpClient.Do)
func (idpMetrics *IDPMetrics) CountRequestError() {
idpMetrics.requestErrorCounter.Add(idpMetrics.ctx, 1)
}
// CountRequestStatusError counts number of responses that came from IdP with non success status code
func (idpMetrics *IDPMetrics) CountRequestStatusError() {
idpMetrics.requestStatusErrorCounter.Add(idpMetrics.ctx, 1)
}

View File

@@ -46,6 +46,11 @@ type User struct {
AutoGroups []string
}
// IsAdmin returns true if user is an admin, false otherwise
func (u *User) IsAdmin() bool {
return u.Role == UserRoleAdmin
}
// toUserInfo converts a User object to a UserInfo object.
func (u *User) toUserInfo(userData *idp.UserData) (*UserInfo, error) {
autoGroups := u.AutoGroups
@@ -68,7 +73,7 @@ func (u *User) toUserInfo(userData *idp.UserData) (*UserInfo, error) {
}
userStatus := UserStatusActive
if userData.AppMetadata.WTPendingInvite {
if userData.AppMetadata.WTPendingInvite != nil && *userData.AppMetadata.WTPendingInvite {
userStatus = UserStatusInvited
}
@@ -114,8 +119,8 @@ func NewAdminUser(id string) *User {
// CreateUser creates a new user under the given account. Effectively this is a user invite.
func (am *DefaultAccountManager) CreateUser(accountID string, invite *UserInfo) (*UserInfo, error) {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
if am.idpManager == nil {
return nil, Errorf(PreconditionFailed, "IdP manager must be enabled to send user invites")
@@ -179,8 +184,8 @@ func (am *DefaultAccountManager) CreateUser(accountID string, invite *UserInfo)
// SaveUser saves updates a given user. If the user doesn't exit it will throw status.NotFound error.
// Only User.AutoGroups field is allowed to be updated for now.
func (am *DefaultAccountManager) SaveUser(accountID string, update *User) (*UserInfo, error) {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
if update == nil {
return nil, status.Errorf(codes.InvalidArgument, "provided user update is nil")
@@ -229,16 +234,16 @@ func (am *DefaultAccountManager) SaveUser(accountID string, update *User) (*User
}
// GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist
func (am *DefaultAccountManager) GetOrCreateAccountByUser(userId, domain string) (*Account, error) {
am.mux.Lock()
defer am.mux.Unlock()
func (am *DefaultAccountManager) GetOrCreateAccountByUser(userID, domain string) (*Account, error) {
unlock := am.Store.AcquireGlobalLock()
defer unlock()
lowerDomain := strings.ToLower(domain)
account, err := am.Store.GetUserAccount(userId)
account, err := am.Store.GetAccountByUser(userID)
if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
account, err = am.newAccount(userId, lowerDomain)
account, err = am.newAccount(userID, lowerDomain)
if err != nil {
return nil, err
}
@@ -252,7 +257,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userId, domain string)
}
}
userObj := account.Users[userId]
userObj := account.Users[userID]
if account.Domain != lowerDomain && userObj.Role == UserRoleAdmin {
account.Domain = lowerDomain
@@ -265,14 +270,6 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userId, domain string)
return account, nil
}
// GetAccountByUser returns an existing account for a given user id, NotFound if account couldn't be found
func (am *DefaultAccountManager) GetAccountByUser(userId string) (*Account, error) {
am.mux.Lock()
defer am.mux.Unlock()
return am.Store.GetUserAccount(userId)
}
// IsUserAdmin flag for current user authenticated by JWT token
func (am *DefaultAccountManager) IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) {
account, err := am.GetAccountFromToken(claims)
@@ -288,9 +285,15 @@ func (am *DefaultAccountManager) IsUserAdmin(claims jwtclaims.AuthorizationClaim
return user.Role == UserRoleAdmin, nil
}
// GetUsersFromAccount performs a batched request for users from IDP by account ID
func (am *DefaultAccountManager) GetUsersFromAccount(accountID string) ([]*UserInfo, error) {
account, err := am.GetAccountById(accountID)
// GetUsersFromAccount performs a batched request for users from IDP by account ID apply filter on what data to return
// based on provided user role.
func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error) {
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, err
}
user, err := account.FindUser(userID)
if err != nil {
return nil, err
}
@@ -311,8 +314,12 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID string) ([]*UserI
// in case of self-hosted, or IDP doesn't return anything, we will return the locally stored userInfo
if len(queriedUsers) == 0 {
for _, user := range account.Users {
info, err := user.toUserInfo(nil)
for _, accountUser := range account.Users {
if !user.IsAdmin() && user.Id != accountUser.Id {
// if user is not an admin then show only current user and do not show other users
continue
}
info, err := accountUser.toUserInfo(nil)
if err != nil {
return nil, err
}
@@ -322,6 +329,10 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID string) ([]*UserI
}
for _, queriedUser := range queriedUsers {
if !user.IsAdmin() && user.Id != queriedUser.ID {
// if user is not an admin then show only current user and do not show other users
continue
}
if localUser, contains := account.Users[queriedUser.ID]; contains {
info, err := localUser.toUserInfo(queriedUser)

View File

@@ -3,7 +3,6 @@ package util
import (
"encoding/json"
"io"
"io/ioutil"
"os"
"path/filepath"
)
@@ -24,7 +23,7 @@ func WriteJson(file string, obj interface{}) error {
return err
}
tempFile, err := ioutil.TempFile(configDir, ".*"+configFileName)
tempFile, err := os.CreateTemp(configDir, ".*"+configFileName)
if err != nil {
return err
}
@@ -43,7 +42,7 @@ func WriteJson(file string, obj interface{}) error {
}
}()
err = ioutil.WriteFile(tempFileName, bs, 0600)
err = os.WriteFile(tempFileName, bs, 0600)
if err != nil {
return err
}
@@ -65,7 +64,7 @@ func ReadJson(file string, res interface{}) (interface{}, error) {
}
defer f.Close()
bs, err := ioutil.ReadAll(f)
bs, err := io.ReadAll(f)
if err != nil {
return nil, err
}

View File

@@ -40,7 +40,7 @@ func InitLog(logLevel string, logPath string) error {
return "", fileName
}
if level == log.DebugLevel {
if level > log.WarnLevel {
log.SetReportCaller(true)
}