Compare commits

...

9 Commits

Author SHA1 Message Date
Zoltan Papp
581b70b7aa Add random routes for the test 2024-11-11 15:01:23 +01:00
Viktor Liu
e0bed2b0fb [client] Fix race conditions (#2869)
* Fix concurrent map access in status

* Fix race when retrieving ctx state error

* Fix race when accessing service controller server instance
2024-11-11 14:55:10 +01:00
Zoltan Papp
30f025e7dd [client] fix/proxy close (#2873)
When the remote peer switches the Relay instance then must to close the proxy connection to the old instance.

It can cause issues when the remote peer switch connects to the Relay instance multiple times and then reconnects to an instance it had previously connected to.
2024-11-11 14:18:38 +01:00
Zoltan Papp
b4d7605147 [client] Remove loop after route calculation (#2856)
- ICE do not trigger disconnect callbacks if the stated did not change
- Fix route calculation callback loop
- Move route state updates into protected scope by mutex
- Do not calculate routes in case of peer.Open() and peer.Close()
2024-11-11 10:53:57 +01:00
Viktor Liu
08b6e9d647 [management] Fix api error message typo peers_group (#2862) 2024-11-08 23:28:02 +01:00
Pascal Fischer
67ce14eaea [management] Add peer lock to grpc server (#2859)
* add peer lock to grpc server

* remove sleep and put db update first

* don't export lock method
2024-11-08 18:47:22 +01:00
Pascal Fischer
669904cd06 [management] Remove context from database calls (#2863) 2024-11-08 15:49:00 +01:00
Zoltan Papp
4be826450b [client] Use offload in WireGuard bind receiver (#2815)
Improve the performance on Linux and Android in case of P2P connections
2024-11-07 17:28:38 +01:00
Maycon Santos
738387f2de Add benchmark tests to get account with claims (#2761)
* Add benchmark tests to get account with claims

* add users to account objects

* remove hardcoded env
2024-11-07 17:23:35 +01:00
24 changed files with 402 additions and 354 deletions

View File

@@ -2,6 +2,7 @@ package cmd
import (
"context"
"sync"
"github.com/kardianos/service"
log "github.com/sirupsen/logrus"
@@ -13,10 +14,11 @@ import (
)
type program struct {
ctx context.Context
cancel context.CancelFunc
serv *grpc.Server
serverInstance *server.Server
ctx context.Context
cancel context.CancelFunc
serv *grpc.Server
serverInstance *server.Server
serverInstanceMu sync.Mutex
}
func newProgram(ctx context.Context, cancel context.CancelFunc) *program {

View File

@@ -61,7 +61,9 @@ func (p *program) Start(svc service.Service) error {
}
proto.RegisterDaemonServiceServer(p.serv, serverInstance)
p.serverInstanceMu.Lock()
p.serverInstance = serverInstance
p.serverInstanceMu.Unlock()
log.Printf("started daemon server: %v", split[1])
if err := p.serv.Serve(listen); err != nil {
@@ -72,6 +74,7 @@ func (p *program) Start(svc service.Service) error {
}
func (p *program) Stop(srv service.Service) error {
p.serverInstanceMu.Lock()
if p.serverInstance != nil {
in := new(proto.DownRequest)
_, err := p.serverInstance.Down(p.ctx, in)
@@ -79,6 +82,7 @@ func (p *program) Stop(srv service.Service) error {
log.Errorf("failed to stop daemon: %v", err)
}
}
p.serverInstanceMu.Unlock()
p.cancel()

View File

@@ -12,6 +12,7 @@ import (
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
wgConn "golang.zx2c4.com/wireguard/conn"
)
@@ -24,8 +25,8 @@ type receiverCreator struct {
iceBind *ICEBind
}
func (rc receiverCreator) CreateIPv4ReceiverFn(msgPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc {
return rc.iceBind.createIPv4ReceiverFn(msgPool, pc, conn)
func (rc receiverCreator) CreateIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc {
return rc.iceBind.createIPv4ReceiverFn(pc, conn, rxOffload, msgPool)
}
// ICEBind is a bind implementation with two main features:
@@ -154,7 +155,7 @@ func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
return nil
}
func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc {
func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc {
s.muUDPMux.Lock()
defer s.muUDPMux.Unlock()
@@ -166,16 +167,30 @@ func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketC
},
)
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
msgs := ipv4MsgsPool.Get().(*[]ipv4.Message)
defer ipv4MsgsPool.Put(msgs)
msgs := getMessages(msgsPool)
for i := range bufs {
(*msgs)[i].Buffers[0] = bufs[i]
(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
}
defer putMessages(msgs, msgsPool)
var numMsgs int
if runtime.GOOS == "linux" {
numMsgs, err = pc.ReadBatch(*msgs, 0)
if err != nil {
return 0, err
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
if rxOffload {
readAt := len(*msgs) - (wgConn.IdealBatchSize / wgConn.UdpSegmentMaxDatagrams)
//nolint
numMsgs, err = pc.ReadBatch((*msgs)[readAt:], 0)
if err != nil {
return 0, err
}
numMsgs, err = wgConn.SplitCoalescedMessages(*msgs, readAt, wgConn.GetGSOSize)
if err != nil {
return 0, err
}
} else {
numMsgs, err = pc.ReadBatch(*msgs, 0)
if err != nil {
return 0, err
}
}
} else {
msg := &(*msgs)[0]
@@ -191,11 +206,12 @@ func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketC
// todo: handle err
ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr)
if ok {
sizes[i] = 0
} else {
sizes[i] = msg.N
continue
}
sizes[i] = msg.N
if sizes[i] == 0 {
continue
}
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep)
@@ -273,3 +289,15 @@ func fakeAddress(peerAddress *net.UDPAddr) (*net.UDPAddr, error) {
}
return newAddr, nil
}
func getMessages(msgsPool *sync.Pool) *[]ipv6.Message {
return msgsPool.Get().(*[]ipv6.Message)
}
func putMessages(msgs *[]ipv6.Message, msgsPool *sync.Pool) {
for i := range *msgs {
(*msgs)[i].OOB = (*msgs)[i].OOB[:0]
(*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB}
}
msgsPool.Put(msgs)
}

View File

@@ -2,6 +2,7 @@ package bind
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
@@ -94,7 +95,10 @@ func (p *ProxyBind) close() error {
p.Bind.RemoveEndpoint(p.wgAddr)
return p.remoteConn.Close()
if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) {
return rErr
}
return nil
}
func (p *ProxyBind) proxyToLocal(ctx context.Context) {

View File

@@ -77,7 +77,7 @@ func (e *ProxyWrapper) CloseConn() error {
e.cancel()
if err := e.remoteConn.Close(); err != nil {
if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
return fmt.Errorf("failed to close remote conn: %w", err)
}
return nil

View File

@@ -116,7 +116,7 @@ func (p *WGUDPProxy) close() error {
p.cancel()
var result *multierror.Error
if err := p.remoteConn.Close(); err != nil {
if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
result = multierror.Append(result, fmt.Errorf("remote conn: %s", err))
}

View File

@@ -207,7 +207,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
c.statusRecorder.MarkSignalDisconnected(nil)
defer func() {
c.statusRecorder.MarkSignalDisconnected(state.err)
_, err := state.Status()
c.statusRecorder.MarkSignalDisconnected(err)
}()
// with the global Wiretrustee config in hand connect (just a connection, no stream yet) Signal

View File

@@ -442,7 +442,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
if conn.iceP2PIsActive() {
conn.log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority)
conn.wgProxyRelay = wgProxy
conn.setRelayedProxy(wgProxy)
conn.statusRelay.Set(StatusConnected)
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
return
@@ -465,7 +465,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
wgConfigWorkaround()
conn.currentConnPriority = connPriorityRelay
conn.statusRelay.Set(StatusConnected)
conn.wgProxyRelay = wgProxy
conn.setRelayedProxy(wgProxy)
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
conn.log.Infof("start to communicate with peer via relay")
conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr)
@@ -736,6 +736,15 @@ func (conn *Conn) logTraceConnState() {
}
}
func (conn *Conn) setRelayedProxy(proxy wgproxy.Proxy) {
if conn.wgProxyRelay != nil {
if err := conn.wgProxyRelay.CloseConn(); err != nil {
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
}
}
conn.wgProxyRelay = proxy
}
func isController(config ConnConfig) bool {
return config.LocalKey > config.Key
}

View File

@@ -67,7 +67,7 @@ func (s *State) DeleteRoute(network string) {
func (s *State) GetRoutes() map[string]struct{} {
s.Mux.RLock()
defer s.Mux.RUnlock()
return s.routes
return maps.Clone(s.routes)
}
// LocalPeerState contains the latest state of the local peer
@@ -237,10 +237,6 @@ func (d *Status) UpdatePeerState(receivedState State) error {
peerState.IP = receivedState.IP
}
if receivedState.GetRoutes() != nil {
peerState.SetRoutes(receivedState.GetRoutes())
}
skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState)
if receivedState.ConnStatus != peerState.ConnStatus {
@@ -261,12 +257,40 @@ func (d *Status) UpdatePeerState(receivedState State) error {
return nil
}
ch, found := d.changeNotify[receivedState.PubKey]
if found && ch != nil {
close(ch)
d.changeNotify[receivedState.PubKey] = nil
d.notifyPeerListChanged()
return nil
}
func (d *Status) AddPeerStateRoute(peer string, route string) error {
d.mux.Lock()
defer d.mux.Unlock()
peerState, ok := d.peers[peer]
if !ok {
return errors.New("peer doesn't exist")
}
peerState.AddRoute(route)
d.peers[peer] = peerState
// todo: consider to make sense of this notification or not
d.notifyPeerListChanged()
return nil
}
func (d *Status) RemovePeerStateRoute(peer string, route string) error {
d.mux.Lock()
defer d.mux.Unlock()
peerState, ok := d.peers[peer]
if !ok {
return errors.New("peer doesn't exist")
}
peerState.DeleteRoute(route)
d.peers[peer] = peerState
// todo: consider to make sense of this notification or not
d.notifyPeerListChanged()
return nil
}
@@ -301,12 +325,7 @@ func (d *Status) UpdatePeerICEState(receivedState State) error {
return nil
}
ch, found := d.changeNotify[receivedState.PubKey]
if found && ch != nil {
close(ch)
d.changeNotify[receivedState.PubKey] = nil
}
d.notifyPeerStateChangeListeners(receivedState.PubKey)
d.notifyPeerListChanged()
return nil
}
@@ -334,12 +353,7 @@ func (d *Status) UpdatePeerRelayedState(receivedState State) error {
return nil
}
ch, found := d.changeNotify[receivedState.PubKey]
if found && ch != nil {
close(ch)
d.changeNotify[receivedState.PubKey] = nil
}
d.notifyPeerStateChangeListeners(receivedState.PubKey)
d.notifyPeerListChanged()
return nil
}
@@ -366,12 +380,7 @@ func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error
return nil
}
ch, found := d.changeNotify[receivedState.PubKey]
if found && ch != nil {
close(ch)
d.changeNotify[receivedState.PubKey] = nil
}
d.notifyPeerStateChangeListeners(receivedState.PubKey)
d.notifyPeerListChanged()
return nil
}
@@ -401,12 +410,7 @@ func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error {
return nil
}
ch, found := d.changeNotify[receivedState.PubKey]
if found && ch != nil {
close(ch)
d.changeNotify[receivedState.PubKey] = nil
}
d.notifyPeerStateChangeListeners(receivedState.PubKey)
d.notifyPeerListChanged()
return nil
}
@@ -477,11 +481,14 @@ func (d *Status) FinishPeerListModifications() {
func (d *Status) GetPeerStateChangeNotifier(peer string) <-chan struct{} {
d.mux.Lock()
defer d.mux.Unlock()
ch, found := d.changeNotify[peer]
if !found || ch == nil {
ch = make(chan struct{})
d.changeNotify[peer] = ch
if found {
return ch
}
ch = make(chan struct{})
d.changeNotify[peer] = ch
return ch
}
@@ -755,6 +762,17 @@ func (d *Status) onConnectionChanged() {
d.notifier.updateServerStates(d.managementState, d.signalState)
}
// notifyPeerStateChangeListeners notifies route manager about the change in peer state
func (d *Status) notifyPeerStateChangeListeners(peerID string) {
ch, found := d.changeNotify[peerID]
if !found {
return
}
close(ch)
delete(d.changeNotify, peerID)
}
func (d *Status) notifyPeerListChanged() {
d.notifier.peerListChanged(d.numOfPeers())
}

View File

@@ -93,7 +93,7 @@ func TestGetPeerStateChangeNotifierLogic(t *testing.T) {
peerState.IP = ip
err := status.UpdatePeerState(peerState)
err := status.UpdatePeerRelayedStateToDisconnected(peerState)
assert.NoError(t, err, "shouldn't return error")
select {

View File

@@ -57,6 +57,9 @@ type WorkerICE struct {
localUfrag string
localPwd string
// we record the last known state of the ICE agent to avoid duplicate on disconnected events
lastKnownState ice.ConnectionState
}
func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool, callBacks WorkerICECallbacks) (*WorkerICE, error) {
@@ -194,8 +197,7 @@ func (w *WorkerICE) Close() {
return
}
err := w.agent.Close()
if err != nil {
if err := w.agent.Close(); err != nil {
w.log.Warnf("failed to close ICE agent: %s", err)
}
}
@@ -215,15 +217,18 @@ func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, candidates []i
err = agent.OnConnectionStateChange(func(state ice.ConnectionState) {
w.log.Debugf("ICE ConnectionState has changed to %s", state.String())
if state == ice.ConnectionStateFailed || state == ice.ConnectionStateDisconnected {
w.conn.OnStatusChanged(StatusDisconnected)
w.muxAgent.Lock()
agentCancel()
_ = agent.Close()
w.agent = nil
w.muxAgent.Unlock()
switch state {
case ice.ConnectionStateConnected:
w.lastKnownState = ice.ConnectionStateConnected
return
case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected:
if w.lastKnownState != ice.ConnectionStateDisconnected {
w.lastKnownState = ice.ConnectionStateDisconnected
w.conn.OnStatusChanged(StatusDisconnected)
}
w.closeAgent(agentCancel)
default:
return
}
})
if err != nil {
@@ -249,6 +254,17 @@ func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, candidates []i
return agent, nil
}
func (w *WorkerICE) closeAgent(cancel context.CancelFunc) {
w.muxAgent.Lock()
defer w.muxAgent.Unlock()
cancel()
if err := w.agent.Close(); err != nil {
w.log.Warnf("failed to close ICE agent: %s", err)
}
w.agent = nil
}
func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
// wait local endpoint configuration
time.Sleep(time.Second)

View File

@@ -122,13 +122,20 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]
tempScore = float64(metricDiff) * 10
}
// in some temporal cases, latency can be 0, so we set it to 1s to not block but try to avoid this route
latency := time.Second
// in some temporal cases, latency can be 0, so we set it to 999ms to not block but try to avoid this route
latency := 999 * time.Millisecond
if peerStatus.latency != 0 {
latency = peerStatus.latency
} else {
log.Warnf("peer %s has 0 latency", r.Peer)
log.Tracef("peer %s has 0 latency, range %s", r.Peer, c.handler)
}
// avoid negative tempScore on the higher latency calculation
if latency > 1*time.Second {
latency = 999 * time.Millisecond
}
// higher latency is worse score
tempScore += 1 - latency.Seconds()
if !peerStatus.relayed {
@@ -150,6 +157,8 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]
}
}
log.Debugf("chosen route: %s, chosen score: %f, current route: %s, current score: %f", chosen, chosenScore, currID, currScore)
switch {
case chosen == "":
var peers []string
@@ -195,15 +204,20 @@ func (c *clientNetwork) watchPeerStatusChanges(ctx context.Context, peerKey stri
func (c *clientNetwork) startPeersStatusChangeWatcher() {
for _, r := range c.routes {
_, found := c.routePeersNotifiers[r.Peer]
if !found {
c.routePeersNotifiers[r.Peer] = make(chan struct{})
go c.watchPeerStatusChanges(c.ctx, r.Peer, c.peerStateUpdate, c.routePeersNotifiers[r.Peer])
if found {
continue
}
closerChan := make(chan struct{})
c.routePeersNotifiers[r.Peer] = closerChan
go c.watchPeerStatusChanges(c.ctx, r.Peer, c.peerStateUpdate, closerChan)
}
}
func (c *clientNetwork) removeRouteFromWireguardPeer() error {
c.removeStateRoute()
func (c *clientNetwork) removeRouteFromWireGuardPeer() error {
if err := c.statusRecorder.RemovePeerStateRoute(c.currentChosen.Peer, c.handler.String()); err != nil {
log.Warnf("Failed to update peer state: %v", err)
}
if err := c.handler.RemoveAllowedIPs(); err != nil {
return fmt.Errorf("remove allowed IPs: %w", err)
@@ -218,7 +232,7 @@ func (c *clientNetwork) removeRouteFromPeerAndSystem() error {
var merr *multierror.Error
if err := c.removeRouteFromWireguardPeer(); err != nil {
if err := c.removeRouteFromWireGuardPeer(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err))
}
if err := c.handler.RemoveRoute(); err != nil {
@@ -257,7 +271,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
}
} else {
// Otherwise, remove the allowed IPs from the previous peer first
if err := c.removeRouteFromWireguardPeer(); err != nil {
if err := c.removeRouteFromWireGuardPeer(); err != nil {
return fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
}
}
@@ -268,37 +282,13 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
return fmt.Errorf("add allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
}
c.addStateRoute()
err := c.statusRecorder.AddPeerStateRoute(c.currentChosen.Peer, c.handler.String())
if err != nil {
return fmt.Errorf("add peer state route: %w", err)
}
return nil
}
func (c *clientNetwork) addStateRoute() {
state, err := c.statusRecorder.GetPeer(c.currentChosen.Peer)
if err != nil {
log.Errorf("Failed to get peer state: %v", err)
return
}
state.AddRoute(c.handler.String())
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
log.Warnf("Failed to update peer state: %v", err)
}
}
func (c *clientNetwork) removeStateRoute() {
state, err := c.statusRecorder.GetPeer(c.currentChosen.Peer)
if err != nil {
log.Errorf("Failed to get peer state: %v", err)
return
}
state.DeleteRoute(c.handler.String())
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
log.Warnf("Failed to update peer state: %v", err)
}
}
func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) {
go func() {
c.routeUpdate <- update

View File

@@ -1,6 +1,7 @@
package routemanager
import (
"fmt"
"net/netip"
"testing"
"time"
@@ -287,6 +288,26 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
},
}
// fill the test data with random routes
for _, tc := range testCases {
for i := 0; i < 50; i++ {
dummyRoute := &route.Route{
ID: route.ID(fmt.Sprintf("dummy_p1_%d", i)),
Metric: route.MinMetric,
Peer: "peer1",
}
tc.existingRoutes[dummyRoute.ID] = dummyRoute
}
for i := 0; i < 50; i++ {
dummyRoute := &route.Route{
ID: route.ID(fmt.Sprintf("dummy_p2_%d", i)),
Metric: route.MinMetric,
Peer: "peer2",
}
tc.existingRoutes[dummyRoute.ID] = dummyRoute
}
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
currentRoute := &route.Route{

View File

@@ -217,6 +217,11 @@ func (rm *Counter[Key, I, O]) Clear() {
// MarshalJSON implements the json.Marshaler interface for Counter.
func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
rm.idMu.Lock()
defer rm.idMu.Unlock()
return json.Marshal(struct {
RefCountMap map[Key]Ref[O] `json:"refCountMap"`
IDMap map[string][]Key `json:"idMap"`

2
go.mod
View File

@@ -236,7 +236,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2024
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20241107152827-57d8513b5f73
replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6

4
go.sum
View File

@@ -527,8 +527,8 @@ github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9ax
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs=
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
github.com/netbirdio/wireguard-go v0.0.0-20241107152827-57d8513b5f73 h1:jayg97LH/jJlvpIHVxueTfa+tfQ+FY8fy2sIhCwkz0g=
github.com/netbirdio/wireguard-go v0.0.0-20241107152827-57d8513b5f73/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM=
github.com/nicksnyder/go-i18n/v2 v2.4.0/go.mod h1:nxYSZE9M0bf3Y70gPQjN9ha7XNHX7gMc814+6wVyEI4=
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=

View File

@@ -982,6 +982,110 @@ func TestAccountManager_DeleteAccount(t *testing.T) {
}
}
func BenchmarkTest_GetAccountWithclaims(b *testing.B) {
claims := jwtclaims.AuthorizationClaims{
Domain: "example.com",
UserId: "pvt-domain-user",
DomainCategory: PrivateCategory,
}
publicClaims := jwtclaims.AuthorizationClaims{
Domain: "test.com",
UserId: "public-domain-user",
DomainCategory: PublicCategory,
}
am, err := createManager(b)
if err != nil {
b.Fatal(err)
return
}
id, err := am.getAccountIDWithAuthorizationClaims(context.Background(), claims)
if err != nil {
b.Fatal(err)
}
pid, err := am.getAccountIDWithAuthorizationClaims(context.Background(), publicClaims)
if err != nil {
b.Fatal(err)
}
users := genUsers("priv", 100)
acc, err := am.Store.GetAccount(context.Background(), id)
if err != nil {
b.Fatal(err)
}
acc.Users = users
err = am.Store.SaveAccount(context.Background(), acc)
if err != nil {
b.Fatal(err)
}
userP := genUsers("pub", 100)
pacc, err := am.Store.GetAccount(context.Background(), pid)
if err != nil {
b.Fatal(err)
}
pacc.Users = userP
err = am.Store.SaveAccount(context.Background(), pacc)
if err != nil {
b.Fatal(err)
}
b.Run("public without account ID", func(b *testing.B) {
//b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := am.getAccountIDWithAuthorizationClaims(context.Background(), publicClaims)
if err != nil {
b.Fatal(err)
}
}
})
b.Run("private without account ID", func(b *testing.B) {
//b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := am.getAccountIDWithAuthorizationClaims(context.Background(), claims)
if err != nil {
b.Fatal(err)
}
}
})
b.Run("private with account ID", func(b *testing.B) {
claims.AccountId = id
//b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := am.getAccountIDWithAuthorizationClaims(context.Background(), claims)
if err != nil {
b.Fatal(err)
}
}
})
}
func genUsers(p string, n int) map[string]*User {
users := map[string]*User{}
now := time.Now()
for i := 0; i < n; i++ {
users[fmt.Sprintf("%s-%d", p, i)] = &User{
Id: fmt.Sprintf("%s-%d", p, i),
Role: UserRoleAdmin,
LastLogin: now,
CreatedAt: now,
Issued: "api",
AutoGroups: []string{"one", "two", "three", "four", "five", "six", "seven", "eight", "nine", "ten"},
}
}
return users
}
func TestAccountManager_AddPeer(t *testing.T) {
manager, err := createManager(t)
if err != nil {

View File

@@ -6,6 +6,7 @@ import (
"net"
"net/netip"
"strings"
"sync"
"time"
pb "github.com/golang/protobuf/proto" // nolint
@@ -38,6 +39,7 @@ type GRPCServer struct {
jwtClaimsExtractor *jwtclaims.ClaimsExtractor
appMetrics telemetry.AppMetrics
ephemeralManager *EphemeralManager
peerLocks sync.Map
}
// NewServer creates a new Management server
@@ -148,6 +150,13 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
// nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
unlock := s.acquirePeerLockByUID(ctx, peerKey.String())
defer func() {
if unlock != nil {
unlock()
}
}()
accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String())
if err != nil {
// nolint:staticcheck
@@ -190,6 +199,9 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart))
}
unlock()
unlock = nil
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv)
}
@@ -245,9 +257,12 @@ func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey w
}
func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) {
unlock := s.acquirePeerLockByUID(ctx, peer.Key)
defer unlock()
_ = s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
s.peersUpdateManager.CloseChannel(ctx, peer.ID)
s.secretsManager.CancelRefresh(peer.ID)
_ = s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
}
@@ -274,6 +289,24 @@ func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string
return claims.UserId, nil
}
func (s *GRPCServer) acquirePeerLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
log.WithContext(ctx).Tracef("acquiring peer lock for ID %s", uniqueID)
start := time.Now()
value, _ := s.peerLocks.LoadOrStore(uniqueID, &sync.RWMutex{})
mtx := value.(*sync.RWMutex)
mtx.Lock()
log.WithContext(ctx).Tracef("acquired peer lock for ID %s in %v", uniqueID, time.Since(start))
start = time.Now()
unlock = func() {
mtx.Unlock()
log.WithContext(ctx).Tracef("released peer lock for ID %s in %v", uniqueID, time.Since(start))
}
return unlock
}
// maps internal internalStatus.Error to gRPC status.Error
func mapError(ctx context.Context, err error) error {
if e, ok := internalStatus.FromError(err); ok {

View File

@@ -149,7 +149,7 @@ func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) erro
}
if req.Peer == nil && req.PeerGroups == nil {
return status.Errorf(status.InvalidArgument, "either 'peer' or 'peers_group' should be provided")
return status.Errorf(status.InvalidArgument, "either 'peer' or 'peer_groups' should be provided")
}
if req.Peer != nil && req.PeerGroups != nil {

View File

@@ -300,13 +300,11 @@ func (s *SqlStore) GetInstallationID() string {
}
func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error {
startTime := time.Now()
// To maintain data integrity, we create a copy of the peer's to prevent unintended updates to other fields.
peerCopy := peer.Copy()
peerCopy.AccountID = accountID
err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
err := s.db.Transaction(func(tx *gorm.DB) error {
// check if peer exists before saving
var peerID string
result := tx.Model(&nbpeer.Peer{}).Select("id").Find(&peerID, accountAndIDQueryCondition, accountID, peer.ID)
@@ -327,9 +325,6 @@ func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.
})
if err != nil {
if errors.Is(err, context.Canceled) {
return status.NewStoreContextCanceledError(time.Since(startTime))
}
return err
}
@@ -337,8 +332,6 @@ func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.
}
func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error {
startTime := time.Now()
accountCopy := Account{
Domain: domain,
DomainCategory: category,
@@ -346,14 +339,11 @@ func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID
}
fieldsToUpdate := []string{"domain", "domain_category", "is_domain_primary_account"}
result := s.db.WithContext(ctx).Model(&Account{}).
result := s.db.Model(&Account{}).
Select(fieldsToUpdate).
Where(idQueryCondition, accountID).
Updates(&accountCopy)
if result.Error != nil {
if errors.Is(result.Error, context.Canceled) {
return status.NewStoreContextCanceledError(time.Since(startTime))
}
return result.Error
}
@@ -365,8 +355,6 @@ func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID
}
func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
startTime := time.Now()
var peerCopy nbpeer.Peer
peerCopy.Status = &peerStatus
@@ -379,9 +367,6 @@ func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.Pe
Where(accountAndIDQueryCondition, accountID, peerID).
Updates(&peerCopy)
if result.Error != nil {
if errors.Is(result.Error, context.Canceled) {
return status.NewStoreContextCanceledError(time.Since(startTime))
}
return result.Error
}
@@ -393,8 +378,6 @@ func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.Pe
}
func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.Peer) error {
startTime := time.Now()
// To maintain data integrity, we create a copy of the peer's location to prevent unintended updates to other fields.
var peerCopy nbpeer.Peer
// Since the location field has been migrated to JSON serialization,
@@ -406,9 +389,6 @@ func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.P
Updates(peerCopy)
if result.Error != nil {
if errors.Is(result.Error, context.Canceled) {
return status.NewStoreContextCanceledError(time.Since(startTime))
}
return result.Error
}
@@ -422,8 +402,6 @@ func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.P
// SaveUsers saves the given list of users to the database.
// It updates existing users if a conflict occurs.
func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error {
startTime := time.Now()
usersToSave := make([]User, 0, len(users))
for _, user := range users {
user.AccountID = accountID
@@ -437,9 +415,6 @@ func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error {
Clauses(clause.OnConflict{UpdateAll: true}).
Create(&usersToSave).Error
if err != nil {
if errors.Is(err, context.Canceled) {
return status.NewStoreContextCanceledError(time.Since(startTime))
}
return status.Errorf(status.Internal, "failed to save users to store: %v", err)
}
@@ -448,13 +423,8 @@ func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error {
// SaveUser saves the given user to the database.
func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error {
startTime := time.Now()
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(user)
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(user)
if result.Error != nil {
if errors.Is(result.Error, context.Canceled) {
return status.NewStoreContextCanceledError(time.Since(startTime))
}
return status.Errorf(status.Internal, "failed to save user to store: %v", result.Error)
}
return nil
@@ -462,17 +432,12 @@ func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, u
// SaveGroups saves the given list of groups to the database.
func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error {
startTime := time.Now()
if len(groups) == 0 {
return nil
}
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&groups)
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&groups)
if result.Error != nil {
if errors.Is(result.Error, context.Canceled) {
return status.NewStoreContextCanceledError(time.Since(startTime))
}
return status.Errorf(status.Internal, "failed to save groups to store: %v", result.Error)
}
return nil
@@ -499,10 +464,8 @@ func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string)
}
func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) {
startTime := time.Now()
var accountID string
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("id").
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("id").
Where("domain = ? and is_domain_primary_account = ? and domain_category = ?",
strings.ToLower(domain), true, PrivateCategory,
).First(&accountID)
@@ -510,9 +473,6 @@ func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private")
}
if errors.Is(result.Error, context.Canceled) {
return "", status.NewStoreContextCanceledError(time.Since(startTime))
}
log.WithContext(ctx).Errorf("error when getting account from the store: %s", result.Error)
return "", status.NewGetAccountFromStoreError(result.Error)
}
@@ -521,17 +481,12 @@ func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength
}
func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) {
startTime := time.Now()
var key SetupKey
result := s.db.WithContext(ctx).Select("account_id").First(&key, keyQueryCondition, setupKey)
result := s.db.Select("account_id").First(&key, keyQueryCondition, setupKey)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
if errors.Is(result.Error, context.Canceled) {
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
}
return nil, status.NewSetupKeyNotFoundError(result.Error)
}
@@ -543,17 +498,12 @@ func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*
}
func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken string) (string, error) {
startTime := time.Now()
var token PersonalAccessToken
result := s.db.First(&token, "hashed_token = ?", hashedToken)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
if errors.Is(result.Error, context.Canceled) {
return "", status.NewStoreContextCanceledError(time.Since(startTime))
}
log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error)
return "", status.NewGetAccountFromStoreError(result.Error)
}
@@ -562,17 +512,12 @@ func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken stri
}
func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) {
startTime := time.Now()
var token PersonalAccessToken
result := s.db.First(&token, idQueryCondition, tokenID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
if errors.Is(result.Error, context.Canceled) {
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
}
log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error)
return nil, status.NewGetAccountFromStoreError(result.Error)
}
@@ -596,18 +541,13 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User,
}
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) {
startTime := time.Now()
var user User
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Preload(clause.Associations).First(&user, idQueryCondition, userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewUserNotFoundError(userID)
}
if errors.Is(result.Error, context.Canceled) {
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
}
return nil, status.NewGetUserFromStoreError()
}
@@ -615,17 +555,12 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
}
func (s *SqlStore) GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) {
startTime := time.Now()
var users []*User
result := s.db.Find(&users, accountIDCondition, accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
}
if errors.Is(result.Error, context.Canceled) {
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
}
log.WithContext(ctx).Errorf("error when getting users from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting users from store")
}
@@ -634,17 +569,12 @@ func (s *SqlStore) GetAccountUsers(ctx context.Context, accountID string) ([]*Us
}
func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
startTime := time.Now()
var groups []*nbgroup.Group
result := s.db.Find(&groups, accountIDCondition, accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
}
if errors.Is(result.Error, context.Canceled) {
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
}
log.WithContext(ctx).Errorf("error when getting groups from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting groups from store")
}
@@ -744,17 +674,12 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account,
}
func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Account, error) {
startTime := time.Now()
var user User
result := s.db.WithContext(ctx).Select("account_id").First(&user, idQueryCondition, userID)
result := s.db.Select("account_id").First(&user, idQueryCondition, userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
if errors.Is(result.Error, context.Canceled) {
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
}
return nil, status.NewGetAccountFromStoreError(result.Error)
}
@@ -766,17 +691,12 @@ func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Accoun
}
func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) {
startTime := time.Now()
var peer nbpeer.Peer
result := s.db.WithContext(ctx).Select("account_id").First(&peer, idQueryCondition, peerID)
result := s.db.Select("account_id").First(&peer, idQueryCondition, peerID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
if errors.Is(result.Error, context.Canceled) {
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
}
return nil, status.NewGetAccountFromStoreError(result.Error)
}
@@ -788,17 +708,12 @@ func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Acco
}
func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) {
startTime := time.Now()
var peer nbpeer.Peer
result := s.db.WithContext(ctx).Select("account_id").First(&peer, keyQueryCondition, peerKey)
result := s.db.Select("account_id").First(&peer, keyQueryCondition, peerKey)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
if errors.Is(result.Error, context.Canceled) {
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
}
return nil, status.NewGetAccountFromStoreError(result.Error)
}
@@ -810,18 +725,13 @@ func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (
}
func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) {
startTime := time.Now()
var peer nbpeer.Peer
var accountID string
result := s.db.WithContext(ctx).Model(&peer).Select("account_id").Where(keyQueryCondition, peerKey).First(&accountID)
result := s.db.Model(&peer).Select("account_id").Where(keyQueryCondition, peerKey).First(&accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
if errors.Is(result.Error, context.Canceled) {
return "", status.NewStoreContextCanceledError(time.Since(startTime))
}
return "", status.NewGetAccountFromStoreError(result.Error)
}
@@ -829,17 +739,12 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string)
}
func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
startTime := time.Now()
var accountID string
result := s.db.Model(&User{}).Select("account_id").Where(idQueryCondition, userID).First(&accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
if errors.Is(result.Error, context.Canceled) {
return "", status.NewStoreContextCanceledError(time.Since(startTime))
}
return "", status.NewGetAccountFromStoreError(result.Error)
}
@@ -847,17 +752,12 @@ func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
}
func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) (string, error) {
startTime := time.Now()
var accountID string
result := s.db.WithContext(ctx).Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, setupKey).First(&accountID)
result := s.db.Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, setupKey).First(&accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
if errors.Is(result.Error, context.Canceled) {
return "", status.NewStoreContextCanceledError(time.Since(startTime))
}
return "", status.NewSetupKeyNotFoundError(result.Error)
}
@@ -869,21 +769,16 @@ func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string)
}
func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) {
startTime := time.Now()
var ipJSONStrings []string
// Fetch the IP addresses as JSON strings
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}).
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}).
Where("account_id = ?", accountID).
Pluck("ip", &ipJSONStrings)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "no peers found for the account")
}
if errors.Is(result.Error, context.Canceled) {
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
}
return nil, status.Errorf(status.Internal, "issue getting IPs from store: %s", result.Error)
}
@@ -901,10 +796,8 @@ func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength
}
func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) {
startTime := time.Now()
var labels []string
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}).
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}).
Where("account_id = ?", accountID).
Pluck("dns_label", &labels)
@@ -912,9 +805,6 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "no peers found for the account")
}
if errors.Is(result.Error, context.Canceled) {
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
}
log.WithContext(ctx).Errorf("error when getting dns labels from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting dns labels from store: %s", result.Error)
}
@@ -923,33 +813,23 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock
}
func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*Network, error) {
startTime := time.Now()
var accountNetwork AccountNetwork
if err := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil {
if err := s.db.Model(&Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewAccountNotFoundError(accountID)
}
if errors.Is(err, context.Canceled) {
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
}
return nil, status.Errorf(status.Internal, "issue getting network from store: %s", err)
}
return accountNetwork.Network, nil
}
func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) {
startTime := time.Now()
var peer nbpeer.Peer
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).First(&peer, keyQueryCondition, peerKey)
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&peer, keyQueryCondition, peerKey)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "peer not found")
}
if errors.Is(result.Error, context.Canceled) {
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
}
return nil, status.Errorf(status.Internal, "issue getting peer from store: %s", result.Error)
}
@@ -957,16 +837,11 @@ func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength Locking
}
func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) {
startTime := time.Now()
var accountSettings AccountSettings
if err := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil {
if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "settings not found")
}
if errors.Is(err, context.Canceled) {
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
}
return nil, status.Errorf(status.Internal, "issue getting settings from store: %s", err)
}
return accountSettings.Settings, nil
@@ -974,17 +849,13 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingS
// SaveUserLastLogin stores the last login time for a user in DB.
func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error {
startTime := time.Now()
var user User
result := s.db.WithContext(ctx).First(&user, accountAndIDQueryCondition, accountID, userID)
result := s.db.First(&user, accountAndIDQueryCondition, accountID, userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.NewUserNotFoundError(userID)
}
if errors.Is(result.Error, context.Canceled) {
return status.NewStoreContextCanceledError(time.Since(startTime))
}
return status.NewGetUserFromStoreError()
}
user.LastLogin = lastLogin
@@ -993,8 +864,6 @@ func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID stri
}
func (s *SqlStore) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {
startTime := time.Now()
definitionJSON, err := json.Marshal(checks)
if err != nil {
return nil, err
@@ -1003,9 +872,6 @@ func (s *SqlStore) GetPostureCheckByChecksDefinition(accountID string, checks *p
var postureCheck posture.Checks
err = s.db.Where("account_id = ? AND checks = ?", accountID, string(definitionJSON)).First(&postureCheck).Error
if err != nil {
if errors.Is(err, context.Canceled) {
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
}
return nil, err
}
@@ -1115,27 +981,20 @@ func NewPostgresqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore,
}
func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) {
startTime := time.Now()
var setupKey SetupKey
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&setupKey, keyQueryCondition, key)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "setup key not found")
}
if errors.Is(result.Error, context.Canceled) {
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
}
return nil, status.NewSetupKeyNotFoundError(result.Error)
}
return &setupKey, nil
}
func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error {
startTime := time.Now()
result := s.db.WithContext(ctx).Model(&SetupKey{}).
result := s.db.Model(&SetupKey{}).
Where(idQueryCondition, setupKeyID).
Updates(map[string]interface{}{
"used_times": gorm.Expr("used_times + 1"),
@@ -1143,9 +1002,6 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string
})
if result.Error != nil {
if errors.Is(result.Error, context.Canceled) {
return status.NewStoreContextCanceledError(time.Since(startTime))
}
return status.Errorf(status.Internal, "issue incrementing setup key usage count: %s", result.Error)
}
@@ -1157,17 +1013,12 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string
}
func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
startTime := time.Now()
var group nbgroup.Group
result := s.db.WithContext(ctx).Where("account_id = ? AND name = ?", accountID, "All").First(&group)
result := s.db.Where("account_id = ? AND name = ?", accountID, "All").First(&group)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "group 'All' not found for account")
}
if errors.Is(result.Error, context.Canceled) {
return status.NewStoreContextCanceledError(time.Since(startTime))
}
return status.Errorf(status.Internal, "issue finding group 'All': %s", result.Error)
}
@@ -1180,9 +1031,6 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer
group.Peers = append(group.Peers, peerID)
if err := s.db.Save(&group).Error; err != nil {
if errors.Is(result.Error, context.Canceled) {
return status.NewStoreContextCanceledError(time.Since(startTime))
}
return status.Errorf(status.Internal, "issue updating group 'All': %s", err)
}
@@ -1190,17 +1038,13 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer
}
func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error {
startTime := time.Now()
var group nbgroup.Group
result := s.db.WithContext(ctx).Where(accountAndIDQueryCondition, accountId, groupID).First(&group)
result := s.db.Where(accountAndIDQueryCondition, accountId, groupID).First(&group)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "group not found for account")
}
if errors.Is(result.Error, context.Canceled) {
return status.NewStoreContextCanceledError(time.Since(startTime))
}
return status.Errorf(status.Internal, "issue finding group: %s", result.Error)
}
@@ -1213,9 +1057,6 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId
group.Peers = append(group.Peers, peerId)
if err := s.db.Save(&group).Error; err != nil {
if errors.Is(result.Error, context.Canceled) {
return status.NewStoreContextCanceledError(time.Since(startTime))
}
return status.Errorf(status.Internal, "issue updating group: %s", err)
}
@@ -1224,16 +1065,11 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId
// GetUserPeers retrieves peers for a user.
func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) {
return getRecords[*nbpeer.Peer](s.db.WithContext(ctx).Where("user_id = ?", userID), lockStrength, accountID)
return getRecords[*nbpeer.Peer](s.db.Where("user_id = ?", userID), lockStrength, accountID)
}
func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
startTime := time.Now()
if err := s.db.WithContext(ctx).Create(peer).Error; err != nil {
if errors.Is(err, context.Canceled) {
return status.NewStoreContextCanceledError(time.Since(startTime))
}
if err := s.db.Create(peer).Error; err != nil {
return status.Errorf(status.Internal, "issue adding peer to account: %s", err)
}
@@ -1241,20 +1077,15 @@ func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) erro
}
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
startTime := time.Now()
result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
result := s.db.Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
if result.Error != nil {
if errors.Is(result.Error, context.Canceled) {
return status.NewStoreContextCanceledError(time.Since(startTime))
}
return status.Errorf(status.Internal, "issue incrementing network serial count: %s", result.Error)
}
return nil
}
func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(store Store) error) error {
tx := s.db.WithContext(ctx).Begin()
tx := s.db.Begin()
if tx.Error != nil {
return tx.Error
}
@@ -1278,18 +1109,13 @@ func (s *SqlStore) GetDB() *gorm.DB {
}
func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error) {
startTime := time.Now()
var accountDNSSettings AccountDNSSettings
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
First(&accountDNSSettings, idQueryCondition, accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "dns settings not found")
}
if errors.Is(result.Error, context.Canceled) {
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
}
return nil, status.Errorf(status.Internal, "failed to get dns settings from store: %v", result.Error)
}
return &accountDNSSettings.DNSSettings, nil
@@ -1297,18 +1123,13 @@ func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength Locki
// AccountExists checks whether an account exists by the given ID.
func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) {
startTime := time.Now()
var accountID string
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
Select("id").First(&accountID, idQueryCondition, id)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return false, nil
}
if errors.Is(result.Error, context.Canceled) {
return false, status.NewStoreContextCanceledError(time.Since(startTime))
}
return false, result.Error
}
@@ -1317,18 +1138,13 @@ func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStreng
// GetAccountDomainAndCategory retrieves the Domain and DomainCategory fields for an account based on the given accountID.
func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) {
startTime := time.Now()
var account Account
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("domain", "domain_category").
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("domain", "domain_category").
Where(idQueryCondition, accountID).First(&account)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", "", status.Errorf(status.NotFound, "account not found")
}
if errors.Is(result.Error, context.Canceled) {
return "", "", status.NewStoreContextCanceledError(time.Since(startTime))
}
return "", "", status.Errorf(status.Internal, "failed to get domain category from store: %v", result.Error)
}
@@ -1337,7 +1153,7 @@ func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength
// GetGroupByID retrieves a group by ID and account ID.
func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) {
return getRecordByID[nbgroup.Group](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, groupID, accountID)
return getRecordByID[nbgroup.Group](s.db.Preload(clause.Associations), lockStrength, groupID, accountID)
}
// GetGroupByName retrieves a group by name and account ID.
@@ -1346,7 +1162,7 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
// TODO: This fix is accepted for now, but if we need to handle this more frequently
// we may need to reconsider changing the types.
query := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations)
query := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations)
if s.storeEngine == PostgresStoreEngine {
query = query.Order("json_array_length(peers::json) DESC")
} else {
@@ -1365,7 +1181,7 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
// SaveGroup saves a group to the store.
func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error {
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group)
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group)
if result.Error != nil {
return status.Errorf(status.Internal, "failed to save group to store: %v", result.Error)
}
@@ -1374,56 +1190,56 @@ func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength,
// GetAccountPolicies retrieves policies for an account.
func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) {
return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID)
return getRecords[*Policy](s.db.Preload(clause.Associations), lockStrength, accountID)
}
// GetPolicyByID retrieves a policy by its ID and account ID.
func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) {
return getRecordByID[Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, policyID, accountID)
return getRecordByID[Policy](s.db.Preload(clause.Associations), lockStrength, policyID, accountID)
}
// GetAccountPostureChecks retrieves posture checks for an account.
func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) {
return getRecords[*posture.Checks](s.db.WithContext(ctx), lockStrength, accountID)
return getRecords[*posture.Checks](s.db, lockStrength, accountID)
}
// GetPostureChecksByID retrieves posture checks by their ID and account ID.
func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) {
return getRecordByID[posture.Checks](s.db.WithContext(ctx), lockStrength, postureCheckID, accountID)
return getRecordByID[posture.Checks](s.db, lockStrength, postureCheckID, accountID)
}
// GetAccountRoutes retrieves network routes for an account.
func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) {
return getRecords[*route.Route](s.db.WithContext(ctx), lockStrength, accountID)
return getRecords[*route.Route](s.db, lockStrength, accountID)
}
// GetRouteByID retrieves a route by its ID and account ID.
func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) {
return getRecordByID[route.Route](s.db.WithContext(ctx), lockStrength, routeID, accountID)
return getRecordByID[route.Route](s.db, lockStrength, routeID, accountID)
}
// GetAccountSetupKeys retrieves setup keys for an account.
func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) {
return getRecords[*SetupKey](s.db.WithContext(ctx), lockStrength, accountID)
return getRecords[*SetupKey](s.db, lockStrength, accountID)
}
// GetSetupKeyByID retrieves a setup key by its ID and account ID.
func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error) {
return getRecordByID[SetupKey](s.db.WithContext(ctx), lockStrength, setupKeyID, accountID)
return getRecordByID[SetupKey](s.db, lockStrength, setupKeyID, accountID)
}
// GetAccountNameServerGroups retrieves name server groups for an account.
func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) {
return getRecords[*nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, accountID)
return getRecords[*nbdns.NameServerGroup](s.db, lockStrength, accountID)
}
// GetNameServerGroupByID retrieves a name server group by its ID and account ID.
func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nsGroupID string, accountID string) (*nbdns.NameServerGroup, error) {
return getRecordByID[nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, nsGroupID, accountID)
return getRecordByID[nbdns.NameServerGroup](s.db, lockStrength, nsGroupID, accountID)
}
func (s *SqlStore) DeleteSetupKey(ctx context.Context, accountID, keyID string) error {
return deleteRecordByID[SetupKey](s.db.WithContext(ctx), LockingStrengthUpdate, keyID, accountID)
return deleteRecordByID[SetupKey](s.db, LockingStrengthUpdate, keyID, accountID)
}
// getRecords retrieves records from the database based on the account ID.

