Compare commits

...

30 Commits

Author SHA1 Message Date
Maycon Santos
157137e4ad Use a single way to generate network map (#550) 2022-11-08 11:38:40 +01:00
Maycon Santos
7d7e576775 Set report caller when info or higher (#555) 2022-11-08 10:56:13 +01:00
Misha Bragin
f37b43a542 Save Peer Status separately in the FileStore (#554)
Due to peer reconnects when restarting the Management service,
there are lots of SaveStore operations to update peer status.

Store.SavePeerStatus stores peer status separately and the
FileStore implementation stores it in memory.
2022-11-08 10:46:12 +01:00
Maycon Santos
7e262572a4 Move dns label generation to store (#552) 2022-11-08 10:31:34 +01:00
Misha Bragin
a768a0aa8a Always lock the store when getting an account (#551) 2022-11-07 19:09:22 +01:00
Misha Bragin
ed7ac81027 Introduce locking on the account level (#548) 2022-11-07 17:52:23 +01:00
Maycon Santos
1f845f466c Add account copy test (#549) 2022-11-07 17:37:28 +01:00
Maycon Santos
270f0e4ce8 Feature/dns protocol (#543)
Added DNS update protocol message

Added sync to clients

Update nameserver API with new fields

Added default NS groups

Added new dns-name flag for the management service append to peer DNS label
2022-11-07 15:38:21 +01:00
Misha Bragin
d0c6d88971 Simplified Store Interface (#545)
This PR simplifies Store and FileStore
by keeping just the Get and Save account methods.

The AccountManager operates mostly around
a single account, so it makes sense to fetch
the whole account object from the store.
2022-11-07 12:10:56 +01:00
Misha Bragin
4321b71984 Hide content based on user role (#541) 2022-11-05 10:24:50 +01:00
Maycon Santos
e8d82c1bd3 Feature/dns-server (#537)
Adding DNS server for client

Updated the API with new fields

Added custom zone object for peer's DNS resolution
2022-11-03 18:39:37 +01:00
Misha Bragin
6aa7a2c5e1 Hide setup key from non-admin users (#539) 2022-11-03 17:02:31 +01:00
Rui Lopes
2e0bf61e9a correctly set the windows application icon on windows (#535)
the icon format is not really supported, so this uses a png instead.

this closes https://github.com/netbirdio/netbird/issues/534.
2022-11-01 00:34:30 +01:00
Maycon Santos
126af9dffc Return gateway address if not nil (#533)
If the gateway address would be nil which is
the case on macOS, we return the preferredSrc

added tests for getExistingRIBRouteGateway function

update log message
2022-10-31 11:54:34 +01:00
Maycon Santos
4cdf2df660 Update sign pipeline version to 0.0.4 (#531)
This version has a fix for the
macOS UI client architecture
2022-10-31 11:03:42 +01:00
Maycon Santos
9a4c9aa286 Add active peers count per OS (#526)
* Add active peers count per OS

* increase iface tests timeout
2022-10-26 14:48:40 +02:00
Rui Lopes
5ed61700ff Set the application icon, settings window title and systray tooltip (#523) 2022-10-26 14:34:30 +02:00
Misha Bragin
84117a9fb7 Update WireGuard trademark note 2022-10-23 11:47:42 +02:00
Misha Bragin
92b612eba4 Update demo video link 2022-10-22 16:55:49 +02:00
Misha Bragin
aeeaa21eed Update README.md (#524) 2022-10-22 16:19:16 +02:00
Misha Bragin
d228cd0cb1 Remove release note 2022-10-22 15:10:09 +02:00
Misha Bragin
b41f36fccd Add gRPC metrics (#522) 2022-10-22 15:06:54 +02:00
Misha Bragin
d2cde4a040 Add IdP metrics (#521) 2022-10-22 13:29:39 +02:00
Misha Bragin
84879a356b Extract app metrics to a separate struct (#520) 2022-10-22 11:50:21 +02:00
Misha Bragin
ed2214f9a9 Add HTTP request/response totals to metrics (#519) 2022-10-22 10:07:13 +02:00
braginini
4388dcc20b Listen metrics on all interfaces 2022-10-21 16:50:06 +02:00
Misha Bragin
4f1f0df7d2 Add Open-telemetry support (#517)
This PR brings open-telemetry metrics to the
Management service.
The Management service exposes new HTTP endpoint
/metrics on 8081 port by default.
The port can be changed by specifying
--metrics-port PORT flag when starting the service.
2022-10-21 16:24:13 +02:00
Misha Bragin
08ddf04c5f Fix IdP tests (#516) 2022-10-19 18:36:10 +02:00
Misha Bragin
b5ee2174a8 Do not set wt_pending_invite when unnecessary (#515)
wt_pending_invite property is set for every user on IdP.
Avoid setting it when unnecessary.
2022-10-19 17:51:41 +02:00
Misha Bragin
7218a3d563 Management single account mode (#511) 2022-10-19 17:43:28 +02:00
87 changed files with 4690 additions and 1411 deletions

View File

@@ -25,7 +25,6 @@ jobs:
needs: pre needs: pre
runs-on: windows-latest runs-on: windows-latest
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v2 uses: actions/checkout@v2

View File

@@ -9,7 +9,7 @@ on:
pull_request: pull_request:
env: env:
SIGN_PIPE_VER: "v0.0.3" SIGN_PIPE_VER: "v0.0.4"
GORELEASER_VER: "v1.6.3" GORELEASER_VER: "v1.6.3"
jobs: jobs:

View File

@@ -1,6 +1,6 @@
<p align="center"> <p align="center">
<strong>:hatching_chick: New release! NetBird Easy SSH</strong>. <strong>:hatching_chick: New Release! User Invites.</strong>
<a href="https://github.com/netbirdio/netbird/releases/tag/v0.8.0"> <a href="https://github.com/netbirdio/netbird/releases">
Learn more Learn more
</a> </a>
</p> </p>
@@ -62,9 +62,8 @@ NetBird creates an overlay peer-to-peer network connecting machines automaticall
- \[ ] Network Activity Monitoring. - \[ ] Network Activity Monitoring.
### Secure peer-to-peer VPN with SSO and MFA in minutes ### Secure peer-to-peer VPN with SSO and MFA in minutes
<p float="left" align="middle">
<img src="docs/media/netbird-sso-mfa-demo.gif" width="800"/> https://user-images.githubusercontent.com/700848/197345890-2e2cded5-7b7a-436f-a444-94e80dd24f46.mov
</p>
**Note**: The `main` branch may be in an *unstable or even broken state* during development. **Note**: The `main` branch may be in an *unstable or even broken state* during development.
For stable versions, see [releases](https://github.com/netbirdio/netbird/releases). For stable versions, see [releases](https://github.com/netbirdio/netbird/releases).
@@ -104,5 +103,5 @@ See a complete [architecture overview](https://netbird.io/docs/overview/architec
We use open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE (WebRTC)](https://github.com/pion/ice), and [Coturn](https://github.com/coturn/coturn). We very much appreciate the work these guys are doing and we'd greatly appreciate if you could support them in any way (e.g. giving a star or a contribution). We use open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE (WebRTC)](https://github.com/pion/ice), and [Coturn](https://github.com/coturn/coturn). We very much appreciate the work these guys are doing and we'd greatly appreciate if you could support them in any way (e.g. giving a star or a contribution).
### Legal ### Legal
[WireGuard](https://wireguard.com/) is a registered trademark of Jason A. Donenfeld. _WireGuard_ and the _WireGuard_ logo are [registered trademarks](https://www.wireguard.com/trademark-policy/) of Jason A. Donenfeld.

View File

@@ -62,18 +62,18 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
t.Fatal(err) t.Fatal(err)
} }
s := grpc.NewServer() s := grpc.NewServer()
store, err := mgmt.NewStore(config.Datadir) store, err := mgmt.NewFileStore(config.Datadir)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
peersUpdateManager := mgmt.NewPeersUpdateManager() peersUpdateManager := mgmt.NewPeersUpdateManager()
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil) accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig) turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager) mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -0,0 +1,56 @@
package dns
import (
"github.com/miekg/dns"
nbdns "github.com/netbirdio/netbird/dns"
log "github.com/sirupsen/logrus"
"sync"
)
type localResolver struct {
registeredMap registrationMap
records sync.Map
}
// ServeDNS handles a DNS request
func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
log.Tracef("received question: %#v\n", r.Question[0])
response := d.lookupRecord(r)
if response == nil {
log.Debugf("got empty response for question: %#v\n", r.Question[0])
return
}
replyMessage := &dns.Msg{}
replyMessage.SetReply(r)
replyMessage.Answer = append(replyMessage.Answer, response)
err := w.WriteMsg(replyMessage)
if err != nil {
log.Debugf("got an error while writing the local resolver response, error: %v", err)
}
}
func (d *localResolver) lookupRecord(r *dns.Msg) dns.RR {
record, found := d.records.Load(r.Question[0].Name)
if !found {
return nil
}
return record.(dns.RR)
}
func (d *localResolver) registerRecord(record nbdns.SimpleRecord) error {
fullRecord, err := dns.NewRR(record.String())
if err != nil {
return err
}
d.records.Store(fullRecord.Header().Name, fullRecord)
return nil
}
func (d *localResolver) deleteRecord(recordKey string) {
d.records.Delete(dns.Fqdn(recordKey))
}

View File

@@ -0,0 +1,86 @@
package dns
import (
"github.com/miekg/dns"
nbdns "github.com/netbirdio/netbird/dns"
"strings"
"testing"
)
func TestLocalResolver_ServeDNS(t *testing.T) {
recordA := nbdns.SimpleRecord{
Name: "peera.netbird.cloud.",
Type: 1,
Class: nbdns.DefaultClass,
TTL: 300,
RData: "1.2.3.4",
}
recordCNAME := nbdns.SimpleRecord{
Name: "peerb.netbird.cloud.",
Type: 5,
Class: nbdns.DefaultClass,
TTL: 300,
RData: "www.netbird.io",
}
testCases := []struct {
name string
inputRecord nbdns.SimpleRecord
inputMSG *dns.Msg
responseShouldBeNil bool
}{
{
name: "Should Resolve A Record",
inputRecord: recordA,
inputMSG: new(dns.Msg).SetQuestion(recordA.Name, dns.TypeA),
},
{
name: "Should Resolve CNAME Record",
inputRecord: recordCNAME,
inputMSG: new(dns.Msg).SetQuestion(recordCNAME.Name, dns.TypeCNAME),
},
{
name: "Should Not Write When Not Found A Record",
inputRecord: recordA,
inputMSG: new(dns.Msg).SetQuestion("not.found.com", dns.TypeA),
responseShouldBeNil: true,
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
resolver := &localResolver{
registeredMap: make(registrationMap),
}
_ = resolver.registerRecord(testCase.inputRecord)
var responseMSG *dns.Msg
responseWriter := &mockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
resolver.ServeDNS(responseWriter, testCase.inputMSG)
if responseMSG == nil {
if testCase.responseShouldBeNil {
return
}
t.Fatalf("should write a response message")
}
answerString := responseMSG.Answer[0].String()
if !strings.Contains(answerString, testCase.inputRecord.Name) {
t.Fatalf("answer doesn't contain the same domain name: \nWant: %s\nGot:%s", testCase.name, answerString)
}
if !strings.Contains(answerString, dns.Type(testCase.inputRecord.Type).String()) {
t.Fatalf("answer doesn't contain the correct type: \nWant: %s\nGot:%s", dns.Type(testCase.inputRecord.Type).String(), answerString)
}
if !strings.Contains(answerString, testCase.inputRecord.RData) {
t.Fatalf("answer doesn't contain the same address: \nWant: %s\nGot:%s", testCase.inputRecord.RData, answerString)
}
})
}
}

View File

@@ -0,0 +1,35 @@
package dns
import (
"fmt"
nbdns "github.com/netbirdio/netbird/dns"
)
// MockServer is the mock instance of a dns server
type MockServer struct {
StartFunc func()
StopFunc func()
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
}
// Start mock implementation of Start from Server interface
func (m *MockServer) Start() {
if m.StartFunc != nil {
m.StartFunc()
}
}
// Stop mock implementation of Stop from Server interface
func (m *MockServer) Stop() {
if m.StopFunc != nil {
m.StopFunc()
}
}
// UpdateDNSServer mock implementation of UpdateDNSServer from Server interface
func (m *MockServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
if m.UpdateDNSServerFunc != nil {
return m.UpdateDNSServerFunc(serial, update)
}
return fmt.Errorf("method UpdateDNSServer is not implemented")
}

View File

@@ -0,0 +1,25 @@
package dns
import (
"github.com/miekg/dns"
"net"
)
type mockResponseWriter struct {
WriteMsgFunc func(m *dns.Msg) error
}
func (rw *mockResponseWriter) WriteMsg(m *dns.Msg) error {
if rw.WriteMsgFunc != nil {
return rw.WriteMsgFunc(m)
}
return nil
}
func (rw *mockResponseWriter) LocalAddr() net.Addr { return nil }
func (rw *mockResponseWriter) RemoteAddr() net.Addr { return nil }
func (rw *mockResponseWriter) Write([]byte) (int, error) { return 0, nil }
func (rw *mockResponseWriter) Close() error { return nil }
func (rw *mockResponseWriter) TsigStatus() error { return nil }
func (rw *mockResponseWriter) TsigTimersOnly(bool) {}
func (rw *mockResponseWriter) Hijack() {}

View File

@@ -0,0 +1,277 @@
package dns
import (
"context"
"fmt"
"github.com/miekg/dns"
nbdns "github.com/netbirdio/netbird/dns"
log "github.com/sirupsen/logrus"
"sync"
"time"
)
const (
port = 5053
defaultIP = "0.0.0.0"
)
// Server is a dns server interface
type Server interface {
Start()
Stop()
UpdateDNSServer(serial uint64, update nbdns.Config) error
}
// DefaultServer dns server object
type DefaultServer struct {
ctx context.Context
stop context.CancelFunc
mux sync.Mutex
server *dns.Server
dnsMux *dns.ServeMux
dnsMuxMap registrationMap
localResolver *localResolver
updateSerial uint64
listenerIsRunning bool
}
type registrationMap map[string]struct{}
type muxUpdate struct {
domain string
handler dns.Handler
}
// NewDefaultServer returns a new dns server
func NewDefaultServer(ctx context.Context) *DefaultServer {
mux := dns.NewServeMux()
dnsServer := &dns.Server{
Addr: fmt.Sprintf("%s:%d", defaultIP, port),
Net: "udp",
Handler: mux,
UDPSize: 65535,
}
ctx, stop := context.WithCancel(ctx)
return &DefaultServer{
ctx: ctx,
stop: stop,
server: dnsServer,
dnsMux: mux,
dnsMuxMap: make(registrationMap),
localResolver: &localResolver{
registeredMap: make(registrationMap),
},
}
}
// Start runs the listener in a go routine
func (s *DefaultServer) Start() {
log.Debugf("starting dns on %s:%d", defaultIP, port)
go func() {
s.setListenerStatus(true)
defer s.setListenerStatus(false)
err := s.server.ListenAndServe()
if err != nil {
log.Errorf("dns server returned an error: %v", err)
}
}()
}
func (s *DefaultServer) setListenerStatus(running bool) {
s.listenerIsRunning = running
}
// Stop stops the server
func (s *DefaultServer) Stop() {
s.stop()
err := s.stopListener()
if err != nil {
log.Error(err)
}
}
func (s *DefaultServer) stopListener() error {
if !s.listenerIsRunning {
return nil
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err := s.server.ShutdownContext(ctx)
if err != nil {
return fmt.Errorf("stopping dns server listener returned an error: %v", err)
}
return nil
}
// UpdateDNSServer processes an update received from the management service
func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
select {
case <-s.ctx.Done():
log.Infof("not updating DNS server as context is closed")
return s.ctx.Err()
default:
if serial < s.updateSerial {
return fmt.Errorf("not applying dns update, error: "+
"network update is %d behind the last applied update", s.updateSerial-serial)
}
s.mux.Lock()
defer s.mux.Unlock()
// is the service should be disabled, we stop the listener
// and proceed with a regular update to clean up the handlers and records
if !update.ServiceEnable {
err := s.stopListener()
if err != nil {
log.Error(err)
}
} else if !s.listenerIsRunning {
s.Start()
}
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
if err != nil {
return fmt.Errorf("not applying dns update, error: %v", err)
}
upstreamMuxUpdates, err := s.buildUpstreamHandlerUpdate(update.NameServerGroups)
if err != nil {
return fmt.Errorf("not applying dns update, error: %v", err)
}
muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...)
s.updateMux(muxUpdates)
s.updateLocalResolver(localRecords)
s.updateSerial = serial
return nil
}
}
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]muxUpdate, map[string]nbdns.SimpleRecord, error) {
var muxUpdates []muxUpdate
localRecords := make(map[string]nbdns.SimpleRecord, 0)
for _, customZone := range customZones {
if len(customZone.Records) == 0 {
return nil, nil, fmt.Errorf("received an empty list of records")
}
muxUpdates = append(muxUpdates, muxUpdate{
domain: customZone.Domain,
handler: s.localResolver,
})
for _, record := range customZone.Records {
localRecords[record.Name] = record
}
}
return muxUpdates, localRecords, nil
}
func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]muxUpdate, error) {
var muxUpdates []muxUpdate
for _, nsGroup := range nameServerGroups {
if len(nsGroup.NameServers) == 0 {
return nil, fmt.Errorf("received a nameserver group with empty nameserver list")
}
handler := &upstreamResolver{
parentCTX: s.ctx,
upstreamClient: &dns.Client{},
upstreamTimeout: defaultUpstreamTimeout,
}
for _, ns := range nsGroup.NameServers {
if ns.NSType != nbdns.UDPNameServerType {
log.Warnf("skiping nameserver %s with type %s, this peer supports only %s",
ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String())
continue
}
handler.upstreamServers = append(handler.upstreamServers, getNSHostPort(ns))
}
if len(handler.upstreamServers) == 0 {
log.Errorf("received a nameserver group with an invalid nameserver list")
continue
}
if nsGroup.Primary {
muxUpdates = append(muxUpdates, muxUpdate{
domain: nbdns.RootZone,
handler: handler,
})
continue
}
if len(nsGroup.Domains) == 0 {
return nil, fmt.Errorf("received a non primary nameserver group with an empty domain list")
}
for _, domain := range nsGroup.Domains {
if domain == "" {
return nil, fmt.Errorf("received a nameserver group with an empty domain element")
}
muxUpdates = append(muxUpdates, muxUpdate{
domain: domain,
handler: handler,
})
}
}
return muxUpdates, nil
}
func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
muxUpdateMap := make(registrationMap)
for _, update := range muxUpdates {
s.registerMux(update.domain, update.handler)
muxUpdateMap[update.domain] = struct{}{}
}
for key := range s.dnsMuxMap {
_, found := muxUpdateMap[key]
if !found {
s.deregisterMux(key)
}
}
s.dnsMuxMap = muxUpdateMap
}
func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord) {
for key := range s.localResolver.registeredMap {
_, found := update[key]
if !found {
s.localResolver.deleteRecord(key)
}
}
updatedMap := make(registrationMap)
for key, record := range update {
err := s.localResolver.registerRecord(record)
if err != nil {
log.Warnf("got an error while registering the record (%s), error: %v", record.String(), err)
}
updatedMap[key] = struct{}{}
}
s.localResolver.registeredMap = updatedMap
}
func getNSHostPort(ns nbdns.NameServer) string {
return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port)
}
func (s *DefaultServer) registerMux(pattern string, handler dns.Handler) {
s.dnsMux.Handle(pattern, handler)
}
func (s *DefaultServer) deregisterMux(pattern string) {
s.dnsMux.HandleRemove(pattern)
}

View File

@@ -0,0 +1,285 @@
package dns
import (
"context"
"fmt"
nbdns "github.com/netbirdio/netbird/dns"
"net"
"net/netip"
"os"
"runtime"
"testing"
"time"
)
var zoneRecords = []nbdns.SimpleRecord{
{
Name: "peera.netbird.cloud",
Type: 1,
Class: nbdns.DefaultClass,
TTL: 300,
RData: "1.2.3.4",
},
}
func TestUpdateDNSServer(t *testing.T) {
nameServers := []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
{
IP: netip.MustParseAddr("8.8.4.4"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
}
testCases := []struct {
name string
initUpstreamMap registrationMap
initLocalMap registrationMap
initSerial uint64
inputSerial uint64
inputUpdate nbdns.Config
shouldFail bool
expectedUpstreamMap registrationMap
expectedLocalMap registrationMap
}{
{
name: "Initial Config Should Succeed",
initLocalMap: make(registrationMap),
initUpstreamMap: make(registrationMap),
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
Domains: []string{"netbird.io"},
NameServers: nameServers,
},
{
NameServers: nameServers,
Primary: true,
},
},
},
expectedUpstreamMap: registrationMap{"netbird.io": struct{}{}, "netbird.cloud": struct{}{}, nbdns.RootZone: struct{}{}},
expectedLocalMap: registrationMap{zoneRecords[0].Name: struct{}{}},
},
{
name: "New Config Should Succeed",
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
initUpstreamMap: registrationMap{zoneRecords[0].Name: struct{}{}},
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
Domains: []string{"netbird.io"},
NameServers: nameServers,
},
},
},
expectedUpstreamMap: registrationMap{"netbird.io": struct{}{}, "netbird.cloud": struct{}{}},
expectedLocalMap: registrationMap{zoneRecords[0].Name: struct{}{}},
},
{
name: "Smaller Config Serial Should Be Skipped",
initLocalMap: make(registrationMap),
initUpstreamMap: make(registrationMap),
initSerial: 2,
inputSerial: 1,
shouldFail: true,
},
{
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
initLocalMap: make(registrationMap),
initUpstreamMap: make(registrationMap),
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
NameServers: nameServers,
},
},
},
shouldFail: true,
},
{
name: "Invalid NS Group Nameservers list Should Fail",
initLocalMap: make(registrationMap),
initUpstreamMap: make(registrationMap),
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
NameServers: nameServers,
},
},
},
shouldFail: true,
},
{
name: "Invalid Custom Zone Records list Should Fail",
initLocalMap: make(registrationMap),
initUpstreamMap: make(registrationMap),
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
NameServers: nameServers,
Primary: true,
},
},
},
shouldFail: true,
},
{
name: "Empty Config Should Succeed and Clean Maps",
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
initUpstreamMap: registrationMap{zoneRecords[0].Name: struct{}{}},
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{ServiceEnable: true},
expectedUpstreamMap: make(registrationMap),
expectedLocalMap: make(registrationMap),
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
ctx := context.Background()
dnsServer := NewDefaultServer(ctx)
dnsServer.dnsMuxMap = testCase.initUpstreamMap
dnsServer.localResolver.registeredMap = testCase.initLocalMap
dnsServer.updateSerial = testCase.initSerial
dnsServer.listenerIsRunning = true
err := dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
if err != nil {
if testCase.shouldFail {
return
}
t.Fatalf("update dns server should not fail, got error: %v", err)
}
if len(dnsServer.dnsMuxMap) != len(testCase.expectedUpstreamMap) {
t.Fatalf("update upstream failed, map size is different than expected, want %d, got %d", len(testCase.expectedUpstreamMap), len(dnsServer.dnsMuxMap))
}
for key := range testCase.expectedUpstreamMap {
_, found := dnsServer.dnsMuxMap[key]
if !found {
t.Fatalf("update upstream failed, key %s was not found in the dnsMuxMap: %#v", key, dnsServer.dnsMuxMap)
}
}
if len(dnsServer.localResolver.registeredMap) != len(testCase.expectedLocalMap) {
t.Fatalf("update local failed, registered map size is different than expected, want %d, got %d", len(testCase.expectedLocalMap), len(dnsServer.localResolver.registeredMap))
}
for key := range testCase.expectedLocalMap {
_, found := dnsServer.localResolver.registeredMap[key]
if !found {
t.Fatalf("update local failed, key %s was not found in the localResolver.registeredMap: %#v", key, dnsServer.localResolver.registeredMap)
}
}
})
}
}
func TestDNSServerStartStop(t *testing.T) {
ctx := context.Background()
dnsServer := NewDefaultServer(ctx)
if runtime.GOOS == "windows" && os.Getenv("CI") == "true" {
// todo review why this test is not working only on github actions workflows
t.Skip("skipping test in Windows CI workflows.")
}
dnsServer.Start()
err := dnsServer.localResolver.registerRecord(zoneRecords[0])
if err != nil {
t.Error(err)
}
dnsServer.dnsMux.Handle("netbird.cloud", dnsServer.localResolver)
resolver := &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
d := net.Dialer{
Timeout: time.Second * 5,
}
addr := fmt.Sprintf("127.0.0.1:%d", port)
conn, err := d.DialContext(ctx, network, addr)
if err != nil {
t.Log(err)
// retry test before exit, for slower systems
return d.DialContext(ctx, network, addr)
}
return conn, nil
},
}
ips, err := resolver.LookupHost(context.Background(), zoneRecords[0].Name)
if err != nil {
t.Fatalf("failed to connect to the server, error: %v", err)
}
t.Log(ips)
if ips[0] != zoneRecords[0].RData {
t.Fatalf("got a different IP from the server: want %s, got %s", zoneRecords[0].RData, ips[0])
}
dnsServer.Stop()
ctx, cancel := context.WithTimeout(ctx, time.Second*1)
defer cancel()
_, err = resolver.LookupHost(ctx, zoneRecords[0].Name)
if err == nil {
t.Fatalf("we should encounter an error when querying a stopped server")
}
}

View File

@@ -0,0 +1,67 @@
package dns
import (
"context"
"errors"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"net"
"time"
)
const defaultUpstreamTimeout = 15 * time.Second
type upstreamResolver struct {
parentCTX context.Context
upstreamClient *dns.Client
upstreamServers []string
upstreamTimeout time.Duration
}
// ServeDNS handles a DNS request
func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
log.Tracef("received an upstream question: %#v", r.Question[0])
select {
case <-u.parentCTX.Done():
return
default:
}
for _, upstream := range u.upstreamServers {
ctx, cancel := context.WithTimeout(u.parentCTX, u.upstreamTimeout)
rm, t, err := u.upstreamClient.ExchangeContext(ctx, r, upstream)
cancel()
if err != nil {
if err == context.DeadlineExceeded || isTimeout(err) {
log.Warnf("got an error while connecting to upstream %s, error: %v", upstream, err)
continue
}
log.Errorf("got an error while querying the upstream %s, error: %v", upstream, err)
return
}
log.Tracef("took %s to query the upstream %s", t, upstream)
err = w.WriteMsg(rm)
if err != nil {
log.Errorf("got an error while writing the upstream resolver response, error: %v", err)
}
return
}
log.Errorf("all queries to the upstream nameservers failed with timeout")
}
// isTimeout returns true if the given error is a network timeout error.
//
// Copied from k8s.io/apimachinery/pkg/util/net.IsTimeout
func isTimeout(err error) bool {
var neterr net.Error
if errors.As(err, &neterr) {
return neterr != nil && neterr.Timeout()
}
return false
}

View File

@@ -0,0 +1,110 @@
package dns
import (
"context"
"github.com/miekg/dns"
"strings"
"testing"
"time"
)
func TestUpstreamResolver_ServeDNS(t *testing.T) {
testCases := []struct {
name string
inputMSG *dns.Msg
responseShouldBeNil bool
InputServers []string
timeout time.Duration
cancelCTX bool
expectedAnswer string
}{
{
name: "Should Resolve A Record",
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
InputServers: []string{"8.8.8.8:53", "8.8.4.4:53"},
timeout: defaultUpstreamTimeout,
expectedAnswer: "1.1.1.1",
},
{
name: "Should Resolve If First Upstream Times Out",
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
InputServers: []string{"8.0.0.0:53", "8.8.4.4:53"},
timeout: 2 * time.Second,
expectedAnswer: "1.1.1.1",
},
{
name: "Should Not Resolve If Can't Connect To Both Servers",
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
InputServers: []string{"8.0.0.0:53", "8.0.0.1:53"},
timeout: 200 * time.Millisecond,
responseShouldBeNil: true,
},
{
name: "Should Not Resolve If Parent Context Is Canceled",
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
InputServers: []string{"8.0.0.0:53", "8.8.4.4:53"},
cancelCTX: true,
timeout: defaultUpstreamTimeout,
responseShouldBeNil: true,
},
//{
// name: "Should Resolve CNAME Record",
// inputMSG: new(dns.Msg).SetQuestion("one.one.one.one", dns.TypeCNAME),
//},
//{
// name: "Should Not Write When Not Found A Record",
// inputMSG: new(dns.Msg).SetQuestion("not.found.com", dns.TypeA),
// responseShouldBeNil: true,
//},
}
// should resolve if first upstream times out
// should not write when both fails
// should not resolve if parent context is canceled
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO())
resolver := &upstreamResolver{
parentCTX: ctx,
upstreamClient: &dns.Client{},
upstreamServers: testCase.InputServers,
upstreamTimeout: testCase.timeout,
}
if testCase.cancelCTX {
cancel()
} else {
defer cancel()
}
var responseMSG *dns.Msg
responseWriter := &mockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
resolver.ServeDNS(responseWriter, testCase.inputMSG)
if responseMSG == nil {
if testCase.responseShouldBeNil {
return
}
t.Fatalf("should write a response message")
}
foundAnswer := false
for _, answer := range responseMSG.Answer {
if strings.Contains(answer.String(), testCase.expectedAnswer) {
foundAnswer = true
break
}
}
if !foundAnswer {
t.Errorf("couldn't find the required answer, %s, in the dns response", testCase.expectedAnswer)
}
})
}
}

View File

@@ -3,12 +3,15 @@ package internal
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager"
nbssh "github.com/netbirdio/netbird/client/ssh" nbssh "github.com/netbirdio/netbird/client/ssh"
nbstatus "github.com/netbirdio/netbird/client/status" nbstatus "github.com/netbirdio/netbird/client/status"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"math/rand" "math/rand"
"net" "net"
"net/netip"
"reflect" "reflect"
"runtime" "runtime"
"strings" "strings"
@@ -103,6 +106,8 @@ type Engine struct {
statusRecorder *nbstatus.Status statusRecorder *nbstatus.Status
routeManager routemanager.Manager routeManager routemanager.Manager
dnsServer dns.Server
} }
// Peer is an instance of the Connection Peer // Peer is an instance of the Connection Peer
@@ -130,6 +135,7 @@ func NewEngine(
networkSerial: 0, networkSerial: 0,
sshServerFunc: nbssh.DefaultSSHServer, sshServerFunc: nbssh.DefaultSSHServer,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
dnsServer: dns.NewDefaultServer(ctx),
} }
} }
@@ -190,6 +196,10 @@ func (e *Engine) Stop() error {
e.routeManager.Stop() e.routeManager.Stop()
} }
if e.dnsServer != nil {
e.dnsServer.Stop()
}
log.Infof("stopped Netbird Engine") log.Infof("stopped Netbird Engine")
return nil return nil
@@ -638,6 +648,15 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
log.Errorf("failed to update routes, err: %v", err) log.Errorf("failed to update routes, err: %v", err)
} }
protoDNSConfig := networkMap.GetDNSConfig()
if protoDNSConfig == nil {
protoDNSConfig = &mgmProto.DNSConfig{}
}
err = e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig))
if err != nil {
log.Errorf("failed to update dns server, err: %v", err)
}
e.networkSerial = serial e.networkSerial = serial
return nil return nil
} }
@@ -660,6 +679,48 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
return routes return routes
} }
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig) nbdns.Config {
dnsUpdate := nbdns.Config{
ServiceEnable: protoDNSConfig.GetServiceEnable(),
CustomZones: make([]nbdns.CustomZone, 0),
NameServerGroups: make([]*nbdns.NameServerGroup, 0),
}
for _, zone := range protoDNSConfig.GetCustomZones() {
dnsZone := nbdns.CustomZone{
Domain: zone.GetDomain(),
}
for _, record := range zone.Records {
dnsRecord := nbdns.SimpleRecord{
Name: record.GetName(),
Type: int(record.GetType()),
Class: record.GetClass(),
TTL: int(record.GetTTL()),
RData: record.GetRData(),
}
dnsZone.Records = append(dnsZone.Records, dnsRecord)
}
dnsUpdate.CustomZones = append(dnsUpdate.CustomZones, dnsZone)
}
for _, nsGroup := range protoDNSConfig.GetNameServerGroups() {
dnsNSGroup := &nbdns.NameServerGroup{
Primary: nsGroup.GetPrimary(),
Domains: nsGroup.GetDomains(),
}
for _, ns := range nsGroup.GetNameServers() {
dnsNS := nbdns.NameServer{
IP: netip.MustParseAddr(ns.GetIP()),
NSType: nbdns.NameServerType(ns.GetNSType()),
Port: int(ns.GetPort()),
}
dnsNSGroup.NameServers = append(dnsNSGroup.NameServers, dnsNS)
}
dnsUpdate.NameServerGroups = append(dnsUpdate.NameServerGroups, dnsNSGroup)
}
return dnsUpdate
}
// addNewPeers adds peers that were not know before but arrived from the Management service with the update // addNewPeers adds peers that were not know before but arrived from the Management service with the update
func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig) error { func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
for _, p := range peersUpdate { for _, p := range peersUpdate {

View File

@@ -3,9 +3,11 @@ package internal
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/ssh"
nbstatus "github.com/netbirdio/netbird/client/status" nbstatus "github.com/netbirdio/netbird/client/status"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -440,7 +442,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
expectedSerial uint64 expectedSerial uint64
}{ }{
{ {
name: "Routes Update Should Be Passed To Manager", name: "Routes Config Should Be Passed To Manager",
networkMap: &mgmtProto.NetworkMap{ networkMap: &mgmtProto.NetworkMap{
Serial: 1, Serial: 1,
PeerConfig: nil, PeerConfig: nil,
@@ -486,7 +488,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
expectedSerial: 1, expectedSerial: 1,
}, },
{ {
name: "Empty Routes Update Should Be Passed", name: "Empty Routes Config Should Be Passed",
networkMap: &mgmtProto.NetworkMap{ networkMap: &mgmtProto.NetworkMap{
Serial: 1, Serial: 1,
PeerConfig: nil, PeerConfig: nil,
@@ -566,6 +568,183 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
} }
} }
func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
testCases := []struct {
name string
inputErr error
networkMap *mgmtProto.NetworkMap
expectedZonesLen int
expectedZones []nbdns.CustomZone
expectedNSGroupsLen int
expectedNSGroups []*nbdns.NameServerGroup
expectedSerial uint64
}{
{
name: "DNS Config Should Be Passed To DNS Server",
networkMap: &mgmtProto.NetworkMap{
Serial: 1,
PeerConfig: nil,
RemotePeersIsEmpty: false,
Routes: nil,
DNSConfig: &mgmtProto.DNSConfig{
ServiceEnable: true,
CustomZones: []*mgmtProto.CustomZone{
{
Domain: "netbird.cloud.",
Records: []*mgmtProto.SimpleRecord{
{
Name: "peer-a.netbird.cloud.",
Type: 1,
Class: nbdns.DefaultClass,
TTL: 300,
RData: "100.64.0.1",
},
},
},
},
NameServerGroups: []*mgmtProto.NameServerGroup{
{
Primary: true,
NameServers: []*mgmtProto.NameServer{
{
IP: "8.8.8.8",
NSType: 1,
Port: 53,
},
},
},
},
},
},
expectedZonesLen: 1,
expectedZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud.",
Records: []nbdns.SimpleRecord{
{
Name: "peer-a.netbird.cloud.",
Type: 1,
Class: nbdns.DefaultClass,
TTL: 300,
RData: "100.64.0.1",
},
},
},
},
expectedNSGroupsLen: 1,
expectedNSGroups: []*nbdns.NameServerGroup{
{
Primary: true,
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
NSType: 1,
Port: 53,
},
},
},
},
expectedSerial: 1,
},
{
name: "Empty DNS Config Should Be OK",
networkMap: &mgmtProto.NetworkMap{
Serial: 1,
PeerConfig: nil,
RemotePeersIsEmpty: false,
Routes: nil,
DNSConfig: nil,
},
expectedZonesLen: 0,
expectedZones: []nbdns.CustomZone{},
expectedNSGroupsLen: 0,
expectedNSGroups: []*nbdns.NameServerGroup{},
expectedSerial: 1,
},
{
name: "Error Shouldn't Break Engine",
inputErr: fmt.Errorf("mocking error"),
networkMap: &mgmtProto.NetworkMap{
Serial: 1,
PeerConfig: nil,
RemotePeersIsEmpty: false,
Routes: nil,
},
expectedZonesLen: 0,
expectedZones: []nbdns.CustomZone{},
expectedNSGroupsLen: 0,
expectedNSGroups: []*nbdns.NameServerGroup{},
expectedSerial: 1,
},
}
for n, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
// test setup
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Fatal(err)
return
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
wgIfaceName := fmt.Sprintf("utun%d", 104+n)
wgAddr := fmt.Sprintf("100.66.%d.1/24", n)
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, &EngineConfig{
WgIfaceName: wgIfaceName,
WgAddr: wgAddr,
WgPrivateKey: key,
WgPort: 33100,
}, nbstatus.NewRecorder())
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU)
assert.NoError(t, err, "shouldn't return error")
mockRouteManager := &routemanager.MockManager{
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) error {
return nil
},
}
engine.routeManager = mockRouteManager
input := struct {
inputSerial uint64
inputNSGroups []*nbdns.NameServerGroup
inputZones []nbdns.CustomZone
}{}
mockDNSServer := &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error {
input.inputSerial = serial
input.inputZones = update.CustomZones
input.inputNSGroups = update.NameServerGroups
return testCase.inputErr
},
}
engine.dnsServer = mockDNSServer
defer func() {
exitErr := engine.Stop()
if exitErr != nil {
return
}
}()
err = engine.updateNetworkMap(testCase.networkMap)
assert.NoError(t, err, "shouldn't return error")
assert.Equal(t, testCase.expectedSerial, input.inputSerial, "serial should match")
assert.Len(t, input.inputNSGroups, testCase.expectedZonesLen, "zones len should match")
assert.Equal(t, testCase.expectedZones, input.inputZones, "custom zones should match")
assert.Len(t, input.inputNSGroups, testCase.expectedNSGroupsLen, "ns groups len should match")
assert.Equal(t, testCase.expectedNSGroups, input.inputNSGroups, "ns groups should match")
})
}
}
func TestEngine_MultiplePeers(t *testing.T) { func TestEngine_MultiplePeers(t *testing.T) {
// log.SetLevel(log.DebugLevel) // log.SetLevel(log.DebugLevel)
@@ -756,17 +935,17 @@ func startManagement(port int, dataDir string) (*grpc.Server, error) {
return nil, err return nil, err
} }
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
store, err := server.NewStore(config.Datadir) store, err := server.NewFileStore(config.Datadir)
if err != nil { if err != nil {
log.Fatalf("failed creating a store: %s: %v", config.Datadir, err) log.Fatalf("failed creating a store: %s: %v", config.Datadir, err)
} }
peersUpdateManager := server.NewPeersUpdateManager() peersUpdateManager := server.NewPeersUpdateManager()
accountManager, err := server.BuildManager(store, peersUpdateManager, nil) accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "")
if err != nil { if err != nil {
return nil, err return nil, err
} }
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig) turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager) mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -21,7 +21,7 @@ func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error {
} }
if prefixGateway != nil && !prefixGateway.Equal(gateway) { if prefixGateway != nil && !prefixGateway.Equal(gateway) {
log.Warnf("route for network %s already exist and is pointing to the gateway: %s, won't add another one", prefix, prefixGateway) log.Warnf("skipping adding a new route for network %s because it already exists and is pointing to the non default gateway: %s", prefix, prefixGateway)
return nil return nil
} }
return addToRouteTable(prefix, addr) return addToRouteTable(prefix, addr)
@@ -45,11 +45,14 @@ func getExistingRIBRouteGateway(prefix netip.Prefix) (net.IP, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
_, _, localGatewayAddress, err := r.Route(prefix.Addr().AsSlice()) _, gateway, preferredSrc, err := r.Route(prefix.Addr().AsSlice())
if err != nil { if err != nil {
log.Errorf("getting routes returned an error: %v", err) log.Errorf("getting routes returned an error: %v", err)
return nil, errRouteNotFound return nil, errRouteNotFound
} }
if gateway == nil {
return preferredSrc, nil
}
return localGatewayAddress, nil return gateway, nil
} }

View File

@@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"net"
"net/netip" "net/netip"
"testing" "testing"
) )
@@ -66,3 +67,45 @@ func TestAddRemoveRoutes(t *testing.T) {
}) })
} }
} }
func TestGetExistingRIBRouteGateway(t *testing.T) {
gateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0"))
if err != nil {
t.Fatal("shouldn't return error when fetching the gateway: ", err)
}
if gateway == nil {
t.Fatal("should return a gateway")
}
addresses, err := net.InterfaceAddrs()
if err != nil {
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
}
var testingIP string
var testingPrefix netip.Prefix
for _, address := range addresses {
if address.Network() != "ip+net" {
continue
}
prefix := netip.MustParsePrefix(address.String())
if !prefix.Addr().IsLoopback() && prefix.Addr().Is4() {
testingIP = prefix.Addr().String()
testingPrefix = prefix.Masked()
break
}
}
localIP, err := getExistingRIBRouteGateway(testingPrefix)
if err != nil {
t.Fatal("shouldn't return error: ", err)
}
if localIP == nil {
t.Fatal("should return a gateway for local network")
}
if localIP.String() == gateway.String() {
t.Fatal("local ip should not match with gateway IP")
}
if localIP.String() != testingIP {
t.Fatalf("local ip should match with testing IP: want %s got %s", testingIP, localIP.String())
}
}

View File

@@ -8,7 +8,6 @@ import (
"context" "context"
"flag" "flag"
"fmt" "fmt"
"github.com/netbirdio/netbird/client/system"
"os" "os"
"os/exec" "os/exec"
"path" "path"
@@ -18,6 +17,8 @@ import (
"syscall" "syscall"
"time" "time"
"github.com/netbirdio/netbird/client/system"
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
_ "embed" _ "embed"
@@ -61,6 +62,8 @@ func main() {
flag.Parse() flag.Parse()
a := app.New() a := app.New()
a.SetIcon(fyne.NewStaticResource("netbird", iconDisconnectedPNG))
client := newServiceClient(daemonAddr, a, showSettings) client := newServiceClient(daemonAddr, a, showSettings)
if showSettings { if showSettings {
a.Run() a.Run()
@@ -113,7 +116,7 @@ type serviceClient struct {
iLogFile *widget.Entry iLogFile *widget.Entry
iPreSharedKey *widget.Entry iPreSharedKey *widget.Entry
// observable settings over correspondign iMngURL and iPreSharedKey values. // observable settings over corresponding iMngURL and iPreSharedKey values.
managementURL string managementURL string
preSharedKey string preSharedKey string
adminURL string adminURL string
@@ -121,7 +124,7 @@ type serviceClient struct {
// newServiceClient instance constructor // newServiceClient instance constructor
// //
// This constructor olso build UI elements for settings window. // This constructor also builds the UI elements for the settings window.
func newServiceClient(addr string, a fyne.App, showSettings bool) *serviceClient { func newServiceClient(addr string, a fyne.App, showSettings bool) *serviceClient {
s := &serviceClient{ s := &serviceClient{
ctx: context.Background(), ctx: context.Background(),
@@ -149,7 +152,7 @@ func newServiceClient(addr string, a fyne.App, showSettings bool) *serviceClient
func (s *serviceClient) showUIElements() { func (s *serviceClient) showUIElements() {
// add settings window UI elements. // add settings window UI elements.
s.wSettings = s.app.NewWindow("Settings") s.wSettings = s.app.NewWindow("NetBird Settings")
s.iMngURL = widget.NewEntry() s.iMngURL = widget.NewEntry()
s.iAdminURL = widget.NewEntry() s.iAdminURL = widget.NewEntry()
s.iConfigFile = widget.NewEntry() s.iConfigFile = widget.NewEntry()
@@ -326,11 +329,13 @@ func (s *serviceClient) updateStatus() error {
if status.Status == string(internal.StatusConnected) && !s.mUp.Disabled() { if status.Status == string(internal.StatusConnected) && !s.mUp.Disabled() {
systray.SetIcon(s.icConnected) systray.SetIcon(s.icConnected)
systray.SetTooltip("NetBird (Connected)")
s.mStatus.SetTitle("Connected") s.mStatus.SetTitle("Connected")
s.mUp.Disable() s.mUp.Disable()
s.mDown.Enable() s.mDown.Enable()
} else if status.Status != string(internal.StatusConnected) && s.mUp.Disabled() { } else if status.Status != string(internal.StatusConnected) && s.mUp.Disabled() {
systray.SetIcon(s.icDisconnected) systray.SetIcon(s.icDisconnected)
systray.SetTooltip("NetBird (Disconnected)")
s.mStatus.SetTitle("Disconnected") s.mStatus.SetTitle("Disconnected")
s.mDown.Disable() s.mDown.Disable()
s.mUp.Enable() s.mUp.Enable()
@@ -355,6 +360,7 @@ func (s *serviceClient) updateStatus() error {
func (s *serviceClient) onTrayReady() { func (s *serviceClient) onTrayReady() {
systray.SetIcon(s.icDisconnected) systray.SetIcon(s.icDisconnected)
systray.SetTooltip("NetBird")
// setup systray menu items // setup systray menu items
s.mStatus = systray.AddMenuItem("Disconnected", "Disconnected") s.mStatus = systray.AddMenuItem("Disconnected", "Disconnected")

View File

@@ -2,5 +2,83 @@
// to parse and normalize dns records and configuration // to parse and normalize dns records and configuration
package dns package dns
// DefaultDNSPort well-known port number import (
const DefaultDNSPort = 53 "fmt"
"github.com/miekg/dns"
"golang.org/x/net/idna"
"regexp"
"strings"
)
const (
// DefaultDNSPort well-known port number
DefaultDNSPort = 53
// RootZone is a string representation of the root zone
RootZone = "."
// DefaultClass is the class supported by the system
DefaultClass = "IN"
)
const invalidHostLabel = "[^a-zA-Z0-9-]+"
// Config represents a dns configuration that is exchanged between management and peers
type Config struct {
// ServiceEnable indicates if the service should be enabled
ServiceEnable bool
// NameServerGroups contains a list of nameserver group
NameServerGroups []*NameServerGroup
// CustomZones contains a list of custom zone
CustomZones []CustomZone
}
// CustomZone represents a custom zone to be resolved by the dns server
type CustomZone struct {
// Domain is the zone's domain
Domain string
// Records custom zone records
Records []SimpleRecord
}
// SimpleRecord provides a simple DNS record specification for CNAME, A and AAAA records
type SimpleRecord struct {
// Name domain name
Name string
// Type of record, 1 for A, 5 for CNAME, 28 for AAAA. see https://pkg.go.dev/github.com/miekg/dns@v1.1.41#pkg-constants
Type int
// Class dns class, currently use the DefaultClass for all records
Class string
// TTL time-to-live for the record
TTL int
// RData is the actual value resolved in a dns query
RData string
}
// String returns a string of the simple record formatted as:
// <Name> <TTL> <Class> <Type> <RDATA>
func (s SimpleRecord) String() string {
fqdn := dns.Fqdn(s.Name)
return fmt.Sprintf("%s %d %s %s %s", fqdn, s.TTL, s.Class, dns.Type(s.Type).String(), s.RData)
}
// GetParsedDomainLabel returns a domain label with max 59 characters,
// parsed for old Hosts.txt requirements, and converted to ASCII and lowercase
func GetParsedDomainLabel(name string) (string, error) {
labels := dns.SplitDomainName(name)
if len(labels) == 0 {
return "", fmt.Errorf("got empty label list for name \"%s\"", name)
}
rawLabel := labels[0]
ascii, err := idna.Punycode.ToASCII(rawLabel)
if err != nil {
return "", fmt.Errorf("unable to convert host lavel to ASCII, error: %v", err)
}
invalidHostMatcher := regexp.MustCompile(invalidHostLabel)
validHost := strings.ToLower(invalidHostMatcher.ReplaceAllString(ascii, "-"))
if len(validHost) > 58 {
validHost = validHost[:59]
}
return validHost, nil
}

View File

@@ -9,8 +9,6 @@ import (
) )
const ( const (
// MaxGroupNameChar maximum group name size
MaxGroupNameChar = 40
// InvalidNameServerType invalid nameserver type // InvalidNameServerType invalid nameserver type
InvalidNameServerType NameServerType = iota InvalidNameServerType NameServerType = iota
// UDPNameServerType udp nameserver type // UDPNameServerType udp nameserver type
@@ -18,6 +16,8 @@ const (
) )
const ( const (
// MaxGroupNameChar maximum group name size
MaxGroupNameChar = 40
// InvalidNameServerTypeString invalid nameserver type as string // InvalidNameServerTypeString invalid nameserver type as string
InvalidNameServerTypeString = "invalid" InvalidNameServerTypeString = "invalid"
// UDPNameServerTypeString udp nameserver type as string // UDPNameServerTypeString udp nameserver type as string
@@ -59,6 +59,10 @@ type NameServerGroup struct {
NameServers []NameServer NameServers []NameServer
// Groups list of peer group IDs to distribute the nameservers information // Groups list of peer group IDs to distribute the nameservers information
Groups []string Groups []string
// Primary indicates that the nameserver group is the primary resolver for any dns query
Primary bool
// Domains indicate the dns query domains to use with this nameserver group
Domains []string
// Enabled group status // Enabled group status
Enabled bool Enabled bool
} }
@@ -128,6 +132,8 @@ func (g *NameServerGroup) Copy() *NameServerGroup {
NameServers: g.NameServers, NameServers: g.NameServers,
Groups: g.Groups, Groups: g.Groups,
Enabled: g.Enabled, Enabled: g.Enabled,
Primary: g.Primary,
Domains: g.Domains,
} }
} }
@@ -136,8 +142,10 @@ func (g *NameServerGroup) IsEqual(other *NameServerGroup) bool {
return other.ID == g.ID && return other.ID == g.ID &&
other.Name == g.Name && other.Name == g.Name &&
other.Description == g.Description && other.Description == g.Description &&
other.Primary == g.Primary &&
compareNameServerList(g.NameServers, other.NameServers) && compareNameServerList(g.NameServers, other.NameServers) &&
compareGroupsList(g.Groups, other.Groups) compareGroupsList(g.Groups, other.Groups) &&
compareGroupsList(g.Domains, other.Domains)
} }
func compareNameServerList(list, other []NameServer) bool { func compareNameServerList(list, other []NameServer) bool {

Binary file not shown.

Before

Width:  |  Height:  |  Size: 887 KiB

23
go.mod
View File

@@ -18,12 +18,12 @@ require (
github.com/spf13/pflag v1.0.5 github.com/spf13/pflag v1.0.5
github.com/vishvananda/netlink v1.1.0 github.com/vishvananda/netlink v1.1.0
golang.org/x/crypto v0.0.0-20220513210258-46612604a0f9 golang.org/x/crypto v0.0.0-20220513210258-46612604a0f9
golang.org/x/sys v0.0.0-20220622161953-175b2fd9d664 golang.org/x/sys v0.0.0-20220919091848-fb04ddd9f9c8
golang.zx2c4.com/wireguard v0.0.0-20211209221555-9c9e7e272434 golang.zx2c4.com/wireguard v0.0.0-20211209221555-9c9e7e272434
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211215182854-7a385b3431de golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211215182854-7a385b3431de
golang.zx2c4.com/wireguard/windows v0.5.1 golang.zx2c4.com/wireguard/windows v0.5.1
google.golang.org/grpc v1.43.0 google.golang.org/grpc v1.43.0
google.golang.org/protobuf v1.28.0 google.golang.org/protobuf v1.28.1
gopkg.in/natefinch/lumberjack.v2 v2.0.0 gopkg.in/natefinch/lumberjack.v2 v2.0.0
) )
@@ -38,10 +38,15 @@ require (
github.com/google/nftables v0.0.0-20220808154552-2eca00135732 github.com/google/nftables v0.0.0-20220808154552-2eca00135732
github.com/libp2p/go-netroute v0.2.0 github.com/libp2p/go-netroute v0.2.0
github.com/magiconair/properties v1.8.5 github.com/magiconair/properties v1.8.5
github.com/miekg/dns v1.1.41
github.com/patrickmn/go-cache v2.1.0+incompatible github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/prometheus/client_golang v1.13.0
github.com/rs/xid v1.3.0 github.com/rs/xid v1.3.0
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
github.com/stretchr/testify v1.8.0 github.com/stretchr/testify v1.8.0
go.opentelemetry.io/otel/exporters/prometheus v0.33.0
go.opentelemetry.io/otel/metric v0.33.0
go.opentelemetry.io/otel/sdk/metric v0.33.0
golang.org/x/net v0.0.0-20220630215102-69896b714898 golang.org/x/net v0.0.0-20220630215102-69896b714898
golang.org/x/term v0.0.0-20220526004731-065cf7ba2467 golang.org/x/term v0.0.0-20220526004731-065cf7ba2467
) )
@@ -65,11 +70,13 @@ require (
github.com/getlantern/ops v0.0.0-20190325191751-d70cb0d6f85f // indirect github.com/getlantern/ops v0.0.0-20190325191751-d70cb0d6f85f // indirect
github.com/go-gl/gl v0.0.0-20210813123233-e4099ee2221f // indirect github.com/go-gl/gl v0.0.0-20210813123233-e4099ee2221f // indirect
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20211024062804-40e447a793be // indirect github.com/go-gl/glfw/v3.3/glfw v0.0.0-20211024062804-40e447a793be // indirect
github.com/go-logr/logr v1.2.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-redis/redis/v8 v8.11.5 // indirect github.com/go-redis/redis/v8 v8.11.5 // indirect
github.com/go-stack/stack v1.8.0 // indirect github.com/go-stack/stack v1.8.0 // indirect
github.com/godbus/dbus/v5 v5.0.4 // indirect github.com/godbus/dbus/v5 v5.0.4 // indirect
github.com/goki/freetype v0.0.0-20181231101311-fa8a33aabaff // indirect github.com/goki/freetype v0.0.0-20181231101311-fa8a33aabaff // indirect
github.com/google/go-cmp v0.5.7 // indirect github.com/google/go-cmp v0.5.9 // indirect
github.com/google/gopacket v1.1.19 // indirect github.com/google/gopacket v1.1.19 // indirect
github.com/inconshreveable/mousetrap v1.0.0 // indirect github.com/inconshreveable/mousetrap v1.0.0 // indirect
github.com/josharian/native v0.0.0-20200817173448-b6b71def0850 // indirect github.com/josharian/native v0.0.0-20200817173448-b6b71def0850 // indirect
@@ -89,20 +96,22 @@ require (
github.com/pion/turn/v2 v2.0.8 // indirect github.com/pion/turn/v2 v2.0.8 // indirect
github.com/pion/udp v0.1.1 // indirect github.com/pion/udp v0.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_golang v1.12.2 // indirect
github.com/prometheus/client_model v0.2.0 // indirect github.com/prometheus/client_model v0.2.0 // indirect
github.com/prometheus/common v0.33.0 // indirect github.com/prometheus/common v0.37.0 // indirect
github.com/prometheus/procfs v0.7.3 // indirect github.com/prometheus/procfs v0.8.0 // indirect
github.com/rogpeppe/go-internal v1.8.0 // indirect github.com/rogpeppe/go-internal v1.8.0 // indirect
github.com/spf13/cast v1.5.0 // indirect github.com/spf13/cast v1.5.0 // indirect
github.com/srwiley/oksvg v0.0.0-20200311192757-870daf9aa564 // indirect github.com/srwiley/oksvg v0.0.0-20200311192757-870daf9aa564 // indirect
github.com/srwiley/rasterx v0.0.0-20200120212402-85cb7272f5e9 // indirect github.com/srwiley/rasterx v0.0.0-20200120212402-85cb7272f5e9 // indirect
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df // indirect github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df // indirect
github.com/yuin/goldmark v1.4.1 // indirect github.com/yuin/goldmark v1.4.1 // indirect
go.opentelemetry.io/otel v1.11.1 // indirect
go.opentelemetry.io/otel/sdk v1.11.1 // indirect
go.opentelemetry.io/otel/trace v1.11.1 // indirect
golang.org/x/exp v0.0.0-20220518171630-0b5c67f07fdf // indirect golang.org/x/exp v0.0.0-20220518171630-0b5c67f07fdf // indirect
golang.org/x/image v0.0.0-20200430140353-33d19683fad8 // indirect golang.org/x/image v0.0.0-20200430140353-33d19683fad8 // indirect
golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3 // indirect golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3 // indirect
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f // indirect
golang.org/x/text v0.3.8-0.20211105212822-18b340fc7af2 // indirect golang.org/x/text v0.3.8-0.20211105212822-18b340fc7af2 // indirect
golang.org/x/tools v0.1.10 // indirect golang.org/x/tools v0.1.10 // indirect
golang.org/x/xerrors v0.0.0-20220411194840-2f41105eb62f // indirect golang.org/x/xerrors v0.0.0-20220411194840-2f41105eb62f // indirect

44
go.sum
View File

@@ -202,6 +202,11 @@ github.com/go-logfmt/logfmt v0.5.1/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KE
github.com/go-logr/logr v0.1.0/go.mod h1:ixOQHD9gLJUVQQ2ZOR7zLEifBX6tGkNJF4QyIY7sIas= github.com/go-logr/logr v0.1.0/go.mod h1:ixOQHD9gLJUVQQ2ZOR7zLEifBX6tGkNJF4QyIY7sIas=
github.com/go-logr/logr v0.2.0/go.mod h1:z6/tIYblkpsD+a4lm/fGIIU9mZ+XfAiaFtq7xTgseGU= github.com/go-logr/logr v0.2.0/go.mod h1:z6/tIYblkpsD+a4lm/fGIIU9mZ+XfAiaFtq7xTgseGU=
github.com/go-logr/logr v1.2.0/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.2.0/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0=
github.com/go-logr/logr v1.2.3/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/go-openapi/jsonpointer v0.0.0-20160704185906-46af16f9f7b1/go.mod h1:+35s3my2LFTysnkMfxsJBAMHj/DoqoB9knIWoYG/Vk0= github.com/go-openapi/jsonpointer v0.0.0-20160704185906-46af16f9f7b1/go.mod h1:+35s3my2LFTysnkMfxsJBAMHj/DoqoB9knIWoYG/Vk0=
github.com/go-openapi/jsonpointer v0.19.3/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= github.com/go-openapi/jsonpointer v0.19.3/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg=
@@ -278,8 +283,8 @@ github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.7 h1:81/ik6ipDQS2aGcBfIN5dHDB36BwrStyeAQquSYCV4o= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v0.0.0-20161122191042-44d81051d367/go.mod h1:HP5RmnzzSNb993RKQDq4+1A4ia9nllfqcQFTQJedwGI= github.com/google/gofuzz v0.0.0-20161122191042-44d81051d367/go.mod h1:HP5RmnzzSNb993RKQDq4+1A4ia9nllfqcQFTQJedwGI=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/gofuzz v1.1.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.1.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
@@ -450,6 +455,7 @@ github.com/mdlayher/socket v0.0.0-20211102153432-57e3fa563ecb h1:2dC7L10LmTqlyMV
github.com/mdlayher/socket v0.0.0-20211102153432-57e3fa563ecb/go.mod h1:nFZ1EtZYK8Gi/k6QNu7z7CgO20i/4ExeQswwWuPmG/g= github.com/mdlayher/socket v0.0.0-20211102153432-57e3fa563ecb/go.mod h1:nFZ1EtZYK8Gi/k6QNu7z7CgO20i/4ExeQswwWuPmG/g=
github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg=
github.com/miekg/dns v1.1.26/go.mod h1:bPDLeHnStXmXAq1m/Ch/hvfNHr14JKNPMBo3VZKjuso= github.com/miekg/dns v1.1.26/go.mod h1:bPDLeHnStXmXAq1m/Ch/hvfNHr14JKNPMBo3VZKjuso=
github.com/miekg/dns v1.1.41 h1:WMszZWJG0XmzbK9FEmzH2TVcqYzFesusSIB41b8KHxY=
github.com/miekg/dns v1.1.41/go.mod h1:p6aan82bvRIyn+zDIv9xYNUpwa73JcSh9BKwknJysuI= github.com/miekg/dns v1.1.41/go.mod h1:p6aan82bvRIyn+zDIv9xYNUpwa73JcSh9BKwknJysuI=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws= github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc= github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc=
@@ -542,8 +548,8 @@ github.com/prometheus/client_golang v1.4.0/go.mod h1:e9GMxYsXl05ICDXkRhurwBS4Q3O
github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M= github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M=
github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0= github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0=
github.com/prometheus/client_golang v1.12.1/go.mod h1:3Z9XVyYiZYEO+YQWt3RD2R3jrbd179Rt297l4aS6nDY= github.com/prometheus/client_golang v1.12.1/go.mod h1:3Z9XVyYiZYEO+YQWt3RD2R3jrbd179Rt297l4aS6nDY=
github.com/prometheus/client_golang v1.12.2 h1:51L9cDoUHVrXx4zWYlcLQIZ+d+VXHgqnYKkIuq4g/34= github.com/prometheus/client_golang v1.13.0 h1:b71QUfeo5M8gq2+evJdTPfZhYMAU0uKPkyPJ7TPsloU=
github.com/prometheus/client_golang v1.12.2/go.mod h1:3Z9XVyYiZYEO+YQWt3RD2R3jrbd179Rt297l4aS6nDY= github.com/prometheus/client_golang v1.13.0/go.mod h1:vTeo+zgvILHsnnj/39Ou/1fPN5nJFOEMgftOUOmlvYQ=
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
@@ -554,15 +560,16 @@ github.com/prometheus/common v0.9.1/go.mod h1:yhUN8i9wzaXS3w1O07YhxHEBxD+W35wd8b
github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo= github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo=
github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc= github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc=
github.com/prometheus/common v0.32.1/go.mod h1:vu+V0TpY+O6vW9J44gczi3Ap/oXXR10b+M/gUGO4Hls= github.com/prometheus/common v0.32.1/go.mod h1:vu+V0TpY+O6vW9J44gczi3Ap/oXXR10b+M/gUGO4Hls=
github.com/prometheus/common v0.33.0 h1:rHgav/0a6+uYgGdNt3jwz8FNSesO/Hsang3O0T9A5SE= github.com/prometheus/common v0.37.0 h1:ccBbHCgIiT9uSoFY0vX8H3zsNR5eLt17/RQLUvn8pXE=
github.com/prometheus/common v0.33.0/go.mod h1:gB3sOl7P0TvJabZpLY5uQMpUqRCPPCyRLCZYc7JZTNE= github.com/prometheus/common v0.37.0/go.mod h1:phzohg0JFMnBEFGxTDbfu3QyL5GI8gTQJFhYO5B3mfA=
github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA=
github.com/prometheus/procfs v0.0.8/go.mod h1:7Qr8sr6344vo1JqZ6HhLceV9o3AJ1Ff+GxbHq6oeK9A= github.com/prometheus/procfs v0.0.8/go.mod h1:7Qr8sr6344vo1JqZ6HhLceV9o3AJ1Ff+GxbHq6oeK9A=
github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU= github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU=
github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
github.com/prometheus/procfs v0.7.3 h1:4jVXhlkAyzOScmCkXBTOLRLTz8EeU+eyjrwB/EPq0VU=
github.com/prometheus/procfs v0.7.3/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= github.com/prometheus/procfs v0.7.3/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
github.com/prometheus/procfs v0.8.0 h1:ODq8ZFEaYeCaZOJlZZdJA2AbQR98dSHSM1KW/You5mo=
github.com/prometheus/procfs v0.8.0/go.mod h1:z7EfXMXOkbkqb9IINtpCn86r/to3BnA0uaxHdg830/4=
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8= github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
@@ -648,6 +655,18 @@ go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk= go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk=
go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E= go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E=
go.opentelemetry.io/otel v1.11.1 h1:4WLLAmcfkmDk2ukNXJyq3/kiz/3UzCaYq6PskJsaou4=
go.opentelemetry.io/otel v1.11.1/go.mod h1:1nNhXBbWSD0nsL38H6btgnFN2k4i0sNLHNNMZMSbUGE=
go.opentelemetry.io/otel/exporters/prometheus v0.33.0 h1:xXhPj7SLKWU5/Zd4Hxmd+X1C4jdmvc0Xy+kvjFx2z60=
go.opentelemetry.io/otel/exporters/prometheus v0.33.0/go.mod h1:ZSmYfKdYWEdSDBB4njLBIwTf4AU2JNsH3n2quVQDebI=
go.opentelemetry.io/otel/metric v0.33.0 h1:xQAyl7uGEYvrLAiV/09iTJlp1pZnQ9Wl793qbVvED1E=
go.opentelemetry.io/otel/metric v0.33.0/go.mod h1:QlTYc+EnYNq/M2mNk1qDDMRLpqCOj2f/r5c7Fd5FYaI=
go.opentelemetry.io/otel/sdk v1.11.1 h1:F7KmQgoHljhUuJyA+9BiU+EkJfyX5nVVF4wyzWZpKxs=
go.opentelemetry.io/otel/sdk v1.11.1/go.mod h1:/l3FE4SupHJ12TduVjUkZtlfFqDCQJlOlithYrdktys=
go.opentelemetry.io/otel/sdk/metric v0.33.0 h1:oTqyWfksgKoJmbrs2q7O7ahkJzt+Ipekihf8vhpa9qo=
go.opentelemetry.io/otel/sdk/metric v0.33.0/go.mod h1:xdypMeA21JBOvjjzDUtD0kzIcHO/SPez+a8HOzJPGp0=
go.opentelemetry.io/otel/trace v1.11.1 h1:ofxdnzsNrGBYXbP7t7zpUK281+go5rF7dvdIZXF8gdQ=
go.opentelemetry.io/otel/trace v1.11.1/go.mod h1:f/Q9G7vzk5u91PhbmKbg1Qn0rzH1LJ4vbPHFGkTPtOk=
go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI= go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI=
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU=
@@ -809,8 +828,9 @@ golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f h1:Ax0t5p6N38Ga0dThY21weqDEyz2oklo4IvDkpigvkD8=
golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20170830134202-bb24a47a89ea/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20170830134202-bb24a47a89ea/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@@ -915,8 +935,8 @@ golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220608164250-635b8c9b7f68/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220608164250-635b8c9b7f68/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220622161953-175b2fd9d664 h1:wEZYwx+kK+KlZ0hpvP2Ls1Xr4+RWnlzGFwPP0aiDjIU= golang.org/x/sys v0.0.0-20220919091848-fb04ddd9f9c8 h1:h+EGohizhe9XlX18rfpa8k8RAc5XyaeamM+0VHRd4lc=
golang.org/x/sys v0.0.0-20220622161953-175b2fd9d664/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220919091848-fb04ddd9f9c8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.0.0-20220526004731-065cf7ba2467 h1:CBpWXWQpIRjzmkkA+M7q9Fqnwd2mZr3AFqexg8YTfoM= golang.org/x/term v0.0.0-20220526004731-065cf7ba2467 h1:CBpWXWQpIRjzmkkA+M7q9Fqnwd2mZr3AFqexg8YTfoM=
@@ -1162,8 +1182,8 @@ google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlba
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw= google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w=
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

View File

@@ -89,7 +89,6 @@ func getIfaceAddrs(ifaceName string) ([]net.Addr, error) {
return addrs, nil return addrs, nil
} }
//
func Test_CreateInterface(t *testing.T) { func Test_CreateInterface(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+1) ifaceName := fmt.Sprintf("utun%d", WgIntNumber+1)
wgIP := "10.99.99.1/32" wgIP := "10.99.99.1/32"
@@ -369,8 +368,8 @@ func Test_ConnectPeers(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
// todo: investigate why in some tests execution we need 30s
timeout := 10 * time.Second timeout := 30 * time.Second
timeoutChannel := time.After(timeout) timeoutChannel := time.After(timeout)
for { for {
select { select {

View File

@@ -214,8 +214,9 @@ func isBuiltinModule(name string) (bool, error) {
} }
// /proc/modules // /proc/modules
// name | memory size | reference count | references | state: <Live|Loading|Unloading> //
// macvlan 28672 1 macvtap, Live 0x0000000000000000 // name | memory size | reference count | references | state: <Live|Loading|Unloading>
// macvlan 28672 1 macvtap, Live 0x0000000000000000
func moduleStatus(name string) (status, error) { func moduleStatus(name string) (status, error) {
state := unknown state := unknown
f, err := os.Open("/proc/modules") f, err := os.Open("/proc/modules")

View File

@@ -10,6 +10,8 @@ NETBIRD_MGMT_API_ENDPOINT=https://$NETBIRD_DOMAIN:$NETBIRD_MGMT_API_PORT
NETBIRD_MGMT_API_CERT_FILE="/etc/letsencrypt/live/$NETBIRD_DOMAIN/fullchain.pem" NETBIRD_MGMT_API_CERT_FILE="/etc/letsencrypt/live/$NETBIRD_DOMAIN/fullchain.pem"
# Management Certficate key file path. # Management Certficate key file path.
NETBIRD_MGMT_API_CERT_KEY_FILE="/etc/letsencrypt/live/$NETBIRD_DOMAIN/privkey.pem" NETBIRD_MGMT_API_CERT_KEY_FILE="/etc/letsencrypt/live/$NETBIRD_DOMAIN/privkey.pem"
# By default Management single account mode is enabled and domain set to $NETBIRD_DOMAIN, you may want to set this to your user's email domain
NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN=$NETBIRD_DOMAIN
# Turn credentials # Turn credentials

View File

@@ -48,7 +48,7 @@ services:
# # port and command for Let's Encrypt validation without dashboard container # # port and command for Let's Encrypt validation without dashboard container
# - 443:443 # - 443:443
# command: ["--letsencrypt-domain", "$NETBIRD_DOMAIN", "--log-file", "console"] # command: ["--letsencrypt-domain", "$NETBIRD_DOMAIN", "--log-file", "console"]
command: ["--port", "443", "--log-file", "console", "--disable-anonymous-metrics=$NETBIRD_DISABLE_ANONYMOUS_METRICS"] command: ["--port", "443", "--log-file", "console", "--disable-anonymous-metrics=$NETBIRD_DISABLE_ANONYMOUS_METRICS", "--single-account-mode-domain=$NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN"]
# Coturn # Coturn
coturn: coturn:
image: coturn/coturn image: coturn/coturn

View File

@@ -49,18 +49,18 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
t.Fatal(err) t.Fatal(err)
} }
s := grpc.NewServer() s := grpc.NewServer()
store, err := mgmt.NewStore(config.Datadir) store, err := mgmt.NewFileStore(config.Datadir)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
peersUpdateManager := mgmt.NewPeersUpdateManager() peersUpdateManager := mgmt.NewPeersUpdateManager()
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil) accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig) turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager) mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -0,0 +1,18 @@
package cmd
const (
defaultMgmtDataDir = "/var/lib/netbird/"
defaultMgmtConfigDir = "/etc/netbird"
defaultLogDir = "/var/log/netbird"
oldDefaultMgmtDataDir = "/var/lib/wiretrustee/"
oldDefaultMgmtConfigDir = "/etc/wiretrustee"
oldDefaultLogDir = "/var/log/wiretrustee"
defaultMgmtConfig = defaultMgmtConfigDir + "/management.json"
defaultLogFile = defaultLogDir + "/management.log"
oldDefaultMgmtConfig = oldDefaultMgmtConfigDir + "/management.json"
oldDefaultLogFile = oldDefaultLogDir + "/management.log"
defaultSingleAccModeDomain = "netbird.selfhosted"
)

View File

@@ -8,8 +8,10 @@ import (
"flag" "flag"
"fmt" "fmt"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/miekg/dns"
httpapi "github.com/netbirdio/netbird/management/server/http" httpapi "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/metrics" "github.com/netbirdio/netbird/management/server/metrics"
"github.com/netbirdio/netbird/management/server/telemetry"
"golang.org/x/crypto/acme/autocert" "golang.org/x/crypto/acme/autocert"
"golang.org/x/net/http2" "golang.org/x/net/http2"
"golang.org/x/net/http2/h2c" "golang.org/x/net/http2/h2c"
@@ -41,11 +43,13 @@ import (
const ManagementLegacyPort = 33073 const ManagementLegacyPort = 33073
var ( var (
mgmtPort int mgmtPort int
mgmtLetsencryptDomain string mgmtMetricsPort int
certFile string mgmtLetsencryptDomain string
certKey string mgmtSingleAccModeDomain string
config *server.Config certFile string
certKey string
config *server.Config
kaep = keepalive.EnforcementPolicy{ kaep = keepalive.EnforcementPolicy{
MinTime: 15 * time.Second, MinTime: 15 * time.Second,
@@ -86,6 +90,11 @@ var (
} }
} }
_, valid := dns.IsDomainName(dnsDomain)
if !valid || len(dnsDomain) > 192 {
return fmt.Errorf("failed parsing the provided dns-domain. Valid status: %t, Lenght: %d", valid, len(dnsDomain))
}
return nil return nil
}, },
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
@@ -107,21 +116,33 @@ var (
} }
} }
store, err := server.NewStore(config.Datadir) store, err := server.NewFileStore(config.Datadir)
if err != nil { if err != nil {
return fmt.Errorf("failed creating Store: %s: %v", config.Datadir, err) return fmt.Errorf("failed creating Store: %s: %v", config.Datadir, err)
} }
peersUpdateManager := server.NewPeersUpdateManager() peersUpdateManager := server.NewPeersUpdateManager()
appMetrics, err := telemetry.NewDefaultAppMetrics(cmd.Context())
if err != nil {
return err
}
err = appMetrics.Expose(mgmtMetricsPort, "/metrics")
if err != nil {
return err
}
var idpManager idp.Manager var idpManager idp.Manager
if config.IdpManagerConfig != nil { if config.IdpManagerConfig != nil {
idpManager, err = idp.NewManager(*config.IdpManagerConfig) idpManager, err = idp.NewManager(*config.IdpManagerConfig, appMetrics)
if err != nil { if err != nil {
return fmt.Errorf("failed retrieving a new idp manager with err: %v", err) return fmt.Errorf("failed retrieving a new idp manager with err: %v", err)
} }
} }
accountManager, err := server.BuildManager(store, peersUpdateManager, idpManager) if disableSingleAccMode {
mgmtSingleAccModeDomain = ""
}
accountManager, err := server.BuildManager(store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain, dnsDomain)
if err != nil { if err != nil {
return fmt.Errorf("failed to build default manager: %v", err) return fmt.Errorf("failed to build default manager: %v", err)
} }
@@ -151,14 +172,14 @@ var (
tlsEnabled = true tlsEnabled = true
} }
httpAPIHandler, err := httpapi.APIHandler(accountManager, httpAPIHandler, err := httpapi.APIHandler(accountManager, config.HttpConfig.AuthIssuer,
config.HttpConfig.AuthIssuer, config.HttpConfig.AuthAudience, config.HttpConfig.AuthKeysLocation) config.HttpConfig.AuthAudience, config.HttpConfig.AuthKeysLocation, appMetrics)
if err != nil { if err != nil {
return fmt.Errorf("failed creating HTTP API handler: %v", err) return fmt.Errorf("failed creating HTTP API handler: %v", err)
} }
gRPCAPIHandler := grpc.NewServer(gRPCOpts...) gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
srv, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager) srv, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, appMetrics)
if err != nil { if err != nil {
return fmt.Errorf("failed creating gRPC API handler: %v", err) return fmt.Errorf("failed creating gRPC API handler: %v", err)
} }
@@ -170,8 +191,6 @@ var (
return err return err
} }
fmt.Println("metrics ", disableMetrics)
if !disableMetrics { if !disableMetrics {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
@@ -225,11 +244,13 @@ var (
SetupCloseHandler() SetupCloseHandler()
<-stopCh <-stopCh
_ = appMetrics.Close()
_ = listener.Close() _ = listener.Close()
if certManager != nil { if certManager != nil {
_ = certManager.Listener().Close() _ = certManager.Listener().Close()
} }
gRPCAPIHandler.Stop() gRPCAPIHandler.Stop()
_ = store.Close()
log.Infof("stopped Management Service") log.Infof("stopped Management Service")
return nil return nil
@@ -427,7 +448,7 @@ func loadTLSConfig(certFile string, certKey string) (*tls.Config, error) {
return nil, err return nil, err
} }
// Create the credentials and return it // NewDefaultAppMetrics the credentials and return it
config := &tls.Config{ config := &tls.Config{
Certificates: []tls.Certificate{serverCert}, Certificates: []tls.Certificate{serverCert},
ClientAuth: tls.NoClientCert, ClientAuth: tls.NoClientCert,

View File

@@ -13,21 +13,13 @@ const (
) )
var ( var (
defaultMgmtConfigDir string dnsDomain string
defaultMgmtDataDir string mgmtDataDir string
defaultMgmtConfig string mgmtConfig string
defaultLogDir string logLevel string
defaultLogFile string logFile string
oldDefaultMgmtConfigDir string disableMetrics bool
oldDefaultMgmtDataDir string disableSingleAccMode bool
oldDefaultMgmtConfig string
oldDefaultLogDir string
oldDefaultLogFile string
mgmtDataDir string
mgmtConfig string
logLevel string
logFile string
disableMetrics bool
rootCmd = &cobra.Command{ rootCmd = &cobra.Command{
Use: "netbird-mgmt", Use: "netbird-mgmt",
@@ -46,28 +38,17 @@ func Execute() error {
func init() { func init() {
stopCh = make(chan int) stopCh = make(chan int)
defaultMgmtDataDir = "/var/lib/netbird/"
defaultMgmtConfigDir = "/etc/netbird"
defaultLogDir = "/var/log/netbird"
oldDefaultMgmtDataDir = "/var/lib/wiretrustee/"
oldDefaultMgmtConfigDir = "/etc/wiretrustee"
oldDefaultLogDir = "/var/log/wiretrustee"
defaultMgmtConfig = defaultMgmtConfigDir + "/management.json"
defaultLogFile = defaultLogDir + "/management.log"
oldDefaultMgmtConfig = oldDefaultMgmtConfigDir + "/management.json"
oldDefaultLogFile = oldDefaultLogDir + "/management.log"
mgmtCmd.Flags().IntVar(&mgmtPort, "port", 80, "server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise") mgmtCmd.Flags().IntVar(&mgmtPort, "port", 80, "server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise")
mgmtCmd.Flags().IntVar(&mgmtMetricsPort, "metrics-port", 8081, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics")
mgmtCmd.Flags().StringVar(&mgmtDataDir, "datadir", defaultMgmtDataDir, "server data directory location") mgmtCmd.Flags().StringVar(&mgmtDataDir, "datadir", defaultMgmtDataDir, "server data directory location")
mgmtCmd.Flags().StringVar(&mgmtConfig, "config", defaultMgmtConfig, "Netbird config file location. Config params specified via command line (e.g. datadir) have a precedence over configuration from this file") mgmtCmd.Flags().StringVar(&mgmtConfig, "config", defaultMgmtConfig, "Netbird config file location. Config params specified via command line (e.g. datadir) have a precedence over configuration from this file")
mgmtCmd.Flags().StringVar(&mgmtLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS") mgmtCmd.Flags().StringVar(&mgmtLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
mgmtCmd.Flags().StringVar(&mgmtSingleAccModeDomain, "single-account-mode-domain", defaultSingleAccModeDomain, "Enables single account mode. This means that all the users will be under the same account grouped by the specified domain. If the installation has more than one account, the property is ineffective. Enabled by default with the default domain "+defaultSingleAccModeDomain)
mgmtCmd.Flags().BoolVar(&disableSingleAccMode, "disable-single-account-mode", false, "If set to true, disables single account mode. The --single-account-mode-domain property will be ignored and every new user will have a separate NetBird account.")
mgmtCmd.Flags().StringVar(&certFile, "cert-file", "", "Location of your SSL certificate. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect") mgmtCmd.Flags().StringVar(&certFile, "cert-file", "", "Location of your SSL certificate. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect")
mgmtCmd.Flags().StringVar(&certKey, "cert-key", "", "Location of your SSL certificate private key. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect") mgmtCmd.Flags().StringVar(&certKey, "cert-key", "", "Location of your SSL certificate private key. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect")
mgmtCmd.Flags().BoolVar(&disableMetrics, "disable-anonymous-metrics", false, "disables push of anonymous usage metrics to NetBird") mgmtCmd.Flags().BoolVar(&disableMetrics, "disable-anonymous-metrics", false, "disables push of anonymous usage metrics to NetBird")
mgmtCmd.Flags().StringVar(&dnsDomain, "dns-domain", defaultSingleAccModeDomain, fmt.Sprintf("Domain used for peer resolution. This is appended to the peer's name, e.g. pi-server. %s. Max lenght is 192 characters to allow appending to a peer name with up to 63 characters.", defaultSingleAccModeDomain))
rootCmd.MarkFlagRequired("config") //nolint rootCmd.MarkFlagRequired("config") //nolint
rootCmd.PersistentFlags().StringVar(&logLevel, "log-level", "info", "") rootCmd.PersistentFlags().StringVar(&logLevel, "log-level", "info", "")

View File

@@ -1,15 +1,15 @@
// Code generated by protoc-gen-go. DO NOT EDIT. // Code generated by protoc-gen-go. DO NOT EDIT.
// versions: // versions:
// protoc-gen-go v1.26.0 // protoc-gen-go v1.26.0
// protoc v3.12.4 // protoc v3.21.9
// source: management.proto // source: management.proto
package proto package proto
import ( import (
timestamp "github.com/golang/protobuf/ptypes/timestamp"
protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl" protoimpl "google.golang.org/protobuf/runtime/protoimpl"
timestamppb "google.golang.org/protobuf/types/known/timestamppb"
reflect "reflect" reflect "reflect"
sync "sync" sync "sync"
) )
@@ -611,7 +611,7 @@ type ServerKeyResponse struct {
// Server's Wireguard public key // Server's Wireguard public key
Key string `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"` Key string `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"`
// Key expiration timestamp after which the key should be fetched again by the client // Key expiration timestamp after which the key should be fetched again by the client
ExpiresAt *timestamp.Timestamp `protobuf:"bytes,2,opt,name=expiresAt,proto3" json:"expiresAt,omitempty"` ExpiresAt *timestamppb.Timestamp `protobuf:"bytes,2,opt,name=expiresAt,proto3" json:"expiresAt,omitempty"`
// Version of the Wiretrustee Management Service protocol // Version of the Wiretrustee Management Service protocol
Version int32 `protobuf:"varint,3,opt,name=version,proto3" json:"version,omitempty"` Version int32 `protobuf:"varint,3,opt,name=version,proto3" json:"version,omitempty"`
} }
@@ -655,7 +655,7 @@ func (x *ServerKeyResponse) GetKey() string {
return "" return ""
} }
func (x *ServerKeyResponse) GetExpiresAt() *timestamp.Timestamp { func (x *ServerKeyResponse) GetExpiresAt() *timestamppb.Timestamp {
if x != nil { if x != nil {
return x.ExpiresAt return x.ExpiresAt
} }
@@ -982,6 +982,8 @@ type NetworkMap struct {
RemotePeersIsEmpty bool `protobuf:"varint,4,opt,name=remotePeersIsEmpty,proto3" json:"remotePeersIsEmpty,omitempty"` RemotePeersIsEmpty bool `protobuf:"varint,4,opt,name=remotePeersIsEmpty,proto3" json:"remotePeersIsEmpty,omitempty"`
// List of routes to be applied // List of routes to be applied
Routes []*Route `protobuf:"bytes,5,rep,name=Routes,proto3" json:"Routes,omitempty"` Routes []*Route `protobuf:"bytes,5,rep,name=Routes,proto3" json:"Routes,omitempty"`
// DNS config to be applied
DNSConfig *DNSConfig `protobuf:"bytes,6,opt,name=DNSConfig,proto3" json:"DNSConfig,omitempty"`
} }
func (x *NetworkMap) Reset() { func (x *NetworkMap) Reset() {
@@ -1051,6 +1053,13 @@ func (x *NetworkMap) GetRoutes() []*Route {
return nil return nil
} }
func (x *NetworkMap) GetDNSConfig() *DNSConfig {
if x != nil {
return x.DNSConfig
}
return nil
}
// RemotePeerConfig represents a configuration of a remote peer. // RemotePeerConfig represents a configuration of a remote peer.
// The properties are used to configure Wireguard Peers sections // The properties are used to configure Wireguard Peers sections
type RemotePeerConfig struct { type RemotePeerConfig struct {
@@ -1467,6 +1476,334 @@ func (x *Route) GetNetID() string {
return "" return ""
} }
// DNSConfig represents a dns.Update
type DNSConfig struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
ServiceEnable bool `protobuf:"varint,1,opt,name=ServiceEnable,proto3" json:"ServiceEnable,omitempty"`
NameServerGroups []*NameServerGroup `protobuf:"bytes,2,rep,name=NameServerGroups,proto3" json:"NameServerGroups,omitempty"`
CustomZones []*CustomZone `protobuf:"bytes,3,rep,name=CustomZones,proto3" json:"CustomZones,omitempty"`
}
func (x *DNSConfig) Reset() {
*x = DNSConfig{}
if protoimpl.UnsafeEnabled {
mi := &file_management_proto_msgTypes[20]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *DNSConfig) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*DNSConfig) ProtoMessage() {}
func (x *DNSConfig) ProtoReflect() protoreflect.Message {
mi := &file_management_proto_msgTypes[20]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use DNSConfig.ProtoReflect.Descriptor instead.
func (*DNSConfig) Descriptor() ([]byte, []int) {
return file_management_proto_rawDescGZIP(), []int{20}
}
func (x *DNSConfig) GetServiceEnable() bool {
if x != nil {
return x.ServiceEnable
}
return false
}
func (x *DNSConfig) GetNameServerGroups() []*NameServerGroup {
if x != nil {
return x.NameServerGroups
}
return nil
}
func (x *DNSConfig) GetCustomZones() []*CustomZone {
if x != nil {
return x.CustomZones
}
return nil
}
// CustomZone represents a dns.CustomZone
type CustomZone struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Domain string `protobuf:"bytes,1,opt,name=Domain,proto3" json:"Domain,omitempty"`
Records []*SimpleRecord `protobuf:"bytes,2,rep,name=Records,proto3" json:"Records,omitempty"`
}
func (x *CustomZone) Reset() {
*x = CustomZone{}
if protoimpl.UnsafeEnabled {
mi := &file_management_proto_msgTypes[21]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *CustomZone) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*CustomZone) ProtoMessage() {}
func (x *CustomZone) ProtoReflect() protoreflect.Message {
mi := &file_management_proto_msgTypes[21]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use CustomZone.ProtoReflect.Descriptor instead.
func (*CustomZone) Descriptor() ([]byte, []int) {
return file_management_proto_rawDescGZIP(), []int{21}
}
func (x *CustomZone) GetDomain() string {
if x != nil {
return x.Domain
}
return ""
}
func (x *CustomZone) GetRecords() []*SimpleRecord {
if x != nil {
return x.Records
}
return nil
}
// SimpleRecord represents a dns.SimpleRecord
type SimpleRecord struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Name string `protobuf:"bytes,1,opt,name=Name,proto3" json:"Name,omitempty"`
Type int64 `protobuf:"varint,2,opt,name=Type,proto3" json:"Type,omitempty"`
Class string `protobuf:"bytes,3,opt,name=Class,proto3" json:"Class,omitempty"`
TTL int64 `protobuf:"varint,4,opt,name=TTL,proto3" json:"TTL,omitempty"`
RData string `protobuf:"bytes,5,opt,name=RData,proto3" json:"RData,omitempty"`
}
func (x *SimpleRecord) Reset() {
*x = SimpleRecord{}
if protoimpl.UnsafeEnabled {
mi := &file_management_proto_msgTypes[22]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *SimpleRecord) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*SimpleRecord) ProtoMessage() {}
func (x *SimpleRecord) ProtoReflect() protoreflect.Message {
mi := &file_management_proto_msgTypes[22]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use SimpleRecord.ProtoReflect.Descriptor instead.
func (*SimpleRecord) Descriptor() ([]byte, []int) {
return file_management_proto_rawDescGZIP(), []int{22}
}
func (x *SimpleRecord) GetName() string {
if x != nil {
return x.Name
}
return ""
}
func (x *SimpleRecord) GetType() int64 {
if x != nil {
return x.Type
}
return 0
}
func (x *SimpleRecord) GetClass() string {
if x != nil {
return x.Class
}
return ""
}
func (x *SimpleRecord) GetTTL() int64 {
if x != nil {
return x.TTL
}
return 0
}
func (x *SimpleRecord) GetRData() string {
if x != nil {
return x.RData
}
return ""
}
// NameServerGroup represents a dns.NameServerGroup
type NameServerGroup struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
NameServers []*NameServer `protobuf:"bytes,1,rep,name=NameServers,proto3" json:"NameServers,omitempty"`
Primary bool `protobuf:"varint,2,opt,name=Primary,proto3" json:"Primary,omitempty"`
Domains []string `protobuf:"bytes,3,rep,name=Domains,proto3" json:"Domains,omitempty"`
}
func (x *NameServerGroup) Reset() {
*x = NameServerGroup{}
if protoimpl.UnsafeEnabled {
mi := &file_management_proto_msgTypes[23]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *NameServerGroup) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*NameServerGroup) ProtoMessage() {}
func (x *NameServerGroup) ProtoReflect() protoreflect.Message {
mi := &file_management_proto_msgTypes[23]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use NameServerGroup.ProtoReflect.Descriptor instead.
func (*NameServerGroup) Descriptor() ([]byte, []int) {
return file_management_proto_rawDescGZIP(), []int{23}
}
func (x *NameServerGroup) GetNameServers() []*NameServer {
if x != nil {
return x.NameServers
}
return nil
}
func (x *NameServerGroup) GetPrimary() bool {
if x != nil {
return x.Primary
}
return false
}
func (x *NameServerGroup) GetDomains() []string {
if x != nil {
return x.Domains
}
return nil
}
// NameServer represents a dns.NameServer
type NameServer struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
IP string `protobuf:"bytes,1,opt,name=IP,proto3" json:"IP,omitempty"`
NSType int64 `protobuf:"varint,2,opt,name=NSType,proto3" json:"NSType,omitempty"`
Port int64 `protobuf:"varint,3,opt,name=Port,proto3" json:"Port,omitempty"`
}
func (x *NameServer) Reset() {
*x = NameServer{}
if protoimpl.UnsafeEnabled {
mi := &file_management_proto_msgTypes[24]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *NameServer) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*NameServer) ProtoMessage() {}
func (x *NameServer) ProtoReflect() protoreflect.Message {
mi := &file_management_proto_msgTypes[24]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use NameServer.ProtoReflect.Descriptor instead.
func (*NameServer) Descriptor() ([]byte, []int) {
return file_management_proto_rawDescGZIP(), []int{24}
}
func (x *NameServer) GetIP() string {
if x != nil {
return x.IP
}
return ""
}
func (x *NameServer) GetNSType() int64 {
if x != nil {
return x.NSType
}
return 0
}
func (x *NameServer) GetPort() int64 {
if x != nil {
return x.Port
}
return 0
}
var File_management_proto protoreflect.FileDescriptor var File_management_proto protoreflect.FileDescriptor
var file_management_proto_rawDesc = []byte{ var file_management_proto_rawDesc = []byte{
@@ -1583,7 +1920,7 @@ var file_management_proto_rawDesc = []byte{
0x0a, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28,
0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53,
0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e,
0x66, 0x69, 0x67, 0x22, 0xf7, 0x01, 0x0a, 0x0a, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x66, 0x69, 0x67, 0x22, 0xac, 0x02, 0x0a, 0x0a, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d,
0x61, 0x70, 0x12, 0x16, 0x0a, 0x06, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x61, 0x70, 0x12, 0x16, 0x0a, 0x06, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x18, 0x01, 0x20, 0x01,
0x28, 0x04, 0x52, 0x06, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x12, 0x36, 0x0a, 0x0a, 0x70, 0x65, 0x28, 0x04, 0x52, 0x06, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x12, 0x36, 0x0a, 0x0a, 0x70, 0x65,
0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16,
@@ -1598,85 +1935,125 @@ var file_management_proto_rawDesc = []byte{
0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70,
0x74, 0x79, 0x12, 0x29, 0x0a, 0x06, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x05, 0x20, 0x03, 0x74, 0x79, 0x12, 0x29, 0x0a, 0x06, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x05, 0x20, 0x03,
0x28, 0x0b, 0x32, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e,
0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x06, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x22, 0x83, 0x01, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x06, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x33, 0x0a,
0x0a, 0x10, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b,
0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x44, 0x4e,
0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x1e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66,
0x0a, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x69, 0x67, 0x22, 0x83, 0x01, 0x0a, 0x10, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65,
0x28, 0x09, 0x52, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, 0x73, 0x12, 0x33, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62,
0x0a, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62,
0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x4b, 0x65, 0x79, 0x12, 0x1e, 0x0a, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70,
0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64,
0x66, 0x69, 0x67, 0x22, 0x49, 0x0a, 0x09, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x49, 0x70, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67,
0x12, 0x1e, 0x0a, 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d,
0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x73,
0x12, 0x1c, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0x49, 0x0a, 0x09, 0x53, 0x53, 0x48, 0x43,
0x01, 0x28, 0x0c, 0x52, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x22, 0x20, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1e, 0x0a, 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e, 0x61, 0x62,
0x0a, 0x1e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e,
0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x1c, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, 0x62, 0x4b,
0x22, 0xbf, 0x01, 0x0a, 0x17, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, 0x62,
0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x48, 0x0a, 0x08, 0x4b, 0x65, 0x79, 0x22, 0x20, 0x0a, 0x1e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74,
0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x2c, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x52, 0x65,
0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x44, 0x65, 0x76, 0x69, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0xbf, 0x01, 0x0a, 0x17, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65,
0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f,
0x6c, 0x6f, 0x77, 0x2e, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x52, 0x08, 0x50, 0x72, 0x77, 0x12, 0x48, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x18, 0x01, 0x20,
0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x42, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x01, 0x28, 0x0e, 0x32, 0x2c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74,
0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61,
0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x2e, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65,
0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x72, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x42, 0x0a, 0x0e, 0x50,
0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0x16, 0x0a, 0x08, 0x70, 0x72, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20,
0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x0a, 0x0a, 0x06, 0x48, 0x4f, 0x53, 0x54, 0x45, 0x44, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74,
0x10, 0x00, 0x22, 0xda, 0x01, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52,
0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22,
0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x16, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x0a, 0x0a, 0x06, 0x48,
0x44, 0x12, 0x22, 0x0a, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x4f, 0x53, 0x54, 0x45, 0x44, 0x10, 0x00, 0x22, 0xda, 0x01, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76,
0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x43, 0x6c,
0x65, 0x63, 0x72, 0x65, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x43, 0x6c,
0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x1a, 0x0a, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44, 0x12, 0x22, 0x0a, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74,
0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x43, 0x6c,
0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x2e, 0x0a, 0x12, 0x44, 0x65, 0x76, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f,
0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61,
0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x69, 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x18, 0x04,
0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x24, 0x0a, 0x0d, 0x54, 0x6f, 0x6b, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x2e,
0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x0a, 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70,
0x52, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x22, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x44, 0x65, 0x76, 0x69,
0xb5, 0x01, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x44, 0x18, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x24,
0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x4e, 0x65, 0x74, 0x0a, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18,
0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x4e, 0x65, 0x74, 0x77, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70,
0x6f, 0x72, 0x6b, 0x12, 0x20, 0x0a, 0x0b, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x6f, 0x69, 0x6e, 0x74, 0x22, 0xb5, 0x01, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e,
0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x0a, 0x02, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18,
0x6b, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x65, 0x65, 0x72, 0x18, 0x04, 0x20, 0x0a, 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52,
0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x65, 0x65, 0x72, 0x12, 0x16, 0x0a, 0x06, 0x4d, 0x65, 0x74, 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x20, 0x0a, 0x0b, 0x4e, 0x65, 0x74, 0x77,
0x72, 0x69, 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x4e,
0x63, 0x12, 0x1e, 0x0a, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, 0x61, 0x64, 0x65, 0x18, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x65,
0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, 0x61, 0x64, 0x65, 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x65, 0x65, 0x72, 0x12, 0x16,
0x65, 0x12, 0x14, 0x0a, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x0a, 0x06, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06,
0x52, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x32, 0xf7, 0x02, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x12, 0x1e, 0x0a, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65,
0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x72, 0x61, 0x64, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x4d, 0x61, 0x73, 0x71,
0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x75, 0x65, 0x72, 0x61, 0x64, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x18,
0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x22, 0xb4, 0x01, 0x0a,
0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x24, 0x0a, 0x0d, 0x53, 0x65,
0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28,
0x67, 0x65, 0x22, 0x00, 0x12, 0x46, 0x0a, 0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d, 0x08, 0x52, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65,
0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x12, 0x47, 0x0a, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72,
0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x6f, 0x75, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e,
0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76,
0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, 0x0c, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x52, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72,
0x47, 0x65, 0x74, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x12, 0x38, 0x0a, 0x0b, 0x43, 0x75, 0x73,
0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16,
0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x72, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x75, 0x73, 0x74,
0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x52, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f,
0x12, 0x33, 0x0a, 0x09, 0x69, 0x73, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e, 0x6e, 0x65, 0x73, 0x22, 0x58, 0x0a, 0x0a, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e,
0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28,
0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x32, 0x0a, 0x07, 0x52, 0x65, 0x63,
0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x5a, 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, 0x6f, 0x72, 0x64, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e,
0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65,
0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x63, 0x6f, 0x72, 0x64, 0x52, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x22, 0x74, 0x0a,
0x0c, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x12, 0x12, 0x0a,
0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d,
0x65, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52,
0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x18, 0x03,
0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x54,
0x54, 0x4c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x03, 0x54, 0x54, 0x4c, 0x12, 0x14, 0x0a,
0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x52, 0x44,
0x61, 0x74, 0x61, 0x22, 0x7f, 0x0a, 0x0f, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65,
0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x38, 0x0a, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65,
0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61,
0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72,
0x76, 0x65, 0x72, 0x52, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73,
0x12, 0x18, 0x0a, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28,
0x08, 0x52, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f,
0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d,
0x61, 0x69, 0x6e, 0x73, 0x22, 0x48, 0x0a, 0x0a, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76,
0x65, 0x72, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02,
0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01,
0x28, 0x03, 0x52, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f,
0x72, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x32, 0xf7,
0x02, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72,
0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e,
0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79,
0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61,
0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74,
0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x46, 0x0a, 0x04, 0x53,
0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74,
0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67,
0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45,
0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22,
0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, 0x0c, 0x47, 0x65, 0x74, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72,
0x74, 0x6f, 0x33, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74,
0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d,
0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73,
0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x09, 0x69, 0x73, 0x48, 0x65, 0x61,
0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e,
0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65,
0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x5a, 0x0a, 0x1a,
0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69,
0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e,
0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65,
0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67,
0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d,
0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f,
0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
} }
var ( var (
@@ -1692,7 +2069,7 @@ func file_management_proto_rawDescGZIP() []byte {
} }
var file_management_proto_enumTypes = make([]protoimpl.EnumInfo, 2) var file_management_proto_enumTypes = make([]protoimpl.EnumInfo, 2)
var file_management_proto_msgTypes = make([]protoimpl.MessageInfo, 20) var file_management_proto_msgTypes = make([]protoimpl.MessageInfo, 25)
var file_management_proto_goTypes = []interface{}{ var file_management_proto_goTypes = []interface{}{
(HostConfig_Protocol)(0), // 0: management.HostConfig.Protocol (HostConfig_Protocol)(0), // 0: management.HostConfig.Protocol
(DeviceAuthorizationFlowProvider)(0), // 1: management.DeviceAuthorizationFlow.provider (DeviceAuthorizationFlowProvider)(0), // 1: management.DeviceAuthorizationFlow.provider
@@ -1716,7 +2093,12 @@ var file_management_proto_goTypes = []interface{}{
(*DeviceAuthorizationFlow)(nil), // 19: management.DeviceAuthorizationFlow (*DeviceAuthorizationFlow)(nil), // 19: management.DeviceAuthorizationFlow
(*ProviderConfig)(nil), // 20: management.ProviderConfig (*ProviderConfig)(nil), // 20: management.ProviderConfig
(*Route)(nil), // 21: management.Route (*Route)(nil), // 21: management.Route
(*timestamp.Timestamp)(nil), // 22: google.protobuf.Timestamp (*DNSConfig)(nil), // 22: management.DNSConfig
(*CustomZone)(nil), // 23: management.CustomZone
(*SimpleRecord)(nil), // 24: management.SimpleRecord
(*NameServerGroup)(nil), // 25: management.NameServerGroup
(*NameServer)(nil), // 26: management.NameServer
(*timestamppb.Timestamp)(nil), // 27: google.protobuf.Timestamp
} }
var file_management_proto_depIdxs = []int32{ var file_management_proto_depIdxs = []int32{
11, // 0: management.SyncResponse.wiretrusteeConfig:type_name -> management.WiretrusteeConfig 11, // 0: management.SyncResponse.wiretrusteeConfig:type_name -> management.WiretrusteeConfig
@@ -1727,7 +2109,7 @@ var file_management_proto_depIdxs = []int32{
6, // 5: management.LoginRequest.peerKeys:type_name -> management.PeerKeys 6, // 5: management.LoginRequest.peerKeys:type_name -> management.PeerKeys
11, // 6: management.LoginResponse.wiretrusteeConfig:type_name -> management.WiretrusteeConfig 11, // 6: management.LoginResponse.wiretrusteeConfig:type_name -> management.WiretrusteeConfig
14, // 7: management.LoginResponse.peerConfig:type_name -> management.PeerConfig 14, // 7: management.LoginResponse.peerConfig:type_name -> management.PeerConfig
22, // 8: management.ServerKeyResponse.expiresAt:type_name -> google.protobuf.Timestamp 27, // 8: management.ServerKeyResponse.expiresAt:type_name -> google.protobuf.Timestamp
12, // 9: management.WiretrusteeConfig.stuns:type_name -> management.HostConfig 12, // 9: management.WiretrusteeConfig.stuns:type_name -> management.HostConfig
13, // 10: management.WiretrusteeConfig.turns:type_name -> management.ProtectedHostConfig 13, // 10: management.WiretrusteeConfig.turns:type_name -> management.ProtectedHostConfig
12, // 11: management.WiretrusteeConfig.signal:type_name -> management.HostConfig 12, // 11: management.WiretrusteeConfig.signal:type_name -> management.HostConfig
@@ -1737,24 +2119,29 @@ var file_management_proto_depIdxs = []int32{
14, // 15: management.NetworkMap.peerConfig:type_name -> management.PeerConfig 14, // 15: management.NetworkMap.peerConfig:type_name -> management.PeerConfig
16, // 16: management.NetworkMap.remotePeers:type_name -> management.RemotePeerConfig 16, // 16: management.NetworkMap.remotePeers:type_name -> management.RemotePeerConfig
21, // 17: management.NetworkMap.Routes:type_name -> management.Route 21, // 17: management.NetworkMap.Routes:type_name -> management.Route
17, // 18: management.RemotePeerConfig.sshConfig:type_name -> management.SSHConfig 22, // 18: management.NetworkMap.DNSConfig:type_name -> management.DNSConfig
1, // 19: management.DeviceAuthorizationFlow.Provider:type_name -> management.DeviceAuthorizationFlow.provider 17, // 19: management.RemotePeerConfig.sshConfig:type_name -> management.SSHConfig
20, // 20: management.DeviceAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig 1, // 20: management.DeviceAuthorizationFlow.Provider:type_name -> management.DeviceAuthorizationFlow.provider
2, // 21: management.ManagementService.Login:input_type -> management.EncryptedMessage 20, // 21: management.DeviceAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig
2, // 22: management.ManagementService.Sync:input_type -> management.EncryptedMessage 25, // 22: management.DNSConfig.NameServerGroups:type_name -> management.NameServerGroup
10, // 23: management.ManagementService.GetServerKey:input_type -> management.Empty 23, // 23: management.DNSConfig.CustomZones:type_name -> management.CustomZone
10, // 24: management.ManagementService.isHealthy:input_type -> management.Empty 24, // 24: management.CustomZone.Records:type_name -> management.SimpleRecord
2, // 25: management.ManagementService.GetDeviceAuthorizationFlow:input_type -> management.EncryptedMessage 26, // 25: management.NameServerGroup.NameServers:type_name -> management.NameServer
2, // 26: management.ManagementService.Login:output_type -> management.EncryptedMessage 2, // 26: management.ManagementService.Login:input_type -> management.EncryptedMessage
2, // 27: management.ManagementService.Sync:output_type -> management.EncryptedMessage 2, // 27: management.ManagementService.Sync:input_type -> management.EncryptedMessage
9, // 28: management.ManagementService.GetServerKey:output_type -> management.ServerKeyResponse 10, // 28: management.ManagementService.GetServerKey:input_type -> management.Empty
10, // 29: management.ManagementService.isHealthy:output_type -> management.Empty 10, // 29: management.ManagementService.isHealthy:input_type -> management.Empty
2, // 30: management.ManagementService.GetDeviceAuthorizationFlow:output_type -> management.EncryptedMessage 2, // 30: management.ManagementService.GetDeviceAuthorizationFlow:input_type -> management.EncryptedMessage
26, // [26:31] is the sub-list for method output_type 2, // 31: management.ManagementService.Login:output_type -> management.EncryptedMessage
21, // [21:26] is the sub-list for method input_type 2, // 32: management.ManagementService.Sync:output_type -> management.EncryptedMessage
21, // [21:21] is the sub-list for extension type_name 9, // 33: management.ManagementService.GetServerKey:output_type -> management.ServerKeyResponse
21, // [21:21] is the sub-list for extension extendee 10, // 34: management.ManagementService.isHealthy:output_type -> management.Empty
0, // [0:21] is the sub-list for field type_name 2, // 35: management.ManagementService.GetDeviceAuthorizationFlow:output_type -> management.EncryptedMessage
31, // [31:36] is the sub-list for method output_type
26, // [26:31] is the sub-list for method input_type
26, // [26:26] is the sub-list for extension type_name
26, // [26:26] is the sub-list for extension extendee
0, // [0:26] is the sub-list for field type_name
} }
func init() { file_management_proto_init() } func init() { file_management_proto_init() }
@@ -2003,6 +2390,66 @@ func file_management_proto_init() {
return nil return nil
} }
} }
file_management_proto_msgTypes[20].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*DNSConfig); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_management_proto_msgTypes[21].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*CustomZone); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_management_proto_msgTypes[22].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*SimpleRecord); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_management_proto_msgTypes[23].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*NameServerGroup); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_management_proto_msgTypes[24].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*NameServer); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
} }
type x struct{} type x struct{}
out := protoimpl.TypeBuilder{ out := protoimpl.TypeBuilder{
@@ -2010,7 +2457,7 @@ func file_management_proto_init() {
GoPackagePath: reflect.TypeOf(x{}).PkgPath(), GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_management_proto_rawDesc, RawDescriptor: file_management_proto_rawDesc,
NumEnums: 2, NumEnums: 2,
NumMessages: 20, NumMessages: 25,
NumExtensions: 0, NumExtensions: 0,
NumServices: 1, NumServices: 1,
}, },

View File

@@ -178,6 +178,9 @@ message NetworkMap {
// List of routes to be applied // List of routes to be applied
repeated Route Routes = 5; repeated Route Routes = 5;
// DNS config to be applied
DNSConfig DNSConfig = 6;
} }
// RemotePeerConfig represents a configuration of a remote peer. // RemotePeerConfig represents a configuration of a remote peer.
@@ -247,3 +250,39 @@ message Route {
bool Masquerade = 6; bool Masquerade = 6;
string NetID = 7; string NetID = 7;
} }
// DNSConfig represents a dns.Update
message DNSConfig {
bool ServiceEnable = 1;
repeated NameServerGroup NameServerGroups = 2;
repeated CustomZone CustomZones = 3;
}
// CustomZone represents a dns.CustomZone
message CustomZone {
string Domain = 1;
repeated SimpleRecord Records = 2;
}
// SimpleRecord represents a dns.SimpleRecord
message SimpleRecord {
string Name = 1;
int64 Type = 2;
string Class = 3;
int64 TTL = 4;
string RData = 5;
}
// NameServerGroup represents a dns.NameServerGroup
message NameServerGroup {
repeated NameServer NameServers = 1;
bool Primary = 2;
repeated string Domains = 3;
}
// NameServer represents a dns.NameServer
message NameServer {
string IP = 1;
int64 NSType = 2;
int64 Port = 3;
}

View File

@@ -15,6 +15,7 @@ import (
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"math/rand" "math/rand"
"net/netip"
"reflect" "reflect"
"regexp" "regexp"
"strings" "strings"
@@ -37,7 +38,6 @@ func cacheEntryExpiration() time.Duration {
type AccountManager interface { type AccountManager interface {
GetOrCreateAccountByUser(userId, domain string) (*Account, error) GetOrCreateAccountByUser(userId, domain string) (*Account, error)
GetAccountByUser(userId string) (*Account, error)
CreateSetupKey( CreateSetupKey(
accountId string, accountId string,
keyName string, keyName string,
@@ -47,17 +47,16 @@ type AccountManager interface {
) (*SetupKey, error) ) (*SetupKey, error)
SaveSetupKey(accountID string, key *SetupKey) (*SetupKey, error) SaveSetupKey(accountID string, key *SetupKey) (*SetupKey, error)
CreateUser(accountID string, key *UserInfo) (*UserInfo, error) CreateUser(accountID string, key *UserInfo) (*UserInfo, error)
ListSetupKeys(accountID string) ([]*SetupKey, error) ListSetupKeys(accountID, userID string) ([]*SetupKey, error)
SaveUser(accountID string, key *User) (*UserInfo, error) SaveUser(accountID string, key *User) (*UserInfo, error)
GetSetupKey(accountID, keyID string) (*SetupKey, error) GetSetupKey(accountID, userID, keyID string) (*SetupKey, error)
GetAccountById(accountId string) (*Account, error) GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error)
GetAccountByUserOrAccountId(userId, accountId, domain string) (*Account, error)
GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, error) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, error)
IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error)
AccountExists(accountId string) (*bool, error) AccountExists(accountId string) (*bool, error)
GetPeer(peerKey string) (*Peer, error) GetPeer(peerKey string) (*Peer, error)
GetPeers(accountID, userID string) ([]*Peer, error)
MarkPeerConnected(peerKey string, connected bool) error MarkPeerConnected(peerKey string, connected bool) error
RenamePeer(accountId string, peerKey string, newName string) (*Peer, error)
DeletePeer(accountId string, peerKey string) (*Peer, error) DeletePeer(accountId string, peerKey string) (*Peer, error)
GetPeerByIP(accountId string, peerIP string) (*Peer, error) GetPeerByIP(accountId string, peerIP string) (*Peer, error)
UpdatePeer(accountID string, peer *Peer) (*Peer, error) UpdatePeer(accountID string, peer *Peer) (*Peer, error)
@@ -66,7 +65,7 @@ type AccountManager interface {
AddPeer(setupKey string, userId string, peer *Peer) (*Peer, error) AddPeer(setupKey string, userId string, peer *Peer) (*Peer, error)
UpdatePeerMeta(peerKey string, meta PeerSystemMeta) error UpdatePeerMeta(peerKey string, meta PeerSystemMeta) error
UpdatePeerSSHKey(peerKey string, sshKey string) error UpdatePeerSSHKey(peerKey string, sshKey string) error
GetUsersFromAccount(accountId string) ([]*UserInfo, error) GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error)
GetGroup(accountId, groupID string) (*Group, error) GetGroup(accountId, groupID string) (*Group, error)
SaveGroup(accountId string, group *Group) error SaveGroup(accountId string, group *Group) error
UpdateGroup(accountID string, groupID string, operations []GroupUpdateOperation) (*Group, error) UpdateGroup(accountID string, groupID string, operations []GroupUpdateOperation) (*Group, error)
@@ -75,19 +74,19 @@ type AccountManager interface {
GroupAddPeer(accountId, groupID, peerKey string) error GroupAddPeer(accountId, groupID, peerKey string) error
GroupDeletePeer(accountId, groupID, peerKey string) error GroupDeletePeer(accountId, groupID, peerKey string) error
GroupListPeers(accountId, groupID string) ([]*Peer, error) GroupListPeers(accountId, groupID string) ([]*Peer, error)
GetRule(accountId, ruleID string) (*Rule, error) GetRule(accountID, ruleID, userID string) (*Rule, error)
SaveRule(accountID string, rule *Rule) error SaveRule(accountID string, rule *Rule) error
UpdateRule(accountID string, ruleID string, operations []RuleUpdateOperation) (*Rule, error) UpdateRule(accountID string, ruleID string, operations []RuleUpdateOperation) (*Rule, error)
DeleteRule(accountId, ruleID string) error DeleteRule(accountId, ruleID string) error
ListRules(accountId string) ([]*Rule, error) ListRules(accountID, userID string) ([]*Rule, error)
GetRoute(accountID, routeID string) (*route.Route, error) GetRoute(accountID, routeID, userID string) (*route.Route, error)
CreateRoute(accountID string, prefix, peer, description, netID string, masquerade bool, metric int, enabled bool) (*route.Route, error) CreateRoute(accountID string, prefix, peer, description, netID string, masquerade bool, metric int, enabled bool) (*route.Route, error)
SaveRoute(accountID string, route *route.Route) error SaveRoute(accountID string, route *route.Route) error
UpdateRoute(accountID string, routeID string, operations []RouteUpdateOperation) (*route.Route, error) UpdateRoute(accountID string, routeID string, operations []RouteUpdateOperation) (*route.Route, error)
DeleteRoute(accountID, routeID string) error DeleteRoute(accountID, routeID string) error
ListRoutes(accountID string) ([]*route.Route, error) ListRoutes(accountID, userID string) ([]*route.Route, error)
GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error)
CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, enabled bool) (*nbdns.NameServerGroup, error) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool) (*nbdns.NameServerGroup, error)
SaveNameServerGroup(accountID string, nsGroupToSave *nbdns.NameServerGroup) error SaveNameServerGroup(accountID string, nsGroupToSave *nbdns.NameServerGroup) error
UpdateNameServerGroup(accountID, nsGroupID string, operations []NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) UpdateNameServerGroup(accountID, nsGroupID string, operations []NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error)
DeleteNameServerGroup(accountID, nsGroupID string) error DeleteNameServerGroup(accountID, nsGroupID string) error
@@ -96,8 +95,6 @@ type AccountManager interface {
type DefaultAccountManager struct { type DefaultAccountManager struct {
Store Store Store Store
// mux to synchronise account operations (e.g. generating Peer IP address inside the Network)
mux sync.Mutex
// cacheMux and cacheLoading helps to make sure that only a single cache reload runs at a time per accountID // cacheMux and cacheLoading helps to make sure that only a single cache reload runs at a time per accountID
cacheMux sync.Mutex cacheMux sync.Mutex
// cacheLoading keeps the accountIDs that are currently reloading. The accountID has to be removed once cache has been reloaded // cacheLoading keeps the accountIDs that are currently reloading. The accountID has to be removed once cache has been reloaded
@@ -106,6 +103,15 @@ type DefaultAccountManager struct {
idpManager idp.Manager idpManager idp.Manager
cacheManager cache.CacheInterface[[]*idp.UserData] cacheManager cache.CacheInterface[[]*idp.UserData]
ctx context.Context ctx context.Context
// singleAccountMode indicates whether the instance has a single account.
// If true, then every new user will end up under the same account.
// This value will be set to false if management service has more than one account.
singleAccountMode bool
// singleAccountModeDomain is a domain to use in singleAccountMode setup
singleAccountModeDomain string
// dnsDomain is used for peer resolution. This is appended to the peer's name
dnsDomain string
} }
// Account represents a unique account of the system // Account represents a unique account of the system
@@ -135,6 +141,153 @@ type UserInfo struct {
Status string `json:"-"` Status string `json:"-"`
} }
// GetPeersRoutes returns all active routes of provided peers
func (a *Account) GetPeersRoutes(givenPeers []*Peer) []*route.Route {
//TODO Peer.ID migration: we will need to replace search by Peer.ID here
routes := make([]*route.Route, 0)
for _, peer := range givenPeers {
peerRoutes := a.GetPeerRoutes(peer.Key)
activeRoutes := make([]*route.Route, 0)
for _, pr := range peerRoutes {
if pr.Enabled {
activeRoutes = append(activeRoutes, pr)
}
}
if len(activeRoutes) > 0 {
routes = append(routes, activeRoutes...)
}
}
return routes
}
// GetPeerRoutes returns a list of routes of a given peer
func (a *Account) GetPeerRoutes(peerPubKey string) []*route.Route {
//TODO Peer.ID migration: we will need to replace search by Peer.ID here
var routes []*route.Route
for _, r := range a.Routes {
if r.Peer == peerPubKey {
routes = append(routes, r)
continue
}
}
return routes
}
// GetRoutesByPrefix return list of routes by account and route prefix
func (a *Account) GetRoutesByPrefix(prefix netip.Prefix) []*route.Route {
var routes []*route.Route
for _, r := range a.Routes {
if r.Network.String() == prefix.String() {
routes = append(routes, r)
}
}
return routes
}
// GetPeerRules returns a list of source or destination rules of a given peer.
func (a *Account) GetPeerRules(peerPubKey string) (srcRules []*Rule, dstRules []*Rule) {
// Rules are group based so there is no direct access to peers.
// First, find all groups that the given peer belongs to
peerGroups := make(map[string]struct{})
for s, group := range a.Groups {
for _, peer := range group.Peers {
if peerPubKey == peer {
peerGroups[s] = struct{}{}
break
}
}
}
// Second, find all rules that have discovered source and destination groups
srcRulesMap := make(map[string]*Rule)
dstRulesMap := make(map[string]*Rule)
for _, rule := range a.Rules {
for _, g := range rule.Source {
if _, ok := peerGroups[g]; ok && srcRulesMap[rule.ID] == nil {
srcRules = append(srcRules, rule)
srcRulesMap[rule.ID] = rule
}
}
for _, g := range rule.Destination {
if _, ok := peerGroups[g]; ok && dstRulesMap[rule.ID] == nil {
dstRules = append(dstRules, rule)
dstRulesMap[rule.ID] = rule
}
}
}
return srcRules, dstRules
}
// GetPeers returns a list of all Account peers
func (a *Account) GetPeers() []*Peer {
var peers []*Peer
for _, peer := range a.Peers {
peers = append(peers, peer)
}
return peers
}
// UpdatePeer saves new or replaces existing peer
func (a *Account) UpdatePeer(update *Peer) {
//TODO Peer.ID migration: we will need to replace search by Peer.ID here
a.Peers[update.Key] = update
}
// DeletePeer deletes peer from the account cleaning up all the references
func (a *Account) DeletePeer(peerPubKey string) {
// TODO Peer.ID migration: we will need to replace search by Peer.ID here
// delete peer from groups
for _, g := range a.Groups {
for i, pk := range g.Peers {
if pk == peerPubKey {
g.Peers = append(g.Peers[:i], g.Peers[i+1:]...)
break
}
}
}
delete(a.Peers, peerPubKey)
a.Network.IncSerial()
}
// FindPeerByPubKey looks for a Peer by provided WireGuard public key in the Account or returns error if it wasn't found.
// It will return an object copy of the peer.
func (a *Account) FindPeerByPubKey(peerPubKey string) (*Peer, error) {
for _, peer := range a.Peers {
if peer.Key == peerPubKey {
return peer.Copy(), nil
}
}
return nil, status.Errorf(codes.NotFound, "peer with the public key %s not found", peerPubKey)
}
// FindUser looks for a given user in the Account or returns error if user wasn't found.
func (a *Account) FindUser(userID string) (*User, error) {
user := a.Users[userID]
if user == nil {
return nil, Errorf(UserNotFound, "user %s not found", userID)
}
return user, nil
}
// getPeerDNSLabels return the account's peers' dns labels
func (a *Account) getPeerDNSLabels() lookupMap {
existingLabels := make(lookupMap)
for _, peer := range a.Peers {
if peer.DNSLabel != "" {
existingLabels[peer.DNSLabel] = struct{}{}
}
}
return existingLabels
}
func (a *Account) Copy() *Account { func (a *Account) Copy() *Account {
peers := map[string]*Peer{} peers := map[string]*Peer{}
for id, peer := range a.Peers { for id, peer := range a.Peers {
@@ -172,16 +325,19 @@ func (a *Account) Copy() *Account {
} }
return &Account{ return &Account{
Id: a.Id, Id: a.Id,
CreatedBy: a.CreatedBy, CreatedBy: a.CreatedBy,
SetupKeys: setupKeys, Domain: a.Domain,
Network: a.Network.Copy(), DomainCategory: a.DomainCategory,
Peers: peers, IsDomainPrimaryAccount: a.IsDomainPrimaryAccount,
Users: users, SetupKeys: setupKeys,
Groups: groups, Network: a.Network.Copy(),
Rules: rules, Peers: peers,
Routes: routes, Users: users,
NameServerGroups: nsGroups, Groups: groups,
Rules: rules,
Routes: routes,
NameServerGroups: nsGroups,
} }
} }
@@ -195,36 +351,51 @@ func (a *Account) GetGroupAll() (*Group, error) {
} }
// BuildManager creates a new DefaultAccountManager with a provided Store // BuildManager creates a new DefaultAccountManager with a provided Store
func BuildManager( func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager,
store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager, singleAccountModeDomain string, dnsDomain string) (*DefaultAccountManager, error) {
) (*DefaultAccountManager, error) {
am := &DefaultAccountManager{ am := &DefaultAccountManager{
Store: store, Store: store,
mux: sync.Mutex{},
peersUpdateManager: peersUpdateManager, peersUpdateManager: peersUpdateManager,
idpManager: idpManager, idpManager: idpManager,
ctx: context.Background(), ctx: context.Background(),
cacheMux: sync.Mutex{}, cacheMux: sync.Mutex{},
cacheLoading: map[string]chan struct{}{}, cacheLoading: map[string]chan struct{}{},
dnsDomain: dnsDomain,
}
allAccounts := store.GetAllAccounts()
// enable single account mode only if configured by user and number of existing accounts is not grater than 1
am.singleAccountMode = singleAccountModeDomain != "" && len(allAccounts) <= 1
if am.singleAccountMode {
am.singleAccountModeDomain = singleAccountModeDomain
log.Infof("single account mode enabled, accounts number %d", len(allAccounts))
} else {
log.Infof("single account mode disabled, accounts number %d", len(allAccounts))
} }
// if account has not default group // if account doesn't have a default group
// we create 'all' group and add all peers into it // we create 'all' group and add all peers into it
// also we create default rule with source as destination // also we create default rule with source as destination
for _, account := range store.GetAllAccounts() { for _, account := range allAccounts {
shouldSave := false
_, err := account.GetGroupAll() _, err := account.GetGroupAll()
if err != nil { if err != nil {
addAllGroup(account) addAllGroup(account)
if err := store.SaveAccount(account); err != nil { shouldSave = true
}
if shouldSave {
err = store.SaveAccount(account)
if err != nil {
return nil, err return nil, err
} }
} }
} }
gocacheClient := gocache.New(CacheExpirationMax, 30*time.Minute) goCacheClient := gocache.New(CacheExpirationMax, 30*time.Minute)
gocacheStore := cacheStore.NewGoCache(gocacheClient) goCacheStore := cacheStore.NewGoCache(goCacheClient)
am.cacheManager = cache.NewLoadable[[]*idp.UserData](am.loadAccount, cache.New[[]*idp.UserData](gocacheStore)) am.cacheManager = cache.NewLoadable[[]*idp.UserData](am.loadAccount, cache.New[[]*idp.UserData](goCacheStore))
if !isNil(am.idpManager) { if !isNil(am.idpManager) {
go func() { go func() {
@@ -278,32 +449,17 @@ func (am *DefaultAccountManager) warmupIDPCache() error {
return nil return nil
} }
// GetAccountById returns an existing account using its ID or error (NotFound) if doesn't exist // GetAccountByUserOrAccountID looks for an account by user or accountID, if no account is provided and
func (am *DefaultAccountManager) GetAccountById(accountId string) (*Account, error) { // userID doesn't have an account associated with it, one account is created
am.mux.Lock() func (am *DefaultAccountManager) GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error) {
defer am.mux.Unlock() if accountID != "" {
return am.Store.GetAccount(accountID)
account, err := am.Store.GetAccount(accountId) } else if userID != "" {
if err != nil { account, err := am.GetOrCreateAccountByUser(userID, domain)
return nil, status.Errorf(codes.NotFound, "account not found")
}
return account, nil
}
// GetAccountByUserOrAccountId look for an account by user or account Id, if no account is provided and
// user id doesn't have an account associated with it, one account is created
func (am *DefaultAccountManager) GetAccountByUserOrAccountId(
userId, accountId, domain string,
) (*Account, error) {
if accountId != "" {
return am.GetAccountById(accountId)
} else if userId != "" {
account, err := am.GetOrCreateAccountByUser(userId, domain)
if err != nil { if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found using user id: %s", userId) return nil, status.Errorf(codes.NotFound, "account not found using user id: %s", userID)
} }
err = am.addAccountIDToIDPAppMeta(userId, account) err = am.addAccountIDToIDPAppMeta(userID, account)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -585,7 +741,7 @@ func (am *DefaultAccountManager) redeemInvite(account *Account, userID string) e
return status.Errorf(codes.NotFound, "user %s not found in the IdP", userID) return status.Errorf(codes.NotFound, "user %s not found in the IdP", userID)
} }
if user.AppMetadata.WTPendingInvite { if user.AppMetadata.WTPendingInvite != nil && *user.AppMetadata.WTPendingInvite {
log.Infof("redeeming invite for user %s account %s", userID, account.Id) log.Infof("redeeming invite for user %s account %s", userID, account.Id)
// User has already logged in, meaning that IdP should have set wt_pending_invite to false. // User has already logged in, meaning that IdP should have set wt_pending_invite to false.
// Our job is to just reload cache. // Our job is to just reload cache.
@@ -604,6 +760,15 @@ func (am *DefaultAccountManager) redeemInvite(account *Account, userID string) e
// GetAccountFromToken returns an account associated with this token // GetAccountFromToken returns an account associated with this token
func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, error) { func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, error) {
if am.singleAccountMode && am.singleAccountModeDomain != "" {
// This section is mostly related to self-hosted installations.
// We override incoming domain claims to group users under a single account.
claims.Domain = am.singleAccountModeDomain
claims.DomainCategory = PrivateCategory
log.Infof("overriding JWT Domain and DomainCategory claims since single account mode is enabled")
}
account, err := am.getAccountWithAuthorizationClaims(claims) account, err := am.getAccountWithAuthorizationClaims(claims)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -634,15 +799,13 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat
// Existing user + Existing account + Existing Indexed Domain -> Nothing changes // Existing user + Existing account + Existing Indexed Domain -> Nothing changes
// //
// Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain) // Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain)
func (am *DefaultAccountManager) getAccountWithAuthorizationClaims( func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtclaims.AuthorizationClaims) (*Account, error) {
claims jwtclaims.AuthorizationClaims,
) (*Account, error) {
// if Account ID is part of the claims // if Account ID is part of the claims
// it means that we've already classified the domain and user has an account // it means that we've already classified the domain and user has an account
if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) { if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) {
return am.GetAccountByUserOrAccountId(claims.UserId, claims.AccountId, claims.Domain) return am.GetAccountByUserOrAccountID(claims.UserId, claims.AccountId, claims.Domain)
} else if claims.AccountId != "" { } else if claims.AccountId != "" {
accountFromID, err := am.GetAccountById(claims.AccountId) accountFromID, err := am.Store.GetAccount(claims.AccountId)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -654,8 +817,8 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(
} }
} }
am.mux.Lock() unlock := am.Store.AcquireGlobalLock()
defer am.mux.Unlock() defer unlock()
// We checked if the domain has a primary account already // We checked if the domain has a primary account already
domainAccount, err := am.Store.GetAccountByPrivateDomain(claims.Domain) domainAccount, err := am.Store.GetAccountByPrivateDomain(claims.Domain)
@@ -664,7 +827,7 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(
return nil, err return nil, err
} }
account, err := am.Store.GetUserAccount(claims.UserId) account, err := am.Store.GetAccountByUser(claims.UserId)
if err == nil { if err == nil {
err = am.handleExistingUserAccount(account, domainAccount, claims) err = am.handleExistingUserAccount(account, domainAccount, claims)
if err != nil { if err != nil {
@@ -685,12 +848,13 @@ func isDomainValid(domain string) bool {
} }
// AccountExists checks whether account exists (returns true) or not (returns false) // AccountExists checks whether account exists (returns true) or not (returns false)
func (am *DefaultAccountManager) AccountExists(accountId string) (*bool, error) { func (am *DefaultAccountManager) AccountExists(accountID string) (*bool, error) {
am.mux.Lock()
defer am.mux.Unlock() unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
var res bool var res bool
_, err := am.Store.GetAccount(accountId) _, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound { if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
res = false res = false

View File

@@ -1,7 +1,11 @@
package server package server
import ( import (
"fmt"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/route"
"net" "net"
"reflect"
"sync" "sync"
"testing" "testing"
@@ -117,7 +121,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
t.Fatalf("expected to create an account for a user %s", userId) t.Fatalf("expected to create an account for a user %s", userId)
} }
account, err = manager.GetAccountByUser(userId) account, err = manager.Store.GetAccountByUser(userId)
if err != nil { if err != nil {
t.Errorf("expected to get existing account after creation, no account was found for a user %s", userId) t.Errorf("expected to get existing account after creation, no account was found for a user %s", userId)
} }
@@ -298,7 +302,7 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
manager, err := createManager(t) manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
initAccount, err := manager.GetAccountByUserOrAccountId(testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain) initAccount, err := manager.GetAccountByUserOrAccountID(testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain)
require.NoError(t, err, "create init user failed") require.NoError(t, err, "create init user failed")
if testCase.inputUpdateAttrs { if testCase.inputUpdateAttrs {
@@ -341,7 +345,7 @@ func TestAccountManager_PrivateAccount(t *testing.T) {
t.Fatalf("expected to create an account for a user %s", userId) t.Fatalf("expected to create an account for a user %s", userId)
} }
account, err = manager.GetAccountByUser(userId) account, err = manager.Store.GetAccountByUser(userId)
if err != nil { if err != nil {
t.Errorf("expected to get existing account after creation, no account was found for a user %s", userId) t.Errorf("expected to get existing account after creation, no account was found for a user %s", userId)
} }
@@ -397,7 +401,7 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
userId := "test_user" userId := "test_user"
account, err := manager.GetAccountByUserOrAccountId(userId, "", "") account, err := manager.GetAccountByUserOrAccountID(userId, "", "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -407,12 +411,12 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
accountId := account.Id accountId := account.Id
_, err = manager.GetAccountByUserOrAccountId("", accountId, "") _, err = manager.GetAccountByUserOrAccountID("", accountId, "")
if err != nil { if err != nil {
t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", accountId) t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", accountId)
} }
_, err = manager.GetAccountByUserOrAccountId("", "", "") _, err = manager.GetAccountByUserOrAccountID("", "", "")
if err == nil { if err == nil {
t.Errorf("expected an error when user and account IDs are empty") t.Errorf("expected an error when user and account IDs are empty")
} }
@@ -466,7 +470,7 @@ func TestAccountManager_GetAccount(t *testing.T) {
} }
// AddAccount has been already tested so we can assume it is correct and compare results // AddAccount has been already tested so we can assume it is correct and compare results
getAccount, err := manager.GetAccountById(expectedId) getAccount, err := manager.Store.GetAccount(account.Id)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
@@ -536,7 +540,7 @@ func TestAccountManager_AddPeer(t *testing.T) {
return return
} }
account, err = manager.GetAccountById(account.Id) account, err = manager.Store.GetAccount(account.Id)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
@@ -598,7 +602,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
return return
} }
account, err = manager.GetAccountById(account.Id) account, err = manager.Store.GetAccount(account.Id)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
@@ -676,7 +680,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
peer2 := getPeer() peer2 := getPeer()
peer3 := getPeer() peer3 := getPeer()
account, err = manager.GetAccountById(account.Id) account, err = manager.Store.GetAccount(account.Id)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
@@ -844,7 +848,7 @@ func TestAccountManager_DeletePeer(t *testing.T) {
return return
} }
account, err = manager.GetAccountById(account.Id) account, err = manager.Store.GetAccount(account.Id)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
@@ -874,7 +878,7 @@ func TestGetUsersFromAccount(t *testing.T) {
account.Users[user.Id] = user account.Users[user.Id] = user
} }
userInfos, err := manager.GetUsersFromAccount(accountId) userInfos, err := manager.GetUsersFromAccount(accountId, "1")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -957,17 +961,257 @@ func TestAccountManager_UpdatePeerMeta(t *testing.T) {
assert.Equal(t, newMeta, p.Meta) assert.Equal(t, newMeta, p.Meta)
} }
func TestAccount_GetPeerRules(t *testing.T) {
groups := map[string]*Group{
"group_1": {
ID: "group_1",
Name: "group_1",
Peers: []string{"peer-1", "peer-2"},
},
"group_2": {
ID: "group_2",
Name: "group_2",
Peers: []string{"peer-2", "peer-3"},
},
"group_3": {
ID: "group_3",
Name: "group_3",
Peers: []string{"peer-4"},
},
"group_4": {
ID: "group_4",
Name: "group_4",
Peers: []string{"peer-1"},
},
"group_5": {
ID: "group_5",
Name: "group_5",
Peers: []string{"peer-1"},
},
}
rules := map[string]*Rule{
"rule-1": {
ID: "rule-1",
Name: "rule-1",
Description: "rule-1",
Disabled: false,
Source: []string{"group_1", "group_5"},
Destination: []string{"group_2"},
Flow: 0,
},
"rule-2": {
ID: "rule-2",
Name: "rule-2",
Description: "rule-2",
Disabled: false,
Source: []string{"group_1"},
Destination: []string{"group_1"},
Flow: 0,
},
"rule-3": {
ID: "rule-3",
Name: "rule-3",
Description: "rule-3",
Disabled: false,
Source: []string{"group_3"},
Destination: []string{"group_3"},
Flow: 0,
},
}
account := &Account{
Groups: groups,
Rules: rules,
}
srcRules, dstRules := account.GetPeerRules("peer-1")
assert.Equal(t, 2, len(srcRules))
assert.Equal(t, 1, len(dstRules))
}
func TestFileStore_GetRoutesByPrefix(t *testing.T) {
_, prefix, err := route.ParseNetwork("192.168.64.0/24")
if err != nil {
t.Fatal(err)
}
account := &Account{
Routes: map[string]*route.Route{
"route-1": {
ID: "route-1",
Network: prefix,
NetID: "network-1",
Description: "network-1",
Peer: "peer-1",
NetworkType: 0,
Masquerade: false,
Metric: 999,
Enabled: true,
},
"route-2": {
ID: "route-2",
Network: prefix,
NetID: "network-1",
Description: "network-1",
Peer: "peer-2",
NetworkType: 0,
Masquerade: false,
Metric: 999,
Enabled: true,
},
},
}
routes := account.GetRoutesByPrefix(prefix)
assert.Len(t, routes, 2)
routeIDs := make(map[string]struct{}, 2)
for _, r := range routes {
routeIDs[r.ID] = struct{}{}
}
assert.Contains(t, routeIDs, "route-1")
assert.Contains(t, routeIDs, "route-2")
}
func TestAccount_GetPeersRoutes(t *testing.T) {
_, prefix, err := route.ParseNetwork("192.168.64.0/24")
if err != nil {
t.Fatal(err)
}
account := &Account{
Peers: map[string]*Peer{
"peer-1": {Key: "peer-1"}, "peer-2": {Key: "peer-2"}, "peer-3": {Key: "peer-1"},
},
Routes: map[string]*route.Route{
"route-1": {
ID: "route-1",
Network: prefix,
NetID: "network-1",
Description: "network-1",
Peer: "peer-1",
NetworkType: 0,
Masquerade: false,
Metric: 999,
Enabled: true,
},
"route-2": {
ID: "route-2",
Network: prefix,
NetID: "network-1",
Description: "network-1",
Peer: "peer-2",
NetworkType: 0,
Masquerade: false,
Metric: 999,
Enabled: true,
},
},
}
routes := account.GetPeersRoutes([]*Peer{{Key: "peer-1"}, {Key: "peer-2"}, {Key: "non-existing-peer"}})
assert.Len(t, routes, 2)
routeIDs := make(map[string]struct{}, 2)
for _, r := range routes {
routeIDs[r.ID] = struct{}{}
}
assert.Contains(t, routeIDs, "route-1")
assert.Contains(t, routeIDs, "route-2")
}
func TestAccount_Copy(t *testing.T) {
account := &Account{
Id: "account1",
CreatedBy: "tester",
Domain: "test.com",
DomainCategory: "public",
IsDomainPrimaryAccount: true,
SetupKeys: map[string]*SetupKey{
"setup1": {
Id: "setup1",
AutoGroups: []string{"group1"},
},
},
Network: &Network{
Id: "net1",
},
Peers: map[string]*Peer{
"peer1": {
Key: "key1",
},
},
Users: map[string]*User{
"user1": {
Id: "user1",
Role: UserRoleAdmin,
AutoGroups: []string{"group1"},
},
},
Groups: map[string]*Group{
"group1": {
ID: "group1",
},
},
Rules: map[string]*Rule{
"rule1": {
ID: "rule1",
},
},
Routes: map[string]*route.Route{
"route1": {
ID: "route1",
},
},
NameServerGroups: map[string]*nbdns.NameServerGroup{
"nsGroup1": {
ID: "nsGroup1",
},
},
}
err := hasNilField(account)
if err != nil {
t.Fatal(err)
}
accountCopy := account.Copy()
assert.Equal(t, account, accountCopy, "account copy returned a different value than expected")
}
// hasNilField validates pointers, maps and slices if they are nil
func hasNilField(x interface{}) error {
rv := reflect.ValueOf(x)
rv = rv.Elem()
for i := 0; i < rv.NumField(); i++ {
if f := rv.Field(i); f.IsValid() {
k := f.Kind()
switch k {
case reflect.Ptr:
if f.IsNil() {
return fmt.Errorf("field %s is nil", f.String())
}
case reflect.Map, reflect.Slice:
if f.Len() == 0 || f.IsNil() {
return fmt.Errorf("field %s is nil", f.String())
}
}
}
}
return nil
}
func createManager(t *testing.T) (*DefaultAccountManager, error) { func createManager(t *testing.T) (*DefaultAccountManager, error) {
store, err := createStore(t) store, err := createStore(t)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return BuildManager(store, NewPeersUpdateManager(), nil) return BuildManager(store, NewPeersUpdateManager(), nil, "", "")
} }
func createStore(t *testing.T) (Store, error) { func createStore(t *testing.T) (Store, error) {
dataDir := t.TempDir() dataDir := t.TempDir()
store, err := NewStore(dataDir) store, err := NewFileStore(dataDir)
if err != nil { if err != nil {
return nil, err return nil, err
} }

152
management/server/dns.go Normal file
View File

@@ -0,0 +1,152 @@
package server
import (
"fmt"
"github.com/miekg/dns"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/proto"
log "github.com/sirupsen/logrus"
"strconv"
)
type lookupMap map[string]struct{}
const defaultTTL = 300
func toProtocolDNSConfig(update nbdns.Config) *proto.DNSConfig {
protoUpdate := &proto.DNSConfig{ServiceEnable: update.ServiceEnable}
for _, zone := range update.CustomZones {
protoZone := &proto.CustomZone{Domain: zone.Domain}
for _, record := range zone.Records {
protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{
Name: record.Name,
Type: int64(record.Type),
Class: record.Class,
TTL: int64(record.TTL),
RData: record.RData,
})
}
protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone)
}
for _, nsGroup := range update.NameServerGroups {
protoGroup := &proto.NameServerGroup{
Primary: nsGroup.Primary,
Domains: nsGroup.Domains,
}
for _, ns := range nsGroup.NameServers {
protoNS := &proto.NameServer{
IP: ns.IP.String(),
Port: int64(ns.Port),
NSType: int64(ns.NSType),
}
protoGroup.NameServers = append(protoGroup.NameServers, protoNS)
}
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup)
}
return protoUpdate
}
func getPeersCustomZone(account *Account, dnsDomain string) nbdns.CustomZone {
if dnsDomain == "" {
log.Errorf("no dns domain is set, returning empty zone")
return nbdns.CustomZone{}
}
customZone := nbdns.CustomZone{
Domain: dns.Fqdn(dnsDomain),
}
for _, peer := range account.Peers {
if peer.DNSLabel == "" {
log.Errorf("found a peer with empty dns label. It was probably caused by a invalid character in its name. Peer Name: %s", peer.Name)
continue
}
customZone.Records = append(customZone.Records, nbdns.SimpleRecord{
Name: dns.Fqdn(peer.DNSLabel + "." + dnsDomain),
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: defaultTTL,
RData: peer.IP.String(),
})
}
return customZone
}
func getPeerNSGroups(account *Account, peerID string) []*nbdns.NameServerGroup {
groupList := make(lookupMap)
for groupID, group := range account.Groups {
for _, id := range group.Peers {
if id == peerID {
groupList[groupID] = struct{}{}
break
}
}
}
var peerNSGroups []*nbdns.NameServerGroup
for _, nsGroup := range account.NameServerGroups {
if !nsGroup.Enabled {
continue
}
for _, gID := range nsGroup.Groups {
_, found := groupList[gID]
if found {
peerNSGroups = append(peerNSGroups, nsGroup.Copy())
break
}
}
}
return peerNSGroups
}
func addPeerLabelsToAccount(account *Account, peerLabels lookupMap) {
for _, peer := range account.Peers {
label, err := getPeerHostLabel(peer.Name, peerLabels)
if err != nil {
log.Errorf("got an error while generating a peer host label. Peer name %s, error: %v. Trying with the peer's meta hostname", peer.Name, err)
label, err = getPeerHostLabel(peer.Meta.Hostname, peerLabels)
if err != nil {
log.Errorf("got another error while generating a peer host label with hostname. Peer hostname %s, error: %v. Skiping", peer.Meta.Hostname, err)
continue
}
}
peer.DNSLabel = label
peerLabels[label] = struct{}{}
}
}
func getPeerHostLabel(name string, peerLabels lookupMap) (string, error) {
label, err := nbdns.GetParsedDomainLabel(name)
if err != nil {
return "", err
}
uniqueLabel := getUniqueHostLabel(label, peerLabels)
if uniqueLabel == "" {
return "", fmt.Errorf("couldn't find a unique valid lavel for %s, parsed label %s", name, label)
}
return uniqueLabel, nil
}
// getUniqueHostLabel look for a unique host label, and if doesn't find add a suffix up to 999
func getUniqueHostLabel(name string, peerLabels lookupMap) string {
_, found := peerLabels[name]
if !found {
return name
}
for i := 1; i < 1000; i++ {
nameWithSuffix := name + "-" + strconv.Itoa(i)
_, found = peerLabels[nameWithSuffix]
if !found {
return nameWithSuffix
}
}
return ""
}

View File

@@ -11,6 +11,12 @@ const (
AccountNotFound ErrorType = iota AccountNotFound ErrorType = iota
// PreconditionFailed indicates that some pre-condition for the operation hasn't been fulfilled // PreconditionFailed indicates that some pre-condition for the operation hasn't been fulfilled
PreconditionFailed ErrorType = iota PreconditionFailed ErrorType = iota
// UserNotFound indicates that user wasn't found in the system (or under a given Account)
UserNotFound ErrorType = iota
// PermissionDenied indicates that user has no permissions to view data
PermissionDenied ErrorType = iota
) )
// ErrorType is a type of the Error // ErrorType is a type of the Error

View File

@@ -1,13 +1,12 @@
package server package server
import ( import (
"fmt" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/route"
"net/netip"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
"sync" "sync"
"time"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
@@ -21,29 +20,29 @@ const storeFileName = "store.json"
// FileStore represents an account storage backed by a file persisted to disk // FileStore represents an account storage backed by a file persisted to disk
type FileStore struct { type FileStore struct {
Accounts map[string]*Account Accounts map[string]*Account
SetupKeyId2AccountId map[string]string `json:"-"` SetupKeyID2AccountID map[string]string `json:"-"`
PeerKeyId2AccountId map[string]string `json:"-"` PeerKeyID2AccountID map[string]string `json:"-"`
UserId2AccountId map[string]string `json:"-"` UserID2AccountID map[string]string `json:"-"`
PrivateDomain2AccountId map[string]string `json:"-"` PrivateDomain2AccountID map[string]string `json:"-"`
PeerKeyId2SrcRulesId map[string]map[string]struct{} `json:"-"`
PeerKeyId2DstRulesId map[string]map[string]struct{} `json:"-"`
PeerKeyID2RouteIDs map[string]map[string]struct{} `json:"-"`
AccountPrefix2RouteIDs map[string]map[string][]string `json:"-"`
InstallationID string InstallationID string
// mutex to synchronise Store read/write operations // mutex to synchronise Store read/write operations
mux sync.Mutex `json:"-"` mux sync.Mutex `json:"-"`
storeFile string `json:"-"` storeFile string `json:"-"`
// sync.Mutex indexed by accountID
accountLocks sync.Map `json:"-"`
globalAccountLock sync.Mutex `json:"-"`
} }
type StoredAccount struct{} type StoredAccount struct{}
// NewStore restores a store from the file located in the datadir // NewFileStore restores a store from the file located in the datadir
func NewStore(dataDir string) (*FileStore, error) { func NewFileStore(dataDir string) (*FileStore, error) {
return restore(filepath.Join(dataDir, storeFileName)) return restore(filepath.Join(dataDir, storeFileName))
} }
// restore restores the state of the store from the file. // restore the state of the store from the file.
// Creates a new empty store file if doesn't exist // Creates a new empty store file if doesn't exist
func restore(file string) (*FileStore, error) { func restore(file string) (*FileStore, error) {
if _, err := os.Stat(file); os.IsNotExist(err) { if _, err := os.Stat(file); os.IsNotExist(err) {
@@ -51,14 +50,11 @@ func restore(file string) (*FileStore, error) {
s := &FileStore{ s := &FileStore{
Accounts: make(map[string]*Account), Accounts: make(map[string]*Account),
mux: sync.Mutex{}, mux: sync.Mutex{},
SetupKeyId2AccountId: make(map[string]string), globalAccountLock: sync.Mutex{},
PeerKeyId2AccountId: make(map[string]string), SetupKeyID2AccountID: make(map[string]string),
UserId2AccountId: make(map[string]string), PeerKeyID2AccountID: make(map[string]string),
PrivateDomain2AccountId: make(map[string]string), UserID2AccountID: make(map[string]string),
PeerKeyId2SrcRulesId: make(map[string]map[string]struct{}), PrivateDomain2AccountID: make(map[string]string),
PeerKeyID2RouteIDs: make(map[string]map[string]struct{}),
PeerKeyId2DstRulesId: make(map[string]map[string]struct{}),
AccountPrefix2RouteIDs: make(map[string]map[string][]string),
storeFile: file, storeFile: file,
} }
@@ -77,287 +73,113 @@ func restore(file string) (*FileStore, error) {
store := read.(*FileStore) store := read.(*FileStore)
store.storeFile = file store.storeFile = file
store.SetupKeyId2AccountId = make(map[string]string) store.SetupKeyID2AccountID = make(map[string]string)
store.PeerKeyId2AccountId = make(map[string]string) store.PeerKeyID2AccountID = make(map[string]string)
store.UserId2AccountId = make(map[string]string) store.UserID2AccountID = make(map[string]string)
store.PrivateDomain2AccountId = make(map[string]string) store.PrivateDomain2AccountID = make(map[string]string)
store.PeerKeyId2SrcRulesId = make(map[string]map[string]struct{})
store.PeerKeyId2DstRulesId = make(map[string]map[string]struct{})
store.PeerKeyID2RouteIDs = make(map[string]map[string]struct{})
store.AccountPrefix2RouteIDs = make(map[string]map[string][]string)
for accountId, account := range store.Accounts { for accountID, account := range store.Accounts {
for setupKeyId := range account.SetupKeys { for setupKeyId := range account.SetupKeys {
store.SetupKeyId2AccountId[strings.ToUpper(setupKeyId)] = accountId store.SetupKeyID2AccountID[strings.ToUpper(setupKeyId)] = accountID
}
for _, rule := range account.Rules {
for _, groupID := range rule.Source {
if group, ok := account.Groups[groupID]; ok {
for _, peerID := range group.Peers {
rules := store.PeerKeyId2SrcRulesId[peerID]
if rules == nil {
rules = map[string]struct{}{}
store.PeerKeyId2SrcRulesId[peerID] = rules
}
rules[rule.ID] = struct{}{}
}
}
}
for _, groupID := range rule.Destination {
if group, ok := account.Groups[groupID]; ok {
for _, peerID := range group.Peers {
rules := store.PeerKeyId2DstRulesId[peerID]
if rules == nil {
rules = map[string]struct{}{}
store.PeerKeyId2DstRulesId[peerID] = rules
}
rules[rule.ID] = struct{}{}
}
}
}
} }
for _, peer := range account.Peers { for _, peer := range account.Peers {
store.PeerKeyId2AccountId[peer.Key] = accountId store.PeerKeyID2AccountID[peer.Key] = accountID
// reset all peers to status = Disconnected
if peer.Status != nil && peer.Status.Connected {
peer.Status.Connected = false
}
} }
for _, user := range account.Users { for _, user := range account.Users {
store.UserId2AccountId[user.Id] = accountId store.UserID2AccountID[user.Id] = accountID
} }
for _, user := range account.Users { for _, user := range account.Users {
store.UserId2AccountId[user.Id] = accountId store.UserID2AccountID[user.Id] = accountID
}
for _, route := range account.Routes {
if route.Peer == "" {
continue
}
if store.PeerKeyID2RouteIDs[route.Peer] == nil {
store.PeerKeyID2RouteIDs[route.Peer] = make(map[string]struct{})
}
store.PeerKeyID2RouteIDs[route.Peer][route.ID] = struct{}{}
if store.AccountPrefix2RouteIDs[account.Id] == nil {
store.AccountPrefix2RouteIDs[account.Id] = make(map[string][]string)
}
if _, ok := store.AccountPrefix2RouteIDs[account.Id][route.Network.String()]; !ok {
store.AccountPrefix2RouteIDs[account.Id][route.Network.String()] = make([]string, 0)
}
store.AccountPrefix2RouteIDs[account.Id][route.Network.String()] = append(
store.AccountPrefix2RouteIDs[account.Id][route.Network.String()],
route.ID,
)
} }
if account.Domain != "" && account.DomainCategory == PrivateCategory && if account.Domain != "" && account.DomainCategory == PrivateCategory &&
account.IsDomainPrimaryAccount { account.IsDomainPrimaryAccount {
store.PrivateDomain2AccountId[account.Domain] = accountId store.PrivateDomain2AccountID[account.Domain] = accountID
} }
// for data migration. Can be removed once most base will be with labels
existingLabels := account.getPeerDNSLabels()
if len(existingLabels) != len(account.Peers) {
addPeerLabelsToAccount(account, existingLabels)
}
}
// we need this persist to apply changes we made to account.Peers (we set them to Disconnected)
err = store.persist(store.storeFile)
if err != nil {
return nil, err
} }
return store, nil return store, nil
} }
// persist persists account data to a file // persist account data to a file
// It is recommended to call it with locking FileStore.mux // It is recommended to call it with locking FileStore.mux
func (s *FileStore) persist(file string) error { func (s *FileStore) persist(file string) error {
return util.WriteJson(file, s) return util.WriteJson(file, s)
} }
// SavePeer saves updated peer // AcquireGlobalLock acquires global lock across all the accounts and returns a function that releases the lock
func (s *FileStore) SavePeer(accountId string, peer *Peer) error { func (s *FileStore) AcquireGlobalLock() (unlock func()) {
s.mux.Lock() log.Debugf("acquiring global lock")
defer s.mux.Unlock() start := time.Now()
s.globalAccountLock.Lock()
account, err := s.GetAccount(accountId) unlock = func() {
if err != nil { s.globalAccountLock.Unlock()
return err log.Debugf("released global lock in %v", time.Since(start))
} }
// if it is new peer, add it to default 'All' group return unlock
allGroup, err := account.GetGroupAll()
if err != nil {
return err
}
ind := -1
for i, pid := range allGroup.Peers {
if pid == peer.Key {
ind = i
break
}
}
if ind < 0 {
allGroup.Peers = append(allGroup.Peers, peer.Key)
}
account.Peers[peer.Key] = peer
return s.persist(s.storeFile)
} }
// DeletePeer deletes peer from the Store // AcquireAccountLock acquires account lock and returns a function that releases the lock
func (s *FileStore) DeletePeer(accountId string, peerKey string) (*Peer, error) { func (s *FileStore) AcquireAccountLock(accountID string) (unlock func()) {
s.mux.Lock() log.Debugf("acquiring lock for account %s", accountID)
defer s.mux.Unlock() start := time.Now()
value, _ := s.accountLocks.LoadOrStore(accountID, &sync.Mutex{})
mtx := value.(*sync.Mutex)
mtx.Lock()
account, err := s.GetAccount(accountId) unlock = func() {
if err != nil { mtx.Unlock()
return nil, err log.Debugf("released lock for account %s in %v", accountID, time.Since(start))
} }
peer := account.Peers[peerKey] return unlock
if peer == nil {
return nil, status.Errorf(codes.NotFound, "peer not found")
}
peerRoutes := s.PeerKeyID2RouteIDs[peerKey]
delete(account.Peers, peerKey)
delete(s.PeerKeyId2AccountId, peerKey)
delete(s.PeerKeyId2DstRulesId, peerKey)
delete(s.PeerKeyId2SrcRulesId, peerKey)
delete(s.PeerKeyID2RouteIDs, peerKey)
// cleanup groups
for _, g := range account.Groups {
var peers []string
for _, p := range g.Peers {
if p != peerKey {
peers = append(peers, p)
}
}
g.Peers = peers
}
for routeID := range peerRoutes {
account.Routes[routeID].Enabled = false
account.Routes[routeID].Peer = ""
}
err = s.persist(s.storeFile)
if err != nil {
return nil, err
}
return peer, nil
} }
// GetPeer returns a peer from a Store
func (s *FileStore) GetPeer(peerKey string) (*Peer, error) {
s.mux.Lock()
defer s.mux.Unlock()
accountId, accountIdFound := s.PeerKeyId2AccountId[peerKey]
if !accountIdFound {
return nil, status.Errorf(codes.NotFound, "peer not found")
}
account, err := s.GetAccount(accountId)
if err != nil {
return nil, err
}
if peer, ok := account.Peers[peerKey]; ok {
return peer, nil
}
return nil, status.Errorf(codes.NotFound, "peer not found")
}
// SaveAccount updates an existing account or adds a new one
func (s *FileStore) SaveAccount(account *Account) error { func (s *FileStore) SaveAccount(account *Account) error {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
accountCopy := account.Copy()
// todo will override, handle existing keys // todo will override, handle existing keys
s.Accounts[account.Id] = account s.Accounts[accountCopy.Id] = accountCopy
// todo check that account.Id and keyId are not exist already // todo check that account.Id and keyId are not exist already
// because if keyId exists for other accounts this can be bad // because if keyId exists for other accounts this can be bad
for keyId := range account.SetupKeys { for keyID := range accountCopy.SetupKeys {
s.SetupKeyId2AccountId[strings.ToUpper(keyId)] = account.Id s.SetupKeyID2AccountID[strings.ToUpper(keyID)] = accountCopy.Id
} }
// enforce peer to account index and delete peer to route indexes for rebuild // enforce peer to account index and delete peer to route indexes for rebuild
for _, peer := range account.Peers { for _, peer := range accountCopy.Peers {
s.PeerKeyId2AccountId[peer.Key] = account.Id s.PeerKeyID2AccountID[peer.Key] = accountCopy.Id
delete(s.PeerKeyID2RouteIDs, peer.Key)
} }
delete(s.AccountPrefix2RouteIDs, account.Id) for _, user := range accountCopy.Users {
s.UserID2AccountID[user.Id] = accountCopy.Id
// remove all peers related to account from rules indexes
cleanIDs := make([]string, 0)
for key := range s.PeerKeyId2SrcRulesId {
if accountID, ok := s.PeerKeyId2AccountId[key]; ok && accountID == account.Id {
cleanIDs = append(cleanIDs, key)
}
}
for _, key := range cleanIDs {
delete(s.PeerKeyId2SrcRulesId, key)
}
cleanIDs = cleanIDs[:0]
for key := range s.PeerKeyId2DstRulesId {
if accountID, ok := s.PeerKeyId2AccountId[key]; ok && accountID == account.Id {
cleanIDs = append(cleanIDs, key)
}
}
for _, key := range cleanIDs {
delete(s.PeerKeyId2DstRulesId, key)
} }
// rebuild rule indexes if accountCopy.DomainCategory == PrivateCategory && accountCopy.IsDomainPrimaryAccount {
for _, rule := range account.Rules { s.PrivateDomain2AccountID[accountCopy.Domain] = accountCopy.Id
for _, gid := range rule.Source {
g, ok := account.Groups[gid]
if !ok {
break
}
for _, pid := range g.Peers {
rules := s.PeerKeyId2SrcRulesId[pid]
if rules == nil {
rules = map[string]struct{}{}
s.PeerKeyId2SrcRulesId[pid] = rules
}
rules[rule.ID] = struct{}{}
}
}
for _, gid := range rule.Destination {
g, ok := account.Groups[gid]
if !ok {
break
}
for _, pid := range g.Peers {
rules := s.PeerKeyId2DstRulesId[pid]
if rules == nil {
rules = map[string]struct{}{}
s.PeerKeyId2DstRulesId[pid] = rules
}
rules[rule.ID] = struct{}{}
}
}
}
for _, route := range account.Routes {
if route.Peer == "" {
continue
}
if s.PeerKeyID2RouteIDs[route.Peer] == nil {
s.PeerKeyID2RouteIDs[route.Peer] = make(map[string]struct{})
}
s.PeerKeyID2RouteIDs[route.Peer][route.ID] = struct{}{}
if s.AccountPrefix2RouteIDs[account.Id] == nil {
s.AccountPrefix2RouteIDs[account.Id] = make(map[string][]string)
}
if _, ok := s.AccountPrefix2RouteIDs[account.Id][route.Network.String()]; !ok {
s.AccountPrefix2RouteIDs[account.Id][route.Network.String()] = make([]string, 0)
}
s.AccountPrefix2RouteIDs[account.Id][route.Network.String()] = append(
s.AccountPrefix2RouteIDs[account.Id][route.Network.String()],
route.ID,
)
}
for _, user := range account.Users {
s.UserId2AccountId[user.Id] = account.Id
}
if account.DomainCategory == PrivateCategory && account.IsDomainPrimaryAccount {
s.PrivateDomain2AccountId[account.Domain] = account.Id
} }
return s.persist(s.storeFile) return s.persist(s.storeFile)
@@ -365,53 +187,41 @@ func (s *FileStore) SaveAccount(account *Account) error {
// GetAccountByPrivateDomain returns account by private domain // GetAccountByPrivateDomain returns account by private domain
func (s *FileStore) GetAccountByPrivateDomain(domain string) (*Account, error) { func (s *FileStore) GetAccountByPrivateDomain(domain string) (*Account, error) {
accountId, accountIdFound := s.PrivateDomain2AccountId[strings.ToLower(domain)] s.mux.Lock()
if !accountIdFound { defer s.mux.Unlock()
accountID, accountIDFound := s.PrivateDomain2AccountID[strings.ToLower(domain)]
if !accountIDFound {
return nil, status.Errorf( return nil, status.Errorf(
codes.NotFound, codes.NotFound,
"provided domain is not registered or is not private", "account not found: provided domain is not registered or is not private",
) )
} }
account, err := s.GetAccount(accountId) account, err := s.getAccount(accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return account, nil return account.Copy(), nil
} }
// GetAccountBySetupKey returns account by setup key id // GetAccountBySetupKey returns account by setup key id
func (s *FileStore) GetAccountBySetupKey(setupKey string) (*Account, error) { func (s *FileStore) GetAccountBySetupKey(setupKey string) (*Account, error) {
accountId, accountIdFound := s.SetupKeyId2AccountId[strings.ToUpper(setupKey)]
if !accountIdFound {
return nil, status.Errorf(codes.NotFound, "provided setup key doesn't exists")
}
account, err := s.GetAccount(accountId)
if err != nil {
return nil, err
}
return account, nil
}
// GetAccountPeers returns account peers
func (s *FileStore) GetAccountPeers(accountId string) ([]*Peer, error) {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
account, err := s.GetAccount(accountId) accountID, accountIDFound := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)]
if !accountIDFound {
return nil, status.Errorf(codes.NotFound, "account not found: provided setup key doesn't exists")
}
account, err := s.getAccount(accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var peers []*Peer return account.Copy(), nil
for _, peer := range account.Peers {
peers = append(peers, peer)
}
return peers, nil
} }
// GetAllAccounts returns all accounts // GetAllAccounts returns all accounts
@@ -425,9 +235,9 @@ func (s *FileStore) GetAllAccounts() (all []*Account) {
return all return all
} }
// GetAccount returns an account for id // getAccount returns a reference to the Account. Should not return a copy.
func (s *FileStore) GetAccount(accountId string) (*Account, error) { func (s *FileStore) getAccount(accountID string) (*Account, error) {
account, accountFound := s.Accounts[accountId] account, accountFound := s.Accounts[accountID]
if !accountFound { if !accountFound {
return nil, status.Errorf(codes.NotFound, "account not found") return nil, status.Errorf(codes.NotFound, "account not found")
} }
@@ -435,139 +245,53 @@ func (s *FileStore) GetAccount(accountId string) (*Account, error) {
return account, nil return account, nil
} }
// GetUserAccount returns a user account // GetAccount returns an account for ID
func (s *FileStore) GetUserAccount(userId string) (*Account, error) { func (s *FileStore) GetAccount(accountID string) (*Account, error) {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
accountId, accountIdFound := s.UserId2AccountId[userId] account, err := s.getAccount(accountID)
if !accountIdFound { if err != nil {
return nil, err
}
return account.Copy(), nil
}
// GetAccountByUser returns a user account
func (s *FileStore) GetAccountByUser(userID string) (*Account, error) {
s.mux.Lock()
defer s.mux.Unlock()
accountID, accountIDFound := s.UserID2AccountID[userID]
if !accountIDFound {
return nil, status.Errorf(codes.NotFound, "account not found") return nil, status.Errorf(codes.NotFound, "account not found")
} }
return s.GetAccount(accountId) account, err := s.getAccount(accountID)
if err != nil {
return nil, err
}
return account.Copy(), nil
} }
func (s *FileStore) getPeerAccount(peerKey string) (*Account, error) { // GetAccountByPeerPubKey returns an account for a given peer WireGuard public key
accountId, accountIdFound := s.PeerKeyId2AccountId[peerKey] func (s *FileStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) {
if !accountIdFound { s.mux.Lock()
defer s.mux.Unlock()
accountID, accountIDFound := s.PeerKeyID2AccountID[peerKey]
if !accountIDFound {
return nil, status.Errorf(codes.NotFound, "Provided peer key doesn't exists %s", peerKey) return nil, status.Errorf(codes.NotFound, "Provided peer key doesn't exists %s", peerKey)
} }
return s.GetAccount(accountId) account, err := s.getAccount(accountID)
}
// GetPeerAccount returns user account if exists
func (s *FileStore) GetPeerAccount(peerKey string) (*Account, error) {
s.mux.Lock()
defer s.mux.Unlock()
return s.getPeerAccount(peerKey)
}
// GetPeerSrcRules return list of source rules for peer
func (s *FileStore) GetPeerSrcRules(accountId, peerKey string) ([]*Rule, error) {
s.mux.Lock()
defer s.mux.Unlock()
account, err := s.GetAccount(accountId)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ruleIDs, ok := s.PeerKeyId2SrcRulesId[peerKey] return account.Copy(), nil
if !ok {
return nil, fmt.Errorf("no rules for peer: %v", ruleIDs)
}
rules := []*Rule{}
for id := range ruleIDs {
rule, ok := account.Rules[id]
if ok {
rules = append(rules, rule)
}
}
return rules, nil
}
// GetPeerDstRules return list of destination rules for peer
func (s *FileStore) GetPeerDstRules(accountId, peerKey string) ([]*Rule, error) {
s.mux.Lock()
defer s.mux.Unlock()
account, err := s.GetAccount(accountId)
if err != nil {
return nil, err
}
ruleIDs, ok := s.PeerKeyId2DstRulesId[peerKey]
if !ok {
return nil, fmt.Errorf("no rules for peer: %v", ruleIDs)
}
rules := []*Rule{}
for id := range ruleIDs {
rule, ok := account.Rules[id]
if ok {
rules = append(rules, rule)
}
}
return rules, nil
}
// GetPeerRoutes return list of routes for peer
func (s *FileStore) GetPeerRoutes(peerKey string) ([]*route.Route, error) {
s.mux.Lock()
defer s.mux.Unlock()
account, err := s.getPeerAccount(peerKey)
if err != nil {
return nil, err
}
var routes []*route.Route
routeIDs, ok := s.PeerKeyID2RouteIDs[peerKey]
if !ok {
return routes, nil
}
for id := range routeIDs {
route, found := account.Routes[id]
if found {
routes = append(routes, route)
}
}
return routes, nil
}
// GetRoutesByPrefix return list of routes by account and route prefix
func (s *FileStore) GetRoutesByPrefix(accountID string, prefix netip.Prefix) ([]*route.Route, error) {
s.mux.Lock()
defer s.mux.Unlock()
account, err := s.GetAccount(accountID)
if err != nil {
return nil, err
}
routeIDs, ok := s.AccountPrefix2RouteIDs[accountID][prefix.String()]
if !ok {
return nil, status.Errorf(codes.NotFound, "no routes for prefix: %v", prefix.String())
}
var routes []*route.Route
for _, id := range routeIDs {
route, found := account.Routes[id]
if found {
routes = append(routes, route)
}
}
return routes, nil
} }
// GetInstallationID returns the installation ID from the store // GetInstallationID returns the installation ID from the store
@@ -576,11 +300,42 @@ func (s *FileStore) GetInstallationID() string {
} }
// SaveInstallationID saves the installation ID // SaveInstallationID saves the installation ID
func (s *FileStore) SaveInstallationID(id string) error { func (s *FileStore) SaveInstallationID(ID string) error {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
s.InstallationID = id s.InstallationID = ID
return s.persist(s.storeFile)
}
// SavePeerStatus stores the PeerStatus in memory. It doesn't attempt to persist data to speed up things.
// PeerStatus will be saved eventually when some other changes occur.
func (s *FileStore) SavePeerStatus(accountID, peerKey string, peerStatus PeerStatus) error {
s.mux.Lock()
defer s.mux.Unlock()
account, err := s.getAccount(accountID)
if err != nil {
return err
}
peer := account.Peers[peerKey]
if peer == nil {
return status.Errorf(codes.NotFound, "peer %s not found", peerKey)
}
peer.Status = &peerStatus
return nil
}
// Close the FileStore persisting data to disk
func (s *FileStore) Close() error {
s.mux.Lock()
defer s.mux.Unlock()
log.Infof("closing FileStore")
return s.persist(s.storeFile) return s.persist(s.storeFile)
} }

View File

@@ -2,6 +2,7 @@ package server
import ( import (
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"net" "net"
"path/filepath" "path/filepath"
@@ -9,6 +10,10 @@ import (
"time" "time"
) )
type accounts struct {
Accounts map[string]*Account
}
func TestNewStore(t *testing.T) { func TestNewStore(t *testing.T) {
store := newStore(t) store := newStore(t)
@@ -16,16 +21,16 @@ func TestNewStore(t *testing.T) {
t.Errorf("expected to create a new empty Accounts map when creating a new FileStore") t.Errorf("expected to create a new empty Accounts map when creating a new FileStore")
} }
if store.SetupKeyId2AccountId == nil || len(store.SetupKeyId2AccountId) != 0 { if store.SetupKeyID2AccountID == nil || len(store.SetupKeyID2AccountID) != 0 {
t.Errorf("expected to create a new empty SetupKeyId2AccountId map when creating a new FileStore") t.Errorf("expected to create a new empty SetupKeyID2AccountID map when creating a new FileStore")
} }
if store.PeerKeyId2AccountId == nil || len(store.PeerKeyId2AccountId) != 0 { if store.PeerKeyID2AccountID == nil || len(store.PeerKeyID2AccountID) != 0 {
t.Errorf("expected to create a new empty PeerKeyId2AccountId map when creating a new FileStore") t.Errorf("expected to create a new empty PeerKeyID2AccountID map when creating a new FileStore")
} }
if store.UserId2AccountId == nil || len(store.UserId2AccountId) != 0 { if store.UserID2AccountID == nil || len(store.UserID2AccountID) != 0 {
t.Errorf("expected to create a new empty UserId2AccountId map when creating a new FileStore") t.Errorf("expected to create a new empty UserID2AccountID map when creating a new FileStore")
} }
} }
@@ -55,16 +60,16 @@ func TestSaveAccount(t *testing.T) {
t.Errorf("expecting Account to be stored after SaveAccount()") t.Errorf("expecting Account to be stored after SaveAccount()")
} }
if store.PeerKeyId2AccountId["peerkey"] == "" { if store.PeerKeyID2AccountID["peerkey"] == "" {
t.Errorf("expecting PeerKeyId2AccountId index updated after SaveAccount()") t.Errorf("expecting PeerKeyID2AccountID index updated after SaveAccount()")
} }
if store.UserId2AccountId["testuser"] == "" { if store.UserID2AccountID["testuser"] == "" {
t.Errorf("expecting UserId2AccountId index updated after SaveAccount()") t.Errorf("expecting UserID2AccountID index updated after SaveAccount()")
} }
if store.SetupKeyId2AccountId[setupKey.Key] == "" { if store.SetupKeyID2AccountID[setupKey.Key] == "" {
t.Errorf("expecting SetupKeyId2AccountId index updated after SaveAccount()") t.Errorf("expecting SetupKeyID2AccountID index updated after SaveAccount()")
} }
} }
@@ -88,7 +93,7 @@ func TestStore(t *testing.T) {
return return
} }
restored, err := NewStore(store.storeFile) restored, err := NewFileStore(store.storeFile)
if err != nil { if err != nil {
return return
} }
@@ -124,7 +129,7 @@ func TestRestore(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
store, err := NewStore(storeDir) store, err := NewFileStore(storeDir)
if err != nil { if err != nil {
return return
} }
@@ -141,11 +146,11 @@ func TestRestore(t *testing.T) {
require.NotNil(t, account.SetupKeys["A2C8E62B-38F5-4553-B31E-DD66C696CEBB"], "failed to restore a FileStore file - missing Account SetupKey A2C8E62B-38F5-4553-B31E-DD66C696CEBB") require.NotNil(t, account.SetupKeys["A2C8E62B-38F5-4553-B31E-DD66C696CEBB"], "failed to restore a FileStore file - missing Account SetupKey A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
require.Len(t, store.UserId2AccountId, 2, "failed to restore a FileStore wrong UserId2AccountId mapping length") require.Len(t, store.UserID2AccountID, 2, "failed to restore a FileStore wrong UserID2AccountID mapping length")
require.Len(t, store.SetupKeyId2AccountId, 1, "failed to restore a FileStore wrong SetupKeyId2AccountId mapping length") require.Len(t, store.SetupKeyID2AccountID, 1, "failed to restore a FileStore wrong SetupKeyID2AccountID mapping length")
require.Len(t, store.PrivateDomain2AccountId, 1, "failed to restore a FileStore wrong PrivateDomain2AccountId mapping length") require.Len(t, store.PrivateDomain2AccountID, 1, "failed to restore a FileStore wrong PrivateDomain2AccountID mapping length")
} }
func TestGetAccountByPrivateDomain(t *testing.T) { func TestGetAccountByPrivateDomain(t *testing.T) {
@@ -156,7 +161,7 @@ func TestGetAccountByPrivateDomain(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
store, err := NewStore(storeDir) store, err := NewFileStore(storeDir)
if err != nil { if err != nil {
return return
} }
@@ -171,8 +176,101 @@ func TestGetAccountByPrivateDomain(t *testing.T) {
require.Error(t, err, "should return error on domain lookup") require.Error(t, err, "should return error on domain lookup")
} }
func TestFileStore_GetAccount(t *testing.T) {
storeDir := t.TempDir()
storeFile := filepath.Join(storeDir, "store.json")
err := util.CopyFileContents("testdata/store.json", storeFile)
if err != nil {
t.Fatal(err)
}
accounts := &accounts{}
_, err = util.ReadJson(storeFile, accounts)
if err != nil {
t.Fatal(err)
}
store, err := NewFileStore(storeDir)
if err != nil {
t.Fatal(err)
}
expected := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"]
if expected == nil {
t.Fatalf("expected account doesn't exist")
}
account, err := store.GetAccount(expected.Id)
if err != nil {
t.Fatal(err)
}
assert.Equal(t, expected.IsDomainPrimaryAccount, account.IsDomainPrimaryAccount)
assert.Equal(t, expected.DomainCategory, account.DomainCategory)
assert.Equal(t, expected.Domain, account.Domain)
assert.Equal(t, expected.CreatedBy, account.CreatedBy)
assert.Equal(t, expected.Network.Id, account.Network.Id)
assert.Len(t, account.Peers, len(expected.Peers))
assert.Len(t, account.Users, len(expected.Users))
assert.Len(t, account.SetupKeys, len(expected.SetupKeys))
assert.Len(t, account.Rules, len(expected.Rules))
assert.Len(t, account.Routes, len(expected.Routes))
assert.Len(t, account.NameServerGroups, len(expected.NameServerGroups))
}
func TestFileStore_SavePeerStatus(t *testing.T) {
storeDir := t.TempDir()
err := util.CopyFileContents("testdata/store.json", filepath.Join(storeDir, "store.json"))
if err != nil {
t.Fatal(err)
}
store, err := NewFileStore(storeDir)
if err != nil {
return
}
account, err := store.getAccount("bf1c8084-ba50-4ce7-9439-34653001fc3b")
if err != nil {
t.Fatal(err)
}
// save status of non-existing peer
newStatus := PeerStatus{Connected: true, LastSeen: time.Now()}
err = store.SavePeerStatus(account.Id, "non-existing-peer", newStatus)
assert.Error(t, err)
// save new status of existing peer
account.Peers["testpeer"] = &Peer{
Key: "peerkey",
SetupKey: "peerkeysetupkey",
IP: net.IP{127, 0, 0, 1},
Meta: PeerSystemMeta{},
Name: "peer name",
Status: &PeerStatus{Connected: false, LastSeen: time.Now()},
}
err = store.SaveAccount(account)
if err != nil {
t.Fatal(err)
}
err = store.SavePeerStatus(account.Id, "testpeer", newStatus)
if err != nil {
t.Fatal(err)
}
account, err = store.getAccount(account.Id)
if err != nil {
t.Fatal(err)
}
actual := account.Peers["testpeer"].Status
assert.Equal(t, newStatus, *actual)
}
func newStore(t *testing.T) *FileStore { func newStore(t *testing.T) *FileStore {
store, err := NewStore(t.TempDir()) store, err := NewFileStore(t.TempDir())
if err != nil { if err != nil {
t.Errorf("failed creating a new store") t.Errorf("failed creating a new store")
} }

View File

@@ -47,8 +47,9 @@ func (g *Group) Copy() *Group {
// GetGroup object of the peers // GetGroup object of the peers
func (am *DefaultAccountManager) GetGroup(accountID, groupID string) (*Group, error) { func (am *DefaultAccountManager) GetGroup(accountID, groupID string) (*Group, error) {
am.mux.Lock()
defer am.mux.Unlock() unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
@@ -65,8 +66,9 @@ func (am *DefaultAccountManager) GetGroup(accountID, groupID string) (*Group, er
// SaveGroup object of the peers // SaveGroup object of the peers
func (am *DefaultAccountManager) SaveGroup(accountID string, group *Group) error { func (am *DefaultAccountManager) SaveGroup(accountID string, group *Group) error {
am.mux.Lock()
defer am.mux.Unlock() unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
@@ -86,8 +88,9 @@ func (am *DefaultAccountManager) SaveGroup(accountID string, group *Group) error
// UpdateGroup updates a group using a list of operations // UpdateGroup updates a group using a list of operations
func (am *DefaultAccountManager) UpdateGroup(accountID string, func (am *DefaultAccountManager) UpdateGroup(accountID string,
groupID string, operations []GroupUpdateOperation) (*Group, error) { groupID string, operations []GroupUpdateOperation) (*Group, error) {
am.mux.Lock()
defer am.mux.Unlock() unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
@@ -135,8 +138,9 @@ func (am *DefaultAccountManager) UpdateGroup(accountID string,
// DeleteGroup object of the peers // DeleteGroup object of the peers
func (am *DefaultAccountManager) DeleteGroup(accountID, groupID string) error { func (am *DefaultAccountManager) DeleteGroup(accountID, groupID string) error {
am.mux.Lock()
defer am.mux.Unlock() unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
@@ -155,8 +159,9 @@ func (am *DefaultAccountManager) DeleteGroup(accountID, groupID string) error {
// ListGroups objects of the peers // ListGroups objects of the peers
func (am *DefaultAccountManager) ListGroups(accountID string) ([]*Group, error) { func (am *DefaultAccountManager) ListGroups(accountID string) ([]*Group, error) {
am.mux.Lock()
defer am.mux.Unlock() unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
@@ -173,8 +178,9 @@ func (am *DefaultAccountManager) ListGroups(accountID string) ([]*Group, error)
// GroupAddPeer appends peer to the group // GroupAddPeer appends peer to the group
func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerKey string) error { func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerKey string) error {
am.mux.Lock()
defer am.mux.Unlock() unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
@@ -207,8 +213,9 @@ func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerKey string
// GroupDeletePeer removes peer from the group // GroupDeletePeer removes peer from the group
func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerKey string) error { func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerKey string) error {
am.mux.Lock()
defer am.mux.Unlock() unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
@@ -235,8 +242,9 @@ func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerKey str
// GroupListPeers returns list of the peers from the group // GroupListPeers returns list of the peers from the group
func (am *DefaultAccountManager) GroupListPeers(accountID, groupID string) ([]*Peer, error) { func (am *DefaultAccountManager) GroupListPeers(accountID, groupID string) ([]*Peer, error) {
am.mux.Lock()
defer am.mux.Unlock() unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {

View File

@@ -3,7 +3,7 @@ package server
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/management/server/telemetry"
gPeer "google.golang.org/grpc/peer" gPeer "google.golang.org/grpc/peer"
"strings" "strings"
"time" "time"
@@ -30,10 +30,12 @@ type GRPCServer struct {
config *Config config *Config
turnCredentialsManager TURNCredentialsManager turnCredentialsManager TURNCredentialsManager
jwtMiddleware *middleware.JWTMiddleware jwtMiddleware *middleware.JWTMiddleware
appMetrics telemetry.AppMetrics
} }
// NewServer creates a new Management server // NewServer creates a new Management server
func NewServer(config *Config, accountManager AccountManager, peersUpdateManager *PeersUpdateManager, turnCredentialsManager TURNCredentialsManager) (*GRPCServer, error) { func NewServer(config *Config, accountManager AccountManager, peersUpdateManager *PeersUpdateManager,
turnCredentialsManager TURNCredentialsManager, appMetrics telemetry.AppMetrics) (*GRPCServer, error) {
key, err := wgtypes.GeneratePrivateKey() key, err := wgtypes.GeneratePrivateKey()
if err != nil { if err != nil {
return nil, err return nil, err
@@ -53,6 +55,16 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager
log.Debug("unable to use http config to create new jwt middleware") log.Debug("unable to use http config to create new jwt middleware")
} }
if appMetrics != nil {
// update gauge based on number of connected peers which is equal to open gRPC streams
err = appMetrics.GRPCMetrics().RegisterConnectedStreams(func() int64 {
return int64(len(peersUpdateManager.peerChannels))
})
if err != nil {
return nil, err
}
}
return &GRPCServer{ return &GRPCServer{
wgKey: key, wgKey: key,
// peerKey -> event channel // peerKey -> event channel
@@ -61,11 +73,15 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager
config: config, config: config,
turnCredentialsManager: turnCredentialsManager, turnCredentialsManager: turnCredentialsManager,
jwtMiddleware: jwtMiddleware, jwtMiddleware: jwtMiddleware,
appMetrics: appMetrics,
}, nil }, nil
} }
func (s *GRPCServer) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.ServerKeyResponse, error) { func (s *GRPCServer) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.ServerKeyResponse, error) {
// todo introduce something more meaningful with the key expiration/rotation // todo introduce something more meaningful with the key expiration/rotation
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountGetKeyRequest()
}
now := time.Now().Add(24 * time.Hour) now := time.Now().Add(24 * time.Hour)
secs := int64(now.Second()) secs := int64(now.Second())
nanos := int32(now.Nanosecond()) nanos := int32(now.Nanosecond())
@@ -80,6 +96,9 @@ func (s *GRPCServer) GetServerKey(ctx context.Context, req *proto.Empty) (*proto
// Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and // Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and
// notifies the connected peer of any updates (e.g. new peers under the same account) // notifies the connected peer of any updates (e.g. new peers under the same account)
func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error { func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error {
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountSyncRequest()
}
p, ok := gRPCPeer.FromContext(srv.Context()) p, ok := gRPCPeer.FromContext(srv.Context())
if ok { if ok {
log.Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, p.Addr.String()) log.Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, p.Addr.String())
@@ -233,17 +252,14 @@ func (s *GRPCServer) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest)
if err != nil { if err != nil {
return nil, status.Errorf(codes.Internal, "unable to fetch network map after registering peer, error: %v", err) return nil, status.Errorf(codes.Internal, "unable to fetch network map after registering peer, error: %v", err)
} }
// notify other peers of our registration // notify other peers of our registration
for _, remotePeer := range networkMap.Peers { for _, remotePeer := range networkMap.Peers {
// exclude notified peer and add ourselves remotePeerNetworkMap, err := s.accountManager.GetNetworkMap(remotePeer.Key)
peersToSend := []*Peer{peer} if err != nil {
for _, p := range networkMap.Peers { return nil, status.Errorf(codes.Internal, "unable to fetch network map after registering peer, error: %v", err)
if remotePeer.Key != p.Key {
peersToSend = append(peersToSend, p)
}
} }
update := toSyncResponse(s.config, remotePeer, peersToSend, networkMap.Routes, nil, networkMap.Network.CurrentSerial(), networkMap.Network)
update := toSyncResponse(s.config, remotePeer, nil, remotePeerNetworkMap)
err = s.peersUpdateManager.SendUpdate(remotePeer.Key, &UpdateMessage{Update: update}) err = s.peersUpdateManager.SendUpdate(remotePeer.Key, &UpdateMessage{Update: update})
if err != nil { if err != nil {
// todo rethink if we should keep this return // todo rethink if we should keep this return
@@ -259,6 +275,9 @@ func (s *GRPCServer) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest)
// In case it isn't, the endpoint checks whether setup key is provided within the request and tries to register a peer. // In case it isn't, the endpoint checks whether setup key is provided within the request and tries to register a peer.
// In case of the successful registration login is also successful // In case of the successful registration login is also successful
func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountLoginRequest()
}
p, ok := gRPCPeer.FromContext(ctx) p, ok := gRPCPeer.FromContext(ctx)
if ok { if ok {
log.Debugf("Login request from peer [%s] [%s]", req.WgPubKey, p.Addr.String()) log.Debugf("Login request from peer [%s] [%s]", req.WgPubKey, p.Addr.String())
@@ -370,6 +389,9 @@ func ToResponseProto(configProto Protocol) proto.HostConfig_Protocol {
} }
func toWiretrusteeConfig(config *Config, turnCredentials *TURNCredentials) *proto.WiretrusteeConfig { func toWiretrusteeConfig(config *Config, turnCredentials *TURNCredentials) *proto.WiretrusteeConfig {
if config == nil {
return nil
}
var stuns []*proto.HostConfig var stuns []*proto.HostConfig
for _, stun := range config.Stuns { for _, stun := range config.Stuns {
stuns = append(stuns, &proto.HostConfig{ stuns = append(stuns, &proto.HostConfig{
@@ -428,14 +450,16 @@ func toRemotePeerConfig(peers []*Peer) []*proto.RemotePeerConfig {
return remotePeers return remotePeers
} }
func toSyncResponse(config *Config, peer *Peer, peers []*Peer, routes []*route.Route, turnCredentials *TURNCredentials, serial uint64, network *Network) *proto.SyncResponse { func toSyncResponse(config *Config, peer *Peer, turnCredentials *TURNCredentials, networkMap *NetworkMap) *proto.SyncResponse {
wtConfig := toWiretrusteeConfig(config, turnCredentials) wtConfig := toWiretrusteeConfig(config, turnCredentials)
pConfig := toPeerConfig(peer, network) pConfig := toPeerConfig(peer, networkMap.Network)
remotePeers := toRemotePeerConfig(peers) remotePeers := toRemotePeerConfig(networkMap.Peers)
routesUpdate := toProtocolRoutes(routes) routesUpdate := toProtocolRoutes(networkMap.Routes)
dnsUpdate := toProtocolDNSConfig(networkMap.DNSConfig)
return &proto.SyncResponse{ return &proto.SyncResponse{
WiretrusteeConfig: wtConfig, WiretrusteeConfig: wtConfig,
@@ -443,11 +467,12 @@ func toSyncResponse(config *Config, peer *Peer, peers []*Peer, routes []*route.R
RemotePeers: remotePeers, RemotePeers: remotePeers,
RemotePeersIsEmpty: len(remotePeers) == 0, RemotePeersIsEmpty: len(remotePeers) == 0,
NetworkMap: &proto.NetworkMap{ NetworkMap: &proto.NetworkMap{
Serial: serial, Serial: networkMap.Network.CurrentSerial(),
PeerConfig: pConfig, PeerConfig: pConfig,
RemotePeers: remotePeers, RemotePeers: remotePeers,
RemotePeersIsEmpty: len(remotePeers) == 0, RemotePeersIsEmpty: len(remotePeers) == 0,
Routes: routesUpdate, Routes: routesUpdate,
DNSConfig: dnsUpdate,
}, },
} }
} }
@@ -473,7 +498,7 @@ func (s *GRPCServer) sendInitialSync(peerKey wgtypes.Key, peer *Peer, srv proto.
} else { } else {
turnCredentials = nil turnCredentials = nil
} }
plainResp := toSyncResponse(s.config, peer, networkMap.Peers, networkMap.Routes, turnCredentials, networkMap.Network.CurrentSerial(), networkMap.Network) plainResp := toSyncResponse(s.config, peer, turnCredentials, networkMap)
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp) encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
if err != nil { if err != nil {

View File

@@ -136,6 +136,9 @@ components:
ui_version: ui_version:
description: Peer's desktop UI version description: Peer's desktop UI version
type: string type: string
dns_label:
description: Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud
type: string
required: required:
- ip - ip
- connected - connected
@@ -145,6 +148,7 @@ components:
- groups - groups
- ssh_enabled - ssh_enabled
- hostname - hostname
- dns_label
SetupKey: SetupKey:
type: object type: object
properties: properties:
@@ -444,12 +448,24 @@ components:
type: array type: array
items: items:
type: string type: string
primary:
description: Nameserver group primary status
type: boolean
domains:
description: Nameserver group domain list
type: array
items:
type: string
minLength: 1
maxLength: 255
required: required:
- name - name
- description - description
- nameservers - nameservers
- enabled - enabled
- groups - groups
- primary
- domains
NameserverGroup: NameserverGroup:
allOf: allOf:
- type: object - type: object
@@ -468,7 +484,7 @@ components:
path: path:
description: Nameserver group field to update in form /<field> description: Nameserver group field to update in form /<field>
type: string type: string
enum: [ "name","description","enabled","groups","nameservers" ] enum: [ "name", "description", "enabled", "groups", "nameservers", "primary", "domains" ]
required: required:
- path - path

View File

@@ -39,10 +39,12 @@ const (
// Defines values for NameserverGroupPatchOperationPath. // Defines values for NameserverGroupPatchOperationPath.
const ( const (
NameserverGroupPatchOperationPathDescription NameserverGroupPatchOperationPath = "description" NameserverGroupPatchOperationPathDescription NameserverGroupPatchOperationPath = "description"
NameserverGroupPatchOperationPathDomains NameserverGroupPatchOperationPath = "domains"
NameserverGroupPatchOperationPathEnabled NameserverGroupPatchOperationPath = "enabled" NameserverGroupPatchOperationPathEnabled NameserverGroupPatchOperationPath = "enabled"
NameserverGroupPatchOperationPathGroups NameserverGroupPatchOperationPath = "groups" NameserverGroupPatchOperationPathGroups NameserverGroupPatchOperationPath = "groups"
NameserverGroupPatchOperationPathName NameserverGroupPatchOperationPath = "name" NameserverGroupPatchOperationPathName NameserverGroupPatchOperationPath = "name"
NameserverGroupPatchOperationPathNameservers NameserverGroupPatchOperationPath = "nameservers" NameserverGroupPatchOperationPathNameservers NameserverGroupPatchOperationPath = "nameservers"
NameserverGroupPatchOperationPathPrimary NameserverGroupPatchOperationPath = "primary"
) )
// Defines values for PatchMinimumOp. // Defines values for PatchMinimumOp.
@@ -159,6 +161,9 @@ type NameserverGroup struct {
// Description Nameserver group description // Description Nameserver group description
Description string `json:"description"` Description string `json:"description"`
// Domains Nameserver group domain list
Domains []string `json:"domains"`
// Enabled Nameserver group status // Enabled Nameserver group status
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
@@ -173,6 +178,9 @@ type NameserverGroup struct {
// Nameservers Nameserver group // Nameservers Nameserver group
Nameservers []Nameserver `json:"nameservers"` Nameservers []Nameserver `json:"nameservers"`
// Primary Nameserver group primary status
Primary bool `json:"primary"`
} }
// NameserverGroupPatchOperation defines model for NameserverGroupPatchOperation. // NameserverGroupPatchOperation defines model for NameserverGroupPatchOperation.
@@ -198,6 +206,9 @@ type NameserverGroupRequest struct {
// Description Nameserver group description // Description Nameserver group description
Description string `json:"description"` Description string `json:"description"`
// Domains Nameserver group domain list
Domains []string `json:"domains"`
// Enabled Nameserver group status // Enabled Nameserver group status
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
@@ -209,6 +220,9 @@ type NameserverGroupRequest struct {
// Nameservers Nameserver group // Nameservers Nameserver group
Nameservers []Nameserver `json:"nameservers"` Nameservers []Nameserver `json:"nameservers"`
// Primary Nameserver group primary status
Primary bool `json:"primary"`
} }
// PatchMinimum defines model for PatchMinimum. // PatchMinimum defines model for PatchMinimum.
@@ -228,6 +242,9 @@ type Peer struct {
// Connected Peer to Management connection status // Connected Peer to Management connection status
Connected bool `json:"connected"` Connected bool `json:"connected"`
// DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud
DnsLabel string `json:"dns_label"`
// Groups Groups that the peer belongs to // Groups Groups that the peer belongs to
Groups []GroupMinimum `json:"groups"` Groups []GroupMinimum `json:"groups"`

View File

@@ -33,7 +33,7 @@ func NewGroups(accountManager server.AccountManager, authAudience string) *Group
// GetAllGroupsHandler list for the account // GetAllGroupsHandler list for the account
func (h *Groups) GetAllGroupsHandler(w http.ResponseWriter, r *http.Request) { func (h *Groups) GetAllGroupsHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
@@ -50,7 +50,7 @@ func (h *Groups) GetAllGroupsHandler(w http.ResponseWriter, r *http.Request) {
// UpdateGroupHandler handles update to a group identified by a given ID // UpdateGroupHandler handles update to a group identified by a given ID
func (h *Groups) UpdateGroupHandler(w http.ResponseWriter, r *http.Request) { func (h *Groups) UpdateGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil { if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
@@ -111,7 +111,7 @@ func (h *Groups) UpdateGroupHandler(w http.ResponseWriter, r *http.Request) {
// PatchGroupHandler handles patch updates to a group identified by a given ID // PatchGroupHandler handles patch updates to a group identified by a given ID
func (h *Groups) PatchGroupHandler(w http.ResponseWriter, r *http.Request) { func (h *Groups) PatchGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil { if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
@@ -226,7 +226,7 @@ func (h *Groups) PatchGroupHandler(w http.ResponseWriter, r *http.Request) {
// CreateGroupHandler handles group creation request // CreateGroupHandler handles group creation request
func (h *Groups) CreateGroupHandler(w http.ResponseWriter, r *http.Request) { func (h *Groups) CreateGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil { if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
@@ -260,7 +260,7 @@ func (h *Groups) CreateGroupHandler(w http.ResponseWriter, r *http.Request) {
// DeleteGroupHandler handles group deletion request // DeleteGroupHandler handles group deletion request
func (h *Groups) DeleteGroupHandler(w http.ResponseWriter, r *http.Request) { func (h *Groups) DeleteGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil { if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
@@ -295,7 +295,7 @@ func (h *Groups) DeleteGroupHandler(w http.ResponseWriter, r *http.Request) {
// GetGroupHandler returns a group // GetGroupHandler returns a group
func (h *Groups) GetGroupHandler(w http.ResponseWriter, r *http.Request) { func (h *Groups) GetGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil { if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
return return

View File

@@ -21,11 +21,11 @@ import (
) )
var TestPeers = map[string]*server.Peer{ var TestPeers = map[string]*server.Peer{
"A": &server.Peer{Key: "A", IP: net.ParseIP("100.100.100.100")}, "A": {Key: "A", IP: net.ParseIP("100.100.100.100")},
"B": &server.Peer{Key: "B", IP: net.ParseIP("200.200.200.200")}, "B": {Key: "B", IP: net.ParseIP("200.200.200.200")},
} }
func initGroupTestData(groups ...*server.Group) *Groups { func initGroupTestData(user *server.User, groups ...*server.Group) *Groups {
return &Groups{ return &Groups{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
SaveGroupFunc: func(accountID string, group *server.Group) error { SaveGroupFunc: func(accountID string, group *server.Group) error {
@@ -72,6 +72,9 @@ func initGroupTestData(groups ...*server.Group) *Groups {
Id: claims.AccountId, Id: claims.AccountId,
Domain: "hotmail.com", Domain: "hotmail.com",
Peers: TestPeers, Peers: TestPeers,
Users: map[string]*server.User{
user.Id: user,
},
Groups: map[string]*server.Group{ Groups: map[string]*server.Group{
"id-existed": {ID: "id-existed", Peers: []string{"A", "B"}}, "id-existed": {ID: "id-existed", Peers: []string{"A", "B"}},
"id-all": {ID: "id-all", Name: "All"}}, "id-all": {ID: "id-all", Name: "All"}},
@@ -120,7 +123,8 @@ func TestGetGroup(t *testing.T) {
Name: "Group", Name: "Group",
} }
p := initGroupTestData(group) adminUser := server.NewAdminUser("test_user")
p := initGroupTestData(adminUser, group)
for _, tc := range tt { for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
@@ -270,7 +274,8 @@ func TestWriteGroup(t *testing.T) {
}, },
} }
p := initGroupTestData() adminUser := server.NewAdminUser("test_user")
p := initGroupTestData(adminUser)
for _, tc := range tt { for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {

View File

@@ -4,12 +4,14 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
s "github.com/netbirdio/netbird/management/server" s "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/middleware" "github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/rs/cors" "github.com/rs/cors"
"net/http" "net/http"
) )
// APIHandler creates the Management service HTTP API handler registering all the available endpoints. // APIHandler creates the Management service HTTP API handler registering all the available endpoints.
func APIHandler(accountManager s.AccountManager, authIssuer string, authAudience string, authKeysLocation string) (http.Handler, error) { func APIHandler(accountManager s.AccountManager, authIssuer string, authAudience string, authKeysLocation string,
appMetrics telemetry.AppMetrics) (http.Handler, error) {
jwtMiddleware, err := middleware.NewJwtMiddleware( jwtMiddleware, err := middleware.NewJwtMiddleware(
authIssuer, authIssuer,
authAudience, authAudience,
@@ -21,12 +23,15 @@ func APIHandler(accountManager s.AccountManager, authIssuer string, authAudience
corsMiddleware := cors.AllowAll() corsMiddleware := cors.AllowAll()
acMiddleware := middleware.NewAccessControll( acMiddleware := middleware.NewAccessControl(
authAudience, authAudience,
accountManager.IsUserAdmin) accountManager.IsUserAdmin)
apiHandler := mux.NewRouter() rootRouter := mux.NewRouter()
apiHandler.Use(corsMiddleware.Handler, jwtMiddleware.Handler, acMiddleware.Handler) metricsMiddleware := appMetrics.HTTPMiddleware()
apiHandler := rootRouter.PathPrefix("/api").Subrouter()
apiHandler.Use(metricsMiddleware.Handler, corsMiddleware.Handler, jwtMiddleware.Handler, acMiddleware.Handler)
groupsHandler := NewGroups(accountManager, authAudience) groupsHandler := NewGroups(accountManager, authAudience)
rulesHandler := NewRules(accountManager, authAudience) rulesHandler := NewRules(accountManager, authAudience)
@@ -36,46 +41,67 @@ func APIHandler(accountManager s.AccountManager, authIssuer string, authAudience
routesHandler := NewRoutes(accountManager, authAudience) routesHandler := NewRoutes(accountManager, authAudience)
nameserversHandler := NewNameservers(accountManager, authAudience) nameserversHandler := NewNameservers(accountManager, authAudience)
apiHandler.HandleFunc("/api/peers", peersHandler.GetPeers).Methods("GET", "OPTIONS") apiHandler.HandleFunc("/peers", peersHandler.GetPeers).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/peers/{id}", peersHandler.HandlePeer). apiHandler.HandleFunc("/peers/{id}", peersHandler.HandlePeer).
Methods("GET", "PUT", "DELETE", "OPTIONS") Methods("GET", "PUT", "DELETE", "OPTIONS")
apiHandler.HandleFunc("/api/users", userHandler.GetUsers).Methods("GET", "OPTIONS") apiHandler.HandleFunc("/users", userHandler.GetUsers).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/users/{id}", userHandler.UpdateUser).Methods("PUT", "OPTIONS") apiHandler.HandleFunc("/users/{id}", userHandler.UpdateUser).Methods("PUT", "OPTIONS")
apiHandler.HandleFunc("/api/users", userHandler.CreateUserHandler).Methods("POST", "OPTIONS") apiHandler.HandleFunc("/users", userHandler.CreateUserHandler).Methods("POST", "OPTIONS")
apiHandler.HandleFunc("/api/setup-keys", keysHandler.GetAllSetupKeysHandler).Methods("GET", "OPTIONS") apiHandler.HandleFunc("/setup-keys", keysHandler.GetAllSetupKeysHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/setup-keys", keysHandler.CreateSetupKeyHandler).Methods("POST", "OPTIONS") apiHandler.HandleFunc("/setup-keys", keysHandler.CreateSetupKeyHandler).Methods("POST", "OPTIONS")
apiHandler.HandleFunc("/api/setup-keys/{id}", keysHandler.GetSetupKeyHandler).Methods("GET", "OPTIONS") apiHandler.HandleFunc("/setup-keys/{id}", keysHandler.GetSetupKeyHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/setup-keys/{id}", keysHandler.UpdateSetupKeyHandler).Methods("PUT", "OPTIONS") apiHandler.HandleFunc("/setup-keys/{id}", keysHandler.UpdateSetupKeyHandler).Methods("PUT", "OPTIONS")
apiHandler.HandleFunc("/api/rules", rulesHandler.GetAllRulesHandler).Methods("GET", "OPTIONS") apiHandler.HandleFunc("/rules", rulesHandler.GetAllRulesHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/rules", rulesHandler.CreateRuleHandler).Methods("POST", "OPTIONS") apiHandler.HandleFunc("/rules", rulesHandler.CreateRuleHandler).Methods("POST", "OPTIONS")
apiHandler.HandleFunc("/api/rules/{id}", rulesHandler.UpdateRuleHandler).Methods("PUT", "OPTIONS") apiHandler.HandleFunc("/rules/{id}", rulesHandler.UpdateRuleHandler).Methods("PUT", "OPTIONS")
apiHandler.HandleFunc("/api/rules/{id}", rulesHandler.PatchRuleHandler).Methods("PATCH", "OPTIONS") apiHandler.HandleFunc("/rules/{id}", rulesHandler.PatchRuleHandler).Methods("PATCH", "OPTIONS")
apiHandler.HandleFunc("/api/rules/{id}", rulesHandler.GetRuleHandler).Methods("GET", "OPTIONS") apiHandler.HandleFunc("/rules/{id}", rulesHandler.GetRuleHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/rules/{id}", rulesHandler.DeleteRuleHandler).Methods("DELETE", "OPTIONS") apiHandler.HandleFunc("/rules/{id}", rulesHandler.DeleteRuleHandler).Methods("DELETE", "OPTIONS")
apiHandler.HandleFunc("/api/groups", groupsHandler.GetAllGroupsHandler).Methods("GET", "OPTIONS") apiHandler.HandleFunc("/groups", groupsHandler.GetAllGroupsHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/groups", groupsHandler.CreateGroupHandler).Methods("POST", "OPTIONS") apiHandler.HandleFunc("/groups", groupsHandler.CreateGroupHandler).Methods("POST", "OPTIONS")
apiHandler.HandleFunc("/api/groups/{id}", groupsHandler.UpdateGroupHandler).Methods("PUT", "OPTIONS") apiHandler.HandleFunc("/groups/{id}", groupsHandler.UpdateGroupHandler).Methods("PUT", "OPTIONS")
apiHandler.HandleFunc("/api/groups/{id}", groupsHandler.PatchGroupHandler).Methods("PATCH", "OPTIONS") apiHandler.HandleFunc("/groups/{id}", groupsHandler.PatchGroupHandler).Methods("PATCH", "OPTIONS")
apiHandler.HandleFunc("/api/groups/{id}", groupsHandler.GetGroupHandler).Methods("GET", "OPTIONS") apiHandler.HandleFunc("/groups/{id}", groupsHandler.GetGroupHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/groups/{id}", groupsHandler.DeleteGroupHandler).Methods("DELETE", "OPTIONS") apiHandler.HandleFunc("/groups/{id}", groupsHandler.DeleteGroupHandler).Methods("DELETE", "OPTIONS")
apiHandler.HandleFunc("/api/routes", routesHandler.GetAllRoutesHandler).Methods("GET", "OPTIONS") apiHandler.HandleFunc("/routes", routesHandler.GetAllRoutesHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/routes", routesHandler.CreateRouteHandler).Methods("POST", "OPTIONS") apiHandler.HandleFunc("/routes", routesHandler.CreateRouteHandler).Methods("POST", "OPTIONS")
apiHandler.HandleFunc("/api/routes/{id}", routesHandler.UpdateRouteHandler).Methods("PUT", "OPTIONS") apiHandler.HandleFunc("/routes/{id}", routesHandler.UpdateRouteHandler).Methods("PUT", "OPTIONS")
apiHandler.HandleFunc("/api/routes/{id}", routesHandler.PatchRouteHandler).Methods("PATCH", "OPTIONS") apiHandler.HandleFunc("/routes/{id}", routesHandler.PatchRouteHandler).Methods("PATCH", "OPTIONS")
apiHandler.HandleFunc("/api/routes/{id}", routesHandler.GetRouteHandler).Methods("GET", "OPTIONS") apiHandler.HandleFunc("/routes/{id}", routesHandler.GetRouteHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/routes/{id}", routesHandler.DeleteRouteHandler).Methods("DELETE", "OPTIONS") apiHandler.HandleFunc("/routes/{id}", routesHandler.DeleteRouteHandler).Methods("DELETE", "OPTIONS")
apiHandler.HandleFunc("/api/dns/nameservers", nameserversHandler.GetAllNameserversHandler).Methods("GET", "OPTIONS") apiHandler.HandleFunc("/dns/nameservers", nameserversHandler.GetAllNameserversHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/dns/nameservers", nameserversHandler.CreateNameserverGroupHandler).Methods("POST", "OPTIONS") apiHandler.HandleFunc("/dns/nameservers", nameserversHandler.CreateNameserverGroupHandler).Methods("POST", "OPTIONS")
apiHandler.HandleFunc("/api/dns/nameservers/{id}", nameserversHandler.UpdateNameserverGroupHandler).Methods("PUT", "OPTIONS") apiHandler.HandleFunc("/dns/nameservers/{id}", nameserversHandler.UpdateNameserverGroupHandler).Methods("PUT", "OPTIONS")
apiHandler.HandleFunc("/api/dns/nameservers/{id}", nameserversHandler.PatchNameserverGroupHandler).Methods("PATCH", "OPTIONS") apiHandler.HandleFunc("/dns/nameservers/{id}", nameserversHandler.PatchNameserverGroupHandler).Methods("PATCH", "OPTIONS")
apiHandler.HandleFunc("/api/dns/nameservers/{id}", nameserversHandler.GetNameserverGroupHandler).Methods("GET", "OPTIONS") apiHandler.HandleFunc("/dns/nameservers/{id}", nameserversHandler.GetNameserverGroupHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/dns/nameservers/{id}", nameserversHandler.DeleteNameserverGroupHandler).Methods("DELETE", "OPTIONS") apiHandler.HandleFunc("/dns/nameservers/{id}", nameserversHandler.DeleteNameserverGroupHandler).Methods("DELETE", "OPTIONS")
return apiHandler, nil err = apiHandler.Walk(func(route *mux.Route, router *mux.Router, ancestors []*mux.Route) error {
methods, err := route.GetMethods()
if err != nil {
return err
}
for _, method := range methods {
template, err := route.GetPathTemplate()
if err != nil {
return err
}
err = metricsMiddleware.AddHTTPRequestResponseCounter(template, method)
if err != nil {
return err
}
}
return nil
})
if err != nil {
return nil, err
}
return rootRouter, nil
} }

View File

@@ -9,24 +9,25 @@ import (
type IsUserAdminFunc func(claims jwtclaims.AuthorizationClaims) (bool, error) type IsUserAdminFunc func(claims jwtclaims.AuthorizationClaims) (bool, error)
// AccessControll middleware to restrict to make POST/PUT/DELETE requests by admin only // AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only
type AccessControll struct { type AccessControl struct {
jwtExtractor jwtclaims.ClaimsExtractor jwtExtractor jwtclaims.ClaimsExtractor
isUserAdmin IsUserAdminFunc isUserAdmin IsUserAdminFunc
audience string audience string
} }
// NewAccessControll instance constructor // NewAccessControl instance constructor
func NewAccessControll(audience string, isUserAdmin IsUserAdminFunc) *AccessControll { func NewAccessControl(audience string, isUserAdmin IsUserAdminFunc) *AccessControl {
return &AccessControll{ return &AccessControl{
isUserAdmin: isUserAdmin, isUserAdmin: isUserAdmin,
audience: audience, audience: audience,
jwtExtractor: *jwtclaims.NewClaimsExtractor(nil), jwtExtractor: *jwtclaims.NewClaimsExtractor(nil),
} }
} }
// Handler method of the middleware which forbinneds all modify requests for non admin users // Handler method of the middleware which forbids all modify requests for non admin users
func (a *AccessControll) Handler(h http.Handler) http.Handler { // It also adds
func (a *AccessControl) Handler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
jwtClaims := a.jwtExtractor.ExtractClaimsFromRequestContext(r, a.audience) jwtClaims := a.jwtExtractor.ExtractClaimsFromRequestContext(r, a.audience)
@@ -34,7 +35,6 @@ func (a *AccessControll) Handler(h http.Handler) http.Handler {
if err != nil { if err != nil {
http.Error(w, fmt.Sprintf("error get user from JWT: %v", err), http.StatusUnauthorized) http.Error(w, fmt.Sprintf("error get user from JWT: %v", err), http.StatusUnauthorized)
return return
} }
if !ok { if !ok {

View File

@@ -186,7 +186,7 @@ func (m *JWTMiddleware) CheckJWTFromRequest(w http.ResponseWriter, r *http.Reque
validatedToken, err := m.ValidateAndParse(token) validatedToken, err := m.ValidateAndParse(token)
if err != nil { if err != nil {
m.Options.ErrorHandler(w, r, "The token isn't valid") m.Options.ErrorHandler(w, r, err.Error())
return err return err
} }

View File

@@ -30,7 +30,7 @@ func NewNameservers(accountManager server.AccountManager, authAudience string) *
// GetAllNameserversHandler returns the list of nameserver groups for the account // GetAllNameserversHandler returns the list of nameserver groups for the account
func (h *Nameservers) GetAllNameserversHandler(w http.ResponseWriter, r *http.Request) { func (h *Nameservers) GetAllNameserversHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
@@ -53,7 +53,7 @@ func (h *Nameservers) GetAllNameserversHandler(w http.ResponseWriter, r *http.Re
// CreateNameserverGroupHandler handles nameserver group creation request // CreateNameserverGroupHandler handles nameserver group creation request
func (h *Nameservers) CreateNameserverGroupHandler(w http.ResponseWriter, r *http.Request) { func (h *Nameservers) CreateNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
@@ -71,7 +71,7 @@ func (h *Nameservers) CreateNameserverGroupHandler(w http.ResponseWriter, r *htt
return return
} }
nsGroup, err := h.accountManager.CreateNameServerGroup(account.Id, req.Name, req.Description, nsList, req.Groups, req.Enabled) nsGroup, err := h.accountManager.CreateNameServerGroup(account.Id, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled)
if err != nil { if err != nil {
toHTTPError(err, w) toHTTPError(err, w)
return return
@@ -84,7 +84,7 @@ func (h *Nameservers) CreateNameserverGroupHandler(w http.ResponseWriter, r *htt
// UpdateNameserverGroupHandler handles update to a nameserver group identified by a given ID // UpdateNameserverGroupHandler handles update to a nameserver group identified by a given ID
func (h *Nameservers) UpdateNameserverGroupHandler(w http.ResponseWriter, r *http.Request) { func (h *Nameservers) UpdateNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
@@ -113,6 +113,8 @@ func (h *Nameservers) UpdateNameserverGroupHandler(w http.ResponseWriter, r *htt
ID: nsGroupID, ID: nsGroupID,
Name: req.Name, Name: req.Name,
Description: req.Description, Description: req.Description,
Primary: req.Primary,
Domains: req.Domains,
NameServers: nsList, NameServers: nsList,
Groups: req.Groups, Groups: req.Groups,
Enabled: req.Enabled, Enabled: req.Enabled,
@@ -131,7 +133,7 @@ func (h *Nameservers) UpdateNameserverGroupHandler(w http.ResponseWriter, r *htt
// PatchNameserverGroupHandler handles patch updates to a nameserver group identified by a given ID // PatchNameserverGroupHandler handles patch updates to a nameserver group identified by a given ID
func (h *Nameservers) PatchNameserverGroupHandler(w http.ResponseWriter, r *http.Request) { func (h *Nameservers) PatchNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
@@ -168,6 +170,16 @@ func (h *Nameservers) PatchNameserverGroupHandler(w http.ResponseWriter, r *http
Type: server.UpdateNameServerGroupDescription, Type: server.UpdateNameServerGroupDescription,
Values: patch.Value, Values: patch.Value,
}) })
case api.NameserverGroupPatchOperationPathPrimary:
operations = append(operations, server.NameServerGroupUpdateOperation{
Type: server.UpdateNameServerGroupPrimary,
Values: patch.Value,
})
case api.NameserverGroupPatchOperationPathDomains:
operations = append(operations, server.NameServerGroupUpdateOperation{
Type: server.UpdateNameServerGroupDomains,
Values: patch.Value,
})
case api.NameserverGroupPatchOperationPathNameservers: case api.NameserverGroupPatchOperationPathNameservers:
operations = append(operations, server.NameServerGroupUpdateOperation{ operations = append(operations, server.NameServerGroupUpdateOperation{
Type: server.UpdateNameServerGroupNameServers, Type: server.UpdateNameServerGroupNameServers,
@@ -202,7 +214,7 @@ func (h *Nameservers) PatchNameserverGroupHandler(w http.ResponseWriter, r *http
// DeleteNameserverGroupHandler handles nameserver group deletion request // DeleteNameserverGroupHandler handles nameserver group deletion request
func (h *Nameservers) DeleteNameserverGroupHandler(w http.ResponseWriter, r *http.Request) { func (h *Nameservers) DeleteNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
@@ -226,7 +238,7 @@ func (h *Nameservers) DeleteNameserverGroupHandler(w http.ResponseWriter, r *htt
// GetNameserverGroupHandler handles a nameserver group Get request identified by ID // GetNameserverGroupHandler handles a nameserver group Get request identified by ID
func (h *Nameservers) GetNameserverGroupHandler(w http.ResponseWriter, r *http.Request) { func (h *Nameservers) GetNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
@@ -279,6 +291,8 @@ func toNameserverGroupResponse(serverNSGroup *nbdns.NameServerGroup) *api.Namese
Id: serverNSGroup.ID, Id: serverNSGroup.ID,
Name: serverNSGroup.Name, Name: serverNSGroup.Name,
Description: serverNSGroup.Description, Description: serverNSGroup.Description,
Primary: serverNSGroup.Primary,
Domains: serverNSGroup.Domains,
Groups: serverNSGroup.Groups, Groups: serverNSGroup.Groups,
Nameservers: nsList, Nameservers: nsList,
Enabled: serverNSGroup.Enabled, Enabled: serverNSGroup.Enabled,

View File

@@ -29,12 +29,16 @@ const (
var testingNSAccount = &server.Account{ var testingNSAccount = &server.Account{
Id: testNSGroupAccountID, Id: testNSGroupAccountID,
Domain: "hotmail.com", Domain: "hotmail.com",
Users: map[string]*server.User{
"test_user": server.NewAdminUser("test_user"),
},
} }
var baseExistingNSGroup = &nbdns.NameServerGroup{ var baseExistingNSGroup = &nbdns.NameServerGroup{
ID: existingNSGroupID, ID: existingNSGroupID,
Name: "super", Name: "super",
Description: "super", Description: "super",
Primary: true,
NameServers: []nbdns.NameServer{ NameServers: []nbdns.NameServer{
{ {
IP: netip.MustParseAddr("1.1.1.1"), IP: netip.MustParseAddr("1.1.1.1"),
@@ -60,7 +64,7 @@ func initNameserversTestData() *Nameservers {
} }
return nil, status.Errorf(codes.NotFound, "nameserver group with ID %s not found", nsGroupID) return nil, status.Errorf(codes.NotFound, "nameserver group with ID %s not found", nsGroupID)
}, },
CreateNameServerGroupFunc: func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, enabled bool) (*nbdns.NameServerGroup, error) { CreateNameServerGroupFunc: func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool) (*nbdns.NameServerGroup, error) {
return &nbdns.NameServerGroup{ return &nbdns.NameServerGroup{
ID: existingNSGroupID, ID: existingNSGroupID,
Name: name, Name: name,
@@ -68,6 +72,8 @@ func initNameserversTestData() *Nameservers {
NameServers: nameServerList, NameServers: nameServerList,
Groups: groups, Groups: groups,
Enabled: enabled, Enabled: enabled,
Primary: primary,
Domains: domains,
}, nil }, nil
}, },
DeleteNameServerGroupFunc: func(accountID, nsGroupID string) error { DeleteNameServerGroupFunc: func(accountID, nsGroupID string) error {
@@ -150,7 +156,7 @@ func TestNameserversHandlers(t *testing.T) {
requestType: http.MethodPost, requestType: http.MethodPost,
requestPath: "/api/dns/nameservers", requestPath: "/api/dns/nameservers",
requestBody: bytes.NewBuffer( requestBody: bytes.NewBuffer(
[]byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1.1.1.1\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true}")), []byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1.1.1.1\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true,\"primary\":true}")),
expectedStatus: http.StatusOK, expectedStatus: http.StatusOK,
expectedBody: true, expectedBody: true,
expectedNSGroup: &api.NameserverGroup{ expectedNSGroup: &api.NameserverGroup{
@@ -166,6 +172,7 @@ func TestNameserversHandlers(t *testing.T) {
}, },
Groups: []string{"group"}, Groups: []string{"group"},
Enabled: true, Enabled: true,
Primary: true,
}, },
}, },
{ {
@@ -173,7 +180,7 @@ func TestNameserversHandlers(t *testing.T) {
requestType: http.MethodPost, requestType: http.MethodPost,
requestPath: "/api/dns/nameservers", requestPath: "/api/dns/nameservers",
requestBody: bytes.NewBuffer( requestBody: bytes.NewBuffer(
[]byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1000\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true}")), []byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1000\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true,\"primary\":true}")),
expectedStatus: http.StatusBadRequest, expectedStatus: http.StatusBadRequest,
expectedBody: false, expectedBody: false,
}, },
@@ -182,7 +189,7 @@ func TestNameserversHandlers(t *testing.T) {
requestType: http.MethodPut, requestType: http.MethodPut,
requestPath: "/api/dns/nameservers/" + existingNSGroupID, requestPath: "/api/dns/nameservers/" + existingNSGroupID,
requestBody: bytes.NewBuffer( requestBody: bytes.NewBuffer(
[]byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1.1.1.1\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true}")), []byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1.1.1.1\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true,\"primary\":true}")),
expectedStatus: http.StatusOK, expectedStatus: http.StatusOK,
expectedBody: true, expectedBody: true,
expectedNSGroup: &api.NameserverGroup{ expectedNSGroup: &api.NameserverGroup{
@@ -198,6 +205,7 @@ func TestNameserversHandlers(t *testing.T) {
}, },
Groups: []string{"group"}, Groups: []string{"group"},
Enabled: true, Enabled: true,
Primary: true,
}, },
}, },
{ {
@@ -205,7 +213,7 @@ func TestNameserversHandlers(t *testing.T) {
requestType: http.MethodPut, requestType: http.MethodPut,
requestPath: "/api/dns/nameservers/" + notFoundNSGroupID, requestPath: "/api/dns/nameservers/" + notFoundNSGroupID,
requestBody: bytes.NewBuffer( requestBody: bytes.NewBuffer(
[]byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1.1.1.1\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true}")), []byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1.1.1.1\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true,\"primary\":true}")),
expectedStatus: http.StatusNotFound, expectedStatus: http.StatusNotFound,
expectedBody: false, expectedBody: false,
}, },
@@ -214,7 +222,7 @@ func TestNameserversHandlers(t *testing.T) {
requestType: http.MethodPut, requestType: http.MethodPut,
requestPath: "/api/dns/nameservers/" + notFoundNSGroupID, requestPath: "/api/dns/nameservers/" + notFoundNSGroupID,
requestBody: bytes.NewBuffer( requestBody: bytes.NewBuffer(
[]byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"100\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true}")), []byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"100\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true,\"primary\":true}")),
expectedStatus: http.StatusBadRequest, expectedStatus: http.StatusBadRequest,
expectedBody: false, expectedBody: false,
}, },
@@ -232,6 +240,7 @@ func TestNameserversHandlers(t *testing.T) {
Nameservers: toNameserverGroupResponse(baseExistingNSGroup).Nameservers, Nameservers: toNameserverGroupResponse(baseExistingNSGroup).Nameservers,
Groups: baseExistingNSGroup.Groups, Groups: baseExistingNSGroup.Groups,
Enabled: baseExistingNSGroup.Enabled, Enabled: baseExistingNSGroup.Enabled,
Primary: baseExistingNSGroup.Primary,
}, },
}, },
{ {

View File

@@ -56,7 +56,7 @@ func (h *Peers) deletePeer(accountId string, peer *server.Peer, w http.ResponseW
} }
func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) { func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
@@ -95,15 +95,20 @@ func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) {
func (h *Peers) GetPeers(w http.ResponseWriter, r *http.Request) { func (h *Peers) GetPeers(w http.ResponseWriter, r *http.Request) {
switch r.Method { switch r.Method {
case http.MethodGet: case http.MethodGet:
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
peers, err := h.accountManager.GetPeers(account.Id, user.Id)
if err != nil {
return
}
respBody := []*api.Peer{} respBody := []*api.Peer{}
for _, peer := range account.Peers { for _, peer := range peers {
respBody = append(respBody, toPeerResponse(peer, account)) respBody = append(respBody, toPeerResponse(peer, account))
} }
writeJSONObject(w, respBody) writeJSONObject(w, respBody)
@@ -147,5 +152,6 @@ func toPeerResponse(peer *server.Peer, account *server.Account) *api.Peer {
Hostname: peer.Meta.Hostname, Hostname: peer.Meta.Hostname,
UserId: &peer.UserID, UserId: &peer.UserID,
UiVersion: &peer.Meta.UIVersion, UiVersion: &peer.Meta.UIVersion,
DnsLabel: peer.DNSLabel,
} }
} }

View File

@@ -16,15 +16,21 @@ import (
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
) )
func initTestMetaData(peer ...*server.Peer) *Peers { func initTestMetaData(peers ...*server.Peer) *Peers {
return &Peers{ return &Peers{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
GetPeersFunc: func(accountID, userID string) ([]*server.Peer, error) {
return peers, nil
},
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) { GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
return &server.Account{ return &server.Account{
Id: claims.AccountId, Id: claims.AccountId,
Domain: "hotmail.com", Domain: "hotmail.com",
Peers: map[string]*server.Peer{ Peers: map[string]*server.Peer{
"test_peer": peer[0], "test_peer": peers[0],
},
Users: map[string]*server.User{
"test_user": server.NewAdminUser("test_user"),
}, },
}, nil }, nil
}, },

View File

@@ -33,16 +33,24 @@ func NewRoutes(accountManager server.AccountManager, authAudience string) *Route
// GetAllRoutesHandler returns the list of routes for the account // GetAllRoutesHandler returns the list of routes for the account
func (h *Routes) GetAllRoutesHandler(w http.ResponseWriter, r *http.Request) { func (h *Routes) GetAllRoutesHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
routes, err := h.accountManager.ListRoutes(account.Id) routes, err := h.accountManager.ListRoutes(account.Id, user.Id)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
if e, ok := server.FromError(err); ok {
switch e.Type() {
case server.PermissionDenied:
http.Error(w, e.Error(), http.StatusForbidden)
return
default:
}
}
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
@@ -56,7 +64,7 @@ func (h *Routes) GetAllRoutesHandler(w http.ResponseWriter, r *http.Request) {
// CreateRouteHandler handles route creation request // CreateRouteHandler handles route creation request
func (h *Routes) CreateRouteHandler(w http.ResponseWriter, r *http.Request) { func (h *Routes) CreateRouteHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil { if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
@@ -104,7 +112,7 @@ func (h *Routes) CreateRouteHandler(w http.ResponseWriter, r *http.Request) {
// UpdateRouteHandler handles update to a route identified by a given ID // UpdateRouteHandler handles update to a route identified by a given ID
func (h *Routes) UpdateRouteHandler(w http.ResponseWriter, r *http.Request) { func (h *Routes) UpdateRouteHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil { if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
@@ -117,7 +125,7 @@ func (h *Routes) UpdateRouteHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
_, err = h.accountManager.GetRoute(account.Id, routeID) _, err = h.accountManager.GetRoute(account.Id, routeID, user.Id)
if err != nil { if err != nil {
http.Error(w, fmt.Sprintf("couldn't find route for ID %s", routeID), http.StatusNotFound) http.Error(w, fmt.Sprintf("couldn't find route for ID %s", routeID), http.StatusNotFound)
return return
@@ -177,7 +185,7 @@ func (h *Routes) UpdateRouteHandler(w http.ResponseWriter, r *http.Request) {
// PatchRouteHandler handles patch updates to a route identified by a given ID // PatchRouteHandler handles patch updates to a route identified by a given ID
func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) { func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil { if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
@@ -190,7 +198,7 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
_, err = h.accountManager.GetRoute(account.Id, routeID) _, err = h.accountManager.GetRoute(account.Id, routeID, user.Id)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
http.Error(w, fmt.Sprintf("couldn't find route ID %s", routeID), http.StatusNotFound) http.Error(w, fmt.Sprintf("couldn't find route ID %s", routeID), http.StatusNotFound)
@@ -334,7 +342,7 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
// DeleteRouteHandler handles route deletion request // DeleteRouteHandler handles route deletion request
func (h *Routes) DeleteRouteHandler(w http.ResponseWriter, r *http.Request) { func (h *Routes) DeleteRouteHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil { if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
@@ -363,7 +371,7 @@ func (h *Routes) DeleteRouteHandler(w http.ResponseWriter, r *http.Request) {
// GetRouteHandler handles a route Get request identified by ID // GetRouteHandler handles a route Get request identified by ID
func (h *Routes) GetRouteHandler(w http.ResponseWriter, r *http.Request) { func (h *Routes) GetRouteHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil { if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
@@ -375,7 +383,7 @@ func (h *Routes) GetRouteHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
foundRoute, err := h.accountManager.GetRoute(account.Id, routeID) foundRoute, err := h.accountManager.GetRoute(account.Id, routeID, user.Id)
if err != nil { if err != nil {
http.Error(w, "route not found", http.StatusNotFound) http.Error(w, "route not found", http.StatusNotFound)
return return

View File

@@ -51,12 +51,15 @@ var testingAccount = &server.Account{
IP: netip.MustParseAddr(existingPeerID).AsSlice(), IP: netip.MustParseAddr(existingPeerID).AsSlice(),
}, },
}, },
Users: map[string]*server.User{
"test_user": server.NewAdminUser("test_user"),
},
} }
func initRoutesTestData() *Routes { func initRoutesTestData() *Routes {
return &Routes{ return &Routes{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
GetRouteFunc: func(_, routeID string) (*route.Route, error) { GetRouteFunc: func(_, routeID, _ string) (*route.Route, error) {
if routeID == existingRouteID { if routeID == existingRouteID {
return baseExistingRoute, nil return baseExistingRoute, nil
} }

View File

@@ -31,15 +31,29 @@ func NewRules(accountManager server.AccountManager, authAudience string) *Rules
// GetAllRulesHandler list for the account // GetAllRulesHandler list for the account
func (h *Rules) GetAllRulesHandler(w http.ResponseWriter, r *http.Request) { func (h *Rules) GetAllRulesHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
accountRules, err := h.accountManager.ListRules(account.Id, user.Id)
if err != nil {
log.Error(err)
if e, ok := server.FromError(err); ok {
switch e.Type() {
case server.PermissionDenied:
http.Error(w, e.Error(), http.StatusForbidden)
return
default:
}
}
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
rules := []*api.Rule{} rules := []*api.Rule{}
for _, r := range account.Rules { for _, r := range accountRules {
rules = append(rules, toRuleResponse(account, r)) rules = append(rules, toRuleResponse(account, r))
} }
@@ -48,7 +62,7 @@ func (h *Rules) GetAllRulesHandler(w http.ResponseWriter, r *http.Request) {
// UpdateRuleHandler handles update to a rule identified by a given ID // UpdateRuleHandler handles update to a rule identified by a given ID
func (h *Rules) UpdateRuleHandler(w http.ResponseWriter, r *http.Request) { func (h *Rules) UpdateRuleHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil { if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
@@ -118,7 +132,7 @@ func (h *Rules) UpdateRuleHandler(w http.ResponseWriter, r *http.Request) {
// PatchRuleHandler handles patch updates to a rule identified by a given ID // PatchRuleHandler handles patch updates to a rule identified by a given ID
func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) { func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil { if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
@@ -275,7 +289,7 @@ func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
// CreateRuleHandler handles rule creation request // CreateRuleHandler handles rule creation request
func (h *Rules) CreateRuleHandler(w http.ResponseWriter, r *http.Request) { func (h *Rules) CreateRuleHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil { if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
@@ -332,7 +346,7 @@ func (h *Rules) CreateRuleHandler(w http.ResponseWriter, r *http.Request) {
// DeleteRuleHandler handles rule deletion request // DeleteRuleHandler handles rule deletion request
func (h *Rules) DeleteRuleHandler(w http.ResponseWriter, r *http.Request) { func (h *Rules) DeleteRuleHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil { if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
@@ -356,7 +370,7 @@ func (h *Rules) DeleteRuleHandler(w http.ResponseWriter, r *http.Request) {
// GetRuleHandler handles a group Get request identified by ID // GetRuleHandler handles a group Get request identified by ID
func (h *Rules) GetRuleHandler(w http.ResponseWriter, r *http.Request) { func (h *Rules) GetRuleHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil { if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
@@ -370,7 +384,7 @@ func (h *Rules) GetRuleHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
rule, err := h.accountManager.GetRule(account.Id, ruleID) rule, err := h.accountManager.GetRule(account.Id, ruleID, user.Id)
if err != nil { if err != nil {
http.Error(w, "rule not found", http.StatusNotFound) http.Error(w, "rule not found", http.StatusNotFound)
return return

View File

@@ -28,7 +28,7 @@ func initRulesTestData(rules ...*server.Rule) *Rules {
} }
return nil return nil
}, },
GetRuleFunc: func(_, ruleID string) (*server.Rule, error) { GetRuleFunc: func(_, ruleID, _ string) (*server.Rule, error) {
if ruleID != "idoftherule" { if ruleID != "idoftherule" {
return nil, fmt.Errorf("not found") return nil, fmt.Errorf("not found")
} }
@@ -75,6 +75,9 @@ func initRulesTestData(rules ...*server.Rule) *Rules {
"F": {ID: "F"}, "F": {ID: "F"},
"G": {ID: "G"}, "G": {ID: "G"},
}, },
Users: map[string]*server.User{
"test_user": server.NewAdminUser("test_user"),
},
}, nil }, nil
}, },
}, },

View File

@@ -31,7 +31,7 @@ func NewSetupKeysHandler(accountManager server.AccountManager, authAudience stri
// CreateSetupKeyHandler is a POST requests that creates a new SetupKey // CreateSetupKeyHandler is a POST requests that creates a new SetupKey
func (h *SetupKeys) CreateSetupKeyHandler(w http.ResponseWriter, r *http.Request) { func (h *SetupKeys) CreateSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
@@ -81,7 +81,7 @@ func (h *SetupKeys) CreateSetupKeyHandler(w http.ResponseWriter, r *http.Request
// GetSetupKeyHandler is a GET request to get a SetupKey by ID // GetSetupKeyHandler is a GET request to get a SetupKey by ID
func (h *SetupKeys) GetSetupKeyHandler(w http.ResponseWriter, r *http.Request) { func (h *SetupKeys) GetSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
@@ -95,7 +95,7 @@ func (h *SetupKeys) GetSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
key, err := h.accountManager.GetSetupKey(account.Id, keyID) key, err := h.accountManager.GetSetupKey(account.Id, user.Id, keyID)
if err != nil { if err != nil {
errStatus, ok := status.FromError(err) errStatus, ok := status.FromError(err)
if ok && errStatus.Code() == codes.NotFound { if ok && errStatus.Code() == codes.NotFound {
@@ -112,7 +112,7 @@ func (h *SetupKeys) GetSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
// UpdateSetupKeyHandler is a PUT request to update server.SetupKey // UpdateSetupKeyHandler is a PUT request to update server.SetupKey
func (h *SetupKeys) UpdateSetupKeyHandler(w http.ResponseWriter, r *http.Request) { func (h *SetupKeys) UpdateSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
@@ -168,14 +168,14 @@ func (h *SetupKeys) UpdateSetupKeyHandler(w http.ResponseWriter, r *http.Request
// GetAllSetupKeysHandler is a GET request that returns a list of SetupKey // GetAllSetupKeysHandler is a GET request that returns a list of SetupKey
func (h *SetupKeys) GetAllSetupKeysHandler(w http.ResponseWriter, r *http.Request) { func (h *SetupKeys) GetAllSetupKeysHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
setupKeys, err := h.accountManager.ListSetupKeys(account.Id) setupKeys, err := h.accountManager.ListSetupKeys(account.Id, user.Id)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)

View File

@@ -28,13 +28,17 @@ const (
notFoundSetupKeyID = "notFoundSetupKeyID" notFoundSetupKeyID = "notFoundSetupKeyID"
) )
func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.SetupKey, updatedSetupKey *server.SetupKey) *SetupKeys { func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.SetupKey, updatedSetupKey *server.SetupKey,
user *server.User) *SetupKeys {
return &SetupKeys{ return &SetupKeys{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) { GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
return &server.Account{ return &server.Account{
Id: testAccountID, Id: testAccountID,
Domain: "hotmail.com", Domain: "hotmail.com",
Users: map[string]*server.User{
user.Id: user,
},
SetupKeys: map[string]*server.SetupKey{ SetupKeys: map[string]*server.SetupKey{
defaultKey.Key: defaultKey, defaultKey.Key: defaultKey,
}, },
@@ -49,7 +53,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
} }
return nil, fmt.Errorf("failed creating setup key") return nil, fmt.Errorf("failed creating setup key")
}, },
GetSetupKeyFunc: func(accountID string, keyID string) (*server.SetupKey, error) { GetSetupKeyFunc: func(accountID, userID, keyID string) (*server.SetupKey, error) {
switch keyID { switch keyID {
case defaultKey.Id: case defaultKey.Id:
return defaultKey, nil return defaultKey, nil
@@ -67,7 +71,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
return nil, status.Errorf(codes.NotFound, "key %s not found", key.Id) return nil, status.Errorf(codes.NotFound, "key %s not found", key.Id)
}, },
ListSetupKeysFunc: func(accountID string) ([]*server.SetupKey, error) { ListSetupKeysFunc: func(accountID, userID string) ([]*server.SetupKey, error) {
return []*server.SetupKey{defaultKey}, nil return []*server.SetupKey{defaultKey}, nil
}, },
}, },
@@ -75,7 +79,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
jwtExtractor: jwtclaims.ClaimsExtractor{ jwtExtractor: jwtclaims.ClaimsExtractor{
ExtractClaimsFromRequestContext: func(r *http.Request, authAudience string) jwtclaims.AuthorizationClaims { ExtractClaimsFromRequestContext: func(r *http.Request, authAudience string) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{ return jwtclaims.AuthorizationClaims{
UserId: "test_user", UserId: user.Id,
Domain: "hotmail.com", Domain: "hotmail.com",
AccountId: testAccountID, AccountId: testAccountID,
} }
@@ -88,6 +92,8 @@ func TestSetupKeysHandlers(t *testing.T) {
defaultSetupKey := server.GenerateDefaultSetupKey() defaultSetupKey := server.GenerateDefaultSetupKey()
defaultSetupKey.Id = existingSetupKeyID defaultSetupKey.Id = existingSetupKeyID
adminUser := server.NewAdminUser("test_user")
newSetupKey := server.GenerateSetupKey(newSetupKeyName, server.SetupKeyReusable, 0, []string{"group-1"}) newSetupKey := server.GenerateSetupKey(newSetupKeyName, server.SetupKeyReusable, 0, []string{"group-1"})
updatedDefaultSetupKey := defaultSetupKey.Copy() updatedDefaultSetupKey := defaultSetupKey.Copy()
updatedDefaultSetupKey.AutoGroups = []string{"group-1"} updatedDefaultSetupKey.AutoGroups = []string{"group-1"}
@@ -153,7 +159,7 @@ func TestSetupKeysHandlers(t *testing.T) {
}, },
} }
handler := initSetupKeysTestMetaData(defaultSetupKey, newSetupKey, updatedDefaultSetupKey) handler := initSetupKeysTestMetaData(defaultSetupKey, newSetupKey, updatedDefaultSetupKey, adminUser)
for _, tc := range tt { for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {

View File

@@ -34,7 +34,7 @@ func (h *UserHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
http.Error(w, "", http.StatusBadRequest) http.Error(w, "", http.StatusBadRequest)
} }
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
@@ -87,7 +87,7 @@ func (h *UserHandler) CreateUserHandler(w http.ResponseWriter, r *http.Request)
http.Error(w, "", http.StatusNotFound) http.Error(w, "", http.StatusNotFound)
} }
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
} }
@@ -132,12 +132,13 @@ func (h *UserHandler) GetUsers(w http.ResponseWriter, r *http.Request) {
http.Error(w, "", http.StatusBadRequest) http.Error(w, "", http.StatusBadRequest)
} }
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil { if err != nil {
log.Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError)
return
} }
data, err := h.accountManager.GetUsersFromAccount(account.Id) data, err := h.accountManager.GetUsersFromAccount(account.Id, user.Id)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)

View File

@@ -27,7 +27,7 @@ func initUsers(user ...*server.User) *UserHandler {
Users: users, Users: users,
}, nil }, nil
}, },
GetUsersFromAccountFunc: func(accountID string) ([]*server.UserInfo, error) { GetUsersFromAccountFunc: func(accountID, userID string) ([]*server.UserInfo, error) {
users := make([]*server.UserInfo, 0) users := make([]*server.UserInfo, 0)
for _, v := range user { for _, v := range user {
users = append(users, &server.UserInfo{ users = append(users, &server.UserInfo{
@@ -44,7 +44,7 @@ func initUsers(user ...*server.User) *UserHandler {
jwtExtractor: jwtclaims.ClaimsExtractor{ jwtExtractor: jwtclaims.ClaimsExtractor{
ExtractClaimsFromRequestContext: func(r *http.Request, authAudiance string) jwtclaims.AuthorizationClaims { ExtractClaimsFromRequestContext: func(r *http.Request, authAudiance string) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{ return jwtclaims.AuthorizationClaims{
UserId: "test_user", UserId: "1",
Domain: "hotmail.com", Domain: "hotmail.com",
AccountId: "test_id", AccountId: "test_id",
} }

View File

@@ -56,16 +56,22 @@ func (d *Duration) UnmarshalJSON(b []byte) error {
func getJWTAccount(accountManager server.AccountManager, func getJWTAccount(accountManager server.AccountManager,
jwtExtractor jwtclaims.ClaimsExtractor, jwtExtractor jwtclaims.ClaimsExtractor,
authAudience string, r *http.Request) (*server.Account, error) { authAudience string, r *http.Request) (*server.Account, *server.User, error) {
jwtClaims := jwtExtractor.ExtractClaimsFromRequestContext(r, authAudience) claims := jwtExtractor.ExtractClaimsFromRequestContext(r, authAudience)
account, err := accountManager.GetAccountFromToken(jwtClaims) account, err := accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed getting account of a user %s: %v", jwtClaims.UserId, err) return nil, nil, fmt.Errorf("failed getting account of a user %s: %v", claims.UserId, err)
} }
return account, nil user := account.Users[claims.UserId]
if user == nil {
// this is not really possible because we got an account by user ID
return nil, nil, fmt.Errorf("user %s not found", claims.UserId)
}
return account, user, nil
} }
func toHTTPError(err error, w http.ResponseWriter) { func toHTTPError(err error, w http.ResponseWriter) {

View File

@@ -6,6 +6,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/netbirdio/netbird/management/server/telemetry"
"io" "io"
"net/http" "net/http"
"net/url" "net/url"
@@ -24,6 +25,7 @@ type Auth0Manager struct {
httpClient ManagerHTTPClient httpClient ManagerHTTPClient
credentials ManagerCredentials credentials ManagerCredentials
helper ManagerHelper helper ManagerHelper
appMetrics telemetry.AppMetrics
} }
// Auth0ClientConfig auth0 manager client configurations // Auth0ClientConfig auth0 manager client configurations
@@ -51,6 +53,7 @@ type Auth0Credentials struct {
httpClient ManagerHTTPClient httpClient ManagerHTTPClient
jwtToken JWTToken jwtToken JWTToken
mux sync.Mutex mux sync.Mutex
appMetrics telemetry.AppMetrics
} }
// createUserRequest is a user create request // createUserRequest is a user create request
@@ -106,7 +109,7 @@ type auth0Profile struct {
} }
// NewAuth0Manager creates a new instance of the Auth0Manager // NewAuth0Manager creates a new instance of the Auth0Manager
func NewAuth0Manager(config Auth0ClientConfig) (*Auth0Manager, error) { func NewAuth0Manager(config Auth0ClientConfig, appMetrics telemetry.AppMetrics) (*Auth0Manager, error) {
httpTransport := http.DefaultTransport.(*http.Transport).Clone() httpTransport := http.DefaultTransport.(*http.Transport).Clone()
httpTransport.MaxIdleConns = 5 httpTransport.MaxIdleConns = 5
@@ -134,12 +137,15 @@ func NewAuth0Manager(config Auth0ClientConfig) (*Auth0Manager, error) {
clientConfig: config, clientConfig: config,
httpClient: httpClient, httpClient: httpClient,
helper: helper, helper: helper,
appMetrics: appMetrics,
} }
return &Auth0Manager{ return &Auth0Manager{
authIssuer: config.AuthIssuer, authIssuer: config.AuthIssuer,
credentials: credentials, credentials: credentials,
httpClient: httpClient, httpClient: httpClient,
helper: helper, helper: helper,
appMetrics: appMetrics,
}, nil }, nil
} }
@@ -170,6 +176,9 @@ func (c *Auth0Credentials) requestJWTToken() (*http.Response, error) {
res, err = c.httpClient.Do(req) res, err = c.httpClient.Do(req)
if err != nil { if err != nil {
if c.appMetrics != nil {
c.appMetrics.IDPMetrics().CountRequestError()
}
return res, err return res, err
} }
@@ -214,6 +223,10 @@ func (c *Auth0Credentials) Authenticate() (JWTToken, error) {
c.mux.Lock() c.mux.Lock()
defer c.mux.Unlock() defer c.mux.Unlock()
if c.appMetrics != nil {
c.appMetrics.IDPMetrics().CountAuthenticate()
}
// If jwtToken has an expires time and we have enough time to do a request return immediately // If jwtToken has an expires time and we have enough time to do a request return immediately
if c.jwtStillValid() { if c.jwtStillValid() {
return c.jwtToken, nil return c.jwtToken, nil
@@ -287,9 +300,16 @@ func (am *Auth0Manager) GetAccount(accountID string) ([]*UserData, error) {
res, err := am.httpClient.Do(req) res, err := am.httpClient.Do(req)
if err != nil { if err != nil {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
return nil, err return nil, err
} }
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountGetAccount()
}
body, err := io.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -342,9 +362,16 @@ func (am *Auth0Manager) GetUserDataByID(userID string, appMetadata AppMetadata)
res, err := am.httpClient.Do(req) res, err := am.httpClient.Do(req)
if err != nil { if err != nil {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
return nil, err return nil, err
} }
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountGetUserDataByID()
}
body, err := io.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -398,9 +425,16 @@ func (am *Auth0Manager) UpdateUserAppMetadata(userID string, appMetadata AppMeta
res, err := am.httpClient.Do(req) res, err := am.httpClient.Do(req)
if err != nil { if err != nil {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
return err return err
} }
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountUpdateUserAppMetadata()
}
defer func() { defer func() {
err = res.Body.Close() err = res.Body.Close()
if err != nil { if err != nil {
@@ -416,12 +450,13 @@ func (am *Auth0Manager) UpdateUserAppMetadata(userID string, appMetadata AppMeta
} }
func buildCreateUserRequestPayload(email string, name string, accountID string) (string, error) { func buildCreateUserRequestPayload(email string, name string, accountID string) (string, error) {
invite := true
req := &createUserRequest{ req := &createUserRequest{
Email: email, Email: email,
Name: name, Name: name,
AppMeta: AppMetadata{ AppMeta: AppMetadata{
WTAccountID: accountID, WTAccountID: accountID,
WTPendingInvite: true, WTPendingInvite: &invite,
}, },
Connection: "Username-Password-Authentication", Connection: "Username-Password-Authentication",
Password: GeneratePassword(8, 1, 1, 1), Password: GeneratePassword(8, 1, 1, 1),
@@ -502,6 +537,9 @@ func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) {
jobResp, err := am.httpClient.Do(exportJobReq) jobResp, err := am.httpClient.Do(exportJobReq)
if err != nil { if err != nil {
log.Debugf("Couldn't get job response %v", err) log.Debugf("Couldn't get job response %v", err)
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
return nil, err return nil, err
} }
@@ -512,6 +550,9 @@ func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) {
} }
}() }()
if jobResp.StatusCode != 200 { if jobResp.StatusCode != 200 {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to update the appMetadata, statusCode %d", jobResp.StatusCode) return nil, fmt.Errorf("unable to update the appMetadata, statusCode %d", jobResp.StatusCode)
} }
@@ -530,6 +571,9 @@ func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) {
} }
if exportJobResp.ID == "" { if exportJobResp.ID == "" {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("couldn't get an batch id status %d, %s, response body: %v", jobResp.StatusCode, jobResp.Status, exportJobResp) return nil, fmt.Errorf("couldn't get an batch id status %d, %s, response body: %v", jobResp.StatusCode, jobResp.Status, exportJobResp)
} }
@@ -556,12 +600,16 @@ func (am *Auth0Manager) GetUserByEmail(email string) ([]*UserData, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
reqURL := am.authIssuer + "/api/v2/users-by-email?email=" + email reqURL := am.authIssuer + "/api/v2/users-by-email?email=" + url.QueryEscape(email)
body, err := doGetReq(am.httpClient, reqURL, jwtToken.AccessToken) body, err := doGetReq(am.httpClient, reqURL, jwtToken.AccessToken)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountGetUserByEmail()
}
userResp := []*UserData{} userResp := []*UserData{}
err = am.helper.Unmarshal(body, &userResp) err = am.helper.Unmarshal(body, &userResp)
@@ -585,9 +633,16 @@ func (am *Auth0Manager) CreateUser(email string, name string, accountID string)
return nil, err return nil, err
} }
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountCreateUser()
}
resp, err := am.httpClient.Do(req) resp, err := am.httpClient.Do(req)
if err != nil { if err != nil {
log.Debugf("Couldn't get job response %v", err) log.Debugf("Couldn't get job response %v", err)
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
return nil, err return nil, err
} }
@@ -598,6 +653,9 @@ func (am *Auth0Manager) CreateUser(email string, name string, accountID string)
} }
}() }()
if !(resp.StatusCode == 200 || resp.StatusCode == 201) { if !(resp.StatusCode == 200 || resp.StatusCode == 201) {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to create user, statusCode %d", resp.StatusCode) return nil, fmt.Errorf("unable to create user, statusCode %d", resp.StatusCode)
} }
@@ -698,7 +756,7 @@ func (am *Auth0Manager) downloadProfileExport(location string) (map[string][]*Us
Email: profile.Email, Email: profile.Email,
AppMetadata: AppMetadata{ AppMetadata: AppMetadata{
WTAccountID: profile.AccountID, WTAccountID: profile.AccountID,
WTPendingInvite: profile.PendingInvite, WTPendingInvite: &profile.PendingInvite,
}, },
}) })
} }
@@ -729,13 +787,12 @@ func doGetReq(client ManagerHTTPClient, url, accessToken string) ([]byte, error)
log.Errorf("error while closing body for url %s: %v", url, err) log.Errorf("error while closing body for url %s: %v", url, err)
} }
}() }()
if res.StatusCode != 200 {
return nil, fmt.Errorf("unable to get %s, statusCode %d", url, res.StatusCode)
}
body, err := io.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if res.StatusCode != 200 {
return nil, fmt.Errorf("unable to get %s, statusCode %d", url, res.StatusCode)
}
return body, nil return body, nil
} }

View File

@@ -3,6 +3,7 @@ package idp
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"io" "io"
"net/http" "net/http"
@@ -340,7 +341,7 @@ func TestAuth0_UpdateUserAppMetadata(t *testing.T) {
updateUserAppMetadataTestCase2 := updateUserAppMetadataTest{ updateUserAppMetadataTestCase2 := updateUserAppMetadataTest{
name: "Bad Status Code", name: "Bad Status Code",
inputReqBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp), inputReqBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
expectedReqBody: fmt.Sprintf("{\"app_metadata\":{\"wt_account_id\":\"%s\",\"wt_pending_invite\":false}}", appMetadata.WTAccountID), expectedReqBody: fmt.Sprintf("{\"app_metadata\":{\"wt_account_id\":\"%s\",\"wt_pending_invite\":null}}", appMetadata.WTAccountID),
appMetadata: appMetadata, appMetadata: appMetadata,
statusCode: 400, statusCode: 400,
helper: JsonParser{}, helper: JsonParser{},
@@ -363,7 +364,7 @@ func TestAuth0_UpdateUserAppMetadata(t *testing.T) {
updateUserAppMetadataTestCase4 := updateUserAppMetadataTest{ updateUserAppMetadataTestCase4 := updateUserAppMetadataTest{
name: "Good request", name: "Good request",
inputReqBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp), inputReqBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
expectedReqBody: fmt.Sprintf("{\"app_metadata\":{\"wt_account_id\":\"%s\",\"wt_pending_invite\":false}}", appMetadata.WTAccountID), expectedReqBody: fmt.Sprintf("{\"app_metadata\":{\"wt_account_id\":\"%s\",\"wt_pending_invite\":null}}", appMetadata.WTAccountID),
appMetadata: appMetadata, appMetadata: appMetadata,
statusCode: 200, statusCode: 200,
helper: JsonParser{}, helper: JsonParser{},
@@ -371,7 +372,23 @@ func TestAuth0_UpdateUserAppMetadata(t *testing.T) {
assertErrFuncMessage: "shouldn't return error", assertErrFuncMessage: "shouldn't return error",
} }
for _, testCase := range []updateUserAppMetadataTest{updateUserAppMetadataTestCase1, updateUserAppMetadataTestCase2, updateUserAppMetadataTestCase3, updateUserAppMetadataTestCase4} { invite := true
updateUserAppMetadataTestCase5 := updateUserAppMetadataTest{
name: "Update Pending Invite",
inputReqBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
expectedReqBody: fmt.Sprintf("{\"app_metadata\":{\"wt_account_id\":\"%s\",\"wt_pending_invite\":true}}", appMetadata.WTAccountID),
appMetadata: AppMetadata{
WTAccountID: "ok",
WTPendingInvite: &invite,
},
statusCode: 200,
helper: JsonParser{},
assertErrFunc: assert.NoError,
assertErrFuncMessage: "shouldn't return error",
}
for _, testCase := range []updateUserAppMetadataTest{updateUserAppMetadataTestCase1, updateUserAppMetadataTestCase2,
updateUserAppMetadataTestCase3, updateUserAppMetadataTestCase4, updateUserAppMetadataTestCase5} {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
jwtReqClient := mockHTTPClient{ jwtReqClient := mockHTTPClient{
resBody: testCase.inputReqBody, resBody: testCase.inputReqBody,
@@ -459,7 +476,7 @@ func TestNewAuth0Manager(t *testing.T) {
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4} { for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4} {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
_, err := NewAuth0Manager(testCase.inputConfig) _, err := NewAuth0Manager(testCase.inputConfig, &telemetry.MockAppMetrics{})
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage) testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
}) })
} }

View File

@@ -2,6 +2,7 @@ package idp
import ( import (
"fmt" "fmt"
"github.com/netbirdio/netbird/management/server/telemetry"
"net/http" "net/http"
"strings" "strings"
"time" "time"
@@ -51,7 +52,7 @@ type AppMetadata struct {
// WTAccountID is a NetBird (previously Wiretrustee) account id to update in the IDP // WTAccountID is a NetBird (previously Wiretrustee) account id to update in the IDP
// maps to wt_account_id when json.marshal // maps to wt_account_id when json.marshal
WTAccountID string `json:"wt_account_id,omitempty"` WTAccountID string `json:"wt_account_id,omitempty"`
WTPendingInvite bool `json:"wt_pending_invite"` WTPendingInvite *bool `json:"wt_pending_invite"`
} }
// JWTToken a JWT object that holds information of a token // JWTToken a JWT object that holds information of a token
@@ -64,12 +65,12 @@ type JWTToken struct {
} }
// NewManager returns a new idp manager based on the configuration that it receives // NewManager returns a new idp manager based on the configuration that it receives
func NewManager(config Config) (Manager, error) { func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error) {
switch strings.ToLower(config.ManagerType) { switch strings.ToLower(config.ManagerType) {
case "none", "": case "none", "":
return nil, nil return nil, nil
case "auth0": case "auth0":
return NewAuth0Manager(config.Auth0ClientCredentials) return NewAuth0Manager(config.Auth0ClientCredentials, appMetrics)
default: default:
return nil, fmt.Errorf("invalid manager type: %s", config.ManagerType) return nil, fmt.Errorf("invalid manager type: %s", config.ManagerType)
} }

View File

@@ -398,17 +398,17 @@ func startManagement(t *testing.T, port int, config *Config) (*grpc.Server, erro
return nil, err return nil, err
} }
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
store, err := NewStore(config.Datadir) store, err := NewFileStore(config.Datadir)
if err != nil { if err != nil {
return nil, err return nil, err
} }
peersUpdateManager := NewPeersUpdateManager() peersUpdateManager := NewPeersUpdateManager()
accountManager, err := BuildManager(store, peersUpdateManager, nil) accountManager, err := BuildManager(store, peersUpdateManager, nil, "", "")
if err != nil { if err != nil {
return nil, err return nil, err
} }
turnManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig) turnManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := NewServer(config, accountManager, peersUpdateManager, turnManager) mgmtServer, err := NewServer(config, accountManager, peersUpdateManager, turnManager, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -488,17 +488,17 @@ func startServer(config *server.Config) (*grpc.Server, net.Listener) {
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
s := grpc.NewServer() s := grpc.NewServer()
store, err := server.NewStore(config.Datadir) store, err := server.NewFileStore(config.Datadir)
if err != nil { if err != nil {
log.Fatalf("failed creating a store: %s: %v", config.Datadir, err) log.Fatalf("failed creating a store: %s: %v", config.Datadir, err)
} }
peersUpdateManager := server.NewPeersUpdateManager() peersUpdateManager := server.NewPeersUpdateManager()
accountManager, err := server.BuildManager(store, peersUpdateManager, nil) accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "")
if err != nil { if err != nil {
log.Fatalf("failed creating a manager: %v", err) log.Fatalf("failed creating a manager: %v", err)
} }
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig) turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager) mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
mgmtProto.RegisterManagementServiceServer(s, mgmtServer) mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
go func() { go func() {

View File

@@ -17,7 +17,7 @@ import (
const ( const (
// PayloadEvent identifies an event type // PayloadEvent identifies an event type
PayloadEvent = "self-hosted stats" PayloadEvent = "self-hosted stats"
// payloadEndpoint metrics endpoint to send anonymous data // payloadEndpoint metrics defaultEndpoint to send anonymous data
payloadEndpoint = "https://metrics.netbird.io" payloadEndpoint = "https://metrics.netbird.io"
// defaultPushInterval default interval to push metrics // defaultPushInterval default interval to push metrics
defaultPushInterval = 24 * time.Hour defaultPushInterval = 24 * time.Hour
@@ -188,13 +188,17 @@ func (w *Worker) generateProperties() properties {
userPeers++ userPeers++
} }
_, connected := connections[peer.Key]
if connected || peer.Status.LastSeen.After(w.lastRun) {
activePeersLastDay++
}
osKey := strings.ToLower(fmt.Sprintf("peer_os_%s", peer.Meta.GoOS)) osKey := strings.ToLower(fmt.Sprintf("peer_os_%s", peer.Meta.GoOS))
osCount := osPeers[osKey] osCount := osPeers[osKey]
osPeers[osKey] = osCount + 1 osPeers[osKey] = osCount + 1
_, connected := connections[peer.Key]
if connected || peer.Status.LastSeen.After(w.lastRun) {
activePeersLastDay++
osActiveKey := osKey + "_active"
osActiveCount := osPeers[osActiveKey]
osPeers[osActiveKey] = osActiveCount + 1
}
} }
} }
@@ -279,5 +283,4 @@ func createPostRequest(ctx context.Context, endpoint string, payloadStr string)
req.Header.Add("content-type", "application/json") req.Header.Add("content-type", "application/json")
return req, nil return req, nil
} }

View File

@@ -1,13 +0,0 @@
## Migration from Store v2 to Store v2
Previously Account.Id was an Auth0 user id.
Conversion moves user id to Account.CreatedBy and generates a new Account.Id using xid.
It also adds a User with id = old Account.Id with a role Admin.
To start a conversion simply run the command below providing your current Wiretrustee Management datadir (where store.json file is located)
and a new data directory location (where a converted store.js will be stored):
```shell
./migration --oldDir /var/wiretrustee/datadir --newDir /var/wiretrustee/newdatadir/
```
Afterwards you can run the Management service providing ```/var/wiretrustee/newdatadir/ ``` as a datadir.

View File

@@ -1,56 +0,0 @@
package main
import (
"flag"
"fmt"
"github.com/netbirdio/netbird/management/server"
"github.com/rs/xid"
)
func main() {
oldDir := flag.String("oldDir", "old store directory", "/var/wiretrustee/datadir")
newDir := flag.String("newDir", "new store directory", "/var/wiretrustee/newdatadir")
flag.Parse()
oldStore, err := server.NewStore(*oldDir)
if err != nil {
panic(err)
}
newStore, err := server.NewStore(*newDir)
if err != nil {
panic(err)
}
err = Convert(oldStore, newStore)
if err != nil {
panic(err)
}
fmt.Println("successfully converted")
}
// Convert converts old store ato a new store
// Previously Account.Id was an Auth0 user id
// Conversion moved user id to Account.CreatedBy and generated a new Account.Id using xid
// It also adds a User with id = old Account.Id with a role Admin
func Convert(oldStore *server.FileStore, newStore *server.FileStore) error {
for _, account := range oldStore.Accounts {
accountCopy := account.Copy()
accountCopy.Id = xid.New().String()
accountCopy.CreatedBy = account.Id
accountCopy.Users[account.Id] = &server.User{
Id: account.Id,
Role: server.UserRoleAdmin,
}
err := newStore.SaveAccount(accountCopy)
if err != nil {
return err
}
}
return nil
}

View File

@@ -1,76 +0,0 @@
package main
import (
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/util"
"path/filepath"
"testing"
)
func TestConvertAccounts(t *testing.T) {
storeDir := t.TempDir()
err := util.CopyFileContents("../testdata/storev1.json", filepath.Join(storeDir, "store.json"))
if err != nil {
t.Fatal(err)
}
store, err := server.NewStore(storeDir)
if err != nil {
t.Fatal(err)
}
convertedStore, err := server.NewStore(filepath.Join(storeDir, "converted"))
if err != nil {
t.Fatal(err)
}
err = Convert(store, convertedStore)
if err != nil {
t.Fatal(err)
}
if len(store.Accounts) != len(convertedStore.Accounts) {
t.Errorf("expecting the same number of accounts after conversion")
}
for _, account := range store.Accounts {
convertedAccount, err := convertedStore.GetUserAccount(account.Id)
if err != nil || convertedAccount == nil {
t.Errorf("expecting Account %s to be converted", account.Id)
return
}
if convertedAccount.CreatedBy != account.Id {
t.Errorf("expecting converted Account.CreatedBy field to be equal to the old Account.Id")
return
}
if convertedAccount.Id == account.Id {
t.Errorf("expecting converted Account.Id to be different from Account.Id")
return
}
if len(convertedAccount.Users) != 1 {
t.Errorf("expecting converted Account.Users to be of size 1")
return
}
user := convertedAccount.Users[account.Id]
if user == nil {
t.Errorf("expecting to find a user in converted Account.Users")
return
}
if user.Role != server.UserRoleAdmin {
t.Errorf("expecting to find a user in converted Account.Users with a role Admin")
return
}
for peerId := range account.Peers {
convertedPeer := convertedAccount.Peers[peerId]
if convertedPeer == nil {
t.Errorf("expecting Account Peer of StoreV1 to be found in StoreV2")
return
}
}
}
}

View File

@@ -14,14 +14,13 @@ type MockAccountManager struct {
GetOrCreateAccountByUserFunc func(userId, domain string) (*server.Account, error) GetOrCreateAccountByUserFunc func(userId, domain string) (*server.Account, error)
GetAccountByUserFunc func(userId string) (*server.Account, error) GetAccountByUserFunc func(userId string) (*server.Account, error)
CreateSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType, expiresIn time.Duration, autoGroups []string) (*server.SetupKey, error) CreateSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType, expiresIn time.Duration, autoGroups []string) (*server.SetupKey, error)
GetSetupKeyFunc func(accountID string, keyID string) (*server.SetupKey, error) GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error)
GetAccountByIdFunc func(accountId string) (*server.Account, error)
GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error) GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error)
IsUserAdminFunc func(claims jwtclaims.AuthorizationClaims) (bool, error) IsUserAdminFunc func(claims jwtclaims.AuthorizationClaims) (bool, error)
AccountExistsFunc func(accountId string) (*bool, error) AccountExistsFunc func(accountId string) (*bool, error)
GetPeerFunc func(peerKey string) (*server.Peer, error) GetPeerFunc func(peerKey string) (*server.Peer, error)
GetPeersFunc func(accountID, userID string) ([]*server.Peer, error)
MarkPeerConnectedFunc func(peerKey string, connected bool) error MarkPeerConnectedFunc func(peerKey string, connected bool) error
RenamePeerFunc func(accountId string, peerKey string, newName string) (*server.Peer, error)
DeletePeerFunc func(accountId string, peerKey string) (*server.Peer, error) DeletePeerFunc func(accountId string, peerKey string) (*server.Peer, error)
GetPeerByIPFunc func(accountId string, peerIP string) (*server.Peer, error) GetPeerByIPFunc func(accountId string, peerIP string) (*server.Peer, error)
GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error) GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error)
@@ -35,26 +34,26 @@ type MockAccountManager struct {
GroupAddPeerFunc func(accountID, groupID, peerKey string) error GroupAddPeerFunc func(accountID, groupID, peerKey string) error
GroupDeletePeerFunc func(accountID, groupID, peerKey string) error GroupDeletePeerFunc func(accountID, groupID, peerKey string) error
GroupListPeersFunc func(accountID, groupID string) ([]*server.Peer, error) GroupListPeersFunc func(accountID, groupID string) ([]*server.Peer, error)
GetRuleFunc func(accountID, ruleID string) (*server.Rule, error) GetRuleFunc func(accountID, ruleID, userID string) (*server.Rule, error)
SaveRuleFunc func(accountID string, rule *server.Rule) error SaveRuleFunc func(accountID string, rule *server.Rule) error
UpdateRuleFunc func(accountID string, ruleID string, operations []server.RuleUpdateOperation) (*server.Rule, error) UpdateRuleFunc func(accountID string, ruleID string, operations []server.RuleUpdateOperation) (*server.Rule, error)
DeleteRuleFunc func(accountID, ruleID string) error DeleteRuleFunc func(accountID, ruleID string) error
ListRulesFunc func(accountID string) ([]*server.Rule, error) ListRulesFunc func(accountID, userID string) ([]*server.Rule, error)
GetUsersFromAccountFunc func(accountID string) ([]*server.UserInfo, error) GetUsersFromAccountFunc func(accountID, userID string) ([]*server.UserInfo, error)
UpdatePeerMetaFunc func(peerKey string, meta server.PeerSystemMeta) error UpdatePeerMetaFunc func(peerKey string, meta server.PeerSystemMeta) error
UpdatePeerSSHKeyFunc func(peerKey string, sshKey string) error UpdatePeerSSHKeyFunc func(peerKey string, sshKey string) error
UpdatePeerFunc func(accountID string, peer *server.Peer) (*server.Peer, error) UpdatePeerFunc func(accountID string, peer *server.Peer) (*server.Peer, error)
CreateRouteFunc func(accountID string, prefix, peer, description, netID string, masquerade bool, metric int, enabled bool) (*route.Route, error) CreateRouteFunc func(accountID string, prefix, peer, description, netID string, masquerade bool, metric int, enabled bool) (*route.Route, error)
GetRouteFunc func(accountID, routeID string) (*route.Route, error) GetRouteFunc func(accountID, routeID, userID string) (*route.Route, error)
SaveRouteFunc func(accountID string, route *route.Route) error SaveRouteFunc func(accountID string, route *route.Route) error
UpdateRouteFunc func(accountID string, routeID string, operations []server.RouteUpdateOperation) (*route.Route, error) UpdateRouteFunc func(accountID string, routeID string, operations []server.RouteUpdateOperation) (*route.Route, error)
DeleteRouteFunc func(accountID, routeID string) error DeleteRouteFunc func(accountID, routeID string) error
ListRoutesFunc func(accountID string) ([]*route.Route, error) ListRoutesFunc func(accountID, userID string) ([]*route.Route, error)
SaveSetupKeyFunc func(accountID string, key *server.SetupKey) (*server.SetupKey, error) SaveSetupKeyFunc func(accountID string, key *server.SetupKey) (*server.SetupKey, error)
ListSetupKeysFunc func(accountID string) ([]*server.SetupKey, error) ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error)
SaveUserFunc func(accountID string, user *server.User) (*server.UserInfo, error) SaveUserFunc func(accountID string, user *server.User) (*server.UserInfo, error)
GetNameServerGroupFunc func(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) GetNameServerGroupFunc func(accountID, nsGroupID string) (*nbdns.NameServerGroup, error)
CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, enabled bool) (*nbdns.NameServerGroup, error) CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool) (*nbdns.NameServerGroup, error)
SaveNameServerGroupFunc func(accountID string, nsGroupToSave *nbdns.NameServerGroup) error SaveNameServerGroupFunc func(accountID string, nsGroupToSave *nbdns.NameServerGroup) error
UpdateNameServerGroupFunc func(accountID, nsGroupID string, operations []server.NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) UpdateNameServerGroupFunc func(accountID, nsGroupID string, operations []server.NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error)
DeleteNameServerGroupFunc func(accountID, nsGroupID string) error DeleteNameServerGroupFunc func(accountID, nsGroupID string) error
@@ -64,13 +63,21 @@ type MockAccountManager struct {
} }
// GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface // GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface
func (am *MockAccountManager) GetUsersFromAccount(accountID string) ([]*server.UserInfo, error) { func (am *MockAccountManager) GetUsersFromAccount(accountID string, userID string) ([]*server.UserInfo, error) {
if am.GetUsersFromAccountFunc != nil { if am.GetUsersFromAccountFunc != nil {
return am.GetUsersFromAccountFunc(accountID) return am.GetUsersFromAccountFunc(accountID, userID)
} }
return nil, status.Errorf(codes.Unimplemented, "method GetUsersFromAccount is not implemented") return nil, status.Errorf(codes.Unimplemented, "method GetUsersFromAccount is not implemented")
} }
// DeletePeer mock implementation of DeletePeer from server.AccountManager interface
func (am *MockAccountManager) DeletePeer(accountId string, peerKey string) (*server.Peer, error) {
if am.DeletePeerFunc != nil {
return am.DeletePeerFunc(accountId, peerKey)
}
return nil, status.Errorf(codes.Unimplemented, "method DeletePeer is not implemented")
}
// GetOrCreateAccountByUser mock implementation of GetOrCreateAccountByUser from server.AccountManager interface // GetOrCreateAccountByUser mock implementation of GetOrCreateAccountByUser from server.AccountManager interface
func (am *MockAccountManager) GetOrCreateAccountByUser( func (am *MockAccountManager) GetOrCreateAccountByUser(
userId, domain string, userId, domain string,
@@ -106,16 +113,8 @@ func (am *MockAccountManager) CreateSetupKey(
return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented") return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented")
} }
// GetAccountById mock implementation of GetAccountById from server.AccountManager interface
func (am *MockAccountManager) GetAccountById(accountId string) (*server.Account, error) {
if am.GetAccountByIdFunc != nil {
return am.GetAccountByIdFunc(accountId)
}
return nil, status.Errorf(codes.Unimplemented, "method GetAccountById is not implemented")
}
// GetAccountByUserOrAccountId mock implementation of GetAccountByUserOrAccountId from server.AccountManager interface // GetAccountByUserOrAccountId mock implementation of GetAccountByUserOrAccountId from server.AccountManager interface
func (am *MockAccountManager) GetAccountByUserOrAccountId( func (am *MockAccountManager) GetAccountByUserOrAccountID(
userId, accountId, domain string, userId, accountId, domain string,
) (*server.Account, error) { ) (*server.Account, error) {
if am.GetAccountByUserOrAccountIdFunc != nil { if am.GetAccountByUserOrAccountIdFunc != nil {
@@ -123,7 +122,7 @@ func (am *MockAccountManager) GetAccountByUserOrAccountId(
} }
return nil, status.Errorf( return nil, status.Errorf(
codes.Unimplemented, codes.Unimplemented,
"method GetAccountByUserOrAccountId is not implemented", "method GetAccountByUserOrAccountID is not implemented",
) )
} }
@@ -151,26 +150,6 @@ func (am *MockAccountManager) MarkPeerConnected(peerKey string, connected bool)
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented") return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
} }
// RenamePeer mock implementation of RenamePeer from server.AccountManager interface
func (am *MockAccountManager) RenamePeer(
accountId string,
peerKey string,
newName string,
) (*server.Peer, error) {
if am.RenamePeerFunc != nil {
return am.RenamePeerFunc(accountId, peerKey, newName)
}
return nil, status.Errorf(codes.Unimplemented, "method RenamePeer is not implemented")
}
// DeletePeer mock implementation of DeletePeer from server.AccountManager interface
func (am *MockAccountManager) DeletePeer(accountId string, peerKey string) (*server.Peer, error) {
if am.DeletePeerFunc != nil {
return am.DeletePeerFunc(accountId, peerKey)
}
return nil, status.Errorf(codes.Unimplemented, "method DeletePeer is not implemented")
}
// GetPeerByIP mock implementation of GetPeerByIP from server.AccountManager interface // GetPeerByIP mock implementation of GetPeerByIP from server.AccountManager interface
func (am *MockAccountManager) GetPeerByIP(accountId string, peerIP string) (*server.Peer, error) { func (am *MockAccountManager) GetPeerByIP(accountId string, peerIP string) (*server.Peer, error) {
if am.GetPeerByIPFunc != nil { if am.GetPeerByIPFunc != nil {
@@ -272,9 +251,9 @@ func (am *MockAccountManager) GroupListPeers(accountID, groupID string) ([]*serv
} }
// GetRule mock implementation of GetRule from server.AccountManager interface // GetRule mock implementation of GetRule from server.AccountManager interface
func (am *MockAccountManager) GetRule(accountID, ruleID string) (*server.Rule, error) { func (am *MockAccountManager) GetRule(accountID, ruleID, userID string) (*server.Rule, error) {
if am.GetRuleFunc != nil { if am.GetRuleFunc != nil {
return am.GetRuleFunc(accountID, ruleID) return am.GetRuleFunc(accountID, ruleID, userID)
} }
return nil, status.Errorf(codes.Unimplemented, "method GetRule is not implemented") return nil, status.Errorf(codes.Unimplemented, "method GetRule is not implemented")
} }
@@ -304,9 +283,9 @@ func (am *MockAccountManager) DeleteRule(accountID, ruleID string) error {
} }
// ListRules mock implementation of ListRules from server.AccountManager interface // ListRules mock implementation of ListRules from server.AccountManager interface
func (am *MockAccountManager) ListRules(accountID string) ([]*server.Rule, error) { func (am *MockAccountManager) ListRules(accountID, userID string) ([]*server.Rule, error) {
if am.ListRulesFunc != nil { if am.ListRulesFunc != nil {
return am.ListRulesFunc(accountID) return am.ListRulesFunc(accountID, userID)
} }
return nil, status.Errorf(codes.Unimplemented, "method ListRules is not implemented") return nil, status.Errorf(codes.Unimplemented, "method ListRules is not implemented")
} }
@@ -345,16 +324,16 @@ func (am *MockAccountManager) UpdatePeer(accountID string, peer *server.Peer) (*
// CreateRoute mock implementation of CreateRoute from server.AccountManager interface // CreateRoute mock implementation of CreateRoute from server.AccountManager interface
func (am *MockAccountManager) CreateRoute(accountID string, network, peer, description, netID string, masquerade bool, metric int, enabled bool) (*route.Route, error) { func (am *MockAccountManager) CreateRoute(accountID string, network, peer, description, netID string, masquerade bool, metric int, enabled bool) (*route.Route, error) {
if am.GetRouteFunc != nil { if am.CreateRouteFunc != nil {
return am.CreateRouteFunc(accountID, network, peer, description, netID, masquerade, metric, enabled) return am.CreateRouteFunc(accountID, network, peer, description, netID, masquerade, metric, enabled)
} }
return nil, status.Errorf(codes.Unimplemented, "method CreateRoute is not implemented") return nil, status.Errorf(codes.Unimplemented, "method CreateRoute is not implemented")
} }
// GetRoute mock implementation of GetRoute from server.AccountManager interface // GetRoute mock implementation of GetRoute from server.AccountManager interface
func (am *MockAccountManager) GetRoute(accountID, routeID string) (*route.Route, error) { func (am *MockAccountManager) GetRoute(accountID, routeID, userID string) (*route.Route, error) {
if am.GetRouteFunc != nil { if am.GetRouteFunc != nil {
return am.GetRouteFunc(accountID, routeID) return am.GetRouteFunc(accountID, routeID, userID)
} }
return nil, status.Errorf(codes.Unimplemented, "method GetRoute is not implemented") return nil, status.Errorf(codes.Unimplemented, "method GetRoute is not implemented")
} }
@@ -384,9 +363,9 @@ func (am *MockAccountManager) DeleteRoute(accountID, routeID string) error {
} }
// ListRoutes mock implementation of ListRoutes from server.AccountManager interface // ListRoutes mock implementation of ListRoutes from server.AccountManager interface
func (am *MockAccountManager) ListRoutes(accountID string) ([]*route.Route, error) { func (am *MockAccountManager) ListRoutes(accountID, userID string) ([]*route.Route, error) {
if am.ListRoutesFunc != nil { if am.ListRoutesFunc != nil {
return am.ListRoutesFunc(accountID) return am.ListRoutesFunc(accountID, userID)
} }
return nil, status.Errorf(codes.Unimplemented, "method ListRoutes is not implemented") return nil, status.Errorf(codes.Unimplemented, "method ListRoutes is not implemented")
} }
@@ -401,18 +380,18 @@ func (am *MockAccountManager) SaveSetupKey(accountID string, key *server.SetupKe
} }
// GetSetupKey mocks GetSetupKey of the AccountManager interface // GetSetupKey mocks GetSetupKey of the AccountManager interface
func (am *MockAccountManager) GetSetupKey(accountID, keyID string) (*server.SetupKey, error) { func (am *MockAccountManager) GetSetupKey(accountID, userID, keyID string) (*server.SetupKey, error) {
if am.GetSetupKeyFunc != nil { if am.GetSetupKeyFunc != nil {
return am.GetSetupKeyFunc(accountID, keyID) return am.GetSetupKeyFunc(accountID, userID, keyID)
} }
return nil, status.Errorf(codes.Unimplemented, "method GetSetupKey is not implemented") return nil, status.Errorf(codes.Unimplemented, "method GetSetupKey is not implemented")
} }
// ListSetupKeys mocks ListSetupKeys of the AccountManager interface // ListSetupKeys mocks ListSetupKeys of the AccountManager interface
func (am *MockAccountManager) ListSetupKeys(accountID string) ([]*server.SetupKey, error) { func (am *MockAccountManager) ListSetupKeys(accountID, userID string) ([]*server.SetupKey, error) {
if am.ListSetupKeysFunc != nil { if am.ListSetupKeysFunc != nil {
return am.ListSetupKeysFunc(accountID) return am.ListSetupKeysFunc(accountID, userID)
} }
return nil, status.Errorf(codes.Unimplemented, "method ListSetupKeys is not implemented") return nil, status.Errorf(codes.Unimplemented, "method ListSetupKeys is not implemented")
@@ -435,9 +414,9 @@ func (am *MockAccountManager) GetNameServerGroup(accountID, nsGroupID string) (*
} }
// CreateNameServerGroup mocks CreateNameServerGroup of the AccountManager interface // CreateNameServerGroup mocks CreateNameServerGroup of the AccountManager interface
func (am *MockAccountManager) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, enabled bool) (*nbdns.NameServerGroup, error) { func (am *MockAccountManager) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool) (*nbdns.NameServerGroup, error) {
if am.CreateNameServerGroupFunc != nil { if am.CreateNameServerGroupFunc != nil {
return am.CreateNameServerGroupFunc(accountID, name, description, nameServerList, groups, enabled) return am.CreateNameServerGroupFunc(accountID, name, description, nameServerList, groups, primary, domains, enabled)
} }
return nil, nil return nil, nil
} }
@@ -489,3 +468,11 @@ func (am *MockAccountManager) GetAccountFromToken(claims jwtclaims.Authorization
} }
return nil, status.Errorf(codes.Unimplemented, "method GetAccountFromToken is not implemented") return nil, status.Errorf(codes.Unimplemented, "method GetAccountFromToken is not implemented")
} }
// GetPeers mocks GetPeers of the AccountManager interface
func (am *MockAccountManager) GetPeers(accountID, userID string) ([]*server.Peer, error) {
if am.GetAccountFromTokenFunc != nil {
return am.GetPeersFunc(accountID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetPeersFunc is not implemented")
}

View File

@@ -1,8 +1,10 @@
package server package server
import ( import (
"github.com/miekg/dns"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/rs/xid" "github.com/rs/xid"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"strconv" "strconv"
@@ -20,6 +22,10 @@ const (
UpdateNameServerGroupGroups UpdateNameServerGroupGroups
// UpdateNameServerGroupEnabled indicates a nameserver group status update operation // UpdateNameServerGroupEnabled indicates a nameserver group status update operation
UpdateNameServerGroupEnabled UpdateNameServerGroupEnabled
// UpdateNameServerGroupPrimary indicates a nameserver group primary status update operation
UpdateNameServerGroupPrimary
// UpdateNameServerGroupDomains indicates a nameserver group' domains update operation
UpdateNameServerGroupDomains
) )
// NameServerGroupUpdateOperationType operation type // NameServerGroupUpdateOperationType operation type
@@ -37,6 +43,10 @@ func (t NameServerGroupUpdateOperationType) String() string {
return "UpdateNameServerGroupGroups" return "UpdateNameServerGroupGroups"
case UpdateNameServerGroupEnabled: case UpdateNameServerGroupEnabled:
return "UpdateNameServerGroupEnabled" return "UpdateNameServerGroupEnabled"
case UpdateNameServerGroupPrimary:
return "UpdateNameServerGroupPrimary"
case UpdateNameServerGroupDomains:
return "UpdateNameServerGroupDomains"
default: default:
return "InvalidOperation" return "InvalidOperation"
} }
@@ -50,8 +60,9 @@ type NameServerGroupUpdateOperation struct {
// GetNameServerGroup gets a nameserver group object from account and nameserver group IDs // GetNameServerGroup gets a nameserver group object from account and nameserver group IDs
func (am *DefaultAccountManager) GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) { func (am *DefaultAccountManager) GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) {
am.mux.Lock()
defer am.mux.Unlock() unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
@@ -67,9 +78,10 @@ func (am *DefaultAccountManager) GetNameServerGroup(accountID, nsGroupID string)
} }
// CreateNameServerGroup creates and saves a new nameserver group // CreateNameServerGroup creates and saves a new nameserver group
func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, enabled bool) (*nbdns.NameServerGroup, error) { func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool) (*nbdns.NameServerGroup, error) {
am.mux.Lock()
defer am.mux.Unlock() unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
@@ -83,6 +95,8 @@ func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, d
NameServers: nameServerList, NameServers: nameServerList,
Groups: groups, Groups: groups,
Enabled: enabled, Enabled: enabled,
Primary: primary,
Domains: domains,
} }
err = validateNameServerGroup(false, newNSGroup, account) err = validateNameServerGroup(false, newNSGroup, account)
@@ -102,13 +116,20 @@ func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, d
return nil, err return nil, err
} }
err = am.updateAccountPeers(account)
if err != nil {
log.Error(err)
return newNSGroup.Copy(), status.Errorf(codes.Unavailable, "failed to update peers after create nameserver %s", name)
}
return newNSGroup.Copy(), nil return newNSGroup.Copy(), nil
} }
// SaveNameServerGroup saves nameserver group // SaveNameServerGroup saves nameserver group
func (am *DefaultAccountManager) SaveNameServerGroup(accountID string, nsGroupToSave *nbdns.NameServerGroup) error { func (am *DefaultAccountManager) SaveNameServerGroup(accountID string, nsGroupToSave *nbdns.NameServerGroup) error {
am.mux.Lock()
defer am.mux.Unlock() unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
if nsGroupToSave == nil { if nsGroupToSave == nil {
return status.Errorf(codes.InvalidArgument, "nameserver group provided is nil") return status.Errorf(codes.InvalidArgument, "nameserver group provided is nil")
@@ -132,13 +153,20 @@ func (am *DefaultAccountManager) SaveNameServerGroup(accountID string, nsGroupTo
return err return err
} }
err = am.updateAccountPeers(account)
if err != nil {
log.Error(err)
return status.Errorf(codes.Unavailable, "failed to update peers after update nameserver %s", nsGroupToSave.Name)
}
return nil return nil
} }
// UpdateNameServerGroup updates existing nameserver group with set of operations // UpdateNameServerGroup updates existing nameserver group with set of operations
func (am *DefaultAccountManager) UpdateNameServerGroup(accountID, nsGroupID string, operations []NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) { func (am *DefaultAccountManager) UpdateNameServerGroup(accountID, nsGroupID string, operations []NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) {
am.mux.Lock()
defer am.mux.Unlock() unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
@@ -205,6 +233,18 @@ func (am *DefaultAccountManager) UpdateNameServerGroup(accountID, nsGroupID stri
return nil, status.Errorf(codes.InvalidArgument, "failed to parse enabled %s, not boolean", operation.Values[0]) return nil, status.Errorf(codes.InvalidArgument, "failed to parse enabled %s, not boolean", operation.Values[0])
} }
newNSGroup.Enabled = enabled newNSGroup.Enabled = enabled
case UpdateNameServerGroupPrimary:
primary, err := strconv.ParseBool(operation.Values[0])
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "failed to parse primary status %s, not boolean", operation.Values[0])
}
newNSGroup.Primary = primary
case UpdateNameServerGroupDomains:
err = validateDomainInput(false, operation.Values)
if err != nil {
return nil, err
}
newNSGroup.Domains = operation.Values
} }
} }
@@ -216,13 +256,20 @@ func (am *DefaultAccountManager) UpdateNameServerGroup(accountID, nsGroupID stri
return nil, err return nil, err
} }
err = am.updateAccountPeers(account)
if err != nil {
log.Error(err)
return newNSGroup.Copy(), status.Errorf(codes.Unavailable, "failed to update peers after update nameserver %s", newNSGroup.Name)
}
return newNSGroup.Copy(), nil return newNSGroup.Copy(), nil
} }
// DeleteNameServerGroup deletes nameserver group with nsGroupID // DeleteNameServerGroup deletes nameserver group with nsGroupID
func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID string) error { func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID string) error {
am.mux.Lock()
defer am.mux.Unlock() unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
@@ -237,13 +284,20 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID stri
return err return err
} }
err = am.updateAccountPeers(account)
if err != nil {
log.Error(err)
return status.Errorf(codes.Unavailable, "failed to update peers after deleting nameserver %s", nsGroupID)
}
return nil return nil
} }
// ListNameServerGroups returns a list of nameserver groups from account // ListNameServerGroups returns a list of nameserver groups from account
func (am *DefaultAccountManager) ListNameServerGroups(accountID string) ([]*nbdns.NameServerGroup, error) { func (am *DefaultAccountManager) ListNameServerGroups(accountID string) ([]*nbdns.NameServerGroup, error) {
am.mux.Lock()
defer am.mux.Unlock() unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
@@ -268,7 +322,12 @@ func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServ
} }
} }
err := validateNSGroupName(nameserverGroup.Name, nsGroupID, account.NameServerGroups) err := validateDomainInput(nameserverGroup.Primary, nameserverGroup.Domains)
if err != nil {
return err
}
err = validateNSGroupName(nameserverGroup.Name, nsGroupID, account.NameServerGroups)
if err != nil { if err != nil {
return err return err
} }
@@ -286,6 +345,24 @@ func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServ
return nil return nil
} }
func validateDomainInput(primary bool, domains []string) error {
if !primary && len(domains) == 0 {
return status.Errorf(codes.InvalidArgument, "nameserver group primary status is false and domains are empty,"+
" it should be primary or have at least one domain")
}
if primary && len(domains) != 0 {
return status.Errorf(codes.InvalidArgument, "nameserver group primary status is true and domains are not empty,"+
" you should set either primary or domain")
}
for _, domain := range domains {
_, valid := dns.IsDomainName(domain)
if !valid {
return status.Errorf(codes.InvalidArgument, "nameserver group got an invalid domain: %s", domain)
}
}
return nil
}
func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.NameServerGroup) error { func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.NameServerGroup) error {
if utf8.RuneCountInString(name) > nbdns.MaxGroupNameChar || name == "" { if utf8.RuneCountInString(name) > nbdns.MaxGroupNameChar || name == "" {
return status.Errorf(codes.InvalidArgument, "nameserver group name should be between 1 and %d", nbdns.MaxGroupNameChar) return status.Errorf(codes.InvalidArgument, "nameserver group name should be between 1 and %d", nbdns.MaxGroupNameChar)

View File

@@ -14,6 +14,8 @@ const (
existingNSGroupID = "existingNSGroup" existingNSGroupID = "existingNSGroup"
nsGroupPeer1Key = "BhRPtynAAYRDy08+q4HTMsos8fs4plTP4NOSh7C1ry8=" nsGroupPeer1Key = "BhRPtynAAYRDy08+q4HTMsos8fs4plTP4NOSh7C1ry8="
nsGroupPeer2Key = "/yF0+vCfv+mRR5k0dca0TrGdO/oiNeAI58gToZm5NyI=" nsGroupPeer2Key = "/yF0+vCfv+mRR5k0dca0TrGdO/oiNeAI58gToZm5NyI="
validDomain = "example.com"
invalidDomain = "dnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdns.com"
) )
func TestCreateNameServerGroup(t *testing.T) { func TestCreateNameServerGroup(t *testing.T) {
@@ -23,6 +25,8 @@ func TestCreateNameServerGroup(t *testing.T) {
enabled bool enabled bool
groups []string groups []string
nameServers []nbdns.NameServer nameServers []nbdns.NameServer
primary bool
domains []string
} }
testCases := []struct { testCases := []struct {
@@ -33,11 +37,12 @@ func TestCreateNameServerGroup(t *testing.T) {
expectedNSGroup *nbdns.NameServerGroup expectedNSGroup *nbdns.NameServerGroup
}{ }{
{ {
name: "Create A NS Group", name: "Create A NS Group With Primary Status",
inputArgs: input{ inputArgs: input{
name: "super", name: "super",
description: "super", description: "super",
groups: []string{group1ID}, groups: []string{group1ID},
primary: true,
nameServers: []nbdns.NameServer{ nameServers: []nbdns.NameServer{
{ {
IP: netip.MustParseAddr("1.1.1.1"), IP: netip.MustParseAddr("1.1.1.1"),
@@ -57,6 +62,52 @@ func TestCreateNameServerGroup(t *testing.T) {
expectedNSGroup: &nbdns.NameServerGroup{ expectedNSGroup: &nbdns.NameServerGroup{
Name: "super", Name: "super",
Description: "super", Description: "super",
Primary: true,
Groups: []string{group1ID},
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("1.1.1.1"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
{
IP: netip.MustParseAddr("1.1.2.2"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
},
Enabled: true,
},
},
{
name: "Create A NS Group With Domains",
inputArgs: input{
name: "super",
description: "super",
groups: []string{group1ID},
primary: false,
domains: []string{validDomain},
nameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("1.1.1.1"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
{
IP: netip.MustParseAddr("1.1.2.2"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
},
enabled: true,
},
errFunc: require.NoError,
shouldCreate: true,
expectedNSGroup: &nbdns.NameServerGroup{
Name: "super",
Description: "super",
Primary: false,
Domains: []string{"example.com"},
Groups: []string{group1ID}, Groups: []string{group1ID},
NameServers: []nbdns.NameServer{ NameServers: []nbdns.NameServer{
{ {
@@ -78,6 +129,7 @@ func TestCreateNameServerGroup(t *testing.T) {
inputArgs: input{ inputArgs: input{
name: existingNSGroupName, name: existingNSGroupName,
description: "super", description: "super",
primary: true,
groups: []string{group1ID}, groups: []string{group1ID},
nameServers: []nbdns.NameServer{ nameServers: []nbdns.NameServer{
{ {
@@ -101,6 +153,7 @@ func TestCreateNameServerGroup(t *testing.T) {
inputArgs: input{ inputArgs: input{
name: "", name: "",
description: "super", description: "super",
primary: true,
groups: []string{group1ID}, groups: []string{group1ID},
nameServers: []nbdns.NameServer{ nameServers: []nbdns.NameServer{
{ {
@@ -124,6 +177,7 @@ func TestCreateNameServerGroup(t *testing.T) {
inputArgs: input{ inputArgs: input{
name: "1234567890123456789012345678901234567890extra", name: "1234567890123456789012345678901234567890extra",
description: "super", description: "super",
primary: true,
groups: []string{group1ID}, groups: []string{group1ID},
nameServers: []nbdns.NameServer{ nameServers: []nbdns.NameServer{
{ {
@@ -147,6 +201,7 @@ func TestCreateNameServerGroup(t *testing.T) {
inputArgs: input{ inputArgs: input{
name: "super", name: "super",
description: "super", description: "super",
primary: true,
groups: []string{group1ID}, groups: []string{group1ID},
nameServers: []nbdns.NameServer{}, nameServers: []nbdns.NameServer{},
enabled: true, enabled: true,
@@ -159,6 +214,7 @@ func TestCreateNameServerGroup(t *testing.T) {
inputArgs: input{ inputArgs: input{
name: "super", name: "super",
description: "super", description: "super",
primary: true,
groups: []string{group1ID}, groups: []string{group1ID},
nameServers: []nbdns.NameServer{ nameServers: []nbdns.NameServer{
{ {
@@ -187,6 +243,7 @@ func TestCreateNameServerGroup(t *testing.T) {
inputArgs: input{ inputArgs: input{
name: "super", name: "super",
description: "super", description: "super",
primary: true,
groups: []string{}, groups: []string{},
nameServers: []nbdns.NameServer{ nameServers: []nbdns.NameServer{
{ {
@@ -210,6 +267,7 @@ func TestCreateNameServerGroup(t *testing.T) {
inputArgs: input{ inputArgs: input{
name: "super", name: "super",
description: "super", description: "super",
primary: true,
groups: []string{"missingGroup"}, groups: []string{"missingGroup"},
nameServers: []nbdns.NameServer{ nameServers: []nbdns.NameServer{
{ {
@@ -233,6 +291,7 @@ func TestCreateNameServerGroup(t *testing.T) {
inputArgs: input{ inputArgs: input{
name: "super", name: "super",
description: "super", description: "super",
primary: true,
groups: []string{""}, groups: []string{""},
nameServers: []nbdns.NameServer{ nameServers: []nbdns.NameServer{
{ {
@@ -251,6 +310,53 @@ func TestCreateNameServerGroup(t *testing.T) {
errFunc: require.Error, errFunc: require.Error,
shouldCreate: false, shouldCreate: false,
}, },
{
name: "Should Not Create If No Domain Or Primary",
inputArgs: input{
name: "super",
description: "super",
groups: []string{group1ID},
nameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("1.1.1.1"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
{
IP: netip.MustParseAddr("1.1.2.2"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
},
enabled: true,
},
errFunc: require.Error,
shouldCreate: false,
},
{
name: "Should Not Create If Domain List Is Invalid",
inputArgs: input{
name: "super",
description: "super",
groups: []string{group1ID},
domains: []string{invalidDomain},
nameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("1.1.1.1"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
{
IP: netip.MustParseAddr("1.1.2.2"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
},
enabled: true,
},
errFunc: require.Error,
shouldCreate: false,
},
} }
for _, testCase := range testCases { for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
@@ -270,6 +376,8 @@ func TestCreateNameServerGroup(t *testing.T) {
testCase.inputArgs.description, testCase.inputArgs.description,
testCase.inputArgs.nameServers, testCase.inputArgs.nameServers,
testCase.inputArgs.groups, testCase.inputArgs.groups,
testCase.inputArgs.primary,
testCase.inputArgs.domains,
testCase.inputArgs.enabled, testCase.inputArgs.enabled,
) )
@@ -295,6 +403,7 @@ func TestSaveNameServerGroup(t *testing.T) {
ID: "testingNSGroup", ID: "testingNSGroup",
Name: "super", Name: "super",
Description: "super", Description: "super",
Primary: true,
NameServers: []nbdns.NameServer{ NameServers: []nbdns.NameServer{
{ {
IP: netip.MustParseAddr("1.1.1.1"), IP: netip.MustParseAddr("1.1.1.1"),
@@ -313,6 +422,10 @@ func TestSaveNameServerGroup(t *testing.T) {
validGroups := []string{group2ID} validGroups := []string{group2ID}
invalidGroups := []string{"nonExisting"} invalidGroups := []string{"nonExisting"}
disabledPrimary := false
validDomains := []string{validDomain}
invalidDomains := []string{invalidDomain}
validNameServerList := []nbdns.NameServer{ validNameServerList := []nbdns.NameServer{
{ {
IP: netip.MustParseAddr("1.1.1.1"), IP: netip.MustParseAddr("1.1.1.1"),
@@ -348,6 +461,8 @@ func TestSaveNameServerGroup(t *testing.T) {
existingNSGroup *nbdns.NameServerGroup existingNSGroup *nbdns.NameServerGroup
newID *string newID *string
newName *string newName *string
newPrimary *bool
newDomains []string
newNSList []nbdns.NameServer newNSList []nbdns.NameServer
newGroups []string newGroups []string
skipCopying bool skipCopying bool
@@ -356,16 +471,20 @@ func TestSaveNameServerGroup(t *testing.T) {
expectedNSGroup *nbdns.NameServerGroup expectedNSGroup *nbdns.NameServerGroup
}{ }{
{ {
name: "Should Update Name Server Group", name: "Should Config Name Server Group",
existingNSGroup: existingNSGroup, existingNSGroup: existingNSGroup,
newName: &validName, newName: &validName,
newGroups: validGroups, newGroups: validGroups,
newPrimary: &disabledPrimary,
newDomains: validDomains,
newNSList: validNameServerList, newNSList: validNameServerList,
errFunc: require.NoError, errFunc: require.NoError,
shouldCreate: true, shouldCreate: true,
expectedNSGroup: &nbdns.NameServerGroup{ expectedNSGroup: &nbdns.NameServerGroup{
ID: "testingNSGroup", ID: "testingNSGroup",
Name: validName, Name: validName,
Primary: false,
Domains: validDomains,
Description: "super", Description: "super",
NameServers: validNameServerList, NameServers: validNameServerList,
Groups: validGroups, Groups: validGroups,
@@ -373,68 +492,91 @@ func TestSaveNameServerGroup(t *testing.T) {
}, },
}, },
{ {
name: "Should Not Update If Name Is Small", name: "Should Not Config If Name Is Small",
existingNSGroup: existingNSGroup, existingNSGroup: existingNSGroup,
newName: &invalidNameSmall, newName: &invalidNameSmall,
errFunc: require.Error, errFunc: require.Error,
shouldCreate: false, shouldCreate: false,
}, },
{ {
name: "Should Not Update If Name Is Large", name: "Should Not Config If Name Is Large",
existingNSGroup: existingNSGroup, existingNSGroup: existingNSGroup,
newName: &invalidNameLarge, newName: &invalidNameLarge,
errFunc: require.Error, errFunc: require.Error,
shouldCreate: false, shouldCreate: false,
}, },
{ {
name: "Should Not Update If Name Exists", name: "Should Not Config If Name Exists",
existingNSGroup: existingNSGroup, existingNSGroup: existingNSGroup,
newName: &invalidNameExisting, newName: &invalidNameExisting,
errFunc: require.Error, errFunc: require.Error,
shouldCreate: false, shouldCreate: false,
}, },
{ {
name: "Should Not Update If ID Don't Exist", name: "Should Not Config If ID Don't Exist",
existingNSGroup: existingNSGroup, existingNSGroup: existingNSGroup,
newID: &invalidID, newID: &invalidID,
errFunc: require.Error, errFunc: require.Error,
shouldCreate: false, shouldCreate: false,
}, },
{ {
name: "Should Not Update If Nameserver List Is Small", name: "Should Not Config If Nameserver List Is Small",
existingNSGroup: existingNSGroup, existingNSGroup: existingNSGroup,
newNSList: []nbdns.NameServer{}, newNSList: []nbdns.NameServer{},
errFunc: require.Error, errFunc: require.Error,
shouldCreate: false, shouldCreate: false,
}, },
{ {
name: "Should Not Update If Nameserver List Is Large", name: "Should Not Config If Nameserver List Is Large",
existingNSGroup: existingNSGroup, existingNSGroup: existingNSGroup,
newNSList: invalidNameServerListLarge, newNSList: invalidNameServerListLarge,
errFunc: require.Error, errFunc: require.Error,
shouldCreate: false, shouldCreate: false,
}, },
{ {
name: "Should Not Update If Groups List Is Empty", name: "Should Not Config If Groups List Is Empty",
existingNSGroup: existingNSGroup, existingNSGroup: existingNSGroup,
newGroups: []string{}, newGroups: []string{},
errFunc: require.Error, errFunc: require.Error,
shouldCreate: false, shouldCreate: false,
}, },
{ {
name: "Should Not Update If Groups List Has Empty ID", name: "Should Not Config If Groups List Has Empty ID",
existingNSGroup: existingNSGroup, existingNSGroup: existingNSGroup,
newGroups: []string{""}, newGroups: []string{""},
errFunc: require.Error, errFunc: require.Error,
shouldCreate: false, shouldCreate: false,
}, },
{ {
name: "Should Not Update If Groups List Has Non Existing Group ID", name: "Should Not Config If Groups List Has Non Existing Group ID",
existingNSGroup: existingNSGroup, existingNSGroup: existingNSGroup,
newGroups: invalidGroups, newGroups: invalidGroups,
errFunc: require.Error, errFunc: require.Error,
shouldCreate: false, shouldCreate: false,
}, },
{
name: "Should Not Config If Domains List Is Empty",
existingNSGroup: existingNSGroup,
newPrimary: &disabledPrimary,
errFunc: require.Error,
shouldCreate: false,
},
{
name: "Should Not Config If Primary And Domains",
existingNSGroup: existingNSGroup,
newPrimary: &existingNSGroup.Primary,
newDomains: validDomains,
errFunc: require.Error,
shouldCreate: false,
},
{
name: "Should Not Config If Domains List Is Invalid",
existingNSGroup: existingNSGroup,
newPrimary: &disabledPrimary,
newDomains: invalidDomains,
errFunc: require.Error,
shouldCreate: false,
},
} }
for _, testCase := range testCases { for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
@@ -475,6 +617,14 @@ func TestSaveNameServerGroup(t *testing.T) {
if testCase.newNSList != nil { if testCase.newNSList != nil {
nsGroupToSave.NameServers = testCase.newNSList nsGroupToSave.NameServers = testCase.newNSList
} }
if testCase.newPrimary != nil {
nsGroupToSave.Primary = *testCase.newPrimary
}
if testCase.newDomains != nil {
nsGroupToSave.Domains = testCase.newDomains
}
} }
err = am.SaveNameServerGroup(account.Id, nsGroupToSave) err = am.SaveNameServerGroup(account.Id, nsGroupToSave)
@@ -485,6 +635,11 @@ func TestSaveNameServerGroup(t *testing.T) {
return return
} }
account, err = am.Store.GetAccount(account.Id)
if err != nil {
t.Fatal(err)
}
savedNSGroup, saved := account.NameServerGroups[testCase.expectedNSGroup.ID] savedNSGroup, saved := account.NameServerGroups[testCase.expectedNSGroup.ID]
require.True(t, saved) require.True(t, saved)
@@ -503,6 +658,7 @@ func TestUpdateNameServerGroup(t *testing.T) {
ID: nsGroupID, ID: nsGroupID,
Name: "super", Name: "super",
Description: "super", Description: "super",
Primary: true,
NameServers: []nbdns.NameServer{ NameServers: []nbdns.NameServer{
{ {
IP: netip.MustParseAddr("1.1.1.1"), IP: netip.MustParseAddr("1.1.1.1"),
@@ -529,7 +685,7 @@ func TestUpdateNameServerGroup(t *testing.T) {
expectedNSGroup *nbdns.NameServerGroup expectedNSGroup *nbdns.NameServerGroup
}{ }{
{ {
name: "Should Update Single Property", name: "Should Config Single Property",
existingNSGroup: existingNSGroup, existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID, nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{ operations: []NameServerGroupUpdateOperation{
@@ -544,6 +700,7 @@ func TestUpdateNameServerGroup(t *testing.T) {
ID: nsGroupID, ID: nsGroupID,
Name: "superNew", Name: "superNew",
Description: "super", Description: "super",
Primary: true,
NameServers: []nbdns.NameServer{ NameServers: []nbdns.NameServer{
{ {
IP: netip.MustParseAddr("1.1.1.1"), IP: netip.MustParseAddr("1.1.1.1"),
@@ -561,7 +718,7 @@ func TestUpdateNameServerGroup(t *testing.T) {
}, },
}, },
{ {
name: "Should Update Multiple Properties", name: "Should Config Multiple Properties",
existingNSGroup: existingNSGroup, existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID, nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{ operations: []NameServerGroupUpdateOperation{
@@ -585,6 +742,14 @@ func TestUpdateNameServerGroup(t *testing.T) {
Type: UpdateNameServerGroupEnabled, Type: UpdateNameServerGroupEnabled,
Values: []string{"false"}, Values: []string{"false"},
}, },
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupPrimary,
Values: []string{"false"},
},
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupDomains,
Values: []string{validDomain},
},
}, },
errFunc: require.NoError, errFunc: require.NoError,
shouldCreate: true, shouldCreate: true,
@@ -592,6 +757,8 @@ func TestUpdateNameServerGroup(t *testing.T) {
ID: nsGroupID, ID: nsGroupID,
Name: "superNew", Name: "superNew",
Description: "superDescription", Description: "superDescription",
Primary: false,
Domains: []string{validDomain},
NameServers: []nbdns.NameServer{ NameServers: []nbdns.NameServer{
{ {
IP: netip.MustParseAddr("127.0.0.1"), IP: netip.MustParseAddr("127.0.0.1"),
@@ -609,20 +776,20 @@ func TestUpdateNameServerGroup(t *testing.T) {
}, },
}, },
{ {
name: "Should Not Update On Invalid ID", name: "Should Not Config On Invalid ID",
existingNSGroup: existingNSGroup, existingNSGroup: existingNSGroup,
nsGroupID: "nonExistingNSGroup", nsGroupID: "nonExistingNSGroup",
errFunc: require.Error, errFunc: require.Error,
}, },
{ {
name: "Should Not Update On Empty Operations", name: "Should Not Config On Empty Operations",
existingNSGroup: existingNSGroup, existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID, nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{}, operations: []NameServerGroupUpdateOperation{},
errFunc: require.Error, errFunc: require.Error,
}, },
{ {
name: "Should Not Update On Empty Values", name: "Should Not Config On Empty Values",
existingNSGroup: existingNSGroup, existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID, nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{ operations: []NameServerGroupUpdateOperation{
@@ -633,7 +800,7 @@ func TestUpdateNameServerGroup(t *testing.T) {
errFunc: require.Error, errFunc: require.Error,
}, },
{ {
name: "Should Not Update On Empty String", name: "Should Not Config On Empty String",
existingNSGroup: existingNSGroup, existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID, nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{ operations: []NameServerGroupUpdateOperation{
@@ -645,7 +812,7 @@ func TestUpdateNameServerGroup(t *testing.T) {
errFunc: require.Error, errFunc: require.Error,
}, },
{ {
name: "Should Not Update On Invalid Name Large String", name: "Should Not Config On Invalid Name Large String",
existingNSGroup: existingNSGroup, existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID, nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{ operations: []NameServerGroupUpdateOperation{
@@ -657,7 +824,7 @@ func TestUpdateNameServerGroup(t *testing.T) {
errFunc: require.Error, errFunc: require.Error,
}, },
{ {
name: "Should Not Update On Invalid On Existing Name", name: "Should Not Config On Invalid On Existing Name",
existingNSGroup: existingNSGroup, existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID, nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{ operations: []NameServerGroupUpdateOperation{
@@ -669,7 +836,7 @@ func TestUpdateNameServerGroup(t *testing.T) {
errFunc: require.Error, errFunc: require.Error,
}, },
{ {
name: "Should Not Update On Invalid On Multiple Name Values", name: "Should Not Config On Invalid On Multiple Name Values",
existingNSGroup: existingNSGroup, existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID, nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{ operations: []NameServerGroupUpdateOperation{
@@ -681,7 +848,7 @@ func TestUpdateNameServerGroup(t *testing.T) {
errFunc: require.Error, errFunc: require.Error,
}, },
{ {
name: "Should Not Update On Invalid Boolean", name: "Should Not Config On Invalid Boolean",
existingNSGroup: existingNSGroup, existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID, nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{ operations: []NameServerGroupUpdateOperation{
@@ -693,7 +860,7 @@ func TestUpdateNameServerGroup(t *testing.T) {
errFunc: require.Error, errFunc: require.Error,
}, },
{ {
name: "Should Not Update On Invalid Nameservers Wrong Schema", name: "Should Not Config On Invalid Nameservers Wrong Schema",
existingNSGroup: existingNSGroup, existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID, nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{ operations: []NameServerGroupUpdateOperation{
@@ -705,7 +872,7 @@ func TestUpdateNameServerGroup(t *testing.T) {
errFunc: require.Error, errFunc: require.Error,
}, },
{ {
name: "Should Not Update On Invalid Nameservers Wrong IP", name: "Should Not Config On Invalid Nameservers Wrong IP",
existingNSGroup: existingNSGroup, existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID, nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{ operations: []NameServerGroupUpdateOperation{
@@ -717,7 +884,7 @@ func TestUpdateNameServerGroup(t *testing.T) {
errFunc: require.Error, errFunc: require.Error,
}, },
{ {
name: "Should Not Update On Large Number Of Nameservers", name: "Should Not Config On Large Number Of Nameservers",
existingNSGroup: existingNSGroup, existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID, nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{ operations: []NameServerGroupUpdateOperation{
@@ -729,7 +896,7 @@ func TestUpdateNameServerGroup(t *testing.T) {
errFunc: require.Error, errFunc: require.Error,
}, },
{ {
name: "Should Not Update On Invalid GroupID", name: "Should Not Config On Invalid GroupID",
existingNSGroup: existingNSGroup, existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID, nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{ operations: []NameServerGroupUpdateOperation{
@@ -740,6 +907,30 @@ func TestUpdateNameServerGroup(t *testing.T) {
}, },
errFunc: require.Error, errFunc: require.Error,
}, },
{
name: "Should Not Config On Invalid Domains",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupDomains,
Values: []string{invalidDomain},
},
},
errFunc: require.Error,
},
{
name: "Should Not Config On Invalid Primary Status",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupPrimary,
Values: []string{"yes"},
},
},
errFunc: require.Error,
},
} }
for _, testCase := range testCases { for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
@@ -865,12 +1056,12 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return BuildManager(store, NewPeersUpdateManager(), nil) return BuildManager(store, NewPeersUpdateManager(), nil, "", "")
} }
func createNSStore(t *testing.T) (Store, error) { func createNSStore(t *testing.T) (Store, error) {
dataDir := t.TempDir() dataDir := t.TempDir()
store, err := NewStore(dataDir) store, err := NewFileStore(dataDir)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -2,6 +2,7 @@ package server
import ( import (
"github.com/c-robinson/iplib" "github.com/c-robinson/iplib"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/rs/xid" "github.com/rs/xid"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
@@ -23,9 +24,10 @@ const (
) )
type NetworkMap struct { type NetworkMap struct {
Peers []*Peer Peers []*Peer
Network *Network Network *Network
Routes []*route.Route Routes []*route.Route
DNSConfig nbdns.Config
} }
type Network struct { type Network struct {

View File

@@ -1,6 +1,7 @@
package server package server
import ( import (
nbdns "github.com/netbirdio/netbird/dns"
"net" "net"
"strings" "strings"
"time" "time"
@@ -43,7 +44,11 @@ type Peer struct {
// Meta is a Peer system meta data // Meta is a Peer system meta data
Meta PeerSystemMeta Meta PeerSystemMeta
// Name is peer's name (machine name) // Name is peer's name (machine name)
Name string Name string
// DNSLabel is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's
// domain to the peer label. e.g. peer-dns-label.netbird.cloud
DNSLabel string
// Status peer's management connection status
Status *PeerStatus Status *PeerStatus
// The user ID that registered the peer // The user ID that registered the peer
UserID string UserID string
@@ -65,41 +70,84 @@ func (p *Peer) Copy() *Peer {
UserID: p.UserID, UserID: p.UserID,
SSHKey: p.SSHKey, SSHKey: p.SSHKey,
SSHEnabled: p.SSHEnabled, SSHEnabled: p.SSHEnabled,
DNSLabel: p.DNSLabel,
} }
} }
// GetPeer returns a peer from a Store // Copy PeerStatus
func (am *DefaultAccountManager) GetPeer(peerKey string) (*Peer, error) { func (p *PeerStatus) Copy() *PeerStatus {
am.mux.Lock() return &PeerStatus{
defer am.mux.Unlock() LastSeen: p.LastSeen,
Connected: p.Connected,
}
}
peer, err := am.Store.GetPeer(peerKey) // GetPeer looks up peer by its public WireGuard key
func (am *DefaultAccountManager) GetPeer(peerPubKey string) (*Peer, error) {
account, err := am.Store.GetAccountByPeerPubKey(peerPubKey)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return peer, nil return account.FindPeerByPubKey(peerPubKey)
}
// GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if
// the current user is not an admin.
func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*Peer, error) {
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, err
}
user, err := account.FindUser(userID)
if err != nil {
return nil, err
}
peers := make([]*Peer, 0, len(account.Peers))
for _, peer := range account.Peers {
if !user.IsAdmin() && user.Id != peer.UserID {
// only display peers that belong to the current user if the current user is not an admin
continue
}
peers = append(peers, peer.Copy())
}
return peers, nil
} }
// MarkPeerConnected marks peer as connected (true) or disconnected (false) // MarkPeerConnected marks peer as connected (true) or disconnected (false)
func (am *DefaultAccountManager) MarkPeerConnected(peerKey string, connected bool) error { func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected bool) error {
am.mux.Lock()
defer am.mux.Unlock()
peer, err := am.Store.GetPeer(peerKey) account, err := am.Store.GetAccountByPeerPubKey(peerPubKey)
if err != nil { if err != nil {
return err return err
} }
account, err := am.Store.GetPeerAccount(peerKey) unlock := am.Store.AcquireAccountLock(account.Id)
defer unlock()
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
account, err = am.Store.GetAccount(account.Id)
if err != nil { if err != nil {
return err return err
} }
peerCopy := peer.Copy() peer, err := account.FindPeerByPubKey(peerPubKey)
peerCopy.Status.LastSeen = time.Now() if err != nil {
peerCopy.Status.Connected = connected return err
err = am.Store.SavePeer(account.Id, peerCopy) }
newStatus := peer.Status.Copy()
newStatus.LastSeen = time.Now()
newStatus.Connected = connected
peer.Status = newStatus
account.UpdatePeer(peer)
err = am.Store.SavePeerStatus(account.Id, peerPubKey, *newStatus)
if err != nil { if err != nil {
return err return err
} }
@@ -108,26 +156,38 @@ func (am *DefaultAccountManager) MarkPeerConnected(peerKey string, connected boo
// UpdatePeer updates peer. Only Peer.Name and Peer.SSHEnabled can be updated. // UpdatePeer updates peer. Only Peer.Name and Peer.SSHEnabled can be updated.
func (am *DefaultAccountManager) UpdatePeer(accountID string, update *Peer) (*Peer, error) { func (am *DefaultAccountManager) UpdatePeer(accountID string, update *Peer) (*Peer, error) {
am.mux.Lock()
defer am.mux.Unlock() unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found") return nil, status.Errorf(codes.NotFound, "account not found")
} }
peer, err := am.Store.GetPeer(update.Key) //TODO Peer.ID migration: we will need to replace search by ID here
peer, err := account.FindPeerByPubKey(update.Key)
if err != nil { if err != nil {
return nil, err return nil, err
} }
peerCopy := peer.Copy()
if peer.Name != "" { if peer.Name != "" {
peerCopy.Name = update.Name peer.Name = update.Name
} }
peerCopy.SSHEnabled = update.SSHEnabled peer.SSHEnabled = update.SSHEnabled
err = am.Store.SavePeer(accountID, peerCopy) existingLabels := account.getPeerDNSLabels()
newLabel, err := getPeerHostLabel(peer.Name, existingLabels)
if err != nil {
return nil, err
}
peer.DNSLabel = newLabel
account.UpdatePeer(peer)
err = am.Store.SaveAccount(account)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -137,66 +197,33 @@ func (am *DefaultAccountManager) UpdatePeer(accountID string, update *Peer) (*Pe
return nil, err return nil, err
} }
return peerCopy, nil return peer, nil
} }
// RenamePeer changes peer's name // DeletePeer removes peer from the account by its IP
func (am *DefaultAccountManager) RenamePeer( func (am *DefaultAccountManager) DeletePeer(accountID string, peerPubKey string) (*Peer, error) {
accountId string,
peerKey string,
newName string,
) (*Peer, error) {
am.mux.Lock()
defer am.mux.Unlock()
peer, err := am.Store.GetPeer(peerKey) unlock := am.Store.AcquireAccountLock(accountID)
if err != nil { defer unlock()
return nil, err
}
peerCopy := peer.Copy() account, err := am.Store.GetAccount(accountID)
peerCopy.Name = newName
err = am.Store.SavePeer(accountId, peerCopy)
if err != nil {
return nil, err
}
return peerCopy, nil
}
// DeletePeer removes peer from the account by it's IP
func (am *DefaultAccountManager) DeletePeer(accountId string, peerKey string) (*Peer, error) {
am.mux.Lock()
defer am.mux.Unlock()
account, err := am.Store.GetAccount(accountId)
if err != nil { if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found") return nil, status.Errorf(codes.NotFound, "account not found")
} }
// delete peer from groups peer, err := account.FindPeerByPubKey(peerPubKey)
for _, g := range account.Groups {
for i, pk := range g.Peers {
if pk == peerKey {
g.Peers = append(g.Peers[:i], g.Peers[i+1:]...)
break
}
}
}
peer, err := am.Store.DeletePeer(accountId, peerKey)
if err != nil { if err != nil {
return nil, err return nil, err
} }
account.Network.IncSerial() account.DeletePeer(peerPubKey)
err = am.Store.SaveAccount(account) err = am.Store.SaveAccount(account)
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = am.peersUpdateManager.SendUpdate(peerKey, err = am.peersUpdateManager.SendUpdate(peerPubKey,
&UpdateMessage{ &UpdateMessage{
Update: &proto.SyncResponse{ Update: &proto.SyncResponse{
// fill those field for backward compatibility // fill those field for backward compatibility
@@ -214,20 +241,22 @@ func (am *DefaultAccountManager) DeletePeer(accountId string, peerKey string) (*
return nil, err return nil, err
} }
// TODO Peer.ID migration: we will need to replace search by Peer.ID here
if err := am.updateAccountPeers(account); err != nil { if err := am.updateAccountPeers(account); err != nil {
return nil, err return nil, err
} }
am.peersUpdateManager.CloseChannel(peerKey) am.peersUpdateManager.CloseChannel(peerPubKey)
return peer, nil return peer, nil
} }
// GetPeerByIP returns peer by it's IP // GetPeerByIP returns peer by its IP
func (am *DefaultAccountManager) GetPeerByIP(accountId string, peerIP string) (*Peer, error) { func (am *DefaultAccountManager) GetPeerByIP(accountID string, peerIP string) (*Peer, error) {
am.mux.Lock()
defer am.mux.Unlock()
account, err := am.Store.GetAccount(accountId) unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found") return nil, status.Errorf(codes.NotFound, "account not found")
} }
@@ -242,33 +271,42 @@ func (am *DefaultAccountManager) GetPeerByIP(accountId string, peerIP string) (*
} }
// GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result) // GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result)
func (am *DefaultAccountManager) GetNetworkMap(peerKey string) (*NetworkMap, error) { func (am *DefaultAccountManager) GetNetworkMap(peerPubKey string) (*NetworkMap, error) {
am.mux.Lock()
defer am.mux.Unlock()
account, err := am.Store.GetPeerAccount(peerKey) account, err := am.Store.GetAccountByPeerPubKey(peerPubKey)
if err != nil { if err != nil {
return nil, status.Errorf(codes.Internal, "Invalid peer key %s", peerKey) return nil, status.Errorf(codes.Internal, "Invalid peer key %s", peerPubKey)
} }
aclPeers := am.getPeersByACL(account, peerKey) aclPeers := am.getPeersByACL(account, peerPubKey)
routesUpdate := am.getPeersRoutes(append(aclPeers, account.Peers[peerKey])) routesUpdate := account.GetPeersRoutes(append(aclPeers, account.Peers[peerPubKey]))
var zones []nbdns.CustomZone
peersCustomZone := getPeersCustomZone(account, am.dnsDomain)
if peersCustomZone.Domain != "" {
zones = append(zones, peersCustomZone)
}
dnsUpdate := nbdns.Config{
ServiceEnable: true,
CustomZones: zones,
NameServerGroups: getPeerNSGroups(account, peerPubKey),
}
return &NetworkMap{ return &NetworkMap{
Peers: aclPeers, Peers: aclPeers,
Network: account.Network.Copy(), Network: account.Network.Copy(),
Routes: routesUpdate, Routes: routesUpdate,
DNSConfig: dnsUpdate,
}, err }, err
} }
// GetPeerNetwork returns the Network for a given peer // GetPeerNetwork returns the Network for a given peer
func (am *DefaultAccountManager) GetPeerNetwork(peerKey string) (*Network, error) { func (am *DefaultAccountManager) GetPeerNetwork(peerPubKey string) (*Network, error) {
am.mux.Lock()
defer am.mux.Unlock()
account, err := am.Store.GetPeerAccount(peerKey) account, err := am.Store.GetAccountByPeerPubKey(peerPubKey)
if err != nil { if err != nil {
return nil, status.Errorf(codes.Internal, "Invalid peer key %s", peerKey) return nil, status.Errorf(codes.Internal, "invalid peer key %s", peerPubKey)
} }
return account.Network.Copy(), err return account.Network.Copy(), err
@@ -281,13 +319,7 @@ func (am *DefaultAccountManager) GetPeerNetwork(peerKey string) (*Network, error
// to it. We also add the User ID to the peer metadata to identify registrant. // to it. We also add the User ID to the peer metadata to identify registrant.
// Each new Peer will be assigned a new next net.IP from the Account.Network and Account.Network.LastIP will be updated (IP's are not reused). // Each new Peer will be assigned a new next net.IP from the Account.Network and Account.Network.LastIP will be updated (IP's are not reused).
// The peer property is just a placeholder for the Peer properties to pass further // The peer property is just a placeholder for the Peer properties to pass further
func (am *DefaultAccountManager) AddPeer( func (am *DefaultAccountManager) AddPeer(setupKey string, userID string, peer *Peer) (*Peer, error) {
setupKey string,
userID string,
peer *Peer,
) (*Peer, error) {
am.mux.Lock()
defer am.mux.Unlock()
upperKey := strings.ToUpper(setupKey) upperKey := strings.ToUpper(setupKey)
@@ -326,7 +358,7 @@ func (am *DefaultAccountManager) AddPeer(
groupsToAdd = sk.AutoGroups groupsToAdd = sk.AutoGroups
} else if len(userID) != 0 { } else if len(userID) != 0 {
account, err = am.Store.GetUserAccount(userID) account, err = am.Store.GetAccountByUser(userID)
if err != nil { if err != nil {
return nil, status.Errorf(codes.NotFound, "unable to register peer, unknown user with ID: %s", userID) return nil, status.Errorf(codes.NotFound, "unable to register peer, unknown user with ID: %s", userID)
} }
@@ -342,11 +374,31 @@ func (am *DefaultAccountManager) AddPeer(
return nil, status.Errorf(codes.InvalidArgument, "no setup key or user id provided") return nil, status.Errorf(codes.InvalidArgument, "no setup key or user id provided")
} }
var takenIps []net.IP unlock := am.Store.AcquireAccountLock(account.Id)
for _, peer := range account.Peers { defer unlock()
takenIps = append(takenIps, peer.IP)
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
account, err = am.Store.GetAccount(account.Id)
if err != nil {
return nil, err
} }
var takenIps []net.IP
existingLabels := make(lookupMap)
for _, existingPeer := range account.Peers {
takenIps = append(takenIps, existingPeer.IP)
if existingPeer.DNSLabel != "" {
existingLabels[existingPeer.DNSLabel] = struct{}{}
}
}
newLabel, err := getPeerHostLabel(peer.Name, existingLabels)
if err != nil {
return nil, err
}
peer.DNSLabel = newLabel
network := account.Network network := account.Network
nextIp, err := AllocatePeerIP(network.Net, takenIps) nextIp, err := AllocatePeerIP(network.Net, takenIps)
if err != nil { if err != nil {
@@ -359,6 +411,7 @@ func (am *DefaultAccountManager) AddPeer(
IP: nextIp, IP: nextIp,
Meta: peer.Meta, Meta: peer.Meta,
Name: peer.Name, Name: peer.Name,
DNSLabel: newLabel,
UserID: userID, UserID: userID,
Status: &PeerStatus{Connected: false, LastSeen: time.Now()}, Status: &PeerStatus{Connected: false, LastSeen: time.Now()},
SSHEnabled: false, SSHEnabled: false,
@@ -395,34 +448,41 @@ func (am *DefaultAccountManager) AddPeer(
} }
// UpdatePeerSSHKey updates peer's public SSH key // UpdatePeerSSHKey updates peer's public SSH key
func (am *DefaultAccountManager) UpdatePeerSSHKey(peerKey string, sshKey string) error { func (am *DefaultAccountManager) UpdatePeerSSHKey(peerPubKey string, sshKey string) error {
am.mux.Lock()
defer am.mux.Unlock()
if sshKey == "" { if sshKey == "" {
log.Debugf("empty SSH key provided for peer %s, skipping update", peerKey) log.Debugf("empty SSH key provided for peer %s, skipping update", peerPubKey)
return nil return nil
} }
peer, err := am.Store.GetPeer(peerKey) account, err := am.Store.GetAccountByPeerPubKey(peerPubKey)
if err != nil {
return err
}
unlock := am.Store.AcquireAccountLock(account.Id)
defer unlock()
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
account, err = am.Store.GetAccount(account.Id)
if err != nil {
return err
}
peer, err := account.FindPeerByPubKey(peerPubKey)
if err != nil { if err != nil {
return err return err
} }
if peer.SSHKey == sshKey { if peer.SSHKey == sshKey {
log.Debugf("same SSH key provided for peer %s, skipping update", peerKey) log.Debugf("same SSH key provided for peer %s, skipping update", peerPubKey)
return nil return nil
} }
account, err := am.Store.GetPeerAccount(peerKey) peer.SSHKey = sshKey
if err != nil { account.UpdatePeer(peer)
return err
}
peerCopy := peer.Copy() err = am.Store.SaveAccount(account)
peerCopy.SSHKey = sshKey
err = am.Store.SavePeer(account.Id, peerCopy)
if err != nil { if err != nil {
return err return err
} }
@@ -432,29 +492,30 @@ func (am *DefaultAccountManager) UpdatePeerSSHKey(peerKey string, sshKey string)
} }
// UpdatePeerMeta updates peer's system metadata // UpdatePeerMeta updates peer's system metadata
func (am *DefaultAccountManager) UpdatePeerMeta(peerKey string, meta PeerSystemMeta) error { func (am *DefaultAccountManager) UpdatePeerMeta(peerPubKey string, meta PeerSystemMeta) error {
am.mux.Lock()
defer am.mux.Unlock()
peer, err := am.Store.GetPeer(peerKey) account, err := am.Store.GetAccountByPeerPubKey(peerPubKey)
if err != nil { if err != nil {
return err return err
} }
account, err := am.Store.GetPeerAccount(peerKey) unlock := am.Store.AcquireAccountLock(account.Id)
defer unlock()
peer, err := account.FindPeerByPubKey(peerPubKey)
if err != nil { if err != nil {
return err return err
} }
peerCopy := peer.Copy()
// Avoid overwriting UIVersion if the update was triggered sole by the CLI client // Avoid overwriting UIVersion if the update was triggered sole by the CLI client
if meta.UIVersion == "" { if meta.UIVersion == "" {
meta.UIVersion = peerCopy.Meta.UIVersion meta.UIVersion = peer.Meta.UIVersion
} }
peerCopy.Meta = meta peer.Meta = meta
account.UpdatePeer(peer)
err = am.Store.SavePeer(account.Id, peerCopy) err = am.Store.SaveAccount(account)
if err != nil { if err != nil {
return err return err
} }
@@ -462,17 +523,9 @@ func (am *DefaultAccountManager) UpdatePeerMeta(peerKey string, meta PeerSystemM
} }
// getPeersByACL returns all peers that given peer has access to. // getPeersByACL returns all peers that given peer has access to.
func (am *DefaultAccountManager) getPeersByACL(account *Account, peerKey string) []*Peer { func (am *DefaultAccountManager) getPeersByACL(account *Account, peerPubKey string) []*Peer {
var peers []*Peer var peers []*Peer
srcRules, err := am.Store.GetPeerSrcRules(account.Id, peerKey) srcRules, dstRules := account.GetPeerRules(peerPubKey)
if err != nil {
srcRules = []*Rule{}
}
dstRules, err := am.Store.GetPeerDstRules(account.Id, peerKey)
if err != nil {
dstRules = []*Rule{}
}
groups := map[string]*Group{} groups := map[string]*Group{}
for _, r := range srcRules { for _, r := range srcRules {
@@ -515,7 +568,7 @@ func (am *DefaultAccountManager) getPeersByACL(account *Account, peerKey string)
continue continue
} }
// exclude original peer // exclude original peer
if _, ok := peersSet[peer.Key]; peer.Key != peerKey && !ok { if _, ok := peersSet[peer.Key]; peer.Key != peerPubKey && !ok {
peersSet[peer.Key] = struct{}{} peersSet[peer.Key] = struct{}{}
peers = append(peers, peer.Copy()) peers = append(peers, peer.Copy())
} }
@@ -528,36 +581,18 @@ func (am *DefaultAccountManager) getPeersByACL(account *Account, peerKey string)
// updateAccountPeers updates all peers that belong to an account. // updateAccountPeers updates all peers that belong to an account.
// Should be called when changes have to be synced to peers. // Should be called when changes have to be synced to peers.
func (am *DefaultAccountManager) updateAccountPeers(account *Account) error { func (am *DefaultAccountManager) updateAccountPeers(account *Account) error {
// notify other peers of the change peers := account.GetPeers()
peers, err := am.Store.GetAccountPeers(account.Id)
if err != nil {
return err
}
network := account.Network.Copy()
for _, peer := range peers { for _, peer := range peers {
aclPeers := am.getPeersByACL(account, peer.Key) remotePeerNetworkMap, err := am.GetNetworkMap(peer.Key)
peersUpdate := toRemotePeerConfig(aclPeers)
routesUpdate := toProtocolRoutes(am.getPeersRoutes(append(aclPeers, peer)))
err = am.peersUpdateManager.SendUpdate(peer.Key,
&UpdateMessage{
Update: &proto.SyncResponse{
// fill deprecated fields for backward compatibility
RemotePeers: peersUpdate,
RemotePeersIsEmpty: len(peersUpdate) == 0,
// new field
NetworkMap: &proto.NetworkMap{
Serial: account.Network.CurrentSerial(),
RemotePeers: peersUpdate,
RemotePeersIsEmpty: len(peersUpdate) == 0,
PeerConfig: toPeerConfig(peer, network),
Routes: routesUpdate,
},
},
})
if err != nil { if err != nil {
return err return status.Errorf(codes.Internal, "unable to fetch network map for peer %s, error: %v", peer.Key, err)
}
update := toSyncResponse(nil, peer, nil, remotePeerNetworkMap)
err = am.peersUpdateManager.SendUpdate(peer.Key, &UpdateMessage{Update: update})
if err != nil {
return status.Errorf(codes.Internal, "unable to send update for peer %s, error: %v", peer.Key, err)
} }
} }

View File

@@ -134,7 +134,7 @@ func TestAccountManager_GetNetworkMapWithRule(t *testing.T) {
return return
} }
rules, err := manager.ListRules(account.Id) rules, err := manager.ListRules(account.Id, userId)
if err != nil { if err != nil {
t.Errorf("expecting to get a list of rules, got failure %v", err) t.Errorf("expecting to get a list of rules, got failure %v", err)
return return

View File

@@ -60,15 +60,24 @@ type RouteUpdateOperation struct {
} }
// GetRoute gets a route object from account and route IDs // GetRoute gets a route object from account and route IDs
func (am *DefaultAccountManager) GetRoute(accountID, routeID string) (*route.Route, error) { func (am *DefaultAccountManager) GetRoute(accountID, routeID, userID string) (*route.Route, error) {
am.mux.Lock() unlock := am.Store.AcquireAccountLock(accountID)
defer am.mux.Unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found") return nil, status.Errorf(codes.NotFound, "account not found")
} }
user, err := account.FindUser(userID)
if err != nil {
return nil, err
}
if !user.IsAdmin() {
return nil, Errorf(PermissionDenied, "Only administrators can view Network Routes")
}
wantedRoute, found := account.Routes[routeID] wantedRoute, found := account.Routes[routeID]
if found { if found {
return wantedRoute, nil return wantedRoute, nil
@@ -84,7 +93,12 @@ func (am *DefaultAccountManager) checkPrefixPeerExists(accountID, peer string, p
return nil return nil
} }
routesWithPrefix, err := am.Store.GetRoutesByPrefix(accountID, prefix) account, err := am.Store.GetAccount(accountID)
if err != nil {
return err
}
routesWithPrefix := account.GetRoutesByPrefix(prefix)
if err != nil { if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound { if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
@@ -102,8 +116,8 @@ func (am *DefaultAccountManager) checkPrefixPeerExists(accountID, peer string, p
// CreateRoute creates and saves a new route // CreateRoute creates and saves a new route
func (am *DefaultAccountManager) CreateRoute(accountID string, network, peer, description, netID string, masquerade bool, metric int, enabled bool) (*route.Route, error) { func (am *DefaultAccountManager) CreateRoute(accountID string, network, peer, description, netID string, masquerade bool, metric int, enabled bool) (*route.Route, error) {
am.mux.Lock() unlock := am.Store.AcquireAccountLock(accountID)
defer am.mux.Unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
@@ -166,8 +180,8 @@ func (am *DefaultAccountManager) CreateRoute(accountID string, network, peer, de
// SaveRoute saves route // SaveRoute saves route
func (am *DefaultAccountManager) SaveRoute(accountID string, routeToSave *route.Route) error { func (am *DefaultAccountManager) SaveRoute(accountID string, routeToSave *route.Route) error {
am.mux.Lock() unlock := am.Store.AcquireAccountLock(accountID)
defer am.mux.Unlock() defer unlock()
if routeToSave == nil { if routeToSave == nil {
return status.Errorf(codes.InvalidArgument, "route provided is nil") return status.Errorf(codes.InvalidArgument, "route provided is nil")
@@ -209,8 +223,8 @@ func (am *DefaultAccountManager) SaveRoute(accountID string, routeToSave *route.
// UpdateRoute updates existing route with set of operations // UpdateRoute updates existing route with set of operations
func (am *DefaultAccountManager) UpdateRoute(accountID, routeID string, operations []RouteUpdateOperation) (*route.Route, error) { func (am *DefaultAccountManager) UpdateRoute(accountID, routeID string, operations []RouteUpdateOperation) (*route.Route, error) {
am.mux.Lock() unlock := am.Store.AcquireAccountLock(accountID)
defer am.mux.Unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
@@ -306,8 +320,8 @@ func (am *DefaultAccountManager) UpdateRoute(accountID, routeID string, operatio
// DeleteRoute deletes route with routeID // DeleteRoute deletes route with routeID
func (am *DefaultAccountManager) DeleteRoute(accountID, routeID string) error { func (am *DefaultAccountManager) DeleteRoute(accountID, routeID string) error {
am.mux.Lock() unlock := am.Store.AcquireAccountLock(accountID)
defer am.mux.Unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
@@ -325,15 +339,24 @@ func (am *DefaultAccountManager) DeleteRoute(accountID, routeID string) error {
} }
// ListRoutes returns a list of routes from account // ListRoutes returns a list of routes from account
func (am *DefaultAccountManager) ListRoutes(accountID string) ([]*route.Route, error) { func (am *DefaultAccountManager) ListRoutes(accountID, userID string) ([]*route.Route, error) {
am.mux.Lock() unlock := am.Store.AcquireAccountLock(accountID)
defer am.mux.Unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found") return nil, status.Errorf(codes.NotFound, "account not found")
} }
user, err := account.FindUser(userID)
if err != nil {
return nil, err
}
if !user.IsAdmin() {
return nil, Errorf(PermissionDenied, "Only administrators can view Network Routes")
}
routes := make([]*route.Route, 0, len(account.Routes)) routes := make([]*route.Route, 0, len(account.Routes))
for _, item := range account.Routes { for _, item := range account.Routes {
routes = append(routes, item) routes = append(routes, item)
@@ -354,30 +377,6 @@ func toProtocolRoute(route *route.Route) *proto.Route {
} }
} }
func (am *DefaultAccountManager) getPeersRoutes(peers []*Peer) []*route.Route {
routes := make([]*route.Route, 0)
for _, peer := range peers {
peerRoutes, err := am.Store.GetPeerRoutes(peer.Key)
if err != nil {
errorStatus, ok := status.FromError(err)
if !ok && errorStatus.Code() != codes.NotFound {
log.Error(err)
}
continue
}
activeRoutes := make([]*route.Route, 0)
for _, pr := range peerRoutes {
if pr.Enabled {
activeRoutes = append(activeRoutes, pr)
}
}
if len(activeRoutes) > 0 {
routes = append(routes, activeRoutes...)
}
}
return routes
}
func toProtocolRoutes(routes []*route.Route) []*proto.Route { func toProtocolRoutes(routes []*route.Route) []*proto.Route {
protoRoutes := make([]*proto.Route, 0) protoRoutes := make([]*proto.Route, 0)
for _, r := range routes { for _, r := range routes {

View File

@@ -380,6 +380,11 @@ func TestSaveRoute(t *testing.T) {
return return
} }
account, err = am.Store.GetAccount(account.Id)
if err != nil {
t.Fatal(err)
}
savedRoute, saved := account.Routes[testCase.expectedRoute.ID] savedRoute, saved := account.Routes[testCase.expectedRoute.ID]
require.True(t, saved) require.True(t, saved)
@@ -740,7 +745,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
err = am.SaveGroup(account.Id, newGroup) err = am.SaveGroup(account.Id, newGroup)
require.NoError(t, err) require.NoError(t, err)
rules, err := am.ListRules(account.Id) rules, err := am.ListRules(account.Id, "testingUser")
require.NoError(t, err) require.NoError(t, err)
defaultRule := rules[0] defaultRule := rules[0]
@@ -778,12 +783,12 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return BuildManager(store, NewPeersUpdateManager(), nil) return BuildManager(store, NewPeersUpdateManager(), nil, "", "")
} }
func createRouterStore(t *testing.T) (Store, error) { func createRouterStore(t *testing.T) (Store, error) {
dataDir := t.TempDir() dataDir := t.TempDir()
store, err := NewStore(dataDir) store, err := NewFileStore(dataDir)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -840,5 +845,5 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er
return nil, err return nil, err
} }
return account, nil return am.Store.GetAccount(account.Id)
} }

View File

@@ -89,15 +89,24 @@ func (r *Rule) Copy() *Rule {
} }
// GetRule of ACL from the store // GetRule of ACL from the store
func (am *DefaultAccountManager) GetRule(accountID, ruleID string) (*Rule, error) { func (am *DefaultAccountManager) GetRule(accountID, ruleID, userID string) (*Rule, error) {
am.mux.Lock() unlock := am.Store.AcquireAccountLock(accountID)
defer am.mux.Unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found") return nil, status.Errorf(codes.NotFound, "account not found")
} }
user, err := account.FindUser(userID)
if err != nil {
return nil, err
}
if !user.IsAdmin() {
return nil, Errorf(PermissionDenied, "only admins are allowed to view rules")
}
rule, ok := account.Rules[ruleID] rule, ok := account.Rules[ruleID]
if ok { if ok {
return rule, nil return rule, nil
@@ -108,8 +117,8 @@ func (am *DefaultAccountManager) GetRule(accountID, ruleID string) (*Rule, error
// SaveRule of ACL in the store // SaveRule of ACL in the store
func (am *DefaultAccountManager) SaveRule(accountID string, rule *Rule) error { func (am *DefaultAccountManager) SaveRule(accountID string, rule *Rule) error {
am.mux.Lock() unlock := am.Store.AcquireAccountLock(accountID)
defer am.mux.Unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
@@ -129,8 +138,8 @@ func (am *DefaultAccountManager) SaveRule(accountID string, rule *Rule) error {
// UpdateRule updates a rule using a list of operations // UpdateRule updates a rule using a list of operations
func (am *DefaultAccountManager) UpdateRule(accountID string, ruleID string, func (am *DefaultAccountManager) UpdateRule(accountID string, ruleID string,
operations []RuleUpdateOperation) (*Rule, error) { operations []RuleUpdateOperation) (*Rule, error) {
am.mux.Lock() unlock := am.Store.AcquireAccountLock(accountID)
defer am.mux.Unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
@@ -203,8 +212,8 @@ func (am *DefaultAccountManager) UpdateRule(accountID string, ruleID string,
// DeleteRule of ACL from the store // DeleteRule of ACL from the store
func (am *DefaultAccountManager) DeleteRule(accountID, ruleID string) error { func (am *DefaultAccountManager) DeleteRule(accountID, ruleID string) error {
am.mux.Lock() unlock := am.Store.AcquireAccountLock(accountID)
defer am.mux.Unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
@@ -222,15 +231,24 @@ func (am *DefaultAccountManager) DeleteRule(accountID, ruleID string) error {
} }
// ListRules of ACL from the store // ListRules of ACL from the store
func (am *DefaultAccountManager) ListRules(accountID string) ([]*Rule, error) { func (am *DefaultAccountManager) ListRules(accountID, userID string) ([]*Rule, error) {
am.mux.Lock() unlock := am.Store.AcquireAccountLock(accountID)
defer am.mux.Unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found") return nil, status.Errorf(codes.NotFound, "account not found")
} }
user, err := account.FindUser(userID)
if err != nil {
return nil, err
}
if !user.IsAdmin() {
return nil, Errorf(PermissionDenied, "Only Administrators can view Access Rules")
}
rules := make([]*Rule, 0, len(account.Rules)) rules := make([]*Rule, 0, len(account.Rules))
for _, item := range account.Rules { for _, item := range account.Rules {
rules = append(rules, item) rules = append(rules, item)

View File

@@ -9,6 +9,7 @@ import (
"strconv" "strconv"
"strings" "strings"
"time" "time"
"unicode/utf8"
) )
const ( const (
@@ -100,6 +101,15 @@ func (key *SetupKey) Copy() *SetupKey {
} }
} }
// HiddenCopy returns a copy of the key with a Key value hidden with "*" and a 5 character prefix.
// E.g., "831F6*******************************"
func (key *SetupKey) HiddenCopy() *SetupKey {
k := key.Copy()
prefix := k.Key[0:5]
k.Key = prefix + strings.Repeat("*", utf8.RuneCountInString(key.Key)-len(prefix))
return k
}
// IncrementUsage makes a copy of a key, increments the UsedTimes by 1 and sets LastUsed to now // IncrementUsage makes a copy of a key, increments the UsedTimes by 1 and sets LastUsed to now
func (key *SetupKey) IncrementUsage() *SetupKey { func (key *SetupKey) IncrementUsage() *SetupKey {
c := key.Copy() c := key.Copy()
@@ -163,8 +173,8 @@ func Hash(s string) uint32 {
// and adds it to the specified account. A list of autoGroups IDs can be empty. // and adds it to the specified account. A list of autoGroups IDs can be empty.
func (am *DefaultAccountManager) CreateSetupKey(accountID string, keyName string, keyType SetupKeyType, func (am *DefaultAccountManager) CreateSetupKey(accountID string, keyName string, keyType SetupKeyType,
expiresIn time.Duration, autoGroups []string) (*SetupKey, error) { expiresIn time.Duration, autoGroups []string) (*SetupKey, error) {
am.mux.Lock() unlock := am.Store.AcquireAccountLock(accountID)
defer am.mux.Unlock() defer unlock()
keyDuration := DefaultSetupKeyDuration keyDuration := DefaultSetupKeyDuration
if expiresIn != 0 { if expiresIn != 0 {
@@ -198,8 +208,8 @@ func (am *DefaultAccountManager) CreateSetupKey(accountID string, keyName string
// (e.g. the key itself, creation date, ID, etc). // (e.g. the key itself, creation date, ID, etc).
// These properties are overwritten: Name, AutoGroups, Revoked. The rest is copied from the existing key. // These properties are overwritten: Name, AutoGroups, Revoked. The rest is copied from the existing key.
func (am *DefaultAccountManager) SaveSetupKey(accountID string, keyToSave *SetupKey) (*SetupKey, error) { func (am *DefaultAccountManager) SaveSetupKey(accountID string, keyToSave *SetupKey) (*SetupKey, error) {
am.mux.Lock() unlock := am.Store.AcquireAccountLock(accountID)
defer am.mux.Unlock() defer unlock()
if keyToSave == nil { if keyToSave == nil {
return nil, status.Errorf(codes.InvalidArgument, "provided setup key to update is nil") return nil, status.Errorf(codes.InvalidArgument, "provided setup key to update is nil")
@@ -238,32 +248,48 @@ func (am *DefaultAccountManager) SaveSetupKey(accountID string, keyToSave *Setup
} }
// ListSetupKeys returns a list of all setup keys of the account // ListSetupKeys returns a list of all setup keys of the account
func (am *DefaultAccountManager) ListSetupKeys(accountID string) ([]*SetupKey, error) { func (am *DefaultAccountManager) ListSetupKeys(accountID, userID string) ([]*SetupKey, error) {
am.mux.Lock() unlock := am.Store.AcquireAccountLock(accountID)
defer am.mux.Unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found") return nil, status.Errorf(codes.NotFound, "account not found")
} }
user, err := account.FindUser(userID)
if err != nil {
return nil, err
}
keys := make([]*SetupKey, 0, len(account.SetupKeys)) keys := make([]*SetupKey, 0, len(account.SetupKeys))
for _, key := range account.SetupKeys { for _, key := range account.SetupKeys {
keys = append(keys, key.Copy()) var k *SetupKey
if !user.IsAdmin() {
k = key.HiddenCopy()
} else {
k = key.Copy()
}
keys = append(keys, k)
} }
return keys, nil return keys, nil
} }
// GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found. // GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found.
func (am *DefaultAccountManager) GetSetupKey(accountID, keyID string) (*SetupKey, error) { func (am *DefaultAccountManager) GetSetupKey(accountID, userID, keyID string) (*SetupKey, error) {
am.mux.Lock() unlock := am.Store.AcquireAccountLock(accountID)
defer am.mux.Unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found") return nil, status.Errorf(codes.NotFound, "account not found")
} }
user, err := account.FindUser(userID)
if err != nil {
return nil, err
}
var foundKey *SetupKey var foundKey *SetupKey
for _, key := range account.SetupKeys { for _, key := range account.SetupKeys {
if key.Id == keyID { if key.Id == keyID {
@@ -280,5 +306,9 @@ func (am *DefaultAccountManager) GetSetupKey(accountID, keyID string) (*SetupKey
foundKey.UpdatedAt = foundKey.CreatedAt foundKey.UpdatedAt = foundKey.CreatedAt
} }
if !user.IsAdmin() {
foundKey = foundKey.HiddenCopy()
}
return foundKey, nil return foundKey, nil
} }

View File

@@ -1,26 +1,20 @@
package server package server
import (
"github.com/netbirdio/netbird/route"
"net/netip"
)
type Store interface { type Store interface {
GetPeer(peerKey string) (*Peer, error)
DeletePeer(accountId string, peerKey string) (*Peer, error)
SavePeer(accountId string, peer *Peer) error
GetAllAccounts() []*Account GetAllAccounts() []*Account
GetAccount(accountId string) (*Account, error) GetAccount(accountID string) (*Account, error)
GetUserAccount(userId string) (*Account, error) GetAccountByUser(userID string) (*Account, error)
GetAccountPeers(accountId string) ([]*Peer, error) GetAccountByPeerPubKey(peerKey string) (*Account, error)
GetPeerAccount(peerKey string) (*Account, error) GetAccountBySetupKey(setupKey string) (*Account, error) //todo use key hash later
GetPeerSrcRules(accountId, peerKey string) ([]*Rule, error)
GetPeerDstRules(accountId, peerKey string) ([]*Rule, error)
GetAccountBySetupKey(setupKey string) (*Account, error)
GetAccountByPrivateDomain(domain string) (*Account, error) GetAccountByPrivateDomain(domain string) (*Account, error)
SaveAccount(account *Account) error SaveAccount(account *Account) error
GetPeerRoutes(peerKey string) ([]*route.Route, error)
GetRoutesByPrefix(accountID string, prefix netip.Prefix) ([]*route.Route, error)
GetInstallationID() string GetInstallationID() string
SaveInstallationID(id string) error SaveInstallationID(ID string) error
// AcquireAccountLock should attempt to acquire account lock and return a function that releases the lock
AcquireAccountLock(accountID string) func()
// AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock
AcquireGlobalLock() func()
SavePeerStatus(accountID, peerKey string, status PeerStatus) error
// Close should close the store persisting all unsaved data.
Close() error
} }

View File

@@ -0,0 +1,181 @@
package telemetry
import (
"context"
"fmt"
"github.com/gorilla/mux"
prometheus2 "github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/exporters/prometheus"
metric2 "go.opentelemetry.io/otel/metric"
"go.opentelemetry.io/otel/sdk/metric"
"net"
"net/http"
"reflect"
)
const defaultEndpoint = "/metrics"
// MockAppMetrics mocks the AppMetrics interface
type MockAppMetrics struct {
GetMeterFunc func() metric2.Meter
CloseFunc func() error
ExposeFunc func(port int, endpoint string) error
IDPMetricsFunc func() *IDPMetrics
HTTPMiddlewareFunc func() *HTTPMiddleware
GRPCMetricsFunc func() *GRPCMetrics
}
// GetMeter mocks the GetMeter function of the AppMetrics interface
func (mock *MockAppMetrics) GetMeter() metric2.Meter {
if mock.GetMeterFunc != nil {
return mock.GetMeterFunc()
}
return nil
}
// Close mocks the Close function of the AppMetrics interface
func (mock *MockAppMetrics) Close() error {
if mock.CloseFunc != nil {
return mock.CloseFunc()
}
return fmt.Errorf("unimplemented")
}
// Expose mocks the Expose function of the AppMetrics interface
func (mock *MockAppMetrics) Expose(port int, endpoint string) error {
if mock.ExposeFunc != nil {
return mock.ExposeFunc(port, endpoint)
}
return fmt.Errorf("unimplemented")
}
// IDPMetrics mocks the IDPMetrics function of the IDPMetrics interface
func (mock *MockAppMetrics) IDPMetrics() *IDPMetrics {
if mock.IDPMetricsFunc != nil {
return mock.IDPMetricsFunc()
}
return nil
}
// HTTPMiddleware mocks the HTTPMiddleware function of the IDPMetrics interface
func (mock *MockAppMetrics) HTTPMiddleware() *HTTPMiddleware {
if mock.HTTPMiddlewareFunc != nil {
return mock.HTTPMiddlewareFunc()
}
return nil
}
// GRPCMetrics mocks the GRPCMetrics function of the IDPMetrics interface
func (mock *MockAppMetrics) GRPCMetrics() *GRPCMetrics {
if mock.GRPCMetricsFunc != nil {
return mock.GRPCMetricsFunc()
}
return nil
}
// AppMetrics is metrics interface
type AppMetrics interface {
GetMeter() metric2.Meter
Close() error
Expose(port int, endpoint string) error
IDPMetrics() *IDPMetrics
HTTPMiddleware() *HTTPMiddleware
GRPCMetrics() *GRPCMetrics
}
// defaultAppMetrics are core application metrics based on OpenTelemetry https://opentelemetry.io/
type defaultAppMetrics struct {
// Meter can be used by different application parts to create counters and measure things
Meter metric2.Meter
listener net.Listener
ctx context.Context
idpMetrics *IDPMetrics
httpMiddleware *HTTPMiddleware
grpcMetrics *GRPCMetrics
}
// IDPMetrics returns metrics for the idp package
func (appMetrics *defaultAppMetrics) IDPMetrics() *IDPMetrics {
return appMetrics.idpMetrics
}
// HTTPMiddleware returns metrics for the http api package
func (appMetrics *defaultAppMetrics) HTTPMiddleware() *HTTPMiddleware {
return appMetrics.httpMiddleware
}
// GRPCMetrics returns metrics for the gRPC api
func (appMetrics *defaultAppMetrics) GRPCMetrics() *GRPCMetrics {
return appMetrics.grpcMetrics
}
// Close stop application metrics HTTP handler and closes listener.
func (appMetrics *defaultAppMetrics) Close() error {
if appMetrics.listener == nil {
return nil
}
return appMetrics.listener.Close()
}
// Expose metrics on a given port and endpoint. If endpoint is empty a defaultEndpoint one will be used.
// Exposes metrics in the Prometheus format https://prometheus.io/
func (appMetrics *defaultAppMetrics) Expose(port int, endpoint string) error {
if endpoint == "" {
endpoint = defaultEndpoint
}
rootRouter := mux.NewRouter()
rootRouter.Handle(endpoint, promhttp.HandlerFor(
prometheus2.DefaultGatherer,
promhttp.HandlerOpts{EnableOpenMetrics: true}))
listener, err := net.Listen("tcp4", fmt.Sprintf(":%d", port))
if err != nil {
return err
}
appMetrics.listener = listener
go func() {
err := http.Serve(listener, rootRouter)
if err != nil {
return
}
}()
log.Infof("enabled application metrics and exposing on http://%s", listener.Addr().String())
return nil
}
// GetMeter returns metrics meter that can be used to add various counters
func (appMetrics *defaultAppMetrics) GetMeter() metric2.Meter {
return appMetrics.Meter
}
// NewDefaultAppMetrics and expose them via defaultEndpoint on a given HTTP port
func NewDefaultAppMetrics(ctx context.Context) (AppMetrics, error) {
exporter, err := prometheus.New()
if err != nil {
return nil, err
}
provider := metric.NewMeterProvider(metric.WithReader(exporter))
pkg := reflect.TypeOf(defaultEndpoint).PkgPath()
meter := provider.Meter(pkg)
idpMetrics, err := NewIDPMetrics(ctx, meter)
if err != nil {
return nil, err
}
middleware, err := NewMetricsMiddleware(ctx, meter)
if err != nil {
return nil, err
}
grpcMetrics, err := NewGRPCMetrics(ctx, meter)
if err != nil {
return nil, err
}
return &defaultAppMetrics{Meter: meter, ctx: ctx, idpMetrics: idpMetrics, httpMiddleware: middleware,
grpcMetrics: grpcMetrics}, nil
}

View File

@@ -0,0 +1,76 @@
package telemetry
import (
"context"
"go.opentelemetry.io/otel/metric"
"go.opentelemetry.io/otel/metric/instrument"
"go.opentelemetry.io/otel/metric/instrument/asyncint64"
"go.opentelemetry.io/otel/metric/instrument/syncint64"
)
// GRPCMetrics are gRPC server metrics
type GRPCMetrics struct {
meter metric.Meter
syncRequestsCounter syncint64.Counter
loginRequestsCounter syncint64.Counter
getKeyRequestsCounter syncint64.Counter
activeStreamsGauge asyncint64.Gauge
ctx context.Context
}
// NewGRPCMetrics creates new GRPCMetrics struct and registers common metrics of the gRPC server
func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, error) {
syncRequestsCounter, err := meter.SyncInt64().Counter("management.grpc.sync.request.counter", instrument.WithUnit("1"))
if err != nil {
return nil, err
}
loginRequestsCounter, err := meter.SyncInt64().Counter("management.grpc.login.request.counter", instrument.WithUnit("1"))
if err != nil {
return nil, err
}
getKeyRequestsCounter, err := meter.SyncInt64().Counter("management.grpc.key.request.counter", instrument.WithUnit("1"))
if err != nil {
return nil, err
}
activeStreamsGauge, err := meter.AsyncInt64().Gauge("management.grpc.connected.streams", instrument.WithUnit("1"))
if err != nil {
return nil, err
}
return &GRPCMetrics{
meter: meter,
syncRequestsCounter: syncRequestsCounter,
loginRequestsCounter: loginRequestsCounter,
getKeyRequestsCounter: getKeyRequestsCounter,
activeStreamsGauge: activeStreamsGauge,
ctx: ctx,
}, err
}
// CountSyncRequest counts the number of gRPC sync requests coming to the gRPC API
func (grpcMetrics *GRPCMetrics) CountSyncRequest() {
grpcMetrics.syncRequestsCounter.Add(grpcMetrics.ctx, 1)
}
// CountGetKeyRequest counts the number of gRPC get server key requests coming to the gRPC API
func (grpcMetrics *GRPCMetrics) CountGetKeyRequest() {
grpcMetrics.getKeyRequestsCounter.Add(grpcMetrics.ctx, 1)
}
// CountLoginRequest counts the number of gRPC login requests coming to the gRPC API
func (grpcMetrics *GRPCMetrics) CountLoginRequest() {
grpcMetrics.loginRequestsCounter.Add(grpcMetrics.ctx, 1)
}
// RegisterConnectedStreams registers a function that collects number of active streams and feeds it to the metrics gauge.
func (grpcMetrics *GRPCMetrics) RegisterConnectedStreams(producer func() int64) error {
return grpcMetrics.meter.RegisterCallback(
[]instrument.Asynchronous{
grpcMetrics.activeStreamsGauge,
},
func(ctx context.Context) {
grpcMetrics.activeStreamsGauge.Observe(ctx, producer())
},
)
}

View File

@@ -0,0 +1,176 @@
package telemetry
import (
"context"
"fmt"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/metric"
"go.opentelemetry.io/otel/metric/instrument"
"go.opentelemetry.io/otel/metric/instrument/syncint64"
"hash/fnv"
"net/http"
"strings"
)
const (
httpRequestCounterPrefix = "management.http.request.counter"
httpResponseCounterPrefix = "management.http.response.counter"
)
// WrappedResponseWriter is a wrapper for http.ResponseWriter that allows the
// written HTTP status code to be captured for metrics reporting or logging purposes.
type WrappedResponseWriter struct {
http.ResponseWriter
status int
wroteHeader bool
}
// WrapResponseWriter wraps original http.ResponseWriter
func WrapResponseWriter(w http.ResponseWriter) *WrappedResponseWriter {
return &WrappedResponseWriter{ResponseWriter: w}
}
// Status returns response status
func (rw *WrappedResponseWriter) Status() int {
return rw.status
}
// WriteHeader wraps http.ResponseWriter.WriteHeader method
func (rw *WrappedResponseWriter) WriteHeader(code int) {
if rw.wroteHeader {
return
}
rw.status = code
rw.ResponseWriter.WriteHeader(code)
rw.wroteHeader = true
}
// HTTPMiddleware handler used to collect metrics of every request/response coming to the API.
// Also adds request tracing (logging).
type HTTPMiddleware struct {
meter metric.Meter
ctx context.Context
// defaultEndpoint & method
httpRequestCounters map[string]syncint64.Counter
// defaultEndpoint & method & status code
httpResponseCounters map[string]syncint64.Counter
// all HTTP requests
totalHTTPRequestsCounter syncint64.Counter
// all HTTP responses
totalHTTPResponseCounter syncint64.Counter
// all HTTP responses by status code
totalHTTPResponseCodeCounters map[int]syncint64.Counter
}
// AddHTTPRequestResponseCounter adds a new meter for an HTTP defaultEndpoint and Method (GET, POST, etc)
// Creates one request counter and multiple response counters (one per http response status code).
func (m *HTTPMiddleware) AddHTTPRequestResponseCounter(endpoint string, method string) error {
meterKey := getRequestCounterKey(endpoint, method)
httpReqCounter, err := m.meter.SyncInt64().Counter(meterKey, instrument.WithUnit("1"))
if err != nil {
return err
}
m.httpRequestCounters[meterKey] = httpReqCounter
respCodes := []int{200, 204, 400, 401, 403, 404, 500, 502, 503}
for _, code := range respCodes {
meterKey = getResponseCounterKey(endpoint, method, code)
httpRespCounter, err := m.meter.SyncInt64().Counter(meterKey, instrument.WithUnit("1"))
if err != nil {
return err
}
m.httpResponseCounters[meterKey] = httpRespCounter
meterKey = fmt.Sprintf("%s_%d_total", httpResponseCounterPrefix, code)
totalHTTPResponseCodeCounter, err := m.meter.SyncInt64().Counter(meterKey, instrument.WithUnit("1"))
if err != nil {
return err
}
m.totalHTTPResponseCodeCounters[code] = totalHTTPResponseCodeCounter
}
return nil
}
// NewMetricsMiddleware creates a new HTTPMiddleware
func NewMetricsMiddleware(ctx context.Context, meter metric.Meter) (*HTTPMiddleware, error) {
totalHTTPRequestsCounter, err := meter.SyncInt64().Counter(
fmt.Sprintf("%s_total", httpRequestCounterPrefix),
instrument.WithUnit("1"))
if err != nil {
return nil, err
}
totalHTTPResponseCounter, err := meter.SyncInt64().Counter(
fmt.Sprintf("%s_total", httpResponseCounterPrefix),
instrument.WithUnit("1"))
if err != nil {
return nil, err
}
return &HTTPMiddleware{
ctx: ctx,
httpRequestCounters: map[string]syncint64.Counter{},
httpResponseCounters: map[string]syncint64.Counter{},
totalHTTPResponseCodeCounters: map[int]syncint64.Counter{},
meter: meter,
totalHTTPRequestsCounter: totalHTTPRequestsCounter,
totalHTTPResponseCounter: totalHTTPResponseCounter,
},
nil
}
func getRequestCounterKey(endpoint, method string) string {
return fmt.Sprintf("%s%s_%s", httpRequestCounterPrefix,
strings.ReplaceAll(endpoint, "/", "_"), method)
}
func getResponseCounterKey(endpoint, method string, status int) string {
return fmt.Sprintf("%s%s_%s_%d", httpResponseCounterPrefix,
strings.ReplaceAll(endpoint, "/", "_"), method, status)
}
// Handler logs every request and response and adds the, to metrics.
func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler {
fn := func(rw http.ResponseWriter, r *http.Request) {
traceID := hash(fmt.Sprintf("%v", r))
log.Tracef("HTTP request %v: %v %v", traceID, r.Method, r.URL)
metricKey := getRequestCounterKey(r.URL.Path, r.Method)
if c, ok := m.httpRequestCounters[metricKey]; ok {
c.Add(m.ctx, 1)
}
m.totalHTTPRequestsCounter.Add(m.ctx, 1)
w := WrapResponseWriter(rw)
h.ServeHTTP(w, r)
if w.Status() > 399 {
log.Errorf("HTTP response %v: %v %v status %v", traceID, r.Method, r.URL, w.Status())
} else {
log.Tracef("HTTP response %v: %v %v status %v", traceID, r.Method, r.URL, w.Status())
}
metricKey = getResponseCounterKey(r.URL.Path, r.Method, w.Status())
if c, ok := m.httpResponseCounters[metricKey]; ok {
c.Add(m.ctx, 1)
}
m.totalHTTPResponseCounter.Add(m.ctx, 1)
if c, ok := m.totalHTTPResponseCodeCounters[w.Status()]; ok {
c.Add(m.ctx, 1)
}
}
return http.HandlerFunc(fn)
}
func hash(s string) uint32 {
h := fnv.New32a()
_, err := h.Write([]byte(s))
if err != nil {
panic(err)
}
return h.Sum32()
}

View File

@@ -0,0 +1,119 @@
package telemetry
import (
"context"
"go.opentelemetry.io/otel/metric"
"go.opentelemetry.io/otel/metric/instrument"
"go.opentelemetry.io/otel/metric/instrument/syncint64"
)
// IDPMetrics is common IdP metrics
type IDPMetrics struct {
metaUpdateCounter syncint64.Counter
getUserByEmailCounter syncint64.Counter
getAllAccountsCounter syncint64.Counter
createUserCounter syncint64.Counter
getAccountCounter syncint64.Counter
getUserByIDCounter syncint64.Counter
authenticateRequestCounter syncint64.Counter
requestErrorCounter syncint64.Counter
requestStatusErrorCounter syncint64.Counter
ctx context.Context
}
// NewIDPMetrics creates new IDPMetrics struct and registers common
func NewIDPMetrics(ctx context.Context, meter metric.Meter) (*IDPMetrics, error) {
metaUpdateCounter, err := meter.SyncInt64().Counter("management.idp.update.user.meta.counter", instrument.WithUnit("1"))
if err != nil {
return nil, err
}
getUserByEmailCounter, err := meter.SyncInt64().Counter("management.idp.get.user.by.email.counter", instrument.WithUnit("1"))
if err != nil {
return nil, err
}
getAllAccountsCounter, err := meter.SyncInt64().Counter("management.idp.get.accounts.counter", instrument.WithUnit("1"))
if err != nil {
return nil, err
}
createUserCounter, err := meter.SyncInt64().Counter("management.idp.create.user.counter", instrument.WithUnit("1"))
if err != nil {
return nil, err
}
getAccountCounter, err := meter.SyncInt64().Counter("management.idp.get.account.counter", instrument.WithUnit("1"))
if err != nil {
return nil, err
}
getUserByIDCounter, err := meter.SyncInt64().Counter("management.idp.get.user.by.id.counter", instrument.WithUnit("1"))
if err != nil {
return nil, err
}
authenticateRequestCounter, err := meter.SyncInt64().Counter("management.idp.authenticate.request.counter", instrument.WithUnit("1"))
if err != nil {
return nil, err
}
requestErrorCounter, err := meter.SyncInt64().Counter("management.idp.request.error.counter", instrument.WithUnit("1"))
if err != nil {
return nil, err
}
requestStatusErrorCounter, err := meter.SyncInt64().Counter("management.idp.request.status.error.counter", instrument.WithUnit("1"))
if err != nil {
return nil, err
}
return &IDPMetrics{
metaUpdateCounter: metaUpdateCounter,
getUserByEmailCounter: getUserByEmailCounter,
getAllAccountsCounter: getAllAccountsCounter,
createUserCounter: createUserCounter,
getAccountCounter: getAccountCounter,
getUserByIDCounter: getUserByIDCounter,
authenticateRequestCounter: authenticateRequestCounter,
requestErrorCounter: requestErrorCounter,
requestStatusErrorCounter: requestStatusErrorCounter,
ctx: ctx}, nil
}
// CountUpdateUserAppMetadata ...
func (idpMetrics *IDPMetrics) CountUpdateUserAppMetadata() {
idpMetrics.metaUpdateCounter.Add(idpMetrics.ctx, 1)
}
// CountGetUserByEmail ...
func (idpMetrics *IDPMetrics) CountGetUserByEmail() {
idpMetrics.getUserByEmailCounter.Add(idpMetrics.ctx, 1)
}
// CountCreateUser ...
func (idpMetrics *IDPMetrics) CountCreateUser() {
idpMetrics.createUserCounter.Add(idpMetrics.ctx, 1)
}
// CountGetAllAccounts ...
func (idpMetrics *IDPMetrics) CountGetAllAccounts() {
idpMetrics.getAllAccountsCounter.Add(idpMetrics.ctx, 1)
}
// CountGetAccount ...
func (idpMetrics *IDPMetrics) CountGetAccount() {
idpMetrics.getAccountCounter.Add(idpMetrics.ctx, 1)
}
// CountGetUserDataByID ...
func (idpMetrics *IDPMetrics) CountGetUserDataByID() {
idpMetrics.getUserByIDCounter.Add(idpMetrics.ctx, 1)
}
// CountAuthenticate ...
func (idpMetrics *IDPMetrics) CountAuthenticate() {
idpMetrics.authenticateRequestCounter.Add(idpMetrics.ctx, 1)
}
// CountRequestError counts number of error that happened when doing http request (httpClient.Do)
func (idpMetrics *IDPMetrics) CountRequestError() {
idpMetrics.requestErrorCounter.Add(idpMetrics.ctx, 1)
}
// CountRequestStatusError counts number of responses that came from IdP with non success status code
func (idpMetrics *IDPMetrics) CountRequestStatusError() {
idpMetrics.requestStatusErrorCounter.Add(idpMetrics.ctx, 1)
}

View File

@@ -46,6 +46,11 @@ type User struct {
AutoGroups []string AutoGroups []string
} }
// IsAdmin returns true if user is an admin, false otherwise
func (u *User) IsAdmin() bool {
return u.Role == UserRoleAdmin
}
// toUserInfo converts a User object to a UserInfo object. // toUserInfo converts a User object to a UserInfo object.
func (u *User) toUserInfo(userData *idp.UserData) (*UserInfo, error) { func (u *User) toUserInfo(userData *idp.UserData) (*UserInfo, error) {
autoGroups := u.AutoGroups autoGroups := u.AutoGroups
@@ -68,7 +73,7 @@ func (u *User) toUserInfo(userData *idp.UserData) (*UserInfo, error) {
} }
userStatus := UserStatusActive userStatus := UserStatusActive
if userData.AppMetadata.WTPendingInvite { if userData.AppMetadata.WTPendingInvite != nil && *userData.AppMetadata.WTPendingInvite {
userStatus = UserStatusInvited userStatus = UserStatusInvited
} }
@@ -114,8 +119,8 @@ func NewAdminUser(id string) *User {
// CreateUser creates a new user under the given account. Effectively this is a user invite. // CreateUser creates a new user under the given account. Effectively this is a user invite.
func (am *DefaultAccountManager) CreateUser(accountID string, invite *UserInfo) (*UserInfo, error) { func (am *DefaultAccountManager) CreateUser(accountID string, invite *UserInfo) (*UserInfo, error) {
am.mux.Lock() unlock := am.Store.AcquireAccountLock(accountID)
defer am.mux.Unlock() defer unlock()
if am.idpManager == nil { if am.idpManager == nil {
return nil, Errorf(PreconditionFailed, "IdP manager must be enabled to send user invites") return nil, Errorf(PreconditionFailed, "IdP manager must be enabled to send user invites")
@@ -179,8 +184,8 @@ func (am *DefaultAccountManager) CreateUser(accountID string, invite *UserInfo)
// SaveUser saves updates a given user. If the user doesn't exit it will throw status.NotFound error. // SaveUser saves updates a given user. If the user doesn't exit it will throw status.NotFound error.
// Only User.AutoGroups field is allowed to be updated for now. // Only User.AutoGroups field is allowed to be updated for now.
func (am *DefaultAccountManager) SaveUser(accountID string, update *User) (*UserInfo, error) { func (am *DefaultAccountManager) SaveUser(accountID string, update *User) (*UserInfo, error) {
am.mux.Lock() unlock := am.Store.AcquireAccountLock(accountID)
defer am.mux.Unlock() defer unlock()
if update == nil { if update == nil {
return nil, status.Errorf(codes.InvalidArgument, "provided user update is nil") return nil, status.Errorf(codes.InvalidArgument, "provided user update is nil")
@@ -229,16 +234,16 @@ func (am *DefaultAccountManager) SaveUser(accountID string, update *User) (*User
} }
// GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist // GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist
func (am *DefaultAccountManager) GetOrCreateAccountByUser(userId, domain string) (*Account, error) { func (am *DefaultAccountManager) GetOrCreateAccountByUser(userID, domain string) (*Account, error) {
am.mux.Lock() unlock := am.Store.AcquireGlobalLock()
defer am.mux.Unlock() defer unlock()
lowerDomain := strings.ToLower(domain) lowerDomain := strings.ToLower(domain)
account, err := am.Store.GetUserAccount(userId) account, err := am.Store.GetAccountByUser(userID)
if err != nil { if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound { if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
account, err = am.newAccount(userId, lowerDomain) account, err = am.newAccount(userID, lowerDomain)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -252,7 +257,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userId, domain string)
} }
} }
userObj := account.Users[userId] userObj := account.Users[userID]
if account.Domain != lowerDomain && userObj.Role == UserRoleAdmin { if account.Domain != lowerDomain && userObj.Role == UserRoleAdmin {
account.Domain = lowerDomain account.Domain = lowerDomain
@@ -265,14 +270,6 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userId, domain string)
return account, nil return account, nil
} }
// GetAccountByUser returns an existing account for a given user id, NotFound if account couldn't be found
func (am *DefaultAccountManager) GetAccountByUser(userId string) (*Account, error) {
am.mux.Lock()
defer am.mux.Unlock()
return am.Store.GetUserAccount(userId)
}
// IsUserAdmin flag for current user authenticated by JWT token // IsUserAdmin flag for current user authenticated by JWT token
func (am *DefaultAccountManager) IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) { func (am *DefaultAccountManager) IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) {
account, err := am.GetAccountFromToken(claims) account, err := am.GetAccountFromToken(claims)
@@ -288,9 +285,15 @@ func (am *DefaultAccountManager) IsUserAdmin(claims jwtclaims.AuthorizationClaim
return user.Role == UserRoleAdmin, nil return user.Role == UserRoleAdmin, nil
} }
// GetUsersFromAccount performs a batched request for users from IDP by account ID // GetUsersFromAccount performs a batched request for users from IDP by account ID apply filter on what data to return
func (am *DefaultAccountManager) GetUsersFromAccount(accountID string) ([]*UserInfo, error) { // based on provided user role.
account, err := am.GetAccountById(accountID) func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error) {
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, err
}
user, err := account.FindUser(userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -311,8 +314,12 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID string) ([]*UserI
// in case of self-hosted, or IDP doesn't return anything, we will return the locally stored userInfo // in case of self-hosted, or IDP doesn't return anything, we will return the locally stored userInfo
if len(queriedUsers) == 0 { if len(queriedUsers) == 0 {
for _, user := range account.Users { for _, accountUser := range account.Users {
info, err := user.toUserInfo(nil) if !user.IsAdmin() && user.Id != accountUser.Id {
// if user is not an admin then show only current user and do not show other users
continue
}
info, err := accountUser.toUserInfo(nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -322,6 +329,10 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID string) ([]*UserI
} }
for _, queriedUser := range queriedUsers { for _, queriedUser := range queriedUsers {
if !user.IsAdmin() && user.Id != queriedUser.ID {
// if user is not an admin then show only current user and do not show other users
continue
}
if localUser, contains := account.Users[queriedUser.ID]; contains { if localUser, contains := account.Users[queriedUser.ID]; contains {
info, err := localUser.toUserInfo(queriedUser) info, err := localUser.toUserInfo(queriedUser)

View File

@@ -3,7 +3,6 @@ package util
import ( import (
"encoding/json" "encoding/json"
"io" "io"
"io/ioutil"
"os" "os"
"path/filepath" "path/filepath"
) )
@@ -24,7 +23,7 @@ func WriteJson(file string, obj interface{}) error {
return err return err
} }
tempFile, err := ioutil.TempFile(configDir, ".*"+configFileName) tempFile, err := os.CreateTemp(configDir, ".*"+configFileName)
if err != nil { if err != nil {
return err return err
} }
@@ -43,7 +42,7 @@ func WriteJson(file string, obj interface{}) error {
} }
}() }()
err = ioutil.WriteFile(tempFileName, bs, 0600) err = os.WriteFile(tempFileName, bs, 0600)
if err != nil { if err != nil {
return err return err
} }
@@ -65,7 +64,7 @@ func ReadJson(file string, res interface{}) (interface{}, error) {
} }
defer f.Close() defer f.Close()
bs, err := ioutil.ReadAll(f) bs, err := io.ReadAll(f)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -40,7 +40,7 @@ func InitLog(logLevel string, logPath string) error {
return "", fileName return "", fileName
} }
if level == log.DebugLevel { if level > log.WarnLevel {
log.SetReportCaller(true) log.SetReportCaller(true)
} }