mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-17 07:46:38 +00:00
Compare commits
21 Commits
v0.30.3
...
test/callb
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e5d9e3fb13 | ||
|
|
3d35d6fe09 | ||
|
|
a9d06b883f | ||
|
|
5f06b202c3 | ||
|
|
0eb99c266a | ||
|
|
bac95ace18 | ||
|
|
9812de853b | ||
|
|
ad4f0a6fdf | ||
|
|
4c758c6e52 | ||
|
|
ec5095ba6b | ||
|
|
49a54624f8 | ||
|
|
729bcf2b01 | ||
|
|
a0cdb58303 | ||
|
|
39c99781cb | ||
|
|
01f24907c5 | ||
|
|
10480eb52f | ||
|
|
1e44c5b574 | ||
|
|
940f8b4547 | ||
|
|
46e37fa04c | ||
|
|
b9f205b2ce | ||
|
|
0fd874fa45 |
3
.github/FUNDING.yml
vendored
Normal file
3
.github/FUNDING.yml
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
# These are supported funding model platforms
|
||||
|
||||
github: [netbirdio]
|
||||
@@ -680,7 +680,7 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
|
||||
func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool {
|
||||
statusEval := false
|
||||
ipEval := false
|
||||
nameEval := false
|
||||
nameEval := true
|
||||
|
||||
if statusFilter != "" {
|
||||
lowerStatusFilter := strings.ToLower(statusFilter)
|
||||
@@ -700,11 +700,13 @@ func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool {
|
||||
|
||||
if len(prefixNamesFilter) > 0 {
|
||||
for prefixNameFilter := range prefixNamesFilterMap {
|
||||
if !strings.HasPrefix(peerState.Fqdn, prefixNameFilter) {
|
||||
nameEval = true
|
||||
if strings.HasPrefix(peerState.Fqdn, prefixNameFilter) {
|
||||
nameEval = false
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
nameEval = false
|
||||
}
|
||||
|
||||
return statusEval || ipEval || nameEval
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
@@ -37,62 +38,55 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewal
|
||||
// in any case, because we need to allow netbird interface traffic
|
||||
// so we use AllowNetbird traffic from these firewall managers
|
||||
// for the userspace packet filtering firewall
|
||||
fm, errFw := createNativeFirewall(iface)
|
||||
fm, err := createNativeFirewall(iface, stateManager)
|
||||
|
||||
if fm != nil {
|
||||
if err := fm.Init(stateManager); err != nil {
|
||||
log.Errorf("failed to init nftables manager: %s", err)
|
||||
}
|
||||
if !iface.IsUserspaceBind() {
|
||||
return fm, err
|
||||
}
|
||||
|
||||
if iface.IsUserspaceBind() {
|
||||
return createUserspaceFirewall(iface, fm, errFw)
|
||||
if err != nil {
|
||||
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
|
||||
}
|
||||
|
||||
return fm, errFw
|
||||
return createUserspaceFirewall(iface, fm)
|
||||
}
|
||||
|
||||
func createNativeFirewall(iface IFaceMapper) (firewall.Manager, error) {
|
||||
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) {
|
||||
fm, err := createFW(iface)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create firewall: %s", err)
|
||||
}
|
||||
|
||||
if err = fm.Init(stateManager); err != nil {
|
||||
return nil, fmt.Errorf("init firewall: %s", err)
|
||||
}
|
||||
|
||||
return fm, nil
|
||||
}
|
||||
|
||||
func createFW(iface IFaceMapper) (firewall.Manager, error) {
|
||||
switch check() {
|
||||
case IPTABLES:
|
||||
return createIptablesFirewall(iface)
|
||||
log.Info("creating an iptables firewall manager")
|
||||
return nbiptables.Create(iface)
|
||||
case NFTABLES:
|
||||
return createNftablesFirewall(iface)
|
||||
log.Info("creating an nftables firewall manager")
|
||||
return nbnftables.Create(iface)
|
||||
default:
|
||||
log.Info("no firewall manager found, trying to use userspace packet filtering firewall")
|
||||
return nil, fmt.Errorf("no firewall manager found")
|
||||
return nil, errors.New("no firewall manager found")
|
||||
}
|
||||
}
|
||||
|
||||
func createIptablesFirewall(iface IFaceMapper) (firewall.Manager, error) {
|
||||
log.Info("creating an iptables firewall manager")
|
||||
fm, err := nbiptables.Create(iface)
|
||||
if err != nil {
|
||||
log.Errorf("failed to create iptables manager: %s", err)
|
||||
}
|
||||
return fm, err
|
||||
}
|
||||
|
||||
func createNftablesFirewall(iface IFaceMapper) (firewall.Manager, error) {
|
||||
log.Info("creating an nftables firewall manager")
|
||||
fm, err := nbnftables.Create(iface)
|
||||
if err != nil {
|
||||
log.Errorf("failed to create nftables manager: %s", err)
|
||||
}
|
||||
return fm, err
|
||||
}
|
||||
|
||||
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, errFw error) (firewall.Manager, error) {
|
||||
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager) (firewall.Manager, error) {
|
||||
var errUsp error
|
||||
if errFw == nil {
|
||||
if fm != nil {
|
||||
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm)
|
||||
} else {
|
||||
fm, errUsp = uspfilter.Create(iface)
|
||||
}
|
||||
|
||||
if errUsp != nil {
|
||||
log.Debugf("failed to create userspace filtering firewall: %s", errUsp)
|
||||
return nil, errUsp
|
||||
return nil, fmt.Errorf("create userspace firewall: %s", errUsp)
|
||||
}
|
||||
|
||||
if err := fm.AllowNetbird(); err != nil {
|
||||
|
||||
@@ -296,6 +296,8 @@ func (r *router) RemoveAllLegacyRouteRules() error {
|
||||
}
|
||||
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
|
||||
} else {
|
||||
delete(r.rules, k)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -230,23 +230,7 @@ func (m *Manager) AllowNetbird() error {
|
||||
|
||||
// SetLegacyManagement sets the route manager to use legacy management
|
||||
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||
oldLegacy := m.router.legacyManagement
|
||||
|
||||
if oldLegacy != isLegacy {
|
||||
m.router.legacyManagement = isLegacy
|
||||
log.Debugf("Set legacy management to %v", isLegacy)
|
||||
}
|
||||
|
||||
// client reconnected to a newer mgmt, we need to cleanup the legacy rules
|
||||
if !isLegacy && oldLegacy {
|
||||
if err := m.router.RemoveAllLegacyRouteRules(); err != nil {
|
||||
return fmt.Errorf("remove legacy routing rules: %v", err)
|
||||
}
|
||||
|
||||
log.Debugf("Legacy routing rules removed")
|
||||
}
|
||||
|
||||
return nil
|
||||
return firewall.SetLegacyManagement(m.router, isLegacy)
|
||||
}
|
||||
|
||||
// Reset firewall to the default state
|
||||
|
||||
@@ -551,7 +551,10 @@ func (r *router) RemoveAllLegacyRouteRules() error {
|
||||
}
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
|
||||
} else {
|
||||
delete(r.rules, k)
|
||||
}
|
||||
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
@@ -237,8 +237,11 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
||||
}
|
||||
|
||||
// SetLegacyManagement doesn't need to be implemented for this manager
|
||||
func (m *Manager) SetLegacyManagement(_ bool) error {
|
||||
return nil
|
||||
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return errRouteNotSupported
|
||||
}
|
||||
return m.nativeFirewall.SetLegacyManagement(isLegacy)
|
||||
}
|
||||
|
||||
// Flush doesn't need to be implemented for this manager
|
||||
|
||||
@@ -104,8 +104,8 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
|
||||
}
|
||||
}()
|
||||
|
||||
buf := make([]byte, 1500)
|
||||
for {
|
||||
buf := make([]byte, 1500)
|
||||
n, err := p.remoteConn.Read(buf)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
|
||||
@@ -3,6 +3,7 @@ package acl
|
||||
import (
|
||||
"crypto/md5"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
@@ -10,14 +11,18 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
)
|
||||
|
||||
var ErrSourceRangesEmpty = errors.New("sources range is empty")
|
||||
|
||||
// Manager is a ACL rules manager
|
||||
type Manager interface {
|
||||
ApplyFiltering(networkMap *mgmProto.NetworkMap)
|
||||
@@ -167,31 +172,40 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
||||
}
|
||||
|
||||
func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule) error {
|
||||
var newRouteRules = make(map[id.RuleID]struct{})
|
||||
newRouteRules := make(map[id.RuleID]struct{}, len(rules))
|
||||
var merr *multierror.Error
|
||||
|
||||
// Apply new rules - firewall manager will return existing rule ID if already present
|
||||
for _, rule := range rules {
|
||||
id, err := d.applyRouteACL(rule)
|
||||
if err != nil {
|
||||
return fmt.Errorf("apply route ACL: %w", err)
|
||||
if errors.Is(err, ErrSourceRangesEmpty) {
|
||||
log.Debugf("skipping empty rule with destination %s: %v", rule.Destination, err)
|
||||
} else {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add route rule: %w", err))
|
||||
}
|
||||
continue
|
||||
}
|
||||
newRouteRules[id] = struct{}{}
|
||||
}
|
||||
|
||||
// Clean up old firewall rules
|
||||
for id := range d.routeRules {
|
||||
if _, ok := newRouteRules[id]; !ok {
|
||||
if _, exists := newRouteRules[id]; !exists {
|
||||
if err := d.firewall.DeleteRouteRule(id); err != nil {
|
||||
log.Errorf("failed to delete route firewall rule: %v", err)
|
||||
continue
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete route rule: %w", err))
|
||||
}
|
||||
delete(d.routeRules, id)
|
||||
// implicitly deleted from the map
|
||||
}
|
||||
}
|
||||
|
||||
d.routeRules = newRouteRules
|
||||
return nil
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.RuleID, error) {
|
||||
if len(rule.SourceRanges) == 0 {
|
||||
return "", fmt.Errorf("source ranges is empty")
|
||||
return "", ErrSourceRangesEmpty
|
||||
}
|
||||
|
||||
var sources []netip.Prefix
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"math/rand"
|
||||
"net"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -65,14 +64,6 @@ type ConnConfig struct {
|
||||
ICEConfig icemaker.Config
|
||||
}
|
||||
|
||||
type WorkerCallbacks struct {
|
||||
OnRelayReadyCallback func(info RelayConnInfo)
|
||||
OnRelayStatusChanged func(ConnStatus)
|
||||
|
||||
OnICEConnReadyCallback func(ConnPriority, ICEConnInfo)
|
||||
OnICEStatusChanged func(ConnStatus)
|
||||
}
|
||||
|
||||
type Conn struct {
|
||||
log *log.Entry
|
||||
mu sync.Mutex
|
||||
@@ -134,32 +125,16 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
|
||||
statusICE: NewAtomicConnStatus(),
|
||||
}
|
||||
|
||||
rFns := WorkerRelayCallbacks{
|
||||
OnConnReady: conn.relayConnectionIsReady,
|
||||
OnDisconnected: conn.onWorkerRelayStateDisconnected,
|
||||
}
|
||||
|
||||
wFns := WorkerICECallbacks{
|
||||
OnConnReady: conn.iCEConnectionIsReady,
|
||||
OnStatusChanged: conn.onWorkerICEStateDisconnected,
|
||||
}
|
||||
|
||||
ctrl := isController(config)
|
||||
conn.workerRelay = NewWorkerRelay(connLog, ctrl, config, relayManager, rFns)
|
||||
|
||||
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
|
||||
conn.workerICE, err = NewWorkerICE(ctx, connLog, config, signaler, iFaceDiscover, statusRecorder, relayIsSupportedLocally, wFns)
|
||||
conn.workerRelay = NewWorkerRelay(connLog, ctrl, config, conn, relayManager)
|
||||
conn.workerICE, err = NewWorkerICE(ctx, connLog, config, conn, signaler, iFaceDiscover, statusRecorder, conn.workerRelay.RelayIsSupportedLocally())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conn.handshaker = NewHandshaker(ctx, connLog, config, signaler, conn.workerICE, conn.workerRelay)
|
||||
|
||||
conn.handshaker.AddOnNewOfferListener(conn.workerRelay.OnNewOffer)
|
||||
if os.Getenv("NB_FORCE_RELAY") != "true" {
|
||||
conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer)
|
||||
}
|
||||
|
||||
conn.guard = guard.NewGuard(connLog, ctrl, conn.isConnectedOnAllWay, config.Timeout, srWatcher)
|
||||
|
||||
go conn.handshaker.Listen()
|
||||
@@ -301,7 +276,7 @@ func (conn *Conn) GetKey() string {
|
||||
}
|
||||
|
||||
// configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected
|
||||
func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICEConnInfo) {
|
||||
func (conn *Conn) iCEConnectionIsReady(iceConnInfo ICEConnInfo) {
|
||||
conn.mu.Lock()
|
||||
defer conn.mu.Unlock()
|
||||
|
||||
@@ -309,9 +284,18 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon
|
||||
return
|
||||
}
|
||||
|
||||
/*
|
||||
// temporarily disabled the check
|
||||
if remoteConnNil(conn.log, iceConnInfo.RemoteConn) {
|
||||
conn.log.Errorf("remote ICE connection is nil")
|
||||
return
|
||||
}
|
||||
|
||||
*/
|
||||
|
||||
conn.log.Debugf("ICE connection is ready")
|
||||
|
||||
if conn.currentConnPriority > priority {
|
||||
if conn.currentConnPriority > iceConnInfo.Priority {
|
||||
conn.statusICE.Set(StatusConnected)
|
||||
conn.updateIceState(iceConnInfo)
|
||||
return
|
||||
@@ -361,14 +345,14 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon
|
||||
return
|
||||
}
|
||||
wgConfigWorkaround()
|
||||
conn.currentConnPriority = priority
|
||||
conn.currentConnPriority = iceConnInfo.Priority
|
||||
conn.statusICE.Set(StatusConnected)
|
||||
conn.updateIceState(iceConnInfo)
|
||||
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
|
||||
}
|
||||
|
||||
// todo review to make sense to handle connecting and disconnected status also?
|
||||
func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
|
||||
func (conn *Conn) onWorkerICEStateDisconnected() {
|
||||
conn.mu.Lock()
|
||||
defer conn.mu.Unlock()
|
||||
|
||||
@@ -376,7 +360,7 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
|
||||
return
|
||||
}
|
||||
|
||||
conn.log.Tracef("ICE connection state changed to %s", newState)
|
||||
conn.log.Tracef("ICE connection state changed to disconnected")
|
||||
|
||||
if conn.wgProxyICE != nil {
|
||||
if err := conn.wgProxyICE.CloseConn(); err != nil {
|
||||
@@ -396,8 +380,8 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
|
||||
conn.currentConnPriority = connPriorityRelay
|
||||
}
|
||||
|
||||
changed := conn.statusICE.Get() != newState && newState != StatusConnecting
|
||||
conn.statusICE.Set(newState)
|
||||
changed := conn.statusICE.Get() != StatusDisconnected
|
||||
conn.statusICE.Set(StatusDisconnected)
|
||||
|
||||
conn.guard.SetICEConnDisconnected(changed)
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package peer
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -58,7 +59,7 @@ type Handshaker struct {
|
||||
}
|
||||
|
||||
func NewHandshaker(ctx context.Context, log *log.Entry, config ConnConfig, signaler *Signaler, ice *WorkerICE, relay *WorkerRelay) *Handshaker {
|
||||
return &Handshaker{
|
||||
hs := &Handshaker{
|
||||
ctx: ctx,
|
||||
log: log,
|
||||
config: config,
|
||||
@@ -68,10 +69,12 @@ func NewHandshaker(ctx context.Context, log *log.Entry, config ConnConfig, signa
|
||||
remoteOffersCh: make(chan OfferAnswer),
|
||||
remoteAnswerCh: make(chan OfferAnswer),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handshaker) AddOnNewOfferListener(offer func(remoteOfferAnswer *OfferAnswer)) {
|
||||
h.onNewOfferListeners = append(h.onNewOfferListeners, offer)
|
||||
hs.onNewOfferListeners = append(hs.onNewOfferListeners, hs.relay.OnNewOffer)
|
||||
if os.Getenv("NB_FORCE_RELAY") != "true" {
|
||||
hs.onNewOfferListeners = append(hs.onNewOfferListeners, hs.ice.OnNewOffer)
|
||||
}
|
||||
return hs
|
||||
}
|
||||
|
||||
func (h *Handshaker) Listen() {
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
package ice
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"time"
|
||||
|
||||
"github.com/pion/ice/v3"
|
||||
"github.com/pion/randutil"
|
||||
"github.com/pion/stun/v2"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -77,10 +78,7 @@ func CandidateTypes() []ice.CandidateType {
|
||||
if hasICEForceRelayConn() {
|
||||
return []ice.CandidateType{ice.CandidateTypeRelay}
|
||||
}
|
||||
// TODO: remove this once we have refactored userspace proxy into the bind package
|
||||
if runtime.GOOS == "ios" {
|
||||
return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive}
|
||||
}
|
||||
|
||||
return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay}
|
||||
}
|
||||
|
||||
|
||||
21
client/internal/peer/nilcheck.go
Normal file
21
client/internal/peer/nilcheck.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func remoteConnNil(log *log.Entry, conn net.Conn) bool {
|
||||
if conn == nil {
|
||||
log.Errorf("ice conn is nil")
|
||||
return true
|
||||
}
|
||||
|
||||
if conn.RemoteAddr() == nil {
|
||||
log.Errorf("ICE remote address is nil")
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
@@ -29,24 +29,18 @@ type ICEConnInfo struct {
|
||||
LocalIceCandidateEndpoint string
|
||||
Relayed bool
|
||||
RelayedOnLocal bool
|
||||
}
|
||||
|
||||
type WorkerICECallbacks struct {
|
||||
OnConnReady func(ConnPriority, ICEConnInfo)
|
||||
OnStatusChanged func(ConnStatus)
|
||||
Priority ConnPriority
|
||||
}
|
||||
|
||||
type WorkerICE struct {
|
||||
ctx context.Context
|
||||
log *log.Entry
|
||||
config ConnConfig
|
||||
conn *Conn
|
||||
signaler *Signaler
|
||||
iFaceDiscover stdnet.ExternalIFaceDiscover
|
||||
statusRecorder *Status
|
||||
hasRelayOnLocally bool
|
||||
conn WorkerICECallbacks
|
||||
|
||||
selectedPriority ConnPriority
|
||||
|
||||
agent *ice.Agent
|
||||
muxAgent sync.Mutex
|
||||
@@ -59,16 +53,16 @@ type WorkerICE struct {
|
||||
localPwd string
|
||||
}
|
||||
|
||||
func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool, callBacks WorkerICECallbacks) (*WorkerICE, error) {
|
||||
func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn *Conn, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool) (*WorkerICE, error) {
|
||||
w := &WorkerICE{
|
||||
ctx: ctx,
|
||||
log: log,
|
||||
config: config,
|
||||
conn: conn,
|
||||
signaler: signaler,
|
||||
iFaceDiscover: ifaceDiscover,
|
||||
statusRecorder: statusRecorder,
|
||||
hasRelayOnLocally: hasRelayOnLocally,
|
||||
conn: callBacks,
|
||||
}
|
||||
|
||||
localUfrag, localPwd, err := icemaker.GenerateICECredentials()
|
||||
@@ -90,12 +84,16 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
||||
return
|
||||
}
|
||||
|
||||
var preferredCandidateTypes []ice.CandidateType
|
||||
var (
|
||||
preferredCandidateTypes []ice.CandidateType
|
||||
selectedPriority ConnPriority
|
||||
)
|
||||
|
||||
if w.hasRelayOnLocally && remoteOfferAnswer.RelaySrvAddress != "" {
|
||||
w.selectedPriority = connPriorityICEP2P
|
||||
selectedPriority = connPriorityICEP2P
|
||||
preferredCandidateTypes = icemaker.CandidateTypesP2P()
|
||||
} else {
|
||||
w.selectedPriority = connPriorityICETurn
|
||||
selectedPriority = connPriorityICETurn
|
||||
preferredCandidateTypes = icemaker.CandidateTypes()
|
||||
}
|
||||
|
||||
@@ -154,9 +152,10 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
||||
RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Remote.Port()),
|
||||
Relayed: isRelayed(pair),
|
||||
RelayedOnLocal: isRelayCandidate(pair.Local),
|
||||
Priority: selectedPriority,
|
||||
}
|
||||
w.log.Debugf("on ICE conn read to use ready")
|
||||
go w.conn.OnConnReady(w.selectedPriority, ci)
|
||||
w.conn.iCEConnectionIsReady(ci)
|
||||
}
|
||||
|
||||
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
|
||||
@@ -216,7 +215,7 @@ func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, candidates []i
|
||||
err = agent.OnConnectionStateChange(func(state ice.ConnectionState) {
|
||||
w.log.Debugf("ICE ConnectionState has changed to %s", state.String())
|
||||
if state == ice.ConnectionStateFailed || state == ice.ConnectionStateDisconnected {
|
||||
w.conn.OnStatusChanged(StatusDisconnected)
|
||||
w.conn.onWorkerICEStateDisconnected()
|
||||
|
||||
w.muxAgent.Lock()
|
||||
agentCancel()
|
||||
|
||||
@@ -34,7 +34,7 @@ type WorkerRelay struct {
|
||||
isController bool
|
||||
config ConnConfig
|
||||
relayManager relayClient.ManagerService
|
||||
callBacks WorkerRelayCallbacks
|
||||
conn *Conn
|
||||
|
||||
relayedConn net.Conn
|
||||
relayLock sync.Mutex
|
||||
@@ -45,13 +45,13 @@ type WorkerRelay struct {
|
||||
relaySupportedOnRemotePeer atomic.Bool
|
||||
}
|
||||
|
||||
func NewWorkerRelay(log *log.Entry, ctrl bool, config ConnConfig, relayManager relayClient.ManagerService, callbacks WorkerRelayCallbacks) *WorkerRelay {
|
||||
func NewWorkerRelay(log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager *relayClient.Manager) *WorkerRelay {
|
||||
r := &WorkerRelay{
|
||||
log: log,
|
||||
isController: ctrl,
|
||||
config: config,
|
||||
relayManager: relayManager,
|
||||
callBacks: callbacks,
|
||||
conn: conn,
|
||||
}
|
||||
return r
|
||||
}
|
||||
@@ -95,7 +95,7 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
||||
}
|
||||
|
||||
w.log.Debugf("peer conn opened via Relay: %s", srv)
|
||||
go w.callBacks.OnConnReady(RelayConnInfo{
|
||||
go w.conn.relayConnectionIsReady(RelayConnInfo{
|
||||
relayedConn: relayedConn,
|
||||
rosenpassPubKey: remoteOfferAnswer.RosenpassPubKey,
|
||||
rosenpassAddr: remoteOfferAnswer.RosenpassAddr,
|
||||
@@ -187,7 +187,7 @@ func (w *WorkerRelay) wgStateCheck(ctx context.Context, ctxCancel context.Cancel
|
||||
w.relayLock.Lock()
|
||||
_ = w.relayedConn.Close()
|
||||
w.relayLock.Unlock()
|
||||
w.callBacks.OnDisconnected()
|
||||
go w.conn.onWorkerRelayStateDisconnected()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -232,5 +232,5 @@ func (w *WorkerRelay) onRelayMGDisconnected() {
|
||||
if w.ctxCancelWgWatch != nil {
|
||||
w.ctxCancelWgWatch()
|
||||
}
|
||||
go w.callBacks.OnDisconnected()
|
||||
go w.conn.onWorkerRelayStateDisconnected()
|
||||
}
|
||||
|
||||
7
client/server/panic_generic.go
Normal file
7
client/server/panic_generic.go
Normal file
@@ -0,0 +1,7 @@
|
||||
//go:build !windows
|
||||
|
||||
package server
|
||||
|
||||
func handlePanicLog() error {
|
||||
return nil
|
||||
}
|
||||
83
client/server/panic_windows.go
Normal file
83
client/server/panic_windows.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
const (
|
||||
windowsPanicLogEnvVar = "NB_WINDOWS_PANIC_LOG"
|
||||
// STD_ERROR_HANDLE ((DWORD)-12) = 4294967284
|
||||
stdErrorHandle = ^uintptr(11)
|
||||
)
|
||||
|
||||
var (
|
||||
kernel32 = syscall.NewLazyDLL("kernel32.dll")
|
||||
|
||||
// https://learn.microsoft.com/en-us/windows/console/setstdhandle
|
||||
setStdHandleFn = kernel32.NewProc("SetStdHandle")
|
||||
)
|
||||
|
||||
func handlePanicLog() error {
|
||||
logPath := os.Getenv(windowsPanicLogEnvVar)
|
||||
if logPath == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensure the directory exists
|
||||
logDir := filepath.Dir(logPath)
|
||||
if err := os.MkdirAll(logDir, 0750); err != nil {
|
||||
return fmt.Errorf("create panic log directory: %w", err)
|
||||
}
|
||||
if err := util.EnforcePermission(logPath); err != nil {
|
||||
return fmt.Errorf("enforce permission on panic log file: %w", err)
|
||||
}
|
||||
|
||||
// Open log file with append mode
|
||||
f, err := os.OpenFile(logPath, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open panic log file: %w", err)
|
||||
}
|
||||
|
||||
// Redirect stderr to the file
|
||||
if err = redirectStderr(f); err != nil {
|
||||
if closeErr := f.Close(); closeErr != nil {
|
||||
log.Warnf("failed to close file after redirect error: %v", closeErr)
|
||||
}
|
||||
return fmt.Errorf("redirect stderr: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("successfully configured panic logging to: %s", logPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
// redirectStderr redirects stderr to the provided file
|
||||
func redirectStderr(f *os.File) error {
|
||||
// Get the current process's stderr handle
|
||||
if err := setStdHandle(f); err != nil {
|
||||
return fmt.Errorf("failed to set stderr handle: %w", err)
|
||||
}
|
||||
|
||||
// Also set os.Stderr for Go's standard library
|
||||
os.Stderr = f
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func setStdHandle(f *os.File) error {
|
||||
handle := f.Fd()
|
||||
r0, _, e1 := setStdHandleFn.Call(stdErrorHandle, handle)
|
||||
if r0 == 0 {
|
||||
if e1 != nil {
|
||||
return e1
|
||||
}
|
||||
return syscall.EINVAL
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -97,6 +97,10 @@ func (s *Server) Start() error {
|
||||
defer s.mutex.Unlock()
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
|
||||
if err := handlePanicLog(); err != nil {
|
||||
log.Warnf("failed to redirect stderr: %v", err)
|
||||
}
|
||||
|
||||
if err := restoreResidualState(s.rootCtx); err != nil {
|
||||
log.Warnf(errRestoreResidualState, err)
|
||||
}
|
||||
|
||||
126
funding.json
Normal file
126
funding.json
Normal file
@@ -0,0 +1,126 @@
|
||||
{
|
||||
"version": "v1.0.0",
|
||||
"entity": {
|
||||
"type": "organisation",
|
||||
"role": "owner",
|
||||
"name": "NetBird GmbH",
|
||||
"email": "hello@netbird.io",
|
||||
"phone": "",
|
||||
"description": "NetBird GmbH is a Berlin-based software company specializing in the development of open-source network security solutions. Network security is utterly complex and expensive, accessible only to companies with multi-million dollar IT budgets. In contrast, there are millions of companies left behind. Our mission is to create an advanced network and cybersecurity platform that is both easy-to-use and affordable for teams of all sizes and budgets. By leveraging the open-source strategy and technological advancements, NetBird aims to set the industry standard for connecting and securing IT infrastructure.",
|
||||
"webpageUrl": {
|
||||
"url": "https://github.com/netbirdio"
|
||||
}
|
||||
},
|
||||
"projects": [
|
||||
{
|
||||
"guid": "netbird",
|
||||
"name": "NetBird",
|
||||
"description": "NetBird is a configuration-free peer-to-peer private network and a centralized access control system combined in a single open-source platform. It makes it easy to create secure WireGuard-based private networks for your organization or home.",
|
||||
"webpageUrl": {
|
||||
"url": "https://github.com/netbirdio/netbird"
|
||||
},
|
||||
"repositoryUrl": {
|
||||
"url": "https://github.com/netbirdio/netbird"
|
||||
},
|
||||
"licenses": [
|
||||
"BSD-3"
|
||||
],
|
||||
"tags": [
|
||||
"network-security",
|
||||
"vpn",
|
||||
"developer-tools",
|
||||
"ztna",
|
||||
"zero-trust",
|
||||
"remote-access",
|
||||
"wireguard",
|
||||
"peer-to-peer",
|
||||
"private-networking",
|
||||
"software-defined-networking"
|
||||
]
|
||||
}
|
||||
],
|
||||
"funding": {
|
||||
"channels": [
|
||||
{
|
||||
"guid": "github-sponsors",
|
||||
"type": "payment-provider",
|
||||
"address": "https://github.com/sponsors/netbirdio",
|
||||
"description": ""
|
||||
},
|
||||
{
|
||||
"guid": "bank-transfer",
|
||||
"type": "bank",
|
||||
"address": "",
|
||||
"description": "Contact us at hello@netbird.io for bank transfer details."
|
||||
}
|
||||
],
|
||||
"plans": [
|
||||
{
|
||||
"guid": "support-yearly",
|
||||
"status": "active",
|
||||
"name": "Support Open Source Development and Maintenance - Yearly",
|
||||
"description": "This will help us partially cover the yearly cost of maintaining the open-source NetBird project.",
|
||||
"amount": 100000,
|
||||
"currency": "USD",
|
||||
"frequency": "yearly",
|
||||
"channels": [
|
||||
"github-sponsors",
|
||||
"bank-transfer"
|
||||
]
|
||||
},
|
||||
{
|
||||
"guid": "support-one-time-year",
|
||||
"status": "active",
|
||||
"name": "Support Open Source Development and Maintenance - One Year",
|
||||
"description": "This will help us partially cover the yearly cost of maintaining the open-source NetBird project.",
|
||||
"amount": 100000,
|
||||
"currency": "USD",
|
||||
"frequency": "one-time",
|
||||
"channels": [
|
||||
"github-sponsors",
|
||||
"bank-transfer"
|
||||
]
|
||||
},
|
||||
{
|
||||
"guid": "support-one-time-monthly",
|
||||
"status": "active",
|
||||
"name": "Support Open Source Development and Maintenance - Monthly",
|
||||
"description": "This will help us partially cover the monthly cost of maintaining the open-source NetBird project.",
|
||||
"amount": 10000,
|
||||
"currency": "USD",
|
||||
"frequency": "monthly",
|
||||
"channels": [
|
||||
"github-sponsors",
|
||||
"bank-transfer"
|
||||
]
|
||||
},
|
||||
{
|
||||
"guid": "support-monthly",
|
||||
"status": "active",
|
||||
"name": "Support Open Source Development and Maintenance - One Month",
|
||||
"description": "This will help us partially cover the monthly cost of maintaining the open-source NetBird project.",
|
||||
"amount": 10000,
|
||||
"currency": "USD",
|
||||
"frequency": "monthly",
|
||||
"channels": [
|
||||
"github-sponsors",
|
||||
"bank-transfer"
|
||||
]
|
||||
},
|
||||
{
|
||||
"guid": "goodwill",
|
||||
"status": "active",
|
||||
"name": "Goodwill Plan",
|
||||
"description": "Pay anything you wish to show your goodwill for the project.",
|
||||
"amount": 0,
|
||||
"currency": "USD",
|
||||
"frequency": "monthly",
|
||||
"channels": [
|
||||
"github-sponsors",
|
||||
"bank-transfer"
|
||||
]
|
||||
}
|
||||
],
|
||||
"history": null
|
||||
}
|
||||
}
|
||||
7
go.mod
7
go.mod
@@ -71,7 +71,6 @@ require (
|
||||
github.com/pion/transport/v3 v3.0.1
|
||||
github.com/pion/turn/v3 v3.0.1
|
||||
github.com/prometheus/client_golang v1.19.1
|
||||
github.com/r3labs/diff/v3 v3.0.1
|
||||
github.com/rs/xid v1.3.0
|
||||
github.com/shirou/gopsutil/v3 v3.24.4
|
||||
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
|
||||
@@ -156,7 +155,7 @@ require (
|
||||
github.com/go-text/typesetting v0.1.0 // indirect
|
||||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
|
||||
github.com/google/btree v1.0.1 // indirect
|
||||
github.com/google/btree v1.1.2 // indirect
|
||||
github.com/google/s2a-go v0.1.7 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.12.3 // indirect
|
||||
@@ -211,8 +210,6 @@ require (
|
||||
github.com/tklauser/go-sysconf v0.3.14 // indirect
|
||||
github.com/tklauser/numcpus v0.8.0 // indirect
|
||||
github.com/vishvananda/netns v0.0.4 // indirect
|
||||
github.com/vmihailenco/msgpack/v5 v5.3.5 // indirect
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
|
||||
github.com/yuin/goldmark v1.7.1 // indirect
|
||||
github.com/zeebo/blake3 v0.2.3 // indirect
|
||||
go.opencensus.io v0.24.0 // indirect
|
||||
@@ -231,7 +228,7 @@ require (
|
||||
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
|
||||
gopkg.in/tomb.v2 v2.0.0-20161208151619-d5d1b5820637 // indirect
|
||||
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 // indirect
|
||||
gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1 // indirect
|
||||
k8s.io/apimachinery v0.26.2 // indirect
|
||||
)
|
||||
|
||||
|
||||
14
go.sum
14
go.sum
@@ -297,8 +297,8 @@ github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
|
||||
github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
|
||||
github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
|
||||
github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA=
|
||||
github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU=
|
||||
github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
|
||||
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
|
||||
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
|
||||
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
|
||||
@@ -605,8 +605,6 @@ github.com/prometheus/common v0.53.0 h1:U2pL9w9nmJwJDa4qqLQ3ZaePJ6ZTwt7cMD3AG3+a
|
||||
github.com/prometheus/common v0.53.0/go.mod h1:BrxBKv3FWBIGXw89Mg1AeBq7FSyRzXWI3l3e7W3RN5U=
|
||||
github.com/prometheus/procfs v0.15.0 h1:A82kmvXJq2jTu5YUhSGNlYoxh85zLnKgPz4bMZgI5Ek=
|
||||
github.com/prometheus/procfs v0.15.0/go.mod h1:Y0RJ/Y5g5wJpkTisOtqwDSo4HwhGmLB4VQSw2sQJLHk=
|
||||
github.com/r3labs/diff/v3 v3.0.1 h1:CBKqf3XmNRHXKmdU7mZP1w7TV0pDyVCis1AUHtA4Xtg=
|
||||
github.com/r3labs/diff/v3 v3.0.1/go.mod h1:f1S9bourRbiM66NskseyUdo0fTmEE0qKrikYJX63dgo=
|
||||
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
|
||||
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
|
||||
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
|
||||
@@ -699,10 +697,6 @@ github.com/vishvananda/netlink v1.2.1-beta.2/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhg
|
||||
github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
|
||||
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
|
||||
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
||||
github.com/vmihailenco/msgpack/v5 v5.3.5 h1:5gO0H1iULLWGhs2H5tbAHIZTV8/cYafcFOr9znI5mJU=
|
||||
github.com/vmihailenco/msgpack/v5 v5.3.5/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc=
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
|
||||
github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
@@ -1238,8 +1232,8 @@ gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde h1:9DShaph9qhkIYw7QF91I/ynrr4
|
||||
gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
|
||||
gotest.tools/v3 v3.5.0 h1:Ljk6PdHdOhAb5aDMWXjDLMMhph+BpztA4v1QdqEW2eY=
|
||||
gotest.tools/v3 v3.5.0/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=
|
||||
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ=
|
||||
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY=
|
||||
gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1 h1:qDCwdCWECGnwQSQC01Dpnp09fRHxJs9PbktotUqG+hs=
|
||||
gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1/go.mod h1:8hmigyCdYtw5xJGfQDJzSH5Ju8XEIDBnpyi8+O6GRt8=
|
||||
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||
honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||
honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||
|
||||
@@ -873,7 +873,7 @@ services:
|
||||
zitadel:
|
||||
restart: 'always'
|
||||
networks: [netbird]
|
||||
image: 'ghcr.io/zitadel/zitadel:v2.54.10'
|
||||
image: 'ghcr.io/zitadel/zitadel:v2.64.1'
|
||||
command: 'start-from-init --masterkeyFromEnv --tlsMode $ZITADEL_TLS_MODE'
|
||||
env_file:
|
||||
- ./zitadel.env
|
||||
|
||||
@@ -153,6 +153,7 @@ type AccountManager interface {
|
||||
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||
GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error)
|
||||
GetAccountSettings(ctx context.Context, accountID string, userID string) (*Settings, error)
|
||||
DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error
|
||||
}
|
||||
|
||||
type DefaultAccountManager struct {
|
||||
|
||||
@@ -1010,7 +1010,6 @@ func TestAccountManager_AddPeer(t *testing.T) {
|
||||
return
|
||||
}
|
||||
expectedPeerKey := key.PublicKey().String()
|
||||
expectedSetupKey := setupKey.Key
|
||||
|
||||
peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
|
||||
Key: expectedPeerKey,
|
||||
@@ -1035,10 +1034,6 @@ func TestAccountManager_AddPeer(t *testing.T) {
|
||||
t.Errorf("expecting just added peer's IP %s to be in a network range %s", peer.IP.String(), account.Network.Net.String())
|
||||
}
|
||||
|
||||
if peer.SetupKey != expectedSetupKey {
|
||||
t.Errorf("expecting just added peer to have SetupKey = %s, got %s", expectedSetupKey, peer.SetupKey)
|
||||
}
|
||||
|
||||
if account.Network.CurrentSerial() != 1 {
|
||||
t.Errorf("expecting Network Serial=%d to be incremented by 1 and be equal to %d when adding new peer to account", serial, account.Network.CurrentSerial())
|
||||
}
|
||||
@@ -2367,7 +2362,6 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) {
|
||||
LoginExpired: false,
|
||||
},
|
||||
LoginExpirationEnabled: true,
|
||||
SetupKey: "key",
|
||||
},
|
||||
"peer-2": {
|
||||
Status: &nbpeer.PeerStatus{
|
||||
@@ -2375,7 +2369,6 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) {
|
||||
LoginExpired: false,
|
||||
},
|
||||
LoginExpirationEnabled: true,
|
||||
SetupKey: "key",
|
||||
},
|
||||
},
|
||||
expiration: time.Second,
|
||||
@@ -2529,7 +2522,6 @@ func TestAccount_GetNextInactivePeerExpiration(t *testing.T) {
|
||||
LoginExpired: false,
|
||||
},
|
||||
InactivityExpirationEnabled: true,
|
||||
SetupKey: "key",
|
||||
},
|
||||
"peer-2": {
|
||||
Status: &nbpeer.PeerStatus{
|
||||
@@ -2537,7 +2529,6 @@ func TestAccount_GetNextInactivePeerExpiration(t *testing.T) {
|
||||
LoginExpired: false,
|
||||
},
|
||||
InactivityExpirationEnabled: true,
|
||||
SetupKey: "key",
|
||||
},
|
||||
},
|
||||
expiration: time.Second,
|
||||
|
||||
@@ -146,6 +146,8 @@ const (
|
||||
AccountPeerInactivityExpirationEnabled Activity = 65
|
||||
AccountPeerInactivityExpirationDisabled Activity = 66
|
||||
AccountPeerInactivityExpirationDurationUpdated Activity = 67
|
||||
|
||||
SetupKeyDeleted Activity = 68
|
||||
)
|
||||
|
||||
var activityMap = map[Activity]Code{
|
||||
@@ -219,6 +221,7 @@ var activityMap = map[Activity]Code{
|
||||
AccountPeerInactivityExpirationEnabled: {"Account peer inactivity expiration enabled", "account.peer.inactivity.expiration.enable"},
|
||||
AccountPeerInactivityExpirationDisabled: {"Account peer inactivity expiration disabled", "account.peer.inactivity.expiration.disable"},
|
||||
AccountPeerInactivityExpirationDurationUpdated: {"Account peer inactivity expiration duration updated", "account.peer.inactivity.expiration.update"},
|
||||
SetupKeyDeleted: {"Setup key deleted", "setupkey.delete"},
|
||||
}
|
||||
|
||||
// StringCode returns a string code of the activity
|
||||
|
||||
@@ -1,82 +0,0 @@
|
||||
package differs
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"reflect"
|
||||
|
||||
"github.com/r3labs/diff/v3"
|
||||
)
|
||||
|
||||
// NetIPAddr is a custom differ for netip.Addr
|
||||
type NetIPAddr struct {
|
||||
DiffFunc func(path []string, a, b reflect.Value, p interface{}) error
|
||||
}
|
||||
|
||||
func (differ NetIPAddr) Match(a, b reflect.Value) bool {
|
||||
return diff.AreType(a, b, reflect.TypeOf(netip.Addr{}))
|
||||
}
|
||||
|
||||
func (differ NetIPAddr) Diff(_ diff.DiffType, _ diff.DiffFunc, cl *diff.Changelog, path []string, a, b reflect.Value, _ interface{}) error {
|
||||
if a.Kind() == reflect.Invalid {
|
||||
cl.Add(diff.CREATE, path, nil, b.Interface())
|
||||
return nil
|
||||
}
|
||||
|
||||
if b.Kind() == reflect.Invalid {
|
||||
cl.Add(diff.DELETE, path, a.Interface(), nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
fromAddr, ok1 := a.Interface().(netip.Addr)
|
||||
toAddr, ok2 := b.Interface().(netip.Addr)
|
||||
if !ok1 || !ok2 {
|
||||
return fmt.Errorf("invalid type for netip.Addr")
|
||||
}
|
||||
|
||||
if fromAddr.String() != toAddr.String() {
|
||||
cl.Add(diff.UPDATE, path, fromAddr.String(), toAddr.String())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (differ NetIPAddr) InsertParentDiffer(dfunc func(path []string, a, b reflect.Value, p interface{}) error) {
|
||||
differ.DiffFunc = dfunc //nolint
|
||||
}
|
||||
|
||||
// NetIPPrefix is a custom differ for netip.Prefix
|
||||
type NetIPPrefix struct {
|
||||
DiffFunc func(path []string, a, b reflect.Value, p interface{}) error
|
||||
}
|
||||
|
||||
func (differ NetIPPrefix) Match(a, b reflect.Value) bool {
|
||||
return diff.AreType(a, b, reflect.TypeOf(netip.Prefix{}))
|
||||
}
|
||||
|
||||
func (differ NetIPPrefix) Diff(_ diff.DiffType, _ diff.DiffFunc, cl *diff.Changelog, path []string, a, b reflect.Value, _ interface{}) error {
|
||||
if a.Kind() == reflect.Invalid {
|
||||
cl.Add(diff.CREATE, path, nil, b.Interface())
|
||||
return nil
|
||||
}
|
||||
if b.Kind() == reflect.Invalid {
|
||||
cl.Add(diff.DELETE, path, a.Interface(), nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
fromPrefix, ok1 := a.Interface().(netip.Prefix)
|
||||
toPrefix, ok2 := b.Interface().(netip.Prefix)
|
||||
if !ok1 || !ok2 {
|
||||
return fmt.Errorf("invalid type for netip.Addr")
|
||||
}
|
||||
|
||||
if fromPrefix.String() != toPrefix.String() {
|
||||
cl.Add(diff.UPDATE, path, fromPrefix.String(), toPrefix.String())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (differ NetIPPrefix) InsertParentDiffer(dfunc func(path []string, a, b reflect.Value, p interface{}) error) {
|
||||
differ.DiffFunc = dfunc //nolint
|
||||
}
|
||||
@@ -8,9 +8,10 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
@@ -521,23 +522,64 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
// Creating DNS settings with groups that have no peers should not update account peers or send peer update
|
||||
t.Run("creating dns setting with unused groups", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err = manager.CreateNameServerGroup(
|
||||
context.Background(), account.Id, "ns-group-1", "ns-group-1", []dns.NameServer{{
|
||||
IP: netip.MustParseAddr(peer1.IP.String()),
|
||||
NSType: dns.UDPNameServerType,
|
||||
Port: dns.DefaultDNSPort,
|
||||
}},
|
||||
[]string{"groupA"},
|
||||
true, []string{}, true, userID, false,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
_, err = manager.CreateNameServerGroup(
|
||||
context.Background(), account.Id, "ns-group", "ns-group", []dns.NameServer{{
|
||||
IP: netip.MustParseAddr(peer1.IP.String()),
|
||||
NSType: dns.UDPNameServerType,
|
||||
Port: dns.DefaultDNSPort,
|
||||
}},
|
||||
[]string{"groupB"},
|
||||
true, []string{}, true, userID, false,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Creating DNS settings with groups that have peers should update account peers and send peer update
|
||||
t.Run("creating dns setting with used groups", func(t *testing.T) {
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err = manager.CreateNameServerGroup(
|
||||
context.Background(), account.Id, "ns-group-1", "ns-group-1", []dns.NameServer{{
|
||||
IP: netip.MustParseAddr(peer1.IP.String()),
|
||||
NSType: dns.UDPNameServerType,
|
||||
Port: dns.DefaultDNSPort,
|
||||
}},
|
||||
[]string{"groupA"},
|
||||
true, []string{}, true, userID, false,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Saving DNS settings with groups that have peers should update account peers and send peer update
|
||||
t.Run("saving dns setting with used groups", func(t *testing.T) {
|
||||
@@ -559,27 +601,6 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
// Saving unchanged DNS settings with used groups should update account peers and not send peer update
|
||||
// since there is no change in the network map
|
||||
t.Run("saving unchanged dns setting with used groups", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{
|
||||
DisabledManagementGroups: []string{"groupA", "groupB"},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Removing group with no peers from DNS settings should not trigger updates to account peers or send peer updates
|
||||
t.Run("removing group with no peers from dns settings", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
|
||||
@@ -8,12 +8,13 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -536,29 +537,6 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
// Saving an unchanged group should trigger account peers update and not send peer update
|
||||
// since there is no change in the network map
|
||||
t.Run("saving unchanged group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// adding peer to a used group should update account peers and send peer update
|
||||
t.Run("adding peer to linked group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
|
||||
@@ -530,10 +530,9 @@ components:
|
||||
type: string
|
||||
example: reusable
|
||||
expires_in:
|
||||
description: Expiration time in seconds
|
||||
description: Expiration time in seconds, 0 will mean the key never expires
|
||||
type: integer
|
||||
minimum: 86400
|
||||
maximum: 31536000
|
||||
minimum: 0
|
||||
example: 86400
|
||||
revoked:
|
||||
description: Setup key revocation status
|
||||
@@ -2018,6 +2017,32 @@ paths:
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
delete:
|
||||
summary: Delete a Setup Key
|
||||
description: Delete a Setup Key
|
||||
tags: [ Setup Keys ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
- TokenAuth: [ ]
|
||||
parameters:
|
||||
- in: path
|
||||
name: keyId
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The unique identifier of a setup key
|
||||
responses:
|
||||
'200':
|
||||
description: Delete status code
|
||||
content: { }
|
||||
'400':
|
||||
"$ref": "#/components/responses/bad_request"
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/groups:
|
||||
get:
|
||||
summary: List all Groups
|
||||
|
||||
@@ -1101,7 +1101,7 @@ type SetupKeyRequest struct {
|
||||
// Ephemeral Indicate that the peer will be ephemeral or not
|
||||
Ephemeral *bool `json:"ephemeral,omitempty"`
|
||||
|
||||
// ExpiresIn Expiration time in seconds
|
||||
// ExpiresIn Expiration time in seconds, 0 will mean the key never expires
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
|
||||
// Name Setup Key name
|
||||
|
||||
@@ -141,6 +141,7 @@ func (apiHandler *apiHandler) addSetupKeysEndpoint() {
|
||||
apiHandler.Router.HandleFunc("/setup-keys", keysHandler.CreateSetupKey).Methods("POST", "OPTIONS")
|
||||
apiHandler.Router.HandleFunc("/setup-keys/{keyId}", keysHandler.GetSetupKey).Methods("GET", "OPTIONS")
|
||||
apiHandler.Router.HandleFunc("/setup-keys/{keyId}", keysHandler.UpdateSetupKey).Methods("PUT", "OPTIONS")
|
||||
apiHandler.Router.HandleFunc("/setup-keys/{keyId}", keysHandler.DeleteSetupKey).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
func (apiHandler *apiHandler) addPoliciesEndpoint() {
|
||||
|
||||
@@ -13,12 +13,13 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
@@ -168,7 +169,6 @@ func TestGetPeers(t *testing.T) {
|
||||
peer := &nbpeer.Peer{
|
||||
ID: testPeerID,
|
||||
Key: "key",
|
||||
SetupKey: "setupkey",
|
||||
IP: net.ParseIP("100.64.0.1"),
|
||||
Status: &nbpeer.PeerStatus{Connected: true},
|
||||
Name: "PeerName",
|
||||
|
||||
@@ -61,10 +61,8 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request
|
||||
|
||||
expiresIn := time.Duration(req.ExpiresIn) * time.Second
|
||||
|
||||
day := time.Hour * 24
|
||||
year := day * 365
|
||||
if expiresIn < day || expiresIn > year {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "expiresIn should be between 1 day and 365 days"), w)
|
||||
if expiresIn < 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "expiresIn can not be in the past"), w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -76,6 +74,7 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request
|
||||
if req.Ephemeral != nil {
|
||||
ephemeral = *req.Ephemeral
|
||||
}
|
||||
|
||||
setupKey, err := h.accountManager.CreateSetupKey(r.Context(), accountID, req.Name, server.SetupKeyType(req.Type), expiresIn,
|
||||
req.AutoGroups, req.UsageLimit, userID, ephemeral)
|
||||
if err != nil {
|
||||
@@ -83,7 +82,11 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request
|
||||
return
|
||||
}
|
||||
|
||||
writeSuccess(r.Context(), w, setupKey)
|
||||
apiSetupKeys := toResponseBody(setupKey)
|
||||
// for the creation we need to send the plain key
|
||||
apiSetupKeys.Key = setupKey.Key
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, apiSetupKeys)
|
||||
}
|
||||
|
||||
// GetSetupKey is a GET request to get a SetupKey by ID
|
||||
@@ -98,7 +101,7 @@ func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||
vars := mux.Vars(r)
|
||||
keyID := vars["keyId"]
|
||||
if len(keyID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid key ID"), w)
|
||||
util.WriteError(r.Context(), status.NewInvalidKeyIDError(), w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -123,7 +126,7 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
|
||||
vars := mux.Vars(r)
|
||||
keyID := vars["keyId"]
|
||||
if len(keyID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid key ID"), w)
|
||||
util.WriteError(r.Context(), status.NewInvalidKeyIDError(), w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -181,6 +184,30 @@ func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Reques
|
||||
util.WriteJSONObject(r.Context(), w, apiSetupKeys)
|
||||
}
|
||||
|
||||
func (h *SetupKeysHandler) DeleteSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
keyID := vars["keyId"]
|
||||
if len(keyID) == 0 {
|
||||
util.WriteError(r.Context(), status.NewInvalidKeyIDError(), w)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.DeleteSetupKey(r.Context(), accountID, userID, keyID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, emptyObject{})
|
||||
}
|
||||
|
||||
func writeSuccess(ctx context.Context, w http.ResponseWriter, key *server.SetupKey) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(200)
|
||||
@@ -206,7 +233,7 @@ func toResponseBody(key *server.SetupKey) *api.SetupKey {
|
||||
|
||||
return &api.SetupKey{
|
||||
Id: key.Id,
|
||||
Key: key.Key,
|
||||
Key: key.KeySecret,
|
||||
Name: key.Name,
|
||||
Expires: key.ExpiresAt,
|
||||
Type: string(key.Type),
|
||||
|
||||
@@ -67,6 +67,13 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
|
||||
ListSetupKeysFunc: func(_ context.Context, accountID, userID string) ([]*server.SetupKey, error) {
|
||||
return []*server.SetupKey{defaultKey}, nil
|
||||
},
|
||||
|
||||
DeleteSetupKeyFunc: func(_ context.Context, accountID, userID, keyID string) error {
|
||||
if keyID == defaultKey.Id {
|
||||
return nil
|
||||
}
|
||||
return status.Errorf(status.NotFound, "key %s not found", keyID)
|
||||
},
|
||||
},
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
||||
@@ -81,18 +88,21 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
|
||||
}
|
||||
|
||||
func TestSetupKeysHandlers(t *testing.T) {
|
||||
defaultSetupKey := server.GenerateDefaultSetupKey()
|
||||
defaultSetupKey, _ := server.GenerateDefaultSetupKey()
|
||||
defaultSetupKey.Id = existingSetupKeyID
|
||||
|
||||
adminUser := server.NewAdminUser("test_user")
|
||||
|
||||
newSetupKey := server.GenerateSetupKey(newSetupKeyName, server.SetupKeyReusable, 0, []string{"group-1"},
|
||||
newSetupKey, plainKey := server.GenerateSetupKey(newSetupKeyName, server.SetupKeyReusable, 0, []string{"group-1"},
|
||||
server.SetupKeyUnlimitedUsage, true)
|
||||
newSetupKey.Key = plainKey
|
||||
updatedDefaultSetupKey := defaultSetupKey.Copy()
|
||||
updatedDefaultSetupKey.AutoGroups = []string{"group-1"}
|
||||
updatedDefaultSetupKey.Name = updatedSetupKeyName
|
||||
updatedDefaultSetupKey.Revoked = true
|
||||
|
||||
expectedNewKey := toResponseBody(newSetupKey)
|
||||
expectedNewKey.Key = plainKey
|
||||
tt := []struct {
|
||||
name string
|
||||
requestType string
|
||||
@@ -134,7 +144,7 @@ func TestSetupKeysHandlers(t *testing.T) {
|
||||
[]byte(fmt.Sprintf("{\"name\":\"%s\",\"type\":\"%s\",\"expires_in\":86400, \"ephemeral\":true}", newSetupKey.Name, newSetupKey.Type))),
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: true,
|
||||
expectedSetupKey: toResponseBody(newSetupKey),
|
||||
expectedSetupKey: expectedNewKey,
|
||||
},
|
||||
{
|
||||
name: "Update Setup Key",
|
||||
@@ -150,6 +160,14 @@ func TestSetupKeysHandlers(t *testing.T) {
|
||||
expectedBody: true,
|
||||
expectedSetupKey: toResponseBody(updatedDefaultSetupKey),
|
||||
},
|
||||
{
|
||||
name: "Delete Setup Key",
|
||||
requestType: http.MethodDelete,
|
||||
requestPath: "/api/setup-keys/" + defaultSetupKey.Id,
|
||||
requestBody: bytes.NewBuffer([]byte("")),
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: false,
|
||||
},
|
||||
}
|
||||
|
||||
handler := initSetupKeysTestMetaData(defaultSetupKey, newSetupKey, updatedDefaultSetupKey, adminUser)
|
||||
@@ -164,6 +182,7 @@ func TestSetupKeysHandlers(t *testing.T) {
|
||||
router.HandleFunc("/api/setup-keys", handler.CreateSetupKey).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/api/setup-keys/{keyId}", handler.GetSetupKey).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/api/setup-keys/{keyId}", handler.UpdateSetupKey).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/api/setup-keys/{keyId}", handler.DeleteSetupKey).Methods("DELETE", "OPTIONS")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
|
||||
@@ -267,7 +267,7 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
|
||||
peersSSHEnabled++
|
||||
}
|
||||
|
||||
if peer.SetupKey == "" {
|
||||
if peer.UserID != "" {
|
||||
userPeers++
|
||||
}
|
||||
|
||||
|
||||
@@ -2,13 +2,16 @@ package migration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"database/sql"
|
||||
b64 "encoding/base64"
|
||||
"encoding/gob"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
@@ -205,3 +208,90 @@ func MigrateNetIPFieldFromBlobToJSON[T any](ctx context.Context, db *gorm.DB, fi
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func MigrateSetupKeyToHashedSetupKey[T any](ctx context.Context, db *gorm.DB) error {
|
||||
oldColumnName := "key"
|
||||
newColumnName := "key_secret"
|
||||
|
||||
var model T
|
||||
|
||||
if !db.Migrator().HasTable(&model) {
|
||||
log.WithContext(ctx).Debugf("Table for %T does not exist, no migration needed", model)
|
||||
return nil
|
||||
}
|
||||
|
||||
stmt := &gorm.Statement{DB: db}
|
||||
err := stmt.Parse(&model)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse model: %w", err)
|
||||
}
|
||||
tableName := stmt.Schema.Table
|
||||
|
||||
if err := db.Transaction(func(tx *gorm.DB) error {
|
||||
if !tx.Migrator().HasColumn(&model, newColumnName) {
|
||||
log.WithContext(ctx).Infof("Column %s does not exist in table %s, adding it", newColumnName, tableName)
|
||||
if err := tx.Migrator().AddColumn(&model, newColumnName); err != nil {
|
||||
return fmt.Errorf("add column %s: %w", newColumnName, err)
|
||||
}
|
||||
}
|
||||
|
||||
var rows []map[string]any
|
||||
if err := tx.Table(tableName).
|
||||
Select("id", oldColumnName, newColumnName).
|
||||
Where(newColumnName + " IS NULL OR " + newColumnName + " = ''").
|
||||
Where("SUBSTR(" + oldColumnName + ", 9, 1) = '-'").
|
||||
Find(&rows).Error; err != nil {
|
||||
return fmt.Errorf("find rows with empty secret key and matching pattern: %w", err)
|
||||
}
|
||||
|
||||
if len(rows) == 0 {
|
||||
log.WithContext(ctx).Infof("No plain setup keys found in table %s, no migration needed", tableName)
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, row := range rows {
|
||||
var plainKey string
|
||||
if columnValue := row[oldColumnName]; columnValue != nil {
|
||||
value, ok := columnValue.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("type assertion failed")
|
||||
}
|
||||
plainKey = value
|
||||
}
|
||||
|
||||
secretKey := hiddenKey(plainKey, 4)
|
||||
|
||||
hashedKey := sha256.Sum256([]byte(plainKey))
|
||||
encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:])
|
||||
|
||||
if err := tx.Table(tableName).Where("id = ?", row["id"]).Update(newColumnName, secretKey).Error; err != nil {
|
||||
return fmt.Errorf("update row with secret key: %w", err)
|
||||
}
|
||||
|
||||
if err := tx.Table(tableName).Where("id = ?", row["id"]).Update(oldColumnName, encodedHashedKey).Error; err != nil {
|
||||
return fmt.Errorf("update row with hashed key: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s", "peers", "setup_key")).Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("Failed to drop column %s: %v", "setup_key", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("Migration of plain setup key to hashed setup key completed")
|
||||
return nil
|
||||
}
|
||||
|
||||
// hiddenKey returns the Key value hidden with "*" and a 5 character prefix.
|
||||
// E.g., "831F6*******************************"
|
||||
func hiddenKey(key string, length int) string {
|
||||
prefix := key[0:5]
|
||||
if length > utf8.RuneCountInString(key) {
|
||||
length = utf8.RuneCountInString(key) - len(prefix)
|
||||
}
|
||||
return prefix + strings.Repeat("*", length)
|
||||
}
|
||||
|
||||
@@ -160,3 +160,72 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithJSONData(t *testing.T) {
|
||||
db.Model(&nbpeer.Peer{}).Select("location_connection_ip").First(&jsonStr)
|
||||
assert.JSONEq(t, `"10.0.0.1"`, jsonStr, "Data should be unchanged")
|
||||
}
|
||||
|
||||
func TestMigrateSetupKeyToHashedSetupKey_ForPlainKey(t *testing.T) {
|
||||
db := setupDatabase(t)
|
||||
|
||||
err := db.AutoMigrate(&server.SetupKey{})
|
||||
require.NoError(t, err, "Failed to auto-migrate tables")
|
||||
|
||||
err = db.Save(&server.SetupKey{
|
||||
Id: "1",
|
||||
Key: "EEFDAB47-C1A5-4472-8C05-71DE9A1E8382",
|
||||
}).Error
|
||||
require.NoError(t, err, "Failed to insert setup key")
|
||||
|
||||
err = migration.MigrateSetupKeyToHashedSetupKey[server.SetupKey](context.Background(), db)
|
||||
require.NoError(t, err, "Migration should not fail to migrate setup key")
|
||||
|
||||
var key server.SetupKey
|
||||
err = db.Model(&server.SetupKey{}).First(&key).Error
|
||||
assert.NoError(t, err, "Failed to fetch setup key")
|
||||
|
||||
assert.Equal(t, "EEFDA****", key.KeySecret, "Key should be secret")
|
||||
assert.Equal(t, "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", key.Key, "Key should be hashed")
|
||||
}
|
||||
|
||||
func TestMigrateSetupKeyToHashedSetupKey_ForAlreadyMigratedKey_Case1(t *testing.T) {
|
||||
db := setupDatabase(t)
|
||||
|
||||
err := db.AutoMigrate(&server.SetupKey{})
|
||||
require.NoError(t, err, "Failed to auto-migrate tables")
|
||||
|
||||
err = db.Save(&server.SetupKey{
|
||||
Id: "1",
|
||||
Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=",
|
||||
KeySecret: "EEFDA****",
|
||||
}).Error
|
||||
require.NoError(t, err, "Failed to insert setup key")
|
||||
|
||||
err = migration.MigrateSetupKeyToHashedSetupKey[server.SetupKey](context.Background(), db)
|
||||
require.NoError(t, err, "Migration should not fail to migrate setup key")
|
||||
|
||||
var key server.SetupKey
|
||||
err = db.Model(&server.SetupKey{}).First(&key).Error
|
||||
assert.NoError(t, err, "Failed to fetch setup key")
|
||||
|
||||
assert.Equal(t, "EEFDA****", key.KeySecret, "Key should be secret")
|
||||
assert.Equal(t, "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", key.Key, "Key should be hashed")
|
||||
}
|
||||
|
||||
func TestMigrateSetupKeyToHashedSetupKey_ForAlreadyMigratedKey_Case2(t *testing.T) {
|
||||
db := setupDatabase(t)
|
||||
|
||||
err := db.AutoMigrate(&server.SetupKey{})
|
||||
require.NoError(t, err, "Failed to auto-migrate tables")
|
||||
|
||||
err = db.Save(&server.SetupKey{
|
||||
Id: "1",
|
||||
Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=",
|
||||
}).Error
|
||||
require.NoError(t, err, "Failed to insert setup key")
|
||||
|
||||
err = migration.MigrateSetupKeyToHashedSetupKey[server.SetupKey](context.Background(), db)
|
||||
require.NoError(t, err, "Migration should not fail to migrate setup key")
|
||||
|
||||
var key server.SetupKey
|
||||
err = db.Model(&server.SetupKey{}).First(&key).Error
|
||||
assert.NoError(t, err, "Failed to fetch setup key")
|
||||
|
||||
assert.Equal(t, "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", key.Key, "Key should be hashed")
|
||||
}
|
||||
|
||||
@@ -109,6 +109,14 @@ type MockAccountManager struct {
|
||||
GetAccountByIDFunc func(ctx context.Context, accountID string, userID string) (*server.Account, error)
|
||||
GetUserByIDFunc func(ctx context.Context, id string) (*server.User, error)
|
||||
GetAccountSettingsFunc func(ctx context.Context, accountID string, userID string) (*server.Settings, error)
|
||||
DeleteSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) error
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error {
|
||||
if am.DeleteSetupKeyFunc != nil {
|
||||
return am.DeleteSetupKeyFunc(ctx, accountID, userID, keyID)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method DeleteSetupKey is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
|
||||
|
||||
@@ -1065,36 +1065,6 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
// saving unchanged nameserver group should update account peers and not send peer update
|
||||
t.Run("saving unchanged nameserver group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
newNameServerGroupB.NameServers = []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("1.1.1.2"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: nbdns.DefaultDNSPort,
|
||||
},
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.8.8"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: nbdns.DefaultDNSPort,
|
||||
},
|
||||
}
|
||||
err = manager.SaveNameServerGroup(context.Background(), account.Id, userID, newNameServerGroupB)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Deleting a nameserver group should update account peers and send peer update
|
||||
t.Run("deleting nameserver group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
|
||||
@@ -41,9 +41,9 @@ type Network struct {
|
||||
Dns string
|
||||
// Serial is an ID that increments by 1 when any change to the network happened (e.g. new peer has been added).
|
||||
// Used to synchronize state to the client apps.
|
||||
Serial uint64 `diff:"-"`
|
||||
Serial uint64
|
||||
|
||||
mu sync.Mutex `json:"-" gorm:"-" diff:"-"`
|
||||
mu sync.Mutex `json:"-" gorm:"-"`
|
||||
}
|
||||
|
||||
// NewNetwork creates a new Network initializing it with a Serial=0
|
||||
|
||||
@@ -2,6 +2,8 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
b64 "encoding/base64"
|
||||
"fmt"
|
||||
"net"
|
||||
"slices"
|
||||
@@ -396,6 +398,8 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
}
|
||||
|
||||
upperKey := strings.ToUpper(setupKey)
|
||||
hashedKey := sha256.Sum256([]byte(upperKey))
|
||||
encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:])
|
||||
var accountID string
|
||||
var err error
|
||||
addedByUser := false
|
||||
@@ -403,7 +407,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
addedByUser = true
|
||||
accountID, err = am.Store.GetAccountIDByUserID(userID)
|
||||
} else {
|
||||
accountID, err = am.Store.GetAccountIDBySetupKey(ctx, setupKey)
|
||||
accountID, err = am.Store.GetAccountIDBySetupKey(ctx, encodedHashedKey)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, nil, nil, status.Errorf(status.NotFound, "failed adding new peer: account not found")
|
||||
@@ -448,7 +452,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
opEvent.Activity = activity.PeerAddedByUser
|
||||
} else {
|
||||
// Validate the setup key
|
||||
sk, err := transaction.GetSetupKeyBySecret(ctx, LockingStrengthUpdate, upperKey)
|
||||
sk, err := transaction.GetSetupKeyBySecret(ctx, LockingStrengthUpdate, encodedHashedKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get setup key: %w", err)
|
||||
}
|
||||
@@ -489,7 +493,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
ID: xid.New().String(),
|
||||
AccountID: accountID,
|
||||
Key: peer.Key,
|
||||
SetupKey: upperKey,
|
||||
IP: freeIP,
|
||||
Meta: peer.Meta,
|
||||
Name: peer.Meta.Hostname,
|
||||
@@ -586,6 +589,12 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
return nil, nil, nil, fmt.Errorf("error getting account: %w", err)
|
||||
}
|
||||
|
||||
allGroup, err := account.GetGroupAll()
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("error getting all group ID: %w", err)
|
||||
}
|
||||
|
||||
groupsToAdd = append(groupsToAdd, allGroup.ID)
|
||||
if areGroupChangesAffectPeers(account, groupsToAdd) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"sort"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -16,38 +17,36 @@ type Peer struct {
|
||||
AccountID string `json:"-" gorm:"index"`
|
||||
// WireGuard public key
|
||||
Key string `gorm:"index"`
|
||||
// A setup key this peer was registered with
|
||||
SetupKey string `diff:"-"`
|
||||
// IP address of the Peer
|
||||
IP net.IP `gorm:"serializer:json"`
|
||||
// Meta is a Peer system meta data
|
||||
Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_" diff:"-"`
|
||||
Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_"`
|
||||
// Name is peer's name (machine name)
|
||||
Name string
|
||||
// DNSLabel is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's
|
||||
// domain to the peer label. e.g. peer-dns-label.netbird.cloud
|
||||
DNSLabel string
|
||||
// Status peer's management connection status
|
||||
Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_" diff:"-"`
|
||||
Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_"`
|
||||
// The user ID that registered the peer
|
||||
UserID string `diff:"-"`
|
||||
UserID string
|
||||
// SSHKey is a public SSH key of the peer
|
||||
SSHKey string
|
||||
// SSHEnabled indicates whether SSH server is enabled on the peer
|
||||
SSHEnabled bool
|
||||
// LoginExpirationEnabled indicates whether peer's login expiration is enabled and once expired the peer has to re-login.
|
||||
// Works with LastLogin
|
||||
LoginExpirationEnabled bool `diff:"-"`
|
||||
LoginExpirationEnabled bool
|
||||
|
||||
InactivityExpirationEnabled bool `diff:"-"`
|
||||
InactivityExpirationEnabled bool
|
||||
// LastLogin the time when peer performed last login operation
|
||||
LastLogin time.Time `diff:"-"`
|
||||
LastLogin time.Time
|
||||
// CreatedAt records the time the peer was created
|
||||
CreatedAt time.Time `diff:"-"`
|
||||
CreatedAt time.Time
|
||||
// Indicate ephemeral peer attribute
|
||||
Ephemeral bool `diff:"-"`
|
||||
Ephemeral bool
|
||||
// Geo location based on connection IP
|
||||
Location Location `gorm:"embedded;embeddedPrefix:location_" diff:"-"`
|
||||
Location Location `gorm:"embedded;embeddedPrefix:location_"`
|
||||
}
|
||||
|
||||
type PeerStatus struct { //nolint:revive
|
||||
@@ -109,6 +108,12 @@ type PeerSystemMeta struct { //nolint:revive
|
||||
}
|
||||
|
||||
func (p PeerSystemMeta) isEqual(other PeerSystemMeta) bool {
|
||||
sort.Slice(p.NetworkAddresses, func(i, j int) bool {
|
||||
return p.NetworkAddresses[i].Mac < p.NetworkAddresses[j].Mac
|
||||
})
|
||||
sort.Slice(other.NetworkAddresses, func(i, j int) bool {
|
||||
return other.NetworkAddresses[i].Mac < other.NetworkAddresses[j].Mac
|
||||
})
|
||||
equalNetworkAddresses := slices.EqualFunc(p.NetworkAddresses, other.NetworkAddresses, func(addr NetworkAddress, oAddr NetworkAddress) bool {
|
||||
return addr.Mac == oAddr.Mac && addr.NetIP == oAddr.NetIP
|
||||
})
|
||||
@@ -116,6 +121,12 @@ func (p PeerSystemMeta) isEqual(other PeerSystemMeta) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
sort.Slice(p.Files, func(i, j int) bool {
|
||||
return p.Files[i].Path < p.Files[j].Path
|
||||
})
|
||||
sort.Slice(other.Files, func(i, j int) bool {
|
||||
return other.Files[i].Path < other.Files[j].Path
|
||||
})
|
||||
equalFiles := slices.EqualFunc(p.Files, other.Files, func(file File, oFile File) bool {
|
||||
return file.Path == oFile.Path && file.Exist == oFile.Exist && file.ProcessIsRunning == oFile.ProcessIsRunning
|
||||
})
|
||||
@@ -172,23 +183,22 @@ func (p *Peer) Copy() *Peer {
|
||||
peerStatus = p.Status.Copy()
|
||||
}
|
||||
return &Peer{
|
||||
ID: p.ID,
|
||||
AccountID: p.AccountID,
|
||||
Key: p.Key,
|
||||
SetupKey: p.SetupKey,
|
||||
IP: p.IP,
|
||||
Meta: p.Meta,
|
||||
Name: p.Name,
|
||||
DNSLabel: p.DNSLabel,
|
||||
Status: peerStatus,
|
||||
UserID: p.UserID,
|
||||
SSHKey: p.SSHKey,
|
||||
SSHEnabled: p.SSHEnabled,
|
||||
LoginExpirationEnabled: p.LoginExpirationEnabled,
|
||||
LastLogin: p.LastLogin,
|
||||
CreatedAt: p.CreatedAt,
|
||||
Ephemeral: p.Ephemeral,
|
||||
Location: p.Location,
|
||||
ID: p.ID,
|
||||
AccountID: p.AccountID,
|
||||
Key: p.Key,
|
||||
IP: p.IP,
|
||||
Meta: p.Meta,
|
||||
Name: p.Name,
|
||||
DNSLabel: p.DNSLabel,
|
||||
Status: peerStatus,
|
||||
UserID: p.UserID,
|
||||
SSHKey: p.SSHKey,
|
||||
SSHEnabled: p.SSHEnabled,
|
||||
LoginExpirationEnabled: p.LoginExpirationEnabled,
|
||||
LastLogin: p.LastLogin,
|
||||
CreatedAt: p.CreatedAt,
|
||||
Ephemeral: p.Ephemeral,
|
||||
Location: p.Location,
|
||||
InactivityExpirationEnabled: p.InactivityExpirationEnabled,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package peer
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -29,3 +30,56 @@ func BenchmarkFQDN(b *testing.B) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestIsEqual(t *testing.T) {
|
||||
meta1 := PeerSystemMeta{
|
||||
NetworkAddresses: []NetworkAddress{{
|
||||
NetIP: netip.MustParsePrefix("192.168.1.2/24"),
|
||||
Mac: "2",
|
||||
},
|
||||
{
|
||||
NetIP: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
Mac: "1",
|
||||
},
|
||||
},
|
||||
Files: []File{
|
||||
{
|
||||
Path: "/etc/hosts1",
|
||||
Exist: true,
|
||||
ProcessIsRunning: true,
|
||||
},
|
||||
{
|
||||
Path: "/etc/hosts2",
|
||||
Exist: false,
|
||||
ProcessIsRunning: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
meta2 := PeerSystemMeta{
|
||||
NetworkAddresses: []NetworkAddress{
|
||||
{
|
||||
NetIP: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
Mac: "1",
|
||||
},
|
||||
{
|
||||
NetIP: netip.MustParsePrefix("192.168.1.2/24"),
|
||||
Mac: "2",
|
||||
},
|
||||
},
|
||||
Files: []File{
|
||||
{
|
||||
Path: "/etc/hosts2",
|
||||
Exist: false,
|
||||
ProcessIsRunning: false,
|
||||
},
|
||||
{
|
||||
Path: "/etc/hosts1",
|
||||
Exist: true,
|
||||
ProcessIsRunning: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
if !meta1.isEqual(meta2) {
|
||||
t.Error("meta1 should be equal to meta2")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,8 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
b64 "encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
@@ -1090,7 +1092,6 @@ func Test_RegisterPeerByUser(t *testing.T) {
|
||||
ID: xid.New().String(),
|
||||
AccountID: existingAccountID,
|
||||
Key: "newPeerKey",
|
||||
SetupKey: "",
|
||||
IP: net.IP{123, 123, 123, 123},
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
Hostname: "newPeer",
|
||||
@@ -1155,7 +1156,6 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
|
||||
ID: xid.New().String(),
|
||||
AccountID: existingAccountID,
|
||||
Key: "newPeerKey",
|
||||
SetupKey: "existingSetupKey",
|
||||
UserID: "",
|
||||
IP: net.IP{123, 123, 123, 123},
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
@@ -1175,7 +1175,6 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
|
||||
peer, err := store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, newPeer.Key)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, peer.AccountID, existingAccountID)
|
||||
assert.Equal(t, peer.SetupKey, existingSetupKeyID)
|
||||
|
||||
account, err := store.GetAccount(context.Background(), existingAccountID)
|
||||
require.NoError(t, err)
|
||||
@@ -1187,8 +1186,11 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
|
||||
|
||||
lastUsed, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z")
|
||||
assert.NoError(t, err)
|
||||
assert.NotEqual(t, lastUsed, account.SetupKeys[existingSetupKeyID].LastUsed)
|
||||
assert.Equal(t, 1, account.SetupKeys[existingSetupKeyID].UsedTimes)
|
||||
|
||||
hashedKey := sha256.Sum256([]byte(existingSetupKeyID))
|
||||
encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:])
|
||||
assert.NotEqual(t, lastUsed, account.SetupKeys[encodedHashedKey].LastUsed)
|
||||
assert.Equal(t, 1, account.SetupKeys[encodedHashedKey].UsedTimes)
|
||||
|
||||
}
|
||||
|
||||
@@ -1221,7 +1223,6 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
|
||||
ID: xid.New().String(),
|
||||
AccountID: existingAccountID,
|
||||
Key: "newPeerKey",
|
||||
SetupKey: "existingSetupKey",
|
||||
UserID: "",
|
||||
IP: net.IP{123, 123, 123, 123},
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
@@ -1250,8 +1251,11 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
|
||||
|
||||
lastUsed, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, lastUsed, account.SetupKeys[faultyKey].LastUsed.UTC())
|
||||
assert.Equal(t, 0, account.SetupKeys[faultyKey].UsedTimes)
|
||||
|
||||
hashedKey := sha256.Sum256([]byte(faultyKey))
|
||||
encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:])
|
||||
assert.Equal(t, lastUsed, account.SetupKeys[encodedHashedKey].LastUsed.UTC())
|
||||
assert.Equal(t, 0, account.SetupKeys[encodedHashedKey].UsedTimes)
|
||||
}
|
||||
|
||||
func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
@@ -405,7 +405,9 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
|
||||
|
||||
am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta())
|
||||
|
||||
am.updateAccountPeers(ctx, account)
|
||||
if anyGroupHasPeers(account, policy.ruleGroups()) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -854,16 +854,11 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
updMsg1 := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
})
|
||||
|
||||
updMsg2 := manager.peersUpdateManager.CreateChannel(context.Background(), peer2.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID)
|
||||
})
|
||||
|
||||
// Saving policy with rule groups with no peers should not update account's peers and not send peer update
|
||||
t.Run("saving policy with rule groups with no peers", func(t *testing.T) {
|
||||
policy := Policy{
|
||||
@@ -883,7 +878,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg1)
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -918,7 +913,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg1)
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -953,7 +948,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg2)
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -987,7 +982,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg1)
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1021,7 +1016,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg1)
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1056,7 +1051,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg1)
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1090,7 +1085,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg1)
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1104,46 +1099,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
// Saving unchanged policy should trigger account peers update but not send peer update
|
||||
t.Run("saving unchanged policy", func(t *testing.T) {
|
||||
policy := Policy{
|
||||
ID: "policy-source-destination-peers",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupD"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg1)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Deleting policy should trigger account peers update and send peer update
|
||||
t.Run("deleting policy with source and destination groups with peers", func(t *testing.T) {
|
||||
policyID := "policy-source-destination-peers"
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg1)
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1164,7 +1126,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
policyID := "policy-destination-has-peers-source-none"
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg2)
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1180,10 +1142,10 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
// Deleting policy with no peers in groups should not update account's peers and not send peer update
|
||||
t.Run("deleting policy with no peers in groups", func(t *testing.T) {
|
||||
policyID := "policy-rule-groups-no-peers" // Deleting the policy created in Case 2
|
||||
policyID := "policy-rule-groups-no-peers"
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg1)
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
|
||||
@@ -5,10 +5,11 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/group"
|
||||
"github.com/rs/xid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/group"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
)
|
||||
|
||||
@@ -264,25 +265,6 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
// Saving unchanged posture check should not trigger account peers update and not send peer update
|
||||
// since there is no change in the network map
|
||||
t.Run("saving unchanged posture check", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Removing posture check from policy should trigger account peers update and send peer update
|
||||
t.Run("removing posture check from policy", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
@@ -412,50 +394,9 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
// Updating linked posture check to policy where source has peers but destination does not,
|
||||
// should not trigger account peers update or send peer update
|
||||
t.Run("updating linked posture check to policy where source has peers but destination does not", func(t *testing.T) {
|
||||
policy = Policy{
|
||||
ID: "policyB",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupB"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
SourcePostureChecks: []string{postureCheck.ID},
|
||||
}
|
||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
postureCheck.Checks = posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{
|
||||
MinVersion: "0.29.0",
|
||||
},
|
||||
}
|
||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Updating linked client posture check to policy where source has peers but destination does not,
|
||||
// should trigger account peers update and send peer update
|
||||
t.Run("updating linked client posture check to policy where source has peers but destination does not", func(t *testing.T) {
|
||||
t.Run("updating linked posture check to policy where source has peers but destination does not", func(t *testing.T) {
|
||||
policy = Policy{
|
||||
ID: "policyB",
|
||||
Enabled: true,
|
||||
|
||||
@@ -1938,26 +1938,6 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
// Updating unchanged route should update account peers and not send peer update
|
||||
t.Run("updating unchanged route", func(t *testing.T) {
|
||||
baseRoute.Groups = []string{routeGroup1, routeGroup2}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SaveRoute(context.Background(), account.Id, userID, &baseRoute)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Deleting the route should update account peers and send peer update
|
||||
t.Run("deleting route", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
|
||||
@@ -2,6 +2,9 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
b64 "encoding/base64"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -73,6 +76,7 @@ type SetupKey struct {
|
||||
// AccountID is a reference to Account that this object belongs
|
||||
AccountID string `json:"-" gorm:"index"`
|
||||
Key string
|
||||
KeySecret string
|
||||
Name string
|
||||
Type SetupKeyType
|
||||
CreatedAt time.Time
|
||||
@@ -104,6 +108,7 @@ func (key *SetupKey) Copy() *SetupKey {
|
||||
Id: key.Id,
|
||||
AccountID: key.AccountID,
|
||||
Key: key.Key,
|
||||
KeySecret: key.KeySecret,
|
||||
Name: key.Name,
|
||||
Type: key.Type,
|
||||
CreatedAt: key.CreatedAt,
|
||||
@@ -120,19 +125,17 @@ func (key *SetupKey) Copy() *SetupKey {
|
||||
|
||||
// EventMeta returns activity event meta related to the setup key
|
||||
func (key *SetupKey) EventMeta() map[string]any {
|
||||
return map[string]any{"name": key.Name, "type": key.Type, "key": key.HiddenCopy(1).Key}
|
||||
return map[string]any{"name": key.Name, "type": key.Type, "key": key.KeySecret}
|
||||
}
|
||||
|
||||
// HiddenCopy returns a copy of the key with a Key value hidden with "*" and a 5 character prefix.
|
||||
// hiddenKey returns the Key value hidden with "*" and a 5 character prefix.
|
||||
// E.g., "831F6*******************************"
|
||||
func (key *SetupKey) HiddenCopy(length int) *SetupKey {
|
||||
k := key.Copy()
|
||||
prefix := k.Key[0:5]
|
||||
if length > utf8.RuneCountInString(key.Key) {
|
||||
length = utf8.RuneCountInString(key.Key) - len(prefix)
|
||||
func hiddenKey(key string, length int) string {
|
||||
prefix := key[0:5]
|
||||
if length > utf8.RuneCountInString(key) {
|
||||
length = utf8.RuneCountInString(key) - len(prefix)
|
||||
}
|
||||
k.Key = prefix + strings.Repeat("*", length)
|
||||
return k
|
||||
return prefix + strings.Repeat("*", length)
|
||||
}
|
||||
|
||||
// IncrementUsage makes a copy of a key, increments the UsedTimes by 1 and sets LastUsed to now
|
||||
@@ -155,6 +158,9 @@ func (key *SetupKey) IsRevoked() bool {
|
||||
|
||||
// IsExpired if key was expired
|
||||
func (key *SetupKey) IsExpired() bool {
|
||||
if key.ExpiresAt.IsZero() {
|
||||
return false
|
||||
}
|
||||
return time.Now().After(key.ExpiresAt)
|
||||
}
|
||||
|
||||
@@ -169,30 +175,40 @@ func (key *SetupKey) IsOverUsed() bool {
|
||||
|
||||
// GenerateSetupKey generates a new setup key
|
||||
func GenerateSetupKey(name string, t SetupKeyType, validFor time.Duration, autoGroups []string,
|
||||
usageLimit int, ephemeral bool) *SetupKey {
|
||||
usageLimit int, ephemeral bool) (*SetupKey, string) {
|
||||
key := strings.ToUpper(uuid.New().String())
|
||||
limit := usageLimit
|
||||
if t == SetupKeyOneOff {
|
||||
limit = 1
|
||||
}
|
||||
|
||||
expiresAt := time.Time{}
|
||||
if validFor != 0 {
|
||||
expiresAt = time.Now().UTC().Add(validFor)
|
||||
}
|
||||
|
||||
hashedKey := sha256.Sum256([]byte(key))
|
||||
encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:])
|
||||
|
||||
return &SetupKey{
|
||||
Id: strconv.Itoa(int(Hash(key))),
|
||||
Key: key,
|
||||
Key: encodedHashedKey,
|
||||
KeySecret: hiddenKey(key, 4),
|
||||
Name: name,
|
||||
Type: t,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
ExpiresAt: time.Now().UTC().Add(validFor),
|
||||
ExpiresAt: expiresAt,
|
||||
UpdatedAt: time.Now().UTC(),
|
||||
Revoked: false,
|
||||
UsedTimes: 0,
|
||||
AutoGroups: autoGroups,
|
||||
UsageLimit: limit,
|
||||
Ephemeral: ephemeral,
|
||||
}
|
||||
}, key
|
||||
}
|
||||
|
||||
// GenerateDefaultSetupKey generates a default reusable setup key with an unlimited usage and 30 days expiration
|
||||
func GenerateDefaultSetupKey() *SetupKey {
|
||||
func GenerateDefaultSetupKey() (*SetupKey, string) {
|
||||
return GenerateSetupKey(DefaultSetupKeyName, SetupKeyReusable, DefaultSetupKeyDuration, []string{},
|
||||
SetupKeyUnlimitedUsage, false)
|
||||
}
|
||||
@@ -213,11 +229,6 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
keyDuration := DefaultSetupKeyDuration
|
||||
if expiresIn != 0 {
|
||||
keyDuration = expiresIn
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -227,7 +238,7 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s
|
||||
return nil, err
|
||||
}
|
||||
|
||||
setupKey := GenerateSetupKey(keyName, keyType, keyDuration, autoGroups, usageLimit, ephemeral)
|
||||
setupKey, plainKey := GenerateSetupKey(keyName, keyType, expiresIn, autoGroups, usageLimit, ephemeral)
|
||||
account.SetupKeys[setupKey.Key] = setupKey
|
||||
err = am.Store.SaveAccount(ctx, account)
|
||||
if err != nil {
|
||||
@@ -246,6 +257,9 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s
|
||||
}
|
||||
}
|
||||
|
||||
// for the creation return the plain key to the caller
|
||||
setupKey.Key = plainKey
|
||||
|
||||
return setupKey, nil
|
||||
}
|
||||
|
||||
@@ -334,7 +348,7 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u
|
||||
}
|
||||
|
||||
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||
return nil, status.Errorf(status.Unauthorized, "only users with admin power can view setup keys")
|
||||
return nil, status.NewUnauthorizedToViewSetupKeysError()
|
||||
}
|
||||
|
||||
setupKeys, err := am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID)
|
||||
@@ -342,18 +356,7 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u
|
||||
return nil, err
|
||||
}
|
||||
|
||||
keys := make([]*SetupKey, 0, len(setupKeys))
|
||||
for _, key := range setupKeys {
|
||||
var k *SetupKey
|
||||
if !user.IsAdminOrServiceUser() {
|
||||
k = key.HiddenCopy(999)
|
||||
} else {
|
||||
k = key.Copy()
|
||||
}
|
||||
keys = append(keys, k)
|
||||
}
|
||||
|
||||
return keys, nil
|
||||
return setupKeys, nil
|
||||
}
|
||||
|
||||
// GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found.
|
||||
@@ -364,7 +367,7 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use
|
||||
}
|
||||
|
||||
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||
return nil, status.Errorf(status.Unauthorized, "only users with admin power can view setup keys")
|
||||
return nil, status.NewUnauthorizedToViewSetupKeysError()
|
||||
}
|
||||
|
||||
setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID)
|
||||
@@ -377,11 +380,33 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use
|
||||
setupKey.UpdatedAt = setupKey.CreatedAt
|
||||
}
|
||||
|
||||
if !user.IsAdminOrServiceUser() {
|
||||
setupKey = setupKey.HiddenCopy(999)
|
||||
return setupKey, nil
|
||||
}
|
||||
|
||||
// DeleteSetupKey removes the setup key from the account
|
||||
func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get user: %w", err)
|
||||
}
|
||||
|
||||
return setupKey, nil
|
||||
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||
return status.NewUnauthorizedToViewSetupKeysError()
|
||||
}
|
||||
|
||||
deletedSetupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get setup key: %w", err)
|
||||
}
|
||||
|
||||
err = am.Store.DeleteSetupKey(ctx, accountID, keyID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete setup key: %w", err)
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, keyID, accountID, activity.SetupKeyDeleted, deletedSetupKey.EventMeta())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateSetupKeyAutoGroups(account *Account, autoGroups []string) error {
|
||||
|
||||
@@ -2,8 +2,11 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -66,7 +69,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
|
||||
}
|
||||
|
||||
assertKey(t, newKey, newKeyName, revoked, "reusable", 0, key.CreatedAt, key.ExpiresAt,
|
||||
key.Id, time.Now().UTC(), autoGroups)
|
||||
key.Id, time.Now().UTC(), autoGroups, true)
|
||||
|
||||
// check the corresponding events that should have been generated
|
||||
ev := getEvent(t, account.Id, manager, activity.SetupKeyRevoked)
|
||||
@@ -183,7 +186,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
|
||||
|
||||
assertKey(t, key, tCase.expectedKeyName, false, tCase.expectedType, tCase.expectedUsedTimes,
|
||||
tCase.expectedCreatedAt, tCase.expectedExpiresAt, strconv.Itoa(int(Hash(key.Key))),
|
||||
tCase.expectedUpdatedAt, tCase.expectedGroups)
|
||||
tCase.expectedUpdatedAt, tCase.expectedGroups, false)
|
||||
|
||||
// check the corresponding events that should have been generated
|
||||
ev := getEvent(t, account.Id, manager, activity.SetupKeyCreated)
|
||||
@@ -239,10 +242,10 @@ func TestGenerateDefaultSetupKey(t *testing.T) {
|
||||
expectedExpiresAt := time.Now().UTC().Add(24 * 30 * time.Hour)
|
||||
var expectedAutoGroups []string
|
||||
|
||||
key := GenerateDefaultSetupKey()
|
||||
key, plainKey := GenerateDefaultSetupKey()
|
||||
|
||||
assertKey(t, key, expectedName, expectedRevoke, expectedType, expectedUsedTimes, expectedCreatedAt,
|
||||
expectedExpiresAt, strconv.Itoa(int(Hash(key.Key))), expectedUpdatedAt, expectedAutoGroups)
|
||||
expectedExpiresAt, strconv.Itoa(int(Hash(plainKey))), expectedUpdatedAt, expectedAutoGroups, true)
|
||||
|
||||
}
|
||||
|
||||
@@ -256,41 +259,41 @@ func TestGenerateSetupKey(t *testing.T) {
|
||||
expectedUpdatedAt := time.Now().UTC()
|
||||
var expectedAutoGroups []string
|
||||
|
||||
key := GenerateSetupKey(expectedName, SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false)
|
||||
key, plain := GenerateSetupKey(expectedName, SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false)
|
||||
|
||||
assertKey(t, key, expectedName, expectedRevoke, expectedType, expectedUsedTimes, expectedCreatedAt,
|
||||
expectedExpiresAt, strconv.Itoa(int(Hash(key.Key))), expectedUpdatedAt, expectedAutoGroups)
|
||||
expectedExpiresAt, strconv.Itoa(int(Hash(plain))), expectedUpdatedAt, expectedAutoGroups, true)
|
||||
|
||||
}
|
||||
|
||||
func TestSetupKey_IsValid(t *testing.T) {
|
||||
validKey := GenerateSetupKey("valid key", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false)
|
||||
validKey, _ := GenerateSetupKey("valid key", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false)
|
||||
if !validKey.IsValid() {
|
||||
t.Errorf("expected key to be valid, got invalid %v", validKey)
|
||||
}
|
||||
|
||||
// expired
|
||||
expiredKey := GenerateSetupKey("invalid key", SetupKeyOneOff, -time.Hour, []string{}, SetupKeyUnlimitedUsage, false)
|
||||
expiredKey, _ := GenerateSetupKey("invalid key", SetupKeyOneOff, -time.Hour, []string{}, SetupKeyUnlimitedUsage, false)
|
||||
if expiredKey.IsValid() {
|
||||
t.Errorf("expected key to be invalid due to expiration, got valid %v", expiredKey)
|
||||
}
|
||||
|
||||
// revoked
|
||||
revokedKey := GenerateSetupKey("invalid key", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false)
|
||||
revokedKey, _ := GenerateSetupKey("invalid key", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false)
|
||||
revokedKey.Revoked = true
|
||||
if revokedKey.IsValid() {
|
||||
t.Errorf("expected revoked key to be invalid, got valid %v", revokedKey)
|
||||
}
|
||||
|
||||
// overused
|
||||
overUsedKey := GenerateSetupKey("invalid key", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false)
|
||||
overUsedKey, _ := GenerateSetupKey("invalid key", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false)
|
||||
overUsedKey.UsedTimes = 1
|
||||
if overUsedKey.IsValid() {
|
||||
t.Errorf("expected overused key to be invalid, got valid %v", overUsedKey)
|
||||
}
|
||||
|
||||
// overused
|
||||
reusableKey := GenerateSetupKey("valid key", SetupKeyReusable, time.Hour, []string{}, SetupKeyUnlimitedUsage, false)
|
||||
reusableKey, _ := GenerateSetupKey("valid key", SetupKeyReusable, time.Hour, []string{}, SetupKeyUnlimitedUsage, false)
|
||||
reusableKey.UsedTimes = 99
|
||||
if !reusableKey.IsValid() {
|
||||
t.Errorf("expected reusable key to be valid when used many times, got valid %v", reusableKey)
|
||||
@@ -299,7 +302,7 @@ func TestSetupKey_IsValid(t *testing.T) {
|
||||
|
||||
func assertKey(t *testing.T, key *SetupKey, expectedName string, expectedRevoke bool, expectedType string,
|
||||
expectedUsedTimes int, expectedCreatedAt time.Time, expectedExpiresAt time.Time, expectedID string,
|
||||
expectedUpdatedAt time.Time, expectedAutoGroups []string) {
|
||||
expectedUpdatedAt time.Time, expectedAutoGroups []string, expectHashedKey bool) {
|
||||
t.Helper()
|
||||
if key.Name != expectedName {
|
||||
t.Errorf("expected setup key to have Name %v, got %v", expectedName, key.Name)
|
||||
@@ -329,13 +332,23 @@ func assertKey(t *testing.T, key *SetupKey, expectedName string, expectedRevoke
|
||||
t.Errorf("expected setup key to have CreatedAt ~ %v, got %v", expectedCreatedAt, key.CreatedAt)
|
||||
}
|
||||
|
||||
_, err := uuid.Parse(key.Key)
|
||||
if err != nil {
|
||||
t.Errorf("expected key to be a valid UUID, got %v, %v", key.Key, err)
|
||||
if expectHashedKey {
|
||||
if !isValidBase64SHA256(key.Key) {
|
||||
t.Errorf("expected key to be hashed, got %v", key.Key)
|
||||
}
|
||||
} else {
|
||||
_, err := uuid.Parse(key.Key)
|
||||
if err != nil {
|
||||
t.Errorf("expected key to be a valid UUID, got %v, %v", key.Key, err)
|
||||
}
|
||||
}
|
||||
|
||||
if key.Id != strconv.Itoa(int(Hash(key.Key))) {
|
||||
t.Errorf("expected key Id t= %v, got %v", expectedID, key.Id)
|
||||
if !strings.HasSuffix(key.KeySecret, "****") {
|
||||
t.Errorf("expected key secret to be secure, got %v", key.Key)
|
||||
}
|
||||
|
||||
if key.Id != expectedID {
|
||||
t.Errorf("expected key Id %v, got %v", expectedID, key.Id)
|
||||
}
|
||||
|
||||
if len(key.AutoGroups) != len(expectedAutoGroups) {
|
||||
@@ -344,13 +357,26 @@ func assertKey(t *testing.T, key *SetupKey, expectedName string, expectedRevoke
|
||||
assert.ElementsMatch(t, key.AutoGroups, expectedAutoGroups, "expected key AutoGroups to be equal")
|
||||
}
|
||||
|
||||
func isValidBase64SHA256(encodedKey string) bool {
|
||||
decoded, err := base64.StdEncoding.DecodeString(encodedKey)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if len(decoded) != sha256.Size {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func TestSetupKey_Copy(t *testing.T) {
|
||||
|
||||
key := GenerateSetupKey("key name", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false)
|
||||
key, _ := GenerateSetupKey("key name", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false)
|
||||
keyCopy := key.Copy()
|
||||
|
||||
assertKey(t, keyCopy, key.Name, key.Revoked, string(key.Type), key.UsedTimes, key.CreatedAt, key.ExpiresAt, key.Id,
|
||||
key.UpdatedAt, key.AutoGroups)
|
||||
key.UpdatedAt, key.AutoGroups, true)
|
||||
|
||||
}
|
||||
|
||||
|
||||
@@ -292,6 +292,8 @@ func (s *SqlStore) GetInstallationID() string {
|
||||
}
|
||||
|
||||
func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error {
|
||||
startTime := time.Now()
|
||||
|
||||
// To maintain data integrity, we create a copy of the peer's to prevent unintended updates to other fields.
|
||||
peerCopy := peer.Copy()
|
||||
peerCopy.AccountID = accountID
|
||||
@@ -317,6 +319,9 @@ func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -324,6 +329,8 @@ func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.
|
||||
}
|
||||
|
||||
func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error {
|
||||
startTime := time.Now()
|
||||
|
||||
accountCopy := Account{
|
||||
Domain: domain,
|
||||
DomainCategory: category,
|
||||
@@ -336,6 +343,9 @@ func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID
|
||||
Where(idQueryCondition, accountID).
|
||||
Updates(&accountCopy)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return result.Error
|
||||
}
|
||||
|
||||
@@ -347,6 +357,8 @@ func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID
|
||||
}
|
||||
|
||||
func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
|
||||
startTime := time.Now()
|
||||
|
||||
var peerCopy nbpeer.Peer
|
||||
peerCopy.Status = &peerStatus
|
||||
|
||||
@@ -359,6 +371,9 @@ func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.Pe
|
||||
Where(accountAndIDQueryCondition, accountID, peerID).
|
||||
Updates(&peerCopy)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return result.Error
|
||||
}
|
||||
|
||||
@@ -370,6 +385,8 @@ func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.Pe
|
||||
}
|
||||
|
||||
func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.Peer) error {
|
||||
startTime := time.Now()
|
||||
|
||||
// To maintain data integrity, we create a copy of the peer's location to prevent unintended updates to other fields.
|
||||
var peerCopy nbpeer.Peer
|
||||
// Since the location field has been migrated to JSON serialization,
|
||||
@@ -381,6 +398,9 @@ func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.P
|
||||
Updates(peerCopy)
|
||||
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return result.Error
|
||||
}
|
||||
|
||||
@@ -394,6 +414,8 @@ func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.P
|
||||
// SaveUsers saves the given list of users to the database.
|
||||
// It updates existing users if a conflict occurs.
|
||||
func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error {
|
||||
startTime := time.Now()
|
||||
|
||||
usersToSave := make([]User, 0, len(users))
|
||||
for _, user := range users {
|
||||
user.AccountID = accountID
|
||||
@@ -403,15 +425,28 @@ func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error {
|
||||
}
|
||||
usersToSave = append(usersToSave, *user)
|
||||
}
|
||||
return s.db.Session(&gorm.Session{FullSaveAssociations: true}).
|
||||
err := s.db.Session(&gorm.Session{FullSaveAssociations: true}).
|
||||
Clauses(clause.OnConflict{UpdateAll: true}).
|
||||
Create(&usersToSave).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return status.Errorf(status.Internal, "failed to save users to store: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SaveUser saves the given user to the database.
|
||||
func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error {
|
||||
startTime := time.Now()
|
||||
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(user)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return status.Errorf(status.Internal, "failed to save user to store: %v", result.Error)
|
||||
}
|
||||
return nil
|
||||
@@ -419,12 +454,17 @@ func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, u
|
||||
|
||||
// SaveGroups saves the given list of groups to the database.
|
||||
func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error {
|
||||
startTime := time.Now()
|
||||
|
||||
if len(groups) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&groups)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return status.Errorf(status.Internal, "failed to save groups to store: %v", result.Error)
|
||||
}
|
||||
return nil
|
||||
@@ -451,6 +491,8 @@ func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string)
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var accountID string
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("id").
|
||||
Where("domain = ? and is_domain_primary_account = ? and domain_category = ?",
|
||||
@@ -460,6 +502,9 @@ func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return "", status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return "", status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
log.WithContext(ctx).Errorf("error when getting account from the store: %s", result.Error)
|
||||
return "", status.NewGetAccountFromStoreError(result.Error)
|
||||
}
|
||||
@@ -468,12 +513,17 @@ func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var key SetupKey
|
||||
result := s.db.WithContext(ctx).Select("account_id").First(&key, keyQueryCondition, strings.ToUpper(setupKey))
|
||||
result := s.db.WithContext(ctx).Select("account_id").First(&key, keyQueryCondition, setupKey)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return nil, status.NewSetupKeyNotFoundError(result.Error)
|
||||
}
|
||||
|
||||
@@ -485,12 +535,17 @@ func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken string) (string, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var token PersonalAccessToken
|
||||
result := s.db.First(&token, "hashed_token = ?", hashedToken)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return "", status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error)
|
||||
return "", status.NewGetAccountFromStoreError(result.Error)
|
||||
}
|
||||
@@ -499,12 +554,17 @@ func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken stri
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var token PersonalAccessToken
|
||||
result := s.db.First(&token, idQueryCondition, tokenID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error)
|
||||
return nil, status.NewGetAccountFromStoreError(result.Error)
|
||||
}
|
||||
@@ -528,6 +588,8 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User,
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var user User
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Preload(clause.Associations).First(&user, idQueryCondition, userID)
|
||||
@@ -535,6 +597,9 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewUserNotFoundError(userID)
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return nil, status.NewGetUserFromStoreError()
|
||||
}
|
||||
|
||||
@@ -542,12 +607,17 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var users []*User
|
||||
result := s.db.Find(&users, accountIDCondition, accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
log.WithContext(ctx).Errorf("error when getting users from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "issue getting users from store")
|
||||
}
|
||||
@@ -556,12 +626,17 @@ func (s *SqlStore) GetAccountUsers(ctx context.Context, accountID string) ([]*Us
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var groups []*nbgroup.Group
|
||||
result := s.db.Find(&groups, accountIDCondition, accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
log.WithContext(ctx).Errorf("error when getting groups from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "issue getting groups from store")
|
||||
}
|
||||
@@ -661,12 +736,17 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account,
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Account, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var user User
|
||||
result := s.db.WithContext(ctx).Select("account_id").First(&user, idQueryCondition, userID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return nil, status.NewGetAccountFromStoreError(result.Error)
|
||||
}
|
||||
|
||||
@@ -678,12 +758,17 @@ func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Accoun
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var peer nbpeer.Peer
|
||||
result := s.db.WithContext(ctx).Select("account_id").First(&peer, idQueryCondition, peerID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return nil, status.NewGetAccountFromStoreError(result.Error)
|
||||
}
|
||||
|
||||
@@ -695,13 +780,17 @@ func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Acco
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) {
|
||||
var peer nbpeer.Peer
|
||||
startTime := time.Now()
|
||||
|
||||
var peer nbpeer.Peer
|
||||
result := s.db.WithContext(ctx).Select("account_id").First(&peer, keyQueryCondition, peerKey)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return nil, status.NewGetAccountFromStoreError(result.Error)
|
||||
}
|
||||
|
||||
@@ -713,6 +802,8 @@ func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var peer nbpeer.Peer
|
||||
var accountID string
|
||||
result := s.db.WithContext(ctx).Model(&peer).Select("account_id").Where(keyQueryCondition, peerKey).First(&accountID)
|
||||
@@ -720,6 +811,9 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string)
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return "", status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return "", status.NewGetAccountFromStoreError(result.Error)
|
||||
}
|
||||
|
||||
@@ -727,12 +821,17 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string)
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var accountID string
|
||||
result := s.db.Model(&User{}).Select("account_id").Where(idQueryCondition, userID).First(&accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return "", status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return "", status.NewGetAccountFromStoreError(result.Error)
|
||||
}
|
||||
|
||||
@@ -740,12 +839,17 @@ func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) (string, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var accountID string
|
||||
result := s.db.WithContext(ctx).Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, strings.ToUpper(setupKey)).First(&accountID)
|
||||
result := s.db.WithContext(ctx).Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, setupKey).First(&accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return "", status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return "", status.NewSetupKeyNotFoundError(result.Error)
|
||||
}
|
||||
|
||||
@@ -757,6 +861,8 @@ func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string)
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var ipJSONStrings []string
|
||||
|
||||
// Fetch the IP addresses as JSON strings
|
||||
@@ -767,6 +873,9 @@ func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "no peers found for the account")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return nil, status.Errorf(status.Internal, "issue getting IPs from store: %s", result.Error)
|
||||
}
|
||||
|
||||
@@ -784,8 +893,9 @@ func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) {
|
||||
var labels []string
|
||||
startTime := time.Now()
|
||||
|
||||
var labels []string
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}).
|
||||
Where("account_id = ?", accountID).
|
||||
Pluck("dns_label", &labels)
|
||||
@@ -794,6 +904,9 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "no peers found for the account")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
log.WithContext(ctx).Errorf("error when getting dns labels from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "issue getting dns labels from store: %s", result.Error)
|
||||
}
|
||||
@@ -802,24 +915,33 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*Network, error) {
|
||||
var accountNetwork AccountNetwork
|
||||
startTime := time.Now()
|
||||
|
||||
var accountNetwork AccountNetwork
|
||||
if err := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewAccountNotFoundError(accountID)
|
||||
}
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return nil, status.Errorf(status.Internal, "issue getting network from store: %s", err)
|
||||
}
|
||||
return accountNetwork.Network, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var peer nbpeer.Peer
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).First(&peer, keyQueryCondition, peerKey)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "peer not found")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return nil, status.Errorf(status.Internal, "issue getting peer from store: %s", result.Error)
|
||||
}
|
||||
|
||||
@@ -827,11 +949,16 @@ func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength Locking
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var accountSettings AccountSettings
|
||||
if err := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "settings not found")
|
||||
}
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return nil, status.Errorf(status.Internal, "issue getting settings from store: %s", err)
|
||||
}
|
||||
return accountSettings.Settings, nil
|
||||
@@ -839,13 +966,17 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingS
|
||||
|
||||
// SaveUserLastLogin stores the last login time for a user in DB.
|
||||
func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error {
|
||||
var user User
|
||||
startTime := time.Now()
|
||||
|
||||
var user User
|
||||
result := s.db.WithContext(ctx).First(&user, accountAndIDQueryCondition, accountID, userID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return status.NewUserNotFoundError(userID)
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return status.NewGetUserFromStoreError()
|
||||
}
|
||||
user.LastLogin = lastLogin
|
||||
@@ -854,6 +985,8 @@ func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID stri
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
definitionJSON, err := json.Marshal(checks)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -862,6 +995,9 @@ func (s *SqlStore) GetPostureCheckByChecksDefinition(accountID string, checks *p
|
||||
var postureCheck posture.Checks
|
||||
err = s.db.Where("account_id = ? AND checks = ?", accountID, string(definitionJSON)).First(&postureCheck).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -971,19 +1107,26 @@ func NewPostgresqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore,
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var setupKey SetupKey
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
First(&setupKey, keyQueryCondition, strings.ToUpper(key))
|
||||
First(&setupKey, keyQueryCondition, key)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "setup key not found")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return nil, status.NewSetupKeyNotFoundError(result.Error)
|
||||
}
|
||||
return &setupKey, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error {
|
||||
startTime := time.Now()
|
||||
|
||||
result := s.db.WithContext(ctx).Model(&SetupKey{}).
|
||||
Where(idQueryCondition, setupKeyID).
|
||||
Updates(map[string]interface{}{
|
||||
@@ -992,6 +1135,9 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string
|
||||
})
|
||||
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return status.Errorf(status.Internal, "issue incrementing setup key usage count: %s", result.Error)
|
||||
}
|
||||
|
||||
@@ -1003,13 +1149,17 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string
|
||||
}
|
||||
|
||||
func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
|
||||
var group nbgroup.Group
|
||||
startTime := time.Now()
|
||||
|
||||
var group nbgroup.Group
|
||||
result := s.db.WithContext(ctx).Where("account_id = ? AND name = ?", accountID, "All").First(&group)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return status.Errorf(status.NotFound, "group 'All' not found for account")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return status.Errorf(status.Internal, "issue finding group 'All': %s", result.Error)
|
||||
}
|
||||
|
||||
@@ -1022,6 +1172,9 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer
|
||||
group.Peers = append(group.Peers, peerID)
|
||||
|
||||
if err := s.db.Save(&group).Error; err != nil {
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return status.Errorf(status.Internal, "issue updating group 'All': %s", err)
|
||||
}
|
||||
|
||||
@@ -1029,13 +1182,17 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer
|
||||
}
|
||||
|
||||
func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error {
|
||||
var group nbgroup.Group
|
||||
startTime := time.Now()
|
||||
|
||||
var group nbgroup.Group
|
||||
result := s.db.WithContext(ctx).Where(accountAndIDQueryCondition, accountId, groupID).First(&group)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return status.Errorf(status.NotFound, "group not found for account")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return status.Errorf(status.Internal, "issue finding group: %s", result.Error)
|
||||
}
|
||||
|
||||
@@ -1048,6 +1205,9 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId
|
||||
group.Peers = append(group.Peers, peerId)
|
||||
|
||||
if err := s.db.Save(&group).Error; err != nil {
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return status.Errorf(status.Internal, "issue updating group: %s", err)
|
||||
}
|
||||
|
||||
@@ -1060,7 +1220,12 @@ func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrengt
|
||||
}
|
||||
|
||||
func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
|
||||
startTime := time.Now()
|
||||
|
||||
if err := s.db.WithContext(ctx).Create(peer).Error; err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return status.Errorf(status.Internal, "issue adding peer to account: %s", err)
|
||||
}
|
||||
|
||||
@@ -1068,8 +1233,13 @@ func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) erro
|
||||
}
|
||||
|
||||
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
|
||||
startTime := time.Now()
|
||||
|
||||
result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return status.Errorf(status.Internal, "issue incrementing network serial count: %s", result.Error)
|
||||
}
|
||||
return nil
|
||||
@@ -1100,14 +1270,18 @@ func (s *SqlStore) GetDB() *gorm.DB {
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error) {
|
||||
var accountDNSSettings AccountDNSSettings
|
||||
startTime := time.Now()
|
||||
|
||||
var accountDNSSettings AccountDNSSettings
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
|
||||
First(&accountDNSSettings, idQueryCondition, accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "dns settings not found")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return nil, status.Errorf(status.Internal, "failed to get dns settings from store: %v", result.Error)
|
||||
}
|
||||
return &accountDNSSettings.DNSSettings, nil
|
||||
@@ -1115,14 +1289,18 @@ func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength Locki
|
||||
|
||||
// AccountExists checks whether an account exists by the given ID.
|
||||
func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) {
|
||||
var accountID string
|
||||
startTime := time.Now()
|
||||
|
||||
var accountID string
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
|
||||
Select("id").First(&accountID, idQueryCondition, id)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return false, nil
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return false, status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return false, result.Error
|
||||
}
|
||||
|
||||
@@ -1131,14 +1309,18 @@ func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStreng
|
||||
|
||||
// GetAccountDomainAndCategory retrieves the Domain and DomainCategory fields for an account based on the given accountID.
|
||||
func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) {
|
||||
var account Account
|
||||
startTime := time.Now()
|
||||
|
||||
var account Account
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("domain", "domain_category").
|
||||
Where(idQueryCondition, accountID).First(&account)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return "", "", status.Errorf(status.NotFound, "account not found")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return "", "", status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return "", "", status.Errorf(status.Internal, "failed to get domain category from store: %v", result.Error)
|
||||
}
|
||||
|
||||
@@ -1232,6 +1414,10 @@ func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength Lock
|
||||
return getRecordByID[nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, nsGroupID, accountID)
|
||||
}
|
||||
|
||||
func (s *SqlStore) DeleteSetupKey(ctx context.Context, accountID, keyID string) error {
|
||||
return deleteRecordByID[SetupKey](s.db.WithContext(ctx), LockingStrengthUpdate, keyID, accountID)
|
||||
}
|
||||
|
||||
// getRecords retrieves records from the database based on the account ID.
|
||||
func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]T, error) {
|
||||
var record []T
|
||||
@@ -1264,3 +1450,21 @@ func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, a
|
||||
}
|
||||
return &record, nil
|
||||
}
|
||||
|
||||
// deleteRecordByID deletes a record by its ID and account ID from the database.
|
||||
func deleteRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, accountID string) error {
|
||||
var record T
|
||||
result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).Delete(record, accountAndIDQueryCondition, accountID, recordID)
|
||||
if err := result.Error; err != nil {
|
||||
parts := strings.Split(fmt.Sprintf("%T", record), ".")
|
||||
recordType := parts[len(parts)-1]
|
||||
|
||||
return status.Errorf(status.Internal, "failed to delete %s from store: %v", recordType, err)
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.Errorf(status.NotFound, "record not found")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -2,6 +2,8 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
b64 "encoding/base64"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
@@ -71,7 +73,7 @@ func runLargeTest(t *testing.T, store Store) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
setupKey := GenerateDefaultSetupKey()
|
||||
setupKey, _ := GenerateDefaultSetupKey()
|
||||
account.SetupKeys[setupKey.Key] = setupKey
|
||||
const numPerAccount = 6000
|
||||
for n := 0; n < numPerAccount; n++ {
|
||||
@@ -81,7 +83,6 @@ func runLargeTest(t *testing.T, store Store) {
|
||||
peer := &nbpeer.Peer{
|
||||
ID: peerID,
|
||||
Key: peerID,
|
||||
SetupKey: "",
|
||||
IP: netIP,
|
||||
Name: peerID,
|
||||
DNSLabel: peerID,
|
||||
@@ -133,7 +134,7 @@ func runLargeTest(t *testing.T, store Store) {
|
||||
}
|
||||
account.NameServerGroups[nameserver.ID] = nameserver
|
||||
|
||||
setupKey := GenerateDefaultSetupKey()
|
||||
setupKey, _ := GenerateDefaultSetupKey()
|
||||
account.SetupKeys[setupKey.Key] = setupKey
|
||||
}
|
||||
|
||||
@@ -215,30 +216,28 @@ func TestSqlite_SaveAccount(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
|
||||
setupKey := GenerateDefaultSetupKey()
|
||||
setupKey, _ := GenerateDefaultSetupKey()
|
||||
account.SetupKeys[setupKey.Key] = setupKey
|
||||
account.Peers["testpeer"] = &nbpeer.Peer{
|
||||
Key: "peerkey",
|
||||
SetupKey: "peerkeysetupkey",
|
||||
IP: net.IP{127, 0, 0, 1},
|
||||
Meta: nbpeer.PeerSystemMeta{},
|
||||
Name: "peer name",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
||||
Key: "peerkey",
|
||||
IP: net.IP{127, 0, 0, 1},
|
||||
Meta: nbpeer.PeerSystemMeta{},
|
||||
Name: "peer name",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
||||
}
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
|
||||
account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "")
|
||||
setupKey = GenerateDefaultSetupKey()
|
||||
setupKey, _ = GenerateDefaultSetupKey()
|
||||
account2.SetupKeys[setupKey.Key] = setupKey
|
||||
account2.Peers["testpeer2"] = &nbpeer.Peer{
|
||||
Key: "peerkey2",
|
||||
SetupKey: "peerkeysetupkey2",
|
||||
IP: net.IP{127, 0, 0, 2},
|
||||
Meta: nbpeer.PeerSystemMeta{},
|
||||
Name: "peer name 2",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
||||
Key: "peerkey2",
|
||||
IP: net.IP{127, 0, 0, 2},
|
||||
Meta: nbpeer.PeerSystemMeta{},
|
||||
Name: "peer name 2",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
||||
}
|
||||
|
||||
err = store.SaveAccount(context.Background(), account2)
|
||||
@@ -297,15 +296,14 @@ func TestSqlite_DeleteAccount(t *testing.T) {
|
||||
}}
|
||||
|
||||
account := newAccountWithId(context.Background(), "account_id", testUserID, "")
|
||||
setupKey := GenerateDefaultSetupKey()
|
||||
setupKey, _ := GenerateDefaultSetupKey()
|
||||
account.SetupKeys[setupKey.Key] = setupKey
|
||||
account.Peers["testpeer"] = &nbpeer.Peer{
|
||||
Key: "peerkey",
|
||||
SetupKey: "peerkeysetupkey",
|
||||
IP: net.IP{127, 0, 0, 1},
|
||||
Meta: nbpeer.PeerSystemMeta{},
|
||||
Name: "peer name",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
||||
Key: "peerkey",
|
||||
IP: net.IP{127, 0, 0, 1},
|
||||
Meta: nbpeer.PeerSystemMeta{},
|
||||
Name: "peer name",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
||||
}
|
||||
account.Users[testUserID] = user
|
||||
|
||||
@@ -394,13 +392,12 @@ func TestSqlite_SavePeer(t *testing.T) {
|
||||
|
||||
// save status of non-existing peer
|
||||
peer := &nbpeer.Peer{
|
||||
Key: "peerkey",
|
||||
ID: "testpeer",
|
||||
SetupKey: "peerkeysetupkey",
|
||||
IP: net.IP{127, 0, 0, 1},
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "testingpeer"},
|
||||
Name: "peer name",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
||||
Key: "peerkey",
|
||||
ID: "testpeer",
|
||||
IP: net.IP{127, 0, 0, 1},
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "testingpeer"},
|
||||
Name: "peer name",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
||||
}
|
||||
ctx := context.Background()
|
||||
err = store.SavePeer(ctx, account.Id, peer)
|
||||
@@ -453,13 +450,12 @@ func TestSqlite_SavePeerStatus(t *testing.T) {
|
||||
|
||||
// save new status of existing peer
|
||||
account.Peers["testpeer"] = &nbpeer.Peer{
|
||||
Key: "peerkey",
|
||||
ID: "testpeer",
|
||||
SetupKey: "peerkeysetupkey",
|
||||
IP: net.IP{127, 0, 0, 1},
|
||||
Meta: nbpeer.PeerSystemMeta{},
|
||||
Name: "peer name",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
||||
Key: "peerkey",
|
||||
ID: "testpeer",
|
||||
IP: net.IP{127, 0, 0, 1},
|
||||
Meta: nbpeer.PeerSystemMeta{},
|
||||
Name: "peer name",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
||||
}
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
@@ -720,15 +716,14 @@ func newSqliteStore(t *testing.T) *SqlStore {
|
||||
func newAccount(store Store, id int) error {
|
||||
str := fmt.Sprintf("%s-%d", uuid.New().String(), id)
|
||||
account := newAccountWithId(context.Background(), str, str+"-testuser", "example.com")
|
||||
setupKey := GenerateDefaultSetupKey()
|
||||
setupKey, _ := GenerateDefaultSetupKey()
|
||||
account.SetupKeys[setupKey.Key] = setupKey
|
||||
account.Peers["p"+str] = &nbpeer.Peer{
|
||||
Key: "peerkey" + str,
|
||||
SetupKey: "peerkeysetupkey",
|
||||
IP: net.IP{127, 0, 0, 1},
|
||||
Meta: nbpeer.PeerSystemMeta{},
|
||||
Name: "peer name",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
||||
Key: "peerkey" + str,
|
||||
IP: net.IP{127, 0, 0, 1},
|
||||
Meta: nbpeer.PeerSystemMeta{},
|
||||
Name: "peer name",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
||||
}
|
||||
|
||||
return store.SaveAccount(context.Background(), account)
|
||||
@@ -760,30 +755,28 @@ func TestPostgresql_SaveAccount(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
|
||||
setupKey := GenerateDefaultSetupKey()
|
||||
setupKey, _ := GenerateDefaultSetupKey()
|
||||
account.SetupKeys[setupKey.Key] = setupKey
|
||||
account.Peers["testpeer"] = &nbpeer.Peer{
|
||||
Key: "peerkey",
|
||||
SetupKey: "peerkeysetupkey",
|
||||
IP: net.IP{127, 0, 0, 1},
|
||||
Meta: nbpeer.PeerSystemMeta{},
|
||||
Name: "peer name",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
||||
Key: "peerkey",
|
||||
IP: net.IP{127, 0, 0, 1},
|
||||
Meta: nbpeer.PeerSystemMeta{},
|
||||
Name: "peer name",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
||||
}
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
|
||||
account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "")
|
||||
setupKey = GenerateDefaultSetupKey()
|
||||
setupKey, _ = GenerateDefaultSetupKey()
|
||||
account2.SetupKeys[setupKey.Key] = setupKey
|
||||
account2.Peers["testpeer2"] = &nbpeer.Peer{
|
||||
Key: "peerkey2",
|
||||
SetupKey: "peerkeysetupkey2",
|
||||
IP: net.IP{127, 0, 0, 2},
|
||||
Meta: nbpeer.PeerSystemMeta{},
|
||||
Name: "peer name 2",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
||||
Key: "peerkey2",
|
||||
IP: net.IP{127, 0, 0, 2},
|
||||
Meta: nbpeer.PeerSystemMeta{},
|
||||
Name: "peer name 2",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
||||
}
|
||||
|
||||
err = store.SaveAccount(context.Background(), account2)
|
||||
@@ -842,15 +835,14 @@ func TestPostgresql_DeleteAccount(t *testing.T) {
|
||||
}}
|
||||
|
||||
account := newAccountWithId(context.Background(), "account_id", testUserID, "")
|
||||
setupKey := GenerateDefaultSetupKey()
|
||||
setupKey, _ := GenerateDefaultSetupKey()
|
||||
account.SetupKeys[setupKey.Key] = setupKey
|
||||
account.Peers["testpeer"] = &nbpeer.Peer{
|
||||
Key: "peerkey",
|
||||
SetupKey: "peerkeysetupkey",
|
||||
IP: net.IP{127, 0, 0, 1},
|
||||
Meta: nbpeer.PeerSystemMeta{},
|
||||
Name: "peer name",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
||||
Key: "peerkey",
|
||||
IP: net.IP{127, 0, 0, 1},
|
||||
Meta: nbpeer.PeerSystemMeta{},
|
||||
Name: "peer name",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
||||
}
|
||||
account.Users[testUserID] = user
|
||||
|
||||
@@ -921,13 +913,12 @@ func TestPostgresql_SavePeerStatus(t *testing.T) {
|
||||
|
||||
// save new status of existing peer
|
||||
account.Peers["testpeer"] = &nbpeer.Peer{
|
||||
Key: "peerkey",
|
||||
ID: "testpeer",
|
||||
SetupKey: "peerkeysetupkey",
|
||||
IP: net.IP{127, 0, 0, 1},
|
||||
Meta: nbpeer.PeerSystemMeta{},
|
||||
Name: "peer name",
|
||||
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()},
|
||||
Key: "peerkey",
|
||||
ID: "testpeer",
|
||||
IP: net.IP{127, 0, 0, 1},
|
||||
Meta: nbpeer.PeerSystemMeta{},
|
||||
Name: "peer name",
|
||||
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()},
|
||||
}
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
@@ -1118,12 +1109,17 @@ func TestSqlite_GetSetupKeyBySecret(t *testing.T) {
|
||||
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
plainKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
|
||||
hashedKey := sha256.Sum256([]byte(plainKey))
|
||||
encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:])
|
||||
|
||||
_, err = store.GetAccount(context.Background(), existingAccountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
|
||||
setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, encodedHashedKey)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB", setupKey.Key)
|
||||
assert.Equal(t, encodedHashedKey, setupKey.Key)
|
||||
assert.Equal(t, hiddenKey(plainKey, 4), setupKey.KeySecret)
|
||||
assert.Equal(t, "bf1c8084-ba50-4ce7-9439-34653001fc3b", setupKey.AccountID)
|
||||
assert.Equal(t, "Default key", setupKey.Name)
|
||||
}
|
||||
@@ -1138,24 +1134,28 @@ func TestSqlite_incrementSetupKeyUsage(t *testing.T) {
|
||||
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
plainKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
|
||||
hashedKey := sha256.Sum256([]byte(plainKey))
|
||||
encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:])
|
||||
|
||||
_, err = store.GetAccount(context.Background(), existingAccountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
|
||||
setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, encodedHashedKey)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, setupKey.UsedTimes)
|
||||
|
||||
err = store.IncrementSetupKeyUsage(context.Background(), setupKey.Id)
|
||||
require.NoError(t, err)
|
||||
|
||||
setupKey, err = store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
|
||||
setupKey, err = store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, encodedHashedKey)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, setupKey.UsedTimes)
|
||||
|
||||
err = store.IncrementSetupKeyUsage(context.Background(), setupKey.Id)
|
||||
require.NoError(t, err)
|
||||
|
||||
setupKey, err = store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
|
||||
setupKey, err = store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, encodedHashedKey)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, setupKey.UsedTimes)
|
||||
}
|
||||
@@ -1264,3 +1264,32 @@ func TestSqlite_GetGroupByName(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "All", group.Name)
|
||||
}
|
||||
|
||||
func Test_DeleteSetupKeySuccessfully(t *testing.T) {
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
setupKeyID := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
|
||||
|
||||
err = store.DeleteSetupKey(context.Background(), accountID, setupKeyID)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = store.GetSetupKeyByID(context.Background(), LockingStrengthShare, setupKeyID, accountID)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func Test_DeleteSetupKeyFailsForNonExistingKey(t *testing.T) {
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
nonExistingKeyID := "non-existing-key-id"
|
||||
|
||||
err = store.DeleteSetupKey(context.Background(), accountID, nonExistingKeyID)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package status
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -114,3 +115,18 @@ func NewGetAccountFromStoreError(err error) error {
|
||||
func NewGetUserFromStoreError() error {
|
||||
return Errorf(Internal, "issue getting user from store")
|
||||
}
|
||||
|
||||
// NewStoreContextCanceledError creates a new Error with Internal type for a canceled store context
|
||||
func NewStoreContextCanceledError(duration time.Duration) error {
|
||||
return Errorf(Internal, "store access: context canceled after %v", duration)
|
||||
}
|
||||
|
||||
// NewInvalidKeyIDError creates a new Error with InvalidArgument type for an issue getting a setup key
|
||||
func NewInvalidKeyIDError() error {
|
||||
return Errorf(InvalidArgument, "invalid key ID")
|
||||
}
|
||||
|
||||
// NewUnauthorizedToViewSetupKeysError creates a new Error with Unauthorized type for an issue getting a setup key
|
||||
func NewUnauthorizedToViewSetupKeysError() error {
|
||||
return Errorf(Unauthorized, "only users with admin power can view setup keys")
|
||||
}
|
||||
|
||||
@@ -124,6 +124,7 @@ type Store interface {
|
||||
// This is also a method of metrics.DataSource interface.
|
||||
GetStoreEngine() StoreEngine
|
||||
ExecuteInTransaction(ctx context.Context, f func(store Store) error) error
|
||||
DeleteSetupKey(ctx context.Context, accountID, keyID string) error
|
||||
}
|
||||
|
||||
type StoreEngine string
|
||||
@@ -241,6 +242,9 @@ func getMigrations(ctx context.Context) []migrationFunc {
|
||||
func(db *gorm.DB) error {
|
||||
return migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](ctx, db, "ip", "idx_peers_account_id_ip")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.MigrateSetupKeyToHashedSetupKey[SetupKey](ctx, db)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,13 +2,9 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime/debug"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/differs"
|
||||
"github.com/r3labs/diff/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
@@ -25,8 +21,6 @@ type UpdateMessage struct {
|
||||
type PeersUpdateManager struct {
|
||||
// peerChannels is an update channel indexed by Peer.ID
|
||||
peerChannels map[string]chan *UpdateMessage
|
||||
// peerNetworkMaps is the UpdateMessage indexed by Peer.ID.
|
||||
peerUpdateMessage map[string]*UpdateMessage
|
||||
// channelsMux keeps the mutex to access peerChannels
|
||||
channelsMux *sync.RWMutex
|
||||
// metrics provides method to collect application metrics
|
||||
@@ -36,10 +30,9 @@ type PeersUpdateManager struct {
|
||||
// NewPeersUpdateManager returns a new instance of PeersUpdateManager
|
||||
func NewPeersUpdateManager(metrics telemetry.AppMetrics) *PeersUpdateManager {
|
||||
return &PeersUpdateManager{
|
||||
peerChannels: make(map[string]chan *UpdateMessage),
|
||||
peerUpdateMessage: make(map[string]*UpdateMessage),
|
||||
channelsMux: &sync.RWMutex{},
|
||||
metrics: metrics,
|
||||
peerChannels: make(map[string]chan *UpdateMessage),
|
||||
channelsMux: &sync.RWMutex{},
|
||||
metrics: metrics,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -48,15 +41,6 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda
|
||||
start := time.Now()
|
||||
var found, dropped bool
|
||||
|
||||
// skip sending sync update to the peer if there is no change in update message,
|
||||
// it will not check on turn credential refresh as we do not send network map or client posture checks
|
||||
if update.NetworkMap != nil {
|
||||
updated := p.handlePeerMessageUpdate(ctx, peerID, update)
|
||||
if !updated {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
p.channelsMux.Lock()
|
||||
|
||||
defer func() {
|
||||
@@ -66,16 +50,6 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda
|
||||
}
|
||||
}()
|
||||
|
||||
if update.NetworkMap != nil {
|
||||
lastSentUpdate := p.peerUpdateMessage[peerID]
|
||||
if lastSentUpdate != nil && lastSentUpdate.Update.NetworkMap.GetSerial() > update.Update.NetworkMap.GetSerial() {
|
||||
log.WithContext(ctx).Debugf("peer %s new network map serial: %d not greater than last sent: %d, skip sending update",
|
||||
peerID, update.Update.NetworkMap.GetSerial(), lastSentUpdate.Update.NetworkMap.GetSerial())
|
||||
return
|
||||
}
|
||||
p.peerUpdateMessage[peerID] = update
|
||||
}
|
||||
|
||||
if channel, ok := p.peerChannels[peerID]; ok {
|
||||
found = true
|
||||
select {
|
||||
@@ -108,7 +82,6 @@ func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) c
|
||||
closed = true
|
||||
delete(p.peerChannels, peerID)
|
||||
close(channel)
|
||||
delete(p.peerUpdateMessage, peerID)
|
||||
}
|
||||
// mbragin: todo shouldn't it be more? or configurable?
|
||||
channel := make(chan *UpdateMessage, channelBufferSize)
|
||||
@@ -123,7 +96,6 @@ func (p *PeersUpdateManager) closeChannel(ctx context.Context, peerID string) {
|
||||
if channel, ok := p.peerChannels[peerID]; ok {
|
||||
delete(p.peerChannels, peerID)
|
||||
close(channel)
|
||||
delete(p.peerUpdateMessage, peerID)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("closed updates channel of a peer %s", peerID)
|
||||
@@ -200,72 +172,3 @@ func (p *PeersUpdateManager) HasChannel(peerID string) bool {
|
||||
|
||||
return ok
|
||||
}
|
||||
|
||||
// handlePeerMessageUpdate checks if the update message for a peer is new and should be sent.
|
||||
func (p *PeersUpdateManager) handlePeerMessageUpdate(ctx context.Context, peerID string, update *UpdateMessage) bool {
|
||||
p.channelsMux.RLock()
|
||||
lastSentUpdate := p.peerUpdateMessage[peerID]
|
||||
p.channelsMux.RUnlock()
|
||||
|
||||
if lastSentUpdate != nil {
|
||||
updated, err := isNewPeerUpdateMessage(ctx, lastSentUpdate, update)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error checking for SyncResponse updates: %v", err)
|
||||
return false
|
||||
}
|
||||
if !updated {
|
||||
log.WithContext(ctx).Debugf("peer %s network map is not updated, skip sending update", peerID)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// isNewPeerUpdateMessage checks if the given current update message is a new update that should be sent.
|
||||
func isNewPeerUpdateMessage(ctx context.Context, lastSentUpdate, currUpdateToSend *UpdateMessage) (isNew bool, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.WithContext(ctx).Panicf("comparing peer update messages. Trace: %s", debug.Stack())
|
||||
isNew, err = true, nil
|
||||
}
|
||||
}()
|
||||
|
||||
if lastSentUpdate.Update.NetworkMap.GetSerial() > currUpdateToSend.Update.NetworkMap.GetSerial() {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
differ, err := diff.NewDiffer(
|
||||
diff.CustomValueDiffers(&differs.NetIPAddr{}),
|
||||
diff.CustomValueDiffers(&differs.NetIPPrefix{}),
|
||||
)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to create differ: %v", err)
|
||||
}
|
||||
|
||||
lastSentFiles := getChecksFiles(lastSentUpdate.Update.Checks)
|
||||
currFiles := getChecksFiles(currUpdateToSend.Update.Checks)
|
||||
|
||||
changelog, err := differ.Diff(lastSentFiles, currFiles)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to diff checks: %v", err)
|
||||
}
|
||||
if len(changelog) > 0 {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
changelog, err = differ.Diff(lastSentUpdate.NetworkMap, currUpdateToSend.NetworkMap)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to diff network map: %v", err)
|
||||
}
|
||||
return len(changelog) > 0, nil
|
||||
}
|
||||
|
||||
// getChecksFiles returns a list of files from the given checks.
|
||||
func getChecksFiles(checks []*proto.Checks) []string {
|
||||
files := make([]string, 0, len(checks))
|
||||
for _, check := range checks {
|
||||
files = append(files, check.GetFiles()...)
|
||||
}
|
||||
return files
|
||||
}
|
||||
|
||||
@@ -2,19 +2,10 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
nbroute "github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// var peersUpdater *PeersUpdateManager
|
||||
@@ -86,470 +77,3 @@ func TestCloseChannel(t *testing.T) {
|
||||
t.Error("Error closing the channel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlePeerMessageUpdate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
peerID string
|
||||
existingUpdate *UpdateMessage
|
||||
newUpdate *UpdateMessage
|
||||
expectedResult bool
|
||||
}{
|
||||
{
|
||||
name: "update message with turn credentials update",
|
||||
peerID: "peer",
|
||||
newUpdate: &UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
WiretrusteeConfig: &proto.WiretrusteeConfig{},
|
||||
},
|
||||
},
|
||||
expectedResult: true,
|
||||
},
|
||||
{
|
||||
name: "update message for peer without existing update",
|
||||
peerID: "peer1",
|
||||
newUpdate: &UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
NetworkMap: &proto.NetworkMap{Serial: 1},
|
||||
},
|
||||
NetworkMap: &NetworkMap{Network: &Network{Serial: 2}},
|
||||
},
|
||||
expectedResult: true,
|
||||
},
|
||||
{
|
||||
name: "update message with no changes in update",
|
||||
peerID: "peer2",
|
||||
existingUpdate: &UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
NetworkMap: &proto.NetworkMap{Serial: 1},
|
||||
},
|
||||
NetworkMap: &NetworkMap{Network: &Network{Serial: 1}},
|
||||
},
|
||||
newUpdate: &UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
NetworkMap: &proto.NetworkMap{Serial: 1},
|
||||
},
|
||||
NetworkMap: &NetworkMap{Network: &Network{Serial: 1}},
|
||||
},
|
||||
expectedResult: false,
|
||||
},
|
||||
{
|
||||
name: "update message with changes in checks",
|
||||
peerID: "peer3",
|
||||
existingUpdate: &UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
NetworkMap: &proto.NetworkMap{Serial: 1},
|
||||
},
|
||||
NetworkMap: &NetworkMap{Network: &Network{Serial: 1}},
|
||||
},
|
||||
newUpdate: &UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
NetworkMap: &proto.NetworkMap{Serial: 2},
|
||||
Checks: []*proto.Checks{
|
||||
{
|
||||
Files: []string{"/usr/bin/netbird"},
|
||||
},
|
||||
},
|
||||
},
|
||||
NetworkMap: &NetworkMap{Network: &Network{Serial: 2}},
|
||||
},
|
||||
expectedResult: true,
|
||||
},
|
||||
{
|
||||
name: "update message with lower serial number",
|
||||
peerID: "peer4",
|
||||
existingUpdate: &UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
NetworkMap: &proto.NetworkMap{Serial: 2},
|
||||
},
|
||||
NetworkMap: &NetworkMap{Network: &Network{Serial: 2}},
|
||||
},
|
||||
newUpdate: &UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
NetworkMap: &proto.NetworkMap{Serial: 1},
|
||||
},
|
||||
NetworkMap: &NetworkMap{Network: &Network{Serial: 1}},
|
||||
},
|
||||
expectedResult: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := NewPeersUpdateManager(nil)
|
||||
ctx := context.Background()
|
||||
|
||||
if tt.existingUpdate != nil {
|
||||
p.peerUpdateMessage[tt.peerID] = tt.existingUpdate
|
||||
}
|
||||
|
||||
result := p.handlePeerMessageUpdate(ctx, tt.peerID, tt.newUpdate)
|
||||
assert.Equal(t, tt.expectedResult, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsNewPeerUpdateMessage(t *testing.T) {
|
||||
t.Run("Unchanged value", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, message)
|
||||
})
|
||||
|
||||
t.Run("Unchanged value with serial incremented", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, message)
|
||||
})
|
||||
|
||||
t.Run("Updating routes network", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
newUpdateMessage2.NetworkMap.Routes[0].Network = netip.MustParsePrefix("1.1.1.1/32")
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
|
||||
})
|
||||
|
||||
t.Run("Updating routes groups", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
newUpdateMessage2.NetworkMap.Routes[0].Groups = []string{"randomGroup1"}
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
})
|
||||
|
||||
t.Run("Updating network map peers", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
newPeer := &nbpeer.Peer{
|
||||
IP: net.ParseIP("192.168.1.4"),
|
||||
SSHEnabled: true,
|
||||
Key: "peer4-key",
|
||||
DNSLabel: "peer4",
|
||||
SSHKey: "peer4-ssh-key",
|
||||
}
|
||||
newUpdateMessage2.NetworkMap.Peers = append(newUpdateMessage2.NetworkMap.Peers, newPeer)
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
})
|
||||
|
||||
t.Run("Updating process check", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, message)
|
||||
|
||||
newUpdateMessage3 := createMockUpdateMessage(t)
|
||||
newUpdateMessage3.Update.Checks = []*proto.Checks{}
|
||||
newUpdateMessage3.Update.NetworkMap.Serial++
|
||||
message, err = isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage3)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
|
||||
newUpdateMessage4 := createMockUpdateMessage(t)
|
||||
check := &posture.Checks{
|
||||
Checks: posture.ChecksDefinition{
|
||||
ProcessCheck: &posture.ProcessCheck{
|
||||
Processes: []posture.Process{
|
||||
{
|
||||
LinuxPath: "/usr/local/netbird",
|
||||
MacPath: "/usr/bin/netbird",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
newUpdateMessage4.Update.Checks = []*proto.Checks{toProtocolCheck(check)}
|
||||
newUpdateMessage4.Update.NetworkMap.Serial++
|
||||
message, err = isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage4)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
|
||||
newUpdateMessage5 := createMockUpdateMessage(t)
|
||||
check = &posture.Checks{
|
||||
Checks: posture.ChecksDefinition{
|
||||
ProcessCheck: &posture.ProcessCheck{
|
||||
Processes: []posture.Process{
|
||||
{
|
||||
LinuxPath: "/usr/bin/netbird",
|
||||
WindowsPath: "C:\\Program Files\\netbird\\netbird.exe",
|
||||
MacPath: "/usr/local/netbird",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
newUpdateMessage5.Update.Checks = []*proto.Checks{toProtocolCheck(check)}
|
||||
newUpdateMessage5.Update.NetworkMap.Serial++
|
||||
message, err = isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage5)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
})
|
||||
|
||||
t.Run("Updating DNS configuration", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
newDomain := "newexample.com"
|
||||
newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].Domains = append(
|
||||
newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].Domains,
|
||||
newDomain,
|
||||
)
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
})
|
||||
|
||||
t.Run("Updating peer IP", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
newUpdateMessage2.NetworkMap.Peers[0].IP = net.ParseIP("192.168.1.10")
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
})
|
||||
|
||||
t.Run("Updating firewall rule", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
newUpdateMessage2.NetworkMap.FirewallRules[0].Port = "443"
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
})
|
||||
|
||||
t.Run("Add new firewall rule", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
newRule := &FirewallRule{
|
||||
PeerIP: "192.168.1.3",
|
||||
Direction: firewallRuleDirectionOUT,
|
||||
Action: string(PolicyTrafficActionDrop),
|
||||
Protocol: string(PolicyRuleProtocolUDP),
|
||||
Port: "53",
|
||||
}
|
||||
newUpdateMessage2.NetworkMap.FirewallRules = append(newUpdateMessage2.NetworkMap.FirewallRules, newRule)
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
})
|
||||
|
||||
t.Run("Removing nameserver", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].NameServers = make([]nbdns.NameServer, 0)
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
})
|
||||
|
||||
t.Run("Updating name server IP", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].NameServers[0].IP = netip.MustParseAddr("8.8.4.4")
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
})
|
||||
|
||||
t.Run("Updating custom DNS zone", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
newUpdateMessage2.NetworkMap.DNSConfig.CustomZones[0].Records[0].RData = "100.64.0.2"
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func createMockUpdateMessage(t *testing.T) *UpdateMessage {
|
||||
t.Helper()
|
||||
|
||||
_, ipNet, err := net.ParseCIDR("192.168.1.0/24")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
domainList, err := domain.FromStringList([]string{"example.com"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config := &Config{
|
||||
Signal: &Host{
|
||||
Proto: "https",
|
||||
URI: "signal.uri",
|
||||
Username: "",
|
||||
Password: "",
|
||||
},
|
||||
Stuns: []*Host{{URI: "stun.uri", Proto: UDP}},
|
||||
TURNConfig: &TURNConfig{
|
||||
Turns: []*Host{{URI: "turn.uri", Proto: UDP, Username: "turn-user", Password: "turn-pass"}},
|
||||
},
|
||||
}
|
||||
peer := &nbpeer.Peer{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
SSHEnabled: true,
|
||||
Key: "peer-key",
|
||||
DNSLabel: "peer1",
|
||||
SSHKey: "peer1-ssh-key",
|
||||
}
|
||||
|
||||
secretManager := NewTimeBasedAuthSecretsManager(
|
||||
NewPeersUpdateManager(nil),
|
||||
&TURNConfig{
|
||||
TimeBasedCredentials: false,
|
||||
CredentialsTTL: util.Duration{
|
||||
Duration: defaultDuration,
|
||||
},
|
||||
Secret: "secret",
|
||||
Turns: []*Host{TurnTestHost},
|
||||
},
|
||||
&Relay{
|
||||
Addresses: []string{"localhost:0"},
|
||||
CredentialsTTL: util.Duration{Duration: time.Hour},
|
||||
Secret: "secret",
|
||||
},
|
||||
)
|
||||
|
||||
networkMap := &NetworkMap{
|
||||
Network: &Network{Net: *ipNet, Serial: 1000},
|
||||
Peers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.2"), Key: "peer2-key", DNSLabel: "peer2", SSHEnabled: true, SSHKey: "peer2-ssh-key"}},
|
||||
OfflinePeers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.3"), Key: "peer3-key", DNSLabel: "peer3", SSHEnabled: true, SSHKey: "peer3-ssh-key"}},
|
||||
Routes: []*nbroute.Route{
|
||||
{
|
||||
ID: "route1",
|
||||
Network: netip.MustParsePrefix("10.0.0.0/24"),
|
||||
KeepRoute: true,
|
||||
NetID: "route1",
|
||||
Peer: "peer1",
|
||||
NetworkType: 1,
|
||||
Masquerade: true,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
Groups: []string{"test1", "test2"},
|
||||
},
|
||||
{
|
||||
ID: "route2",
|
||||
Domains: domainList,
|
||||
KeepRoute: true,
|
||||
NetID: "route2",
|
||||
Peer: "peer1",
|
||||
NetworkType: 1,
|
||||
Masquerade: true,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
Groups: []string{"test1", "test2"},
|
||||
},
|
||||
},
|
||||
DNSConfig: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
NameServers: []nbdns.NameServer{{
|
||||
IP: netip.MustParseAddr("8.8.8.8"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: nbdns.DefaultDNSPort,
|
||||
}},
|
||||
Primary: true,
|
||||
Domains: []string{"example.com"},
|
||||
Enabled: true,
|
||||
SearchDomainsEnabled: true,
|
||||
},
|
||||
{
|
||||
ID: "ns1",
|
||||
NameServers: []nbdns.NameServer{{
|
||||
IP: netip.MustParseAddr("1.1.1.1"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: nbdns.DefaultDNSPort,
|
||||
}},
|
||||
Groups: []string{"group1"},
|
||||
Primary: true,
|
||||
Domains: []string{"example.com"},
|
||||
Enabled: true,
|
||||
SearchDomainsEnabled: true,
|
||||
},
|
||||
},
|
||||
CustomZones: []nbdns.CustomZone{{Domain: "example.com", Records: []nbdns.SimpleRecord{{Name: "example.com", Type: 1, Class: "IN", TTL: 60, RData: "100.64.0.1"}}}},
|
||||
},
|
||||
FirewallRules: []*FirewallRule{
|
||||
{PeerIP: "192.168.1.2", Direction: firewallRuleDirectionIN, Action: string(PolicyTrafficActionAccept), Protocol: string(PolicyRuleProtocolTCP), Port: "80"},
|
||||
},
|
||||
}
|
||||
dnsName := "example.com"
|
||||
checks := []*posture.Checks{
|
||||
{
|
||||
Checks: posture.ChecksDefinition{
|
||||
ProcessCheck: &posture.ProcessCheck{
|
||||
Processes: []posture.Process{
|
||||
{
|
||||
LinuxPath: "/usr/bin/netbird",
|
||||
WindowsPath: "C:\\Program Files\\netbird\\netbird.exe",
|
||||
MacPath: "/usr/bin/netbird",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
dnsCache := &DNSConfigCache{}
|
||||
|
||||
turnToken, err := secretManager.GenerateTurnToken()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
relayToken, err := secretManager.GenerateRelayToken()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return &UpdateMessage{
|
||||
Update: toSyncResponse(context.Background(), config, peer, turnToken, relayToken, networkMap, dnsName, checks, dnsCache),
|
||||
NetworkMap: networkMap,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
|
||||
var (
|
||||
relayCleanupInterval = 60 * time.Second
|
||||
keepUnusedServerTime = 5 * time.Second
|
||||
|
||||
ErrRelayClientNotConnected = fmt.Errorf("relay client not connected")
|
||||
)
|
||||
@@ -27,10 +28,13 @@ type RelayTrack struct {
|
||||
sync.RWMutex
|
||||
relayClient *Client
|
||||
err error
|
||||
created time.Time
|
||||
}
|
||||
|
||||
func NewRelayTrack() *RelayTrack {
|
||||
return &RelayTrack{}
|
||||
return &RelayTrack{
|
||||
created: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
type OnServerCloseListener func()
|
||||
@@ -302,6 +306,18 @@ func (m *Manager) cleanUpUnusedRelays() {
|
||||
|
||||
for addr, rt := range m.relayClients {
|
||||
rt.Lock()
|
||||
// if the connection failed to the server the relay client will be nil
|
||||
// but the instance will be kept in the relayClients until the next locking
|
||||
if rt.err != nil {
|
||||
rt.Unlock()
|
||||
continue
|
||||
}
|
||||
|
||||
if time.Since(rt.created) <= keepUnusedServerTime {
|
||||
rt.Unlock()
|
||||
continue
|
||||
}
|
||||
|
||||
if rt.relayClient.HasConns() {
|
||||
rt.Unlock()
|
||||
continue
|
||||
|
||||
@@ -288,8 +288,9 @@ func TestForeginAutoClose(t *testing.T) {
|
||||
t.Fatalf("failed to close connection: %s", err)
|
||||
}
|
||||
|
||||
t.Logf("waiting for relay cleanup: %s", relayCleanupInterval+1*time.Second)
|
||||
time.Sleep(relayCleanupInterval + 1*time.Second)
|
||||
timeout := relayCleanupInterval + keepUnusedServerTime + 1*time.Second
|
||||
t.Logf("waiting for relay cleanup: %s", timeout)
|
||||
time.Sleep(timeout)
|
||||
if len(mgr.relayClients) != 0 {
|
||||
t.Errorf("expected 0, got %d", len(mgr.relayClients))
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestServerPicker_UnavailableServers(t *testing.T) {
|
||||
@@ -13,7 +12,7 @@ func TestServerPicker_UnavailableServers(t *testing.T) {
|
||||
PeerID: "test",
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), connectionTimeout+1)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
|
||||
Reference in New Issue
Block a user