mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-17 15:56:39 +00:00
Compare commits
9 Commits
test/recon
...
test/seed-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
581b70b7aa | ||
|
|
e0bed2b0fb | ||
|
|
30f025e7dd | ||
|
|
b4d7605147 | ||
|
|
08b6e9d647 | ||
|
|
67ce14eaea | ||
|
|
669904cd06 | ||
|
|
4be826450b | ||
|
|
738387f2de |
@@ -2,6 +2,7 @@ package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -13,10 +14,11 @@ import (
|
||||
)
|
||||
|
||||
type program struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
serv *grpc.Server
|
||||
serverInstance *server.Server
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
serv *grpc.Server
|
||||
serverInstance *server.Server
|
||||
serverInstanceMu sync.Mutex
|
||||
}
|
||||
|
||||
func newProgram(ctx context.Context, cancel context.CancelFunc) *program {
|
||||
|
||||
@@ -61,7 +61,9 @@ func (p *program) Start(svc service.Service) error {
|
||||
}
|
||||
proto.RegisterDaemonServiceServer(p.serv, serverInstance)
|
||||
|
||||
p.serverInstanceMu.Lock()
|
||||
p.serverInstance = serverInstance
|
||||
p.serverInstanceMu.Unlock()
|
||||
|
||||
log.Printf("started daemon server: %v", split[1])
|
||||
if err := p.serv.Serve(listen); err != nil {
|
||||
@@ -72,6 +74,7 @@ func (p *program) Start(svc service.Service) error {
|
||||
}
|
||||
|
||||
func (p *program) Stop(srv service.Service) error {
|
||||
p.serverInstanceMu.Lock()
|
||||
if p.serverInstance != nil {
|
||||
in := new(proto.DownRequest)
|
||||
_, err := p.serverInstance.Down(p.ctx, in)
|
||||
@@ -79,6 +82,7 @@ func (p *program) Stop(srv service.Service) error {
|
||||
log.Errorf("failed to stop daemon: %v", err)
|
||||
}
|
||||
}
|
||||
p.serverInstanceMu.Unlock()
|
||||
|
||||
p.cancel()
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/pion/transport/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
wgConn "golang.zx2c4.com/wireguard/conn"
|
||||
)
|
||||
|
||||
@@ -24,8 +25,8 @@ type receiverCreator struct {
|
||||
iceBind *ICEBind
|
||||
}
|
||||
|
||||
func (rc receiverCreator) CreateIPv4ReceiverFn(msgPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc {
|
||||
return rc.iceBind.createIPv4ReceiverFn(msgPool, pc, conn)
|
||||
func (rc receiverCreator) CreateIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc {
|
||||
return rc.iceBind.createIPv4ReceiverFn(pc, conn, rxOffload, msgPool)
|
||||
}
|
||||
|
||||
// ICEBind is a bind implementation with two main features:
|
||||
@@ -154,7 +155,7 @@ func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc {
|
||||
func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc {
|
||||
s.muUDPMux.Lock()
|
||||
defer s.muUDPMux.Unlock()
|
||||
|
||||
@@ -166,16 +167,30 @@ func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketC
|
||||
},
|
||||
)
|
||||
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
||||
msgs := ipv4MsgsPool.Get().(*[]ipv4.Message)
|
||||
defer ipv4MsgsPool.Put(msgs)
|
||||
msgs := getMessages(msgsPool)
|
||||
for i := range bufs {
|
||||
(*msgs)[i].Buffers[0] = bufs[i]
|
||||
(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
|
||||
}
|
||||
defer putMessages(msgs, msgsPool)
|
||||
var numMsgs int
|
||||
if runtime.GOOS == "linux" {
|
||||
numMsgs, err = pc.ReadBatch(*msgs, 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||
if rxOffload {
|
||||
readAt := len(*msgs) - (wgConn.IdealBatchSize / wgConn.UdpSegmentMaxDatagrams)
|
||||
//nolint
|
||||
numMsgs, err = pc.ReadBatch((*msgs)[readAt:], 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
numMsgs, err = wgConn.SplitCoalescedMessages(*msgs, readAt, wgConn.GetGSOSize)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
} else {
|
||||
numMsgs, err = pc.ReadBatch(*msgs, 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
msg := &(*msgs)[0]
|
||||
@@ -191,11 +206,12 @@ func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketC
|
||||
// todo: handle err
|
||||
ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr)
|
||||
if ok {
|
||||
sizes[i] = 0
|
||||
} else {
|
||||
sizes[i] = msg.N
|
||||
continue
|
||||
}
|
||||
sizes[i] = msg.N
|
||||
if sizes[i] == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
|
||||
ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
|
||||
wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep)
|
||||
@@ -273,3 +289,15 @@ func fakeAddress(peerAddress *net.UDPAddr) (*net.UDPAddr, error) {
|
||||
}
|
||||
return newAddr, nil
|
||||
}
|
||||
|
||||
func getMessages(msgsPool *sync.Pool) *[]ipv6.Message {
|
||||
return msgsPool.Get().(*[]ipv6.Message)
|
||||
}
|
||||
|
||||
func putMessages(msgs *[]ipv6.Message, msgsPool *sync.Pool) {
|
||||
for i := range *msgs {
|
||||
(*msgs)[i].OOB = (*msgs)[i].OOB[:0]
|
||||
(*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB}
|
||||
}
|
||||
msgsPool.Put(msgs)
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package bind
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
@@ -94,7 +95,10 @@ func (p *ProxyBind) close() error {
|
||||
|
||||
p.Bind.RemoveEndpoint(p.wgAddr)
|
||||
|
||||
return p.remoteConn.Close()
|
||||
if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) {
|
||||
return rErr
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *ProxyBind) proxyToLocal(ctx context.Context) {
|
||||
|
||||
@@ -77,7 +77,7 @@ func (e *ProxyWrapper) CloseConn() error {
|
||||
|
||||
e.cancel()
|
||||
|
||||
if err := e.remoteConn.Close(); err != nil {
|
||||
if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
return fmt.Errorf("failed to close remote conn: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -116,7 +116,7 @@ func (p *WGUDPProxy) close() error {
|
||||
p.cancel()
|
||||
|
||||
var result *multierror.Error
|
||||
if err := p.remoteConn.Close(); err != nil {
|
||||
if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
result = multierror.Append(result, fmt.Errorf("remote conn: %s", err))
|
||||
}
|
||||
|
||||
|
||||
@@ -207,7 +207,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
|
||||
|
||||
c.statusRecorder.MarkSignalDisconnected(nil)
|
||||
defer func() {
|
||||
c.statusRecorder.MarkSignalDisconnected(state.err)
|
||||
_, err := state.Status()
|
||||
c.statusRecorder.MarkSignalDisconnected(err)
|
||||
}()
|
||||
|
||||
// with the global Wiretrustee config in hand connect (just a connection, no stream yet) Signal
|
||||
|
||||
@@ -442,7 +442,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
|
||||
|
||||
if conn.iceP2PIsActive() {
|
||||
conn.log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority)
|
||||
conn.wgProxyRelay = wgProxy
|
||||
conn.setRelayedProxy(wgProxy)
|
||||
conn.statusRelay.Set(StatusConnected)
|
||||
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
||||
return
|
||||
@@ -465,7 +465,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
|
||||
wgConfigWorkaround()
|
||||
conn.currentConnPriority = connPriorityRelay
|
||||
conn.statusRelay.Set(StatusConnected)
|
||||
conn.wgProxyRelay = wgProxy
|
||||
conn.setRelayedProxy(wgProxy)
|
||||
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
||||
conn.log.Infof("start to communicate with peer via relay")
|
||||
conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr)
|
||||
@@ -736,6 +736,15 @@ func (conn *Conn) logTraceConnState() {
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) setRelayedProxy(proxy wgproxy.Proxy) {
|
||||
if conn.wgProxyRelay != nil {
|
||||
if err := conn.wgProxyRelay.CloseConn(); err != nil {
|
||||
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
|
||||
}
|
||||
}
|
||||
conn.wgProxyRelay = proxy
|
||||
}
|
||||
|
||||
func isController(config ConnConfig) bool {
|
||||
return config.LocalKey > config.Key
|
||||
}
|
||||
|
||||
@@ -67,7 +67,7 @@ func (s *State) DeleteRoute(network string) {
|
||||
func (s *State) GetRoutes() map[string]struct{} {
|
||||
s.Mux.RLock()
|
||||
defer s.Mux.RUnlock()
|
||||
return s.routes
|
||||
return maps.Clone(s.routes)
|
||||
}
|
||||
|
||||
// LocalPeerState contains the latest state of the local peer
|
||||
@@ -237,10 +237,6 @@ func (d *Status) UpdatePeerState(receivedState State) error {
|
||||
peerState.IP = receivedState.IP
|
||||
}
|
||||
|
||||
if receivedState.GetRoutes() != nil {
|
||||
peerState.SetRoutes(receivedState.GetRoutes())
|
||||
}
|
||||
|
||||
skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState)
|
||||
|
||||
if receivedState.ConnStatus != peerState.ConnStatus {
|
||||
@@ -261,12 +257,40 @@ func (d *Status) UpdatePeerState(receivedState State) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
ch, found := d.changeNotify[receivedState.PubKey]
|
||||
if found && ch != nil {
|
||||
close(ch)
|
||||
d.changeNotify[receivedState.PubKey] = nil
|
||||
d.notifyPeerListChanged()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Status) AddPeerStateRoute(peer string, route string) error {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
peerState, ok := d.peers[peer]
|
||||
if !ok {
|
||||
return errors.New("peer doesn't exist")
|
||||
}
|
||||
|
||||
peerState.AddRoute(route)
|
||||
d.peers[peer] = peerState
|
||||
|
||||
// todo: consider to make sense of this notification or not
|
||||
d.notifyPeerListChanged()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Status) RemovePeerStateRoute(peer string, route string) error {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
peerState, ok := d.peers[peer]
|
||||
if !ok {
|
||||
return errors.New("peer doesn't exist")
|
||||
}
|
||||
|
||||
peerState.DeleteRoute(route)
|
||||
d.peers[peer] = peerState
|
||||
|
||||
// todo: consider to make sense of this notification or not
|
||||
d.notifyPeerListChanged()
|
||||
return nil
|
||||
}
|
||||
@@ -301,12 +325,7 @@ func (d *Status) UpdatePeerICEState(receivedState State) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
ch, found := d.changeNotify[receivedState.PubKey]
|
||||
if found && ch != nil {
|
||||
close(ch)
|
||||
d.changeNotify[receivedState.PubKey] = nil
|
||||
}
|
||||
|
||||
d.notifyPeerStateChangeListeners(receivedState.PubKey)
|
||||
d.notifyPeerListChanged()
|
||||
return nil
|
||||
}
|
||||
@@ -334,12 +353,7 @@ func (d *Status) UpdatePeerRelayedState(receivedState State) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
ch, found := d.changeNotify[receivedState.PubKey]
|
||||
if found && ch != nil {
|
||||
close(ch)
|
||||
d.changeNotify[receivedState.PubKey] = nil
|
||||
}
|
||||
|
||||
d.notifyPeerStateChangeListeners(receivedState.PubKey)
|
||||
d.notifyPeerListChanged()
|
||||
return nil
|
||||
}
|
||||
@@ -366,12 +380,7 @@ func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error
|
||||
return nil
|
||||
}
|
||||
|
||||
ch, found := d.changeNotify[receivedState.PubKey]
|
||||
if found && ch != nil {
|
||||
close(ch)
|
||||
d.changeNotify[receivedState.PubKey] = nil
|
||||
}
|
||||
|
||||
d.notifyPeerStateChangeListeners(receivedState.PubKey)
|
||||
d.notifyPeerListChanged()
|
||||
return nil
|
||||
}
|
||||
@@ -401,12 +410,7 @@ func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
ch, found := d.changeNotify[receivedState.PubKey]
|
||||
if found && ch != nil {
|
||||
close(ch)
|
||||
d.changeNotify[receivedState.PubKey] = nil
|
||||
}
|
||||
|
||||
d.notifyPeerStateChangeListeners(receivedState.PubKey)
|
||||
d.notifyPeerListChanged()
|
||||
return nil
|
||||
}
|
||||
@@ -477,11 +481,14 @@ func (d *Status) FinishPeerListModifications() {
|
||||
func (d *Status) GetPeerStateChangeNotifier(peer string) <-chan struct{} {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
ch, found := d.changeNotify[peer]
|
||||
if !found || ch == nil {
|
||||
ch = make(chan struct{})
|
||||
d.changeNotify[peer] = ch
|
||||
if found {
|
||||
return ch
|
||||
}
|
||||
|
||||
ch = make(chan struct{})
|
||||
d.changeNotify[peer] = ch
|
||||
return ch
|
||||
}
|
||||
|
||||
@@ -755,6 +762,17 @@ func (d *Status) onConnectionChanged() {
|
||||
d.notifier.updateServerStates(d.managementState, d.signalState)
|
||||
}
|
||||
|
||||
// notifyPeerStateChangeListeners notifies route manager about the change in peer state
|
||||
func (d *Status) notifyPeerStateChangeListeners(peerID string) {
|
||||
ch, found := d.changeNotify[peerID]
|
||||
if !found {
|
||||
return
|
||||
}
|
||||
|
||||
close(ch)
|
||||
delete(d.changeNotify, peerID)
|
||||
}
|
||||
|
||||
func (d *Status) notifyPeerListChanged() {
|
||||
d.notifier.peerListChanged(d.numOfPeers())
|
||||
}
|
||||
|
||||
@@ -93,7 +93,7 @@ func TestGetPeerStateChangeNotifierLogic(t *testing.T) {
|
||||
|
||||
peerState.IP = ip
|
||||
|
||||
err := status.UpdatePeerState(peerState)
|
||||
err := status.UpdatePeerRelayedStateToDisconnected(peerState)
|
||||
assert.NoError(t, err, "shouldn't return error")
|
||||
|
||||
select {
|
||||
|
||||
@@ -57,6 +57,9 @@ type WorkerICE struct {
|
||||
|
||||
localUfrag string
|
||||
localPwd string
|
||||
|
||||
// we record the last known state of the ICE agent to avoid duplicate on disconnected events
|
||||
lastKnownState ice.ConnectionState
|
||||
}
|
||||
|
||||
func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool, callBacks WorkerICECallbacks) (*WorkerICE, error) {
|
||||
@@ -194,8 +197,7 @@ func (w *WorkerICE) Close() {
|
||||
return
|
||||
}
|
||||
|
||||
err := w.agent.Close()
|
||||
if err != nil {
|
||||
if err := w.agent.Close(); err != nil {
|
||||
w.log.Warnf("failed to close ICE agent: %s", err)
|
||||
}
|
||||
}
|
||||
@@ -215,15 +217,18 @@ func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, candidates []i
|
||||
|
||||
err = agent.OnConnectionStateChange(func(state ice.ConnectionState) {
|
||||
w.log.Debugf("ICE ConnectionState has changed to %s", state.String())
|
||||
if state == ice.ConnectionStateFailed || state == ice.ConnectionStateDisconnected {
|
||||
w.conn.OnStatusChanged(StatusDisconnected)
|
||||
|
||||
w.muxAgent.Lock()
|
||||
agentCancel()
|
||||
_ = agent.Close()
|
||||
w.agent = nil
|
||||
|
||||
w.muxAgent.Unlock()
|
||||
switch state {
|
||||
case ice.ConnectionStateConnected:
|
||||
w.lastKnownState = ice.ConnectionStateConnected
|
||||
return
|
||||
case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected:
|
||||
if w.lastKnownState != ice.ConnectionStateDisconnected {
|
||||
w.lastKnownState = ice.ConnectionStateDisconnected
|
||||
w.conn.OnStatusChanged(StatusDisconnected)
|
||||
}
|
||||
w.closeAgent(agentCancel)
|
||||
default:
|
||||
return
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
@@ -249,6 +254,17 @@ func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, candidates []i
|
||||
return agent, nil
|
||||
}
|
||||
|
||||
func (w *WorkerICE) closeAgent(cancel context.CancelFunc) {
|
||||
w.muxAgent.Lock()
|
||||
defer w.muxAgent.Unlock()
|
||||
|
||||
cancel()
|
||||
if err := w.agent.Close(); err != nil {
|
||||
w.log.Warnf("failed to close ICE agent: %s", err)
|
||||
}
|
||||
w.agent = nil
|
||||
}
|
||||
|
||||
func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
|
||||
// wait local endpoint configuration
|
||||
time.Sleep(time.Second)
|
||||
|
||||
@@ -122,13 +122,20 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]
|
||||
tempScore = float64(metricDiff) * 10
|
||||
}
|
||||
|
||||
// in some temporal cases, latency can be 0, so we set it to 1s to not block but try to avoid this route
|
||||
latency := time.Second
|
||||
// in some temporal cases, latency can be 0, so we set it to 999ms to not block but try to avoid this route
|
||||
latency := 999 * time.Millisecond
|
||||
if peerStatus.latency != 0 {
|
||||
latency = peerStatus.latency
|
||||
} else {
|
||||
log.Warnf("peer %s has 0 latency", r.Peer)
|
||||
log.Tracef("peer %s has 0 latency, range %s", r.Peer, c.handler)
|
||||
}
|
||||
|
||||
// avoid negative tempScore on the higher latency calculation
|
||||
if latency > 1*time.Second {
|
||||
latency = 999 * time.Millisecond
|
||||
}
|
||||
|
||||
// higher latency is worse score
|
||||
tempScore += 1 - latency.Seconds()
|
||||
|
||||
if !peerStatus.relayed {
|
||||
@@ -150,6 +157,8 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("chosen route: %s, chosen score: %f, current route: %s, current score: %f", chosen, chosenScore, currID, currScore)
|
||||
|
||||
switch {
|
||||
case chosen == "":
|
||||
var peers []string
|
||||
@@ -195,15 +204,20 @@ func (c *clientNetwork) watchPeerStatusChanges(ctx context.Context, peerKey stri
|
||||
func (c *clientNetwork) startPeersStatusChangeWatcher() {
|
||||
for _, r := range c.routes {
|
||||
_, found := c.routePeersNotifiers[r.Peer]
|
||||
if !found {
|
||||
c.routePeersNotifiers[r.Peer] = make(chan struct{})
|
||||
go c.watchPeerStatusChanges(c.ctx, r.Peer, c.peerStateUpdate, c.routePeersNotifiers[r.Peer])
|
||||
if found {
|
||||
continue
|
||||
}
|
||||
|
||||
closerChan := make(chan struct{})
|
||||
c.routePeersNotifiers[r.Peer] = closerChan
|
||||
go c.watchPeerStatusChanges(c.ctx, r.Peer, c.peerStateUpdate, closerChan)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *clientNetwork) removeRouteFromWireguardPeer() error {
|
||||
c.removeStateRoute()
|
||||
func (c *clientNetwork) removeRouteFromWireGuardPeer() error {
|
||||
if err := c.statusRecorder.RemovePeerStateRoute(c.currentChosen.Peer, c.handler.String()); err != nil {
|
||||
log.Warnf("Failed to update peer state: %v", err)
|
||||
}
|
||||
|
||||
if err := c.handler.RemoveAllowedIPs(); err != nil {
|
||||
return fmt.Errorf("remove allowed IPs: %w", err)
|
||||
@@ -218,7 +232,7 @@ func (c *clientNetwork) removeRouteFromPeerAndSystem() error {
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := c.removeRouteFromWireguardPeer(); err != nil {
|
||||
if err := c.removeRouteFromWireGuardPeer(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err))
|
||||
}
|
||||
if err := c.handler.RemoveRoute(); err != nil {
|
||||
@@ -257,7 +271,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
|
||||
}
|
||||
} else {
|
||||
// Otherwise, remove the allowed IPs from the previous peer first
|
||||
if err := c.removeRouteFromWireguardPeer(); err != nil {
|
||||
if err := c.removeRouteFromWireGuardPeer(); err != nil {
|
||||
return fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
|
||||
}
|
||||
}
|
||||
@@ -268,37 +282,13 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
|
||||
return fmt.Errorf("add allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
|
||||
}
|
||||
|
||||
c.addStateRoute()
|
||||
|
||||
err := c.statusRecorder.AddPeerStateRoute(c.currentChosen.Peer, c.handler.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("add peer state route: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *clientNetwork) addStateRoute() {
|
||||
state, err := c.statusRecorder.GetPeer(c.currentChosen.Peer)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get peer state: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
state.AddRoute(c.handler.String())
|
||||
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
|
||||
log.Warnf("Failed to update peer state: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *clientNetwork) removeStateRoute() {
|
||||
state, err := c.statusRecorder.GetPeer(c.currentChosen.Peer)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get peer state: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
state.DeleteRoute(c.handler.String())
|
||||
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
|
||||
log.Warnf("Failed to update peer state: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) {
|
||||
go func() {
|
||||
c.routeUpdate <- update
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -287,6 +288,26 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
// fill the test data with random routes
|
||||
for _, tc := range testCases {
|
||||
for i := 0; i < 50; i++ {
|
||||
dummyRoute := &route.Route{
|
||||
ID: route.ID(fmt.Sprintf("dummy_p1_%d", i)),
|
||||
Metric: route.MinMetric,
|
||||
Peer: "peer1",
|
||||
}
|
||||
tc.existingRoutes[dummyRoute.ID] = dummyRoute
|
||||
}
|
||||
for i := 0; i < 50; i++ {
|
||||
dummyRoute := &route.Route{
|
||||
ID: route.ID(fmt.Sprintf("dummy_p2_%d", i)),
|
||||
Metric: route.MinMetric,
|
||||
Peer: "peer2",
|
||||
}
|
||||
tc.existingRoutes[dummyRoute.ID] = dummyRoute
|
||||
}
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
currentRoute := &route.Route{
|
||||
|
||||
@@ -217,6 +217,11 @@ func (rm *Counter[Key, I, O]) Clear() {
|
||||
|
||||
// MarshalJSON implements the json.Marshaler interface for Counter.
|
||||
func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) {
|
||||
rm.refCountMu.Lock()
|
||||
defer rm.refCountMu.Unlock()
|
||||
rm.idMu.Lock()
|
||||
defer rm.idMu.Unlock()
|
||||
|
||||
return json.Marshal(struct {
|
||||
RefCountMap map[Key]Ref[O] `json:"refCountMap"`
|
||||
IDMap map[string][]Key `json:"idMap"`
|
||||
|
||||
2
go.mod
2
go.mod
@@ -236,7 +236,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2024
|
||||
|
||||
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
|
||||
|
||||
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed
|
||||
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20241107152827-57d8513b5f73
|
||||
|
||||
replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6
|
||||
|
||||
|
||||
4
go.sum
4
go.sum
@@ -527,8 +527,8 @@ github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9ax
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20241107152827-57d8513b5f73 h1:jayg97LH/jJlvpIHVxueTfa+tfQ+FY8fy2sIhCwkz0g=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20241107152827-57d8513b5f73/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
|
||||
github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM=
|
||||
github.com/nicksnyder/go-i18n/v2 v2.4.0/go.mod h1:nxYSZE9M0bf3Y70gPQjN9ha7XNHX7gMc814+6wVyEI4=
|
||||
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
|
||||
|
||||
@@ -982,6 +982,110 @@ func TestAccountManager_DeleteAccount(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkTest_GetAccountWithclaims(b *testing.B) {
|
||||
claims := jwtclaims.AuthorizationClaims{
|
||||
Domain: "example.com",
|
||||
UserId: "pvt-domain-user",
|
||||
DomainCategory: PrivateCategory,
|
||||
}
|
||||
|
||||
publicClaims := jwtclaims.AuthorizationClaims{
|
||||
Domain: "test.com",
|
||||
UserId: "public-domain-user",
|
||||
DomainCategory: PublicCategory,
|
||||
}
|
||||
|
||||
am, err := createManager(b)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
return
|
||||
}
|
||||
id, err := am.getAccountIDWithAuthorizationClaims(context.Background(), claims)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
pid, err := am.getAccountIDWithAuthorizationClaims(context.Background(), publicClaims)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
users := genUsers("priv", 100)
|
||||
|
||||
acc, err := am.Store.GetAccount(context.Background(), id)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
acc.Users = users
|
||||
|
||||
err = am.Store.SaveAccount(context.Background(), acc)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
userP := genUsers("pub", 100)
|
||||
|
||||
pacc, err := am.Store.GetAccount(context.Background(), pid)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
pacc.Users = userP
|
||||
|
||||
err = am.Store.SaveAccount(context.Background(), pacc)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
b.Run("public without account ID", func(b *testing.B) {
|
||||
//b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := am.getAccountIDWithAuthorizationClaims(context.Background(), publicClaims)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("private without account ID", func(b *testing.B) {
|
||||
//b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := am.getAccountIDWithAuthorizationClaims(context.Background(), claims)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("private with account ID", func(b *testing.B) {
|
||||
claims.AccountId = id
|
||||
//b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := am.getAccountIDWithAuthorizationClaims(context.Background(), claims)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func genUsers(p string, n int) map[string]*User {
|
||||
users := map[string]*User{}
|
||||
now := time.Now()
|
||||
for i := 0; i < n; i++ {
|
||||
users[fmt.Sprintf("%s-%d", p, i)] = &User{
|
||||
Id: fmt.Sprintf("%s-%d", p, i),
|
||||
Role: UserRoleAdmin,
|
||||
LastLogin: now,
|
||||
CreatedAt: now,
|
||||
Issued: "api",
|
||||
AutoGroups: []string{"one", "two", "three", "four", "five", "six", "seven", "eight", "nine", "ten"},
|
||||
}
|
||||
}
|
||||
return users
|
||||
}
|
||||
|
||||
func TestAccountManager_AddPeer(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
pb "github.com/golang/protobuf/proto" // nolint
|
||||
@@ -38,6 +39,7 @@ type GRPCServer struct {
|
||||
jwtClaimsExtractor *jwtclaims.ClaimsExtractor
|
||||
appMetrics telemetry.AppMetrics
|
||||
ephemeralManager *EphemeralManager
|
||||
peerLocks sync.Map
|
||||
}
|
||||
|
||||
// NewServer creates a new Management server
|
||||
@@ -148,6 +150,13 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
// nolint:staticcheck
|
||||
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
|
||||
|
||||
unlock := s.acquirePeerLockByUID(ctx, peerKey.String())
|
||||
defer func() {
|
||||
if unlock != nil {
|
||||
unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String())
|
||||
if err != nil {
|
||||
// nolint:staticcheck
|
||||
@@ -190,6 +199,9 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart))
|
||||
}
|
||||
|
||||
unlock()
|
||||
unlock = nil
|
||||
|
||||
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv)
|
||||
}
|
||||
|
||||
@@ -245,9 +257,12 @@ func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey w
|
||||
}
|
||||
|
||||
func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) {
|
||||
unlock := s.acquirePeerLockByUID(ctx, peer.Key)
|
||||
defer unlock()
|
||||
|
||||
_ = s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
|
||||
s.peersUpdateManager.CloseChannel(ctx, peer.ID)
|
||||
s.secretsManager.CancelRefresh(peer.ID)
|
||||
_ = s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
|
||||
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
|
||||
}
|
||||
|
||||
@@ -274,6 +289,24 @@ func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string
|
||||
return claims.UserId, nil
|
||||
}
|
||||
|
||||
func (s *GRPCServer) acquirePeerLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
|
||||
log.WithContext(ctx).Tracef("acquiring peer lock for ID %s", uniqueID)
|
||||
|
||||
start := time.Now()
|
||||
value, _ := s.peerLocks.LoadOrStore(uniqueID, &sync.RWMutex{})
|
||||
mtx := value.(*sync.RWMutex)
|
||||
mtx.Lock()
|
||||
log.WithContext(ctx).Tracef("acquired peer lock for ID %s in %v", uniqueID, time.Since(start))
|
||||
start = time.Now()
|
||||
|
||||
unlock = func() {
|
||||
mtx.Unlock()
|
||||
log.WithContext(ctx).Tracef("released peer lock for ID %s in %v", uniqueID, time.Since(start))
|
||||
}
|
||||
|
||||
return unlock
|
||||
}
|
||||
|
||||
// maps internal internalStatus.Error to gRPC status.Error
|
||||
func mapError(ctx context.Context, err error) error {
|
||||
if e, ok := internalStatus.FromError(err); ok {
|
||||
|
||||
@@ -149,7 +149,7 @@ func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) erro
|
||||
}
|
||||
|
||||
if req.Peer == nil && req.PeerGroups == nil {
|
||||
return status.Errorf(status.InvalidArgument, "either 'peer' or 'peers_group' should be provided")
|
||||
return status.Errorf(status.InvalidArgument, "either 'peer' or 'peer_groups' should be provided")
|
||||
}
|
||||
|
||||
if req.Peer != nil && req.PeerGroups != nil {
|
||||
|
||||
@@ -300,13 +300,11 @@ func (s *SqlStore) GetInstallationID() string {
|
||||
}
|
||||
|
||||
func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error {
|
||||
startTime := time.Now()
|
||||
|
||||
// To maintain data integrity, we create a copy of the peer's to prevent unintended updates to other fields.
|
||||
peerCopy := peer.Copy()
|
||||
peerCopy.AccountID = accountID
|
||||
|
||||
err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
err := s.db.Transaction(func(tx *gorm.DB) error {
|
||||
// check if peer exists before saving
|
||||
var peerID string
|
||||
result := tx.Model(&nbpeer.Peer{}).Select("id").Find(&peerID, accountAndIDQueryCondition, accountID, peer.ID)
|
||||
@@ -327,9 +325,6 @@ func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -337,8 +332,6 @@ func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.
|
||||
}
|
||||
|
||||
func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error {
|
||||
startTime := time.Now()
|
||||
|
||||
accountCopy := Account{
|
||||
Domain: domain,
|
||||
DomainCategory: category,
|
||||
@@ -346,14 +339,11 @@ func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID
|
||||
}
|
||||
|
||||
fieldsToUpdate := []string{"domain", "domain_category", "is_domain_primary_account"}
|
||||
result := s.db.WithContext(ctx).Model(&Account{}).
|
||||
result := s.db.Model(&Account{}).
|
||||
Select(fieldsToUpdate).
|
||||
Where(idQueryCondition, accountID).
|
||||
Updates(&accountCopy)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return result.Error
|
||||
}
|
||||
|
||||
@@ -365,8 +355,6 @@ func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID
|
||||
}
|
||||
|
||||
func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
|
||||
startTime := time.Now()
|
||||
|
||||
var peerCopy nbpeer.Peer
|
||||
peerCopy.Status = &peerStatus
|
||||
|
||||
@@ -379,9 +367,6 @@ func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.Pe
|
||||
Where(accountAndIDQueryCondition, accountID, peerID).
|
||||
Updates(&peerCopy)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return result.Error
|
||||
}
|
||||
|
||||
@@ -393,8 +378,6 @@ func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.Pe
|
||||
}
|
||||
|
||||
func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.Peer) error {
|
||||
startTime := time.Now()
|
||||
|
||||
// To maintain data integrity, we create a copy of the peer's location to prevent unintended updates to other fields.
|
||||
var peerCopy nbpeer.Peer
|
||||
// Since the location field has been migrated to JSON serialization,
|
||||
@@ -406,9 +389,6 @@ func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.P
|
||||
Updates(peerCopy)
|
||||
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return result.Error
|
||||
}
|
||||
|
||||
@@ -422,8 +402,6 @@ func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.P
|
||||
// SaveUsers saves the given list of users to the database.
|
||||
// It updates existing users if a conflict occurs.
|
||||
func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error {
|
||||
startTime := time.Now()
|
||||
|
||||
usersToSave := make([]User, 0, len(users))
|
||||
for _, user := range users {
|
||||
user.AccountID = accountID
|
||||
@@ -437,9 +415,6 @@ func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error {
|
||||
Clauses(clause.OnConflict{UpdateAll: true}).
|
||||
Create(&usersToSave).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return status.Errorf(status.Internal, "failed to save users to store: %v", err)
|
||||
}
|
||||
|
||||
@@ -448,13 +423,8 @@ func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error {
|
||||
|
||||
// SaveUser saves the given user to the database.
|
||||
func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error {
|
||||
startTime := time.Now()
|
||||
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(user)
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(user)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return status.Errorf(status.Internal, "failed to save user to store: %v", result.Error)
|
||||
}
|
||||
return nil
|
||||
@@ -462,17 +432,12 @@ func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, u
|
||||
|
||||
// SaveGroups saves the given list of groups to the database.
|
||||
func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error {
|
||||
startTime := time.Now()
|
||||
|
||||
if len(groups) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&groups)
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&groups)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return status.Errorf(status.Internal, "failed to save groups to store: %v", result.Error)
|
||||
}
|
||||
return nil
|
||||
@@ -499,10 +464,8 @@ func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string)
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var accountID string
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("id").
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("id").
|
||||
Where("domain = ? and is_domain_primary_account = ? and domain_category = ?",
|
||||
strings.ToLower(domain), true, PrivateCategory,
|
||||
).First(&accountID)
|
||||
@@ -510,9 +473,6 @@ func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return "", status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return "", status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
log.WithContext(ctx).Errorf("error when getting account from the store: %s", result.Error)
|
||||
return "", status.NewGetAccountFromStoreError(result.Error)
|
||||
}
|
||||
@@ -521,17 +481,12 @@ func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var key SetupKey
|
||||
result := s.db.WithContext(ctx).Select("account_id").First(&key, keyQueryCondition, setupKey)
|
||||
result := s.db.Select("account_id").First(&key, keyQueryCondition, setupKey)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return nil, status.NewSetupKeyNotFoundError(result.Error)
|
||||
}
|
||||
|
||||
@@ -543,17 +498,12 @@ func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken string) (string, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var token PersonalAccessToken
|
||||
result := s.db.First(&token, "hashed_token = ?", hashedToken)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return "", status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error)
|
||||
return "", status.NewGetAccountFromStoreError(result.Error)
|
||||
}
|
||||
@@ -562,17 +512,12 @@ func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken stri
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var token PersonalAccessToken
|
||||
result := s.db.First(&token, idQueryCondition, tokenID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error)
|
||||
return nil, status.NewGetAccountFromStoreError(result.Error)
|
||||
}
|
||||
@@ -596,18 +541,13 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User,
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var user User
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Preload(clause.Associations).First(&user, idQueryCondition, userID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewUserNotFoundError(userID)
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return nil, status.NewGetUserFromStoreError()
|
||||
}
|
||||
|
||||
@@ -615,17 +555,12 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var users []*User
|
||||
result := s.db.Find(&users, accountIDCondition, accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
log.WithContext(ctx).Errorf("error when getting users from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "issue getting users from store")
|
||||
}
|
||||
@@ -634,17 +569,12 @@ func (s *SqlStore) GetAccountUsers(ctx context.Context, accountID string) ([]*Us
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var groups []*nbgroup.Group
|
||||
result := s.db.Find(&groups, accountIDCondition, accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
log.WithContext(ctx).Errorf("error when getting groups from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "issue getting groups from store")
|
||||
}
|
||||
@@ -744,17 +674,12 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account,
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Account, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var user User
|
||||
result := s.db.WithContext(ctx).Select("account_id").First(&user, idQueryCondition, userID)
|
||||
result := s.db.Select("account_id").First(&user, idQueryCondition, userID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return nil, status.NewGetAccountFromStoreError(result.Error)
|
||||
}
|
||||
|
||||
@@ -766,17 +691,12 @@ func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Accoun
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var peer nbpeer.Peer
|
||||
result := s.db.WithContext(ctx).Select("account_id").First(&peer, idQueryCondition, peerID)
|
||||
result := s.db.Select("account_id").First(&peer, idQueryCondition, peerID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return nil, status.NewGetAccountFromStoreError(result.Error)
|
||||
}
|
||||
|
||||
@@ -788,17 +708,12 @@ func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Acco
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var peer nbpeer.Peer
|
||||
result := s.db.WithContext(ctx).Select("account_id").First(&peer, keyQueryCondition, peerKey)
|
||||
result := s.db.Select("account_id").First(&peer, keyQueryCondition, peerKey)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return nil, status.NewGetAccountFromStoreError(result.Error)
|
||||
}
|
||||
|
||||
@@ -810,18 +725,13 @@ func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var peer nbpeer.Peer
|
||||
var accountID string
|
||||
result := s.db.WithContext(ctx).Model(&peer).Select("account_id").Where(keyQueryCondition, peerKey).First(&accountID)
|
||||
result := s.db.Model(&peer).Select("account_id").Where(keyQueryCondition, peerKey).First(&accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return "", status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return "", status.NewGetAccountFromStoreError(result.Error)
|
||||
}
|
||||
|
||||
@@ -829,17 +739,12 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string)
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var accountID string
|
||||
result := s.db.Model(&User{}).Select("account_id").Where(idQueryCondition, userID).First(&accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return "", status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return "", status.NewGetAccountFromStoreError(result.Error)
|
||||
}
|
||||
|
||||
@@ -847,17 +752,12 @@ func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) (string, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var accountID string
|
||||
result := s.db.WithContext(ctx).Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, setupKey).First(&accountID)
|
||||
result := s.db.Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, setupKey).First(&accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return "", status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return "", status.NewSetupKeyNotFoundError(result.Error)
|
||||
}
|
||||
|
||||
@@ -869,21 +769,16 @@ func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string)
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var ipJSONStrings []string
|
||||
|
||||
// Fetch the IP addresses as JSON strings
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}).
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}).
|
||||
Where("account_id = ?", accountID).
|
||||
Pluck("ip", &ipJSONStrings)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "no peers found for the account")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return nil, status.Errorf(status.Internal, "issue getting IPs from store: %s", result.Error)
|
||||
}
|
||||
|
||||
@@ -901,10 +796,8 @@ func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var labels []string
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}).
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}).
|
||||
Where("account_id = ?", accountID).
|
||||
Pluck("dns_label", &labels)
|
||||
|
||||
@@ -912,9 +805,6 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "no peers found for the account")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
log.WithContext(ctx).Errorf("error when getting dns labels from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "issue getting dns labels from store: %s", result.Error)
|
||||
}
|
||||
@@ -923,33 +813,23 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*Network, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var accountNetwork AccountNetwork
|
||||
if err := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil {
|
||||
if err := s.db.Model(&Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewAccountNotFoundError(accountID)
|
||||
}
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return nil, status.Errorf(status.Internal, "issue getting network from store: %s", err)
|
||||
}
|
||||
return accountNetwork.Network, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var peer nbpeer.Peer
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).First(&peer, keyQueryCondition, peerKey)
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&peer, keyQueryCondition, peerKey)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "peer not found")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return nil, status.Errorf(status.Internal, "issue getting peer from store: %s", result.Error)
|
||||
}
|
||||
|
||||
@@ -957,16 +837,11 @@ func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength Locking
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var accountSettings AccountSettings
|
||||
if err := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil {
|
||||
if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "settings not found")
|
||||
}
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return nil, status.Errorf(status.Internal, "issue getting settings from store: %s", err)
|
||||
}
|
||||
return accountSettings.Settings, nil
|
||||
@@ -974,17 +849,13 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingS
|
||||
|
||||
// SaveUserLastLogin stores the last login time for a user in DB.
|
||||
func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error {
|
||||
startTime := time.Now()
|
||||
|
||||
var user User
|
||||
result := s.db.WithContext(ctx).First(&user, accountAndIDQueryCondition, accountID, userID)
|
||||
result := s.db.First(&user, accountAndIDQueryCondition, accountID, userID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return status.NewUserNotFoundError(userID)
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
|
||||
return status.NewGetUserFromStoreError()
|
||||
}
|
||||
user.LastLogin = lastLogin
|
||||
@@ -993,8 +864,6 @@ func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID stri
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
definitionJSON, err := json.Marshal(checks)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1003,9 +872,6 @@ func (s *SqlStore) GetPostureCheckByChecksDefinition(accountID string, checks *p
|
||||
var postureCheck posture.Checks
|
||||
err = s.db.Where("account_id = ? AND checks = ?", accountID, string(definitionJSON)).First(&postureCheck).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1115,27 +981,20 @@ func NewPostgresqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore,
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var setupKey SetupKey
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
First(&setupKey, keyQueryCondition, key)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "setup key not found")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return nil, status.NewSetupKeyNotFoundError(result.Error)
|
||||
}
|
||||
return &setupKey, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error {
|
||||
startTime := time.Now()
|
||||
|
||||
result := s.db.WithContext(ctx).Model(&SetupKey{}).
|
||||
result := s.db.Model(&SetupKey{}).
|
||||
Where(idQueryCondition, setupKeyID).
|
||||
Updates(map[string]interface{}{
|
||||
"used_times": gorm.Expr("used_times + 1"),
|
||||
@@ -1143,9 +1002,6 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string
|
||||
})
|
||||
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return status.Errorf(status.Internal, "issue incrementing setup key usage count: %s", result.Error)
|
||||
}
|
||||
|
||||
@@ -1157,17 +1013,12 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string
|
||||
}
|
||||
|
||||
func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
|
||||
startTime := time.Now()
|
||||
|
||||
var group nbgroup.Group
|
||||
result := s.db.WithContext(ctx).Where("account_id = ? AND name = ?", accountID, "All").First(&group)
|
||||
result := s.db.Where("account_id = ? AND name = ?", accountID, "All").First(&group)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return status.Errorf(status.NotFound, "group 'All' not found for account")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return status.Errorf(status.Internal, "issue finding group 'All': %s", result.Error)
|
||||
}
|
||||
|
||||
@@ -1180,9 +1031,6 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer
|
||||
group.Peers = append(group.Peers, peerID)
|
||||
|
||||
if err := s.db.Save(&group).Error; err != nil {
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return status.Errorf(status.Internal, "issue updating group 'All': %s", err)
|
||||
}
|
||||
|
||||
@@ -1190,17 +1038,13 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer
|
||||
}
|
||||
|
||||
func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error {
|
||||
startTime := time.Now()
|
||||
|
||||
var group nbgroup.Group
|
||||
result := s.db.WithContext(ctx).Where(accountAndIDQueryCondition, accountId, groupID).First(&group)
|
||||
result := s.db.Where(accountAndIDQueryCondition, accountId, groupID).First(&group)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return status.Errorf(status.NotFound, "group not found for account")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
|
||||
return status.Errorf(status.Internal, "issue finding group: %s", result.Error)
|
||||
}
|
||||
|
||||
@@ -1213,9 +1057,6 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId
|
||||
group.Peers = append(group.Peers, peerId)
|
||||
|
||||
if err := s.db.Save(&group).Error; err != nil {
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return status.Errorf(status.Internal, "issue updating group: %s", err)
|
||||
}
|
||||
|
||||
@@ -1224,16 +1065,11 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId
|
||||
|
||||
// GetUserPeers retrieves peers for a user.
|
||||
func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) {
|
||||
return getRecords[*nbpeer.Peer](s.db.WithContext(ctx).Where("user_id = ?", userID), lockStrength, accountID)
|
||||
return getRecords[*nbpeer.Peer](s.db.Where("user_id = ?", userID), lockStrength, accountID)
|
||||
}
|
||||
|
||||
func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
|
||||
startTime := time.Now()
|
||||
|
||||
if err := s.db.WithContext(ctx).Create(peer).Error; err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
if err := s.db.Create(peer).Error; err != nil {
|
||||
return status.Errorf(status.Internal, "issue adding peer to account: %s", err)
|
||||
}
|
||||
|
||||
@@ -1241,20 +1077,15 @@ func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) erro
|
||||
}
|
||||
|
||||
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
|
||||
startTime := time.Now()
|
||||
|
||||
result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
|
||||
result := s.db.Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return status.Errorf(status.Internal, "issue incrementing network serial count: %s", result.Error)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(store Store) error) error {
|
||||
tx := s.db.WithContext(ctx).Begin()
|
||||
tx := s.db.Begin()
|
||||
if tx.Error != nil {
|
||||
return tx.Error
|
||||
}
|
||||
@@ -1278,18 +1109,13 @@ func (s *SqlStore) GetDB() *gorm.DB {
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var accountDNSSettings AccountDNSSettings
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
|
||||
First(&accountDNSSettings, idQueryCondition, accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "dns settings not found")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return nil, status.Errorf(status.Internal, "failed to get dns settings from store: %v", result.Error)
|
||||
}
|
||||
return &accountDNSSettings.DNSSettings, nil
|
||||
@@ -1297,18 +1123,13 @@ func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength Locki
|
||||
|
||||
// AccountExists checks whether an account exists by the given ID.
|
||||
func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var accountID string
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
|
||||
Select("id").First(&accountID, idQueryCondition, id)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return false, nil
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return false, status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return false, result.Error
|
||||
}
|
||||
|
||||
@@ -1317,18 +1138,13 @@ func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStreng
|
||||
|
||||
// GetAccountDomainAndCategory retrieves the Domain and DomainCategory fields for an account based on the given accountID.
|
||||
func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var account Account
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("domain", "domain_category").
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("domain", "domain_category").
|
||||
Where(idQueryCondition, accountID).First(&account)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return "", "", status.Errorf(status.NotFound, "account not found")
|
||||
}
|
||||
if errors.Is(result.Error, context.Canceled) {
|
||||
return "", "", status.NewStoreContextCanceledError(time.Since(startTime))
|
||||
}
|
||||
return "", "", status.Errorf(status.Internal, "failed to get domain category from store: %v", result.Error)
|
||||
}
|
||||
|
||||
@@ -1337,7 +1153,7 @@ func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength
|
||||
|
||||
// GetGroupByID retrieves a group by ID and account ID.
|
||||
func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) {
|
||||
return getRecordByID[nbgroup.Group](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, groupID, accountID)
|
||||
return getRecordByID[nbgroup.Group](s.db.Preload(clause.Associations), lockStrength, groupID, accountID)
|
||||
}
|
||||
|
||||
// GetGroupByName retrieves a group by name and account ID.
|
||||
@@ -1346,7 +1162,7 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
|
||||
|
||||
// TODO: This fix is accepted for now, but if we need to handle this more frequently
|
||||
// we may need to reconsider changing the types.
|
||||
query := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations)
|
||||
query := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations)
|
||||
if s.storeEngine == PostgresStoreEngine {
|
||||
query = query.Order("json_array_length(peers::json) DESC")
|
||||
} else {
|
||||
@@ -1365,7 +1181,7 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
|
||||
|
||||
// SaveGroup saves a group to the store.
|
||||
func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error {
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group)
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group)
|
||||
if result.Error != nil {
|
||||
return status.Errorf(status.Internal, "failed to save group to store: %v", result.Error)
|
||||
}
|
||||
@@ -1374,56 +1190,56 @@ func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength,
|
||||
|
||||
// GetAccountPolicies retrieves policies for an account.
|
||||
func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) {
|
||||
return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID)
|
||||
return getRecords[*Policy](s.db.Preload(clause.Associations), lockStrength, accountID)
|
||||
}
|
||||
|
||||
// GetPolicyByID retrieves a policy by its ID and account ID.
|
||||
func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) {
|
||||
return getRecordByID[Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, policyID, accountID)
|
||||
return getRecordByID[Policy](s.db.Preload(clause.Associations), lockStrength, policyID, accountID)
|
||||
}
|
||||
|
||||
// GetAccountPostureChecks retrieves posture checks for an account.
|
||||
func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) {
|
||||
return getRecords[*posture.Checks](s.db.WithContext(ctx), lockStrength, accountID)
|
||||
return getRecords[*posture.Checks](s.db, lockStrength, accountID)
|
||||
}
|
||||
|
||||
// GetPostureChecksByID retrieves posture checks by their ID and account ID.
|
||||
func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) {
|
||||
return getRecordByID[posture.Checks](s.db.WithContext(ctx), lockStrength, postureCheckID, accountID)
|
||||
return getRecordByID[posture.Checks](s.db, lockStrength, postureCheckID, accountID)
|
||||
}
|
||||
|
||||
// GetAccountRoutes retrieves network routes for an account.
|
||||
func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) {
|
||||
return getRecords[*route.Route](s.db.WithContext(ctx), lockStrength, accountID)
|
||||
return getRecords[*route.Route](s.db, lockStrength, accountID)
|
||||
}
|
||||
|
||||
// GetRouteByID retrieves a route by its ID and account ID.
|
||||
func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) {
|
||||
return getRecordByID[route.Route](s.db.WithContext(ctx), lockStrength, routeID, accountID)
|
||||
return getRecordByID[route.Route](s.db, lockStrength, routeID, accountID)
|
||||
}
|
||||
|
||||
// GetAccountSetupKeys retrieves setup keys for an account.
|
||||
func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) {
|
||||
return getRecords[*SetupKey](s.db.WithContext(ctx), lockStrength, accountID)
|
||||
return getRecords[*SetupKey](s.db, lockStrength, accountID)
|
||||
}
|
||||
|
||||
// GetSetupKeyByID retrieves a setup key by its ID and account ID.
|
||||
func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error) {
|
||||
return getRecordByID[SetupKey](s.db.WithContext(ctx), lockStrength, setupKeyID, accountID)
|
||||
return getRecordByID[SetupKey](s.db, lockStrength, setupKeyID, accountID)
|
||||
}
|
||||
|
||||
// GetAccountNameServerGroups retrieves name server groups for an account.
|
||||
func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) {
|
||||
return getRecords[*nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, accountID)
|
||||
return getRecords[*nbdns.NameServerGroup](s.db, lockStrength, accountID)
|
||||
}
|
||||
|
||||
// GetNameServerGroupByID retrieves a name server group by its ID and account ID.
|
||||
func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nsGroupID string, accountID string) (*nbdns.NameServerGroup, error) {
|
||||
return getRecordByID[nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, nsGroupID, accountID)
|
||||
return getRecordByID[nbdns.NameServerGroup](s.db, lockStrength, nsGroupID, accountID)
|
||||
}
|
||||
|
||||
func (s *SqlStore) DeleteSetupKey(ctx context.Context, accountID, keyID string) error {
|
||||
return deleteRecordByID[SetupKey](s.db.WithContext(ctx), LockingStrengthUpdate, keyID, accountID)
|
||||
return deleteRecordByID[SetupKey](s.db, LockingStrengthUpdate, keyID, accountID)
|
||||
}
|
||||
|
||||
// getRecords retrieves records from the database based on the account ID.
|
||||
|
||||
@@ -3,7 +3,6 @@ package client
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -449,11 +448,11 @@ func (c *Client) writeTo(connReference *Conn, id string, dstID []byte, payload [
|
||||
conn, ok := c.conns[id]
|
||||
c.mu.Unlock()
|
||||
if !ok {
|
||||
return 0, io.EOF
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
|
||||
if conn.conn != connReference {
|
||||
return 0, io.EOF
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
|
||||
// todo: use buffer pool instead of create new transport msg.
|
||||
@@ -508,7 +507,7 @@ func (c *Client) closeConn(connReference *Conn, id string) error {
|
||||
|
||||
container, ok := c.conns[id]
|
||||
if !ok {
|
||||
return fmt.Errorf("connection already closed")
|
||||
return net.ErrClosed
|
||||
}
|
||||
|
||||
if container.conn != connReference {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
@@ -40,7 +39,7 @@ func (c *Conn) Write(p []byte) (n int, err error) {
|
||||
func (c *Conn) Read(b []byte) (n int, err error) {
|
||||
msg, ok := <-c.messageChan
|
||||
if !ok {
|
||||
return 0, io.EOF
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
|
||||
n = copy(b, msg.Payload)
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -100,7 +99,7 @@ func (c *Conn) isClosed() bool {
|
||||
|
||||
func (c *Conn) ioErrHandling(err error) error {
|
||||
if c.isClosed() {
|
||||
return io.EOF
|
||||
return net.ErrClosed
|
||||
}
|
||||
|
||||
var wErr *websocket.CloseError
|
||||
@@ -108,7 +107,7 @@ func (c *Conn) ioErrHandling(err error) error {
|
||||
return err
|
||||
}
|
||||
if wErr.Code == websocket.StatusNormalClosure {
|
||||
return io.EOF
|
||||
return net.ErrClosed
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -57,7 +57,7 @@ func (p *Peer) Work() {
|
||||
for {
|
||||
n, err := p.conn.Read(buf)
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
if !errors.Is(err, net.ErrClosed) {
|
||||
p.log.Errorf("failed to read message: %s", err)
|
||||
}
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user