View File

@@ -3,7 +3,6 @@ package client
import (
"context"
"fmt"
"io"
"net"
"sync"
"time"
@@ -449,11 +448,11 @@ func (c *Client) writeTo(connReference *Conn, id string, dstID []byte, payload [
conn, ok := c.conns[id]
c.mu.Unlock()
if !ok {
return 0, io.EOF
return 0, net.ErrClosed
}
if conn.conn != connReference {
return 0, io.EOF
return 0, net.ErrClosed
}
// todo: use buffer pool instead of create new transport msg.
@@ -508,7 +507,7 @@ func (c *Client) closeConn(connReference *Conn, id string) error {
container, ok := c.conns[id]
if !ok {
return fmt.Errorf("connection already closed")
return net.ErrClosed
}
if container.conn != connReference {

View File

@@ -1,7 +1,6 @@
package client
import (
"io"
"net"
"time"
)
@@ -40,7 +39,7 @@ func (c *Conn) Write(p []byte) (n int, err error) {
func (c *Conn) Read(b []byte) (n int, err error) {
msg, ok := <-c.messageChan
if !ok {
return 0, io.EOF
return 0, net.ErrClosed
}
n = copy(b, msg.Payload)

View File

@@ -4,7 +4,6 @@ import (
"context"
"errors"
"fmt"
"io"
"net"
"sync"
"time"
@@ -100,7 +99,7 @@ func (c *Conn) isClosed() bool {
func (c *Conn) ioErrHandling(err error) error {
if c.isClosed() {
return io.EOF
return net.ErrClosed
}
var wErr *websocket.CloseError
@@ -108,7 +107,7 @@ func (c *Conn) ioErrHandling(err error) error {
return err
}
if wErr.Code == websocket.StatusNormalClosure {
return io.EOF
return net.ErrClosed
}
return err
}

View File

@@ -2,7 +2,7 @@ package server
import (
"context"
"io"
"errors"
"net"
"sync"
"time"
@@ -57,7 +57,7 @@ func (p *Peer) Work() {
for {
n, err := p.conn.Read(buf)
if err != nil {
if err != io.EOF {
if !errors.Is(err, net.ErrClosed) {
p.log.Errorf("failed to read message: %s", err)
}
return