Compare commits

...

15 Commits

Author SHA1 Message Date
Viktor Liu
417fa6e833 Change default interface name from wt0 to nb0 2025-07-21 17:58:56 +02:00
Bethuel Mmbaga
a7af15c4fc [management] Fix group resource count mismatch in policy (#4182) 2025-07-21 15:26:06 +03:00
Viktor Liu
d6ed9c037e [client] Fix bind exclusion routes (#4154) 2025-07-21 12:13:21 +02:00
Ali Amer
40fdeda838 [client] add new filter-by-connection-type flag (#4010)
introduces a new flag --filter-by-connection-type to the status command.
It allows users to filter peers by connection type (P2P or Relayed) in both JSON and detailed views.

Input validation is added in parseFilters() to ensure proper usage, and --detail is auto-enabled if no output format is specified (consistent with other filters).
2025-07-21 11:55:17 +02:00
Zoltan Papp
f6e9d755e4 [client, relay] The openConn function no longer blocks the relayAddress function call (#4180)
The openConn function no longer blocks the relayAddress function call in manager layer
2025-07-21 09:46:53 +02:00
Maycon Santos
08fd460867 [management] Add validate flow response (#4172)
This PR adds a validate flow response feature to the management server by integrating an IntegratedValidator component. The main purpose is to enable validation of PKCE authorization flows through an integrated validator interface.

- Adds a new ValidateFlowResponse method to the IntegratedValidator interface
- Integrates the validator into the management server to validate PKCE authorization flows
- Updates dependency version for management-integrations
2025-07-18 12:18:52 +02:00
Pascal Fischer
4f74509d55 [management] fix index creation if exist on mysql (#4150) 2025-07-16 15:07:31 +02:00
Maycon Santos
58185ced16 [misc] add forum post and update sign pipeline (#4155)
use old git-town version
2025-07-16 14:10:28 +02:00
Pedro Maia Costa
e67f44f47c [client] fix test (#4156) 2025-07-16 12:09:38 +02:00
Zoltan Papp
b524f486e2 [client] Fix/nil relayed address (#4153)
Fix nil pointer in Relay conn address

Meanwhile, we create a relayed net.Conn struct instance, it is possible to set the relayedURL to nil.

panic: value method github.com/netbirdio/netbird/relay/client.RelayAddr.String called using nil *RelayAddr pointer

Fix relayed URL variable protection
Protect the channel closing
2025-07-16 00:00:18 +02:00
Zoltan Papp
0dab03252c [client, relay-server] Feature/relay notification (#4083)
- Clients now subscribe to peer status changes.
- The server manages and maintains these subscriptions.
- Replaced raw string peer IDs with a custom peer ID type for better type safety and clarity.
2025-07-15 10:43:42 +02:00
iisteev
e49bcc343d [client] Avoid parsing NB_NETSTACK_SKIP_PROXY if empty (#4145)
Signed-off-by: iisteev <isteevan.shetoo@is-info.fr>
2025-07-13 15:42:48 +02:00
Zoltan Papp
3e6eede152 [client] Fix elapsed time calculation when machine is in sleep mode (#4140) 2025-07-12 11:10:45 +02:00
Vlad
a76c8eafb4 [management] sync calls to UpdateAccountPeers from BufferUpdateAccountPeers (#4137)
---------

Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>
Co-authored-by: Pedro Costa <550684+pnmcosta@users.noreply.github.com>
2025-07-11 12:37:14 +03:00
Pedro Maia Costa
2b9f331980 always suffix ephemeral peer name (#4138) 2025-07-11 10:29:10 +01:00
113 changed files with 2271 additions and 851 deletions

View File

@@ -16,6 +16,6 @@ jobs:
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: git-town/action@v1 - uses: git-town/action@v1.2.1
with: with:
skip-single-stacks: true skip-single-stacks: true

View File

@@ -9,7 +9,7 @@ on:
pull_request: pull_request:
env: env:
SIGN_PIPE_VER: "v0.0.20" SIGN_PIPE_VER: "v0.0.21"
GORELEASER_VER: "v2.3.2" GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird" PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH" COPYRIGHT: "NetBird GmbH"
@@ -231,3 +231,17 @@ jobs:
ref: ${{ env.SIGN_PIPE_VER }} ref: ${{ env.SIGN_PIPE_VER }}
token: ${{ secrets.SIGN_GITHUB_TOKEN }} token: ${{ secrets.SIGN_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref }}", "skipRelease": false }' inputs: '{ "tag": "${{ github.ref }}", "skipRelease": false }'
post_on_forum:
runs-on: ubuntu-latest
continue-on-error: true
needs: [trigger_signer]
steps:
- uses: Codixer/discourse-topic-github-release-action@v2.0.1
with:
discourse-api-key: ${{ secrets.DISCOURSE_RELEASES_API_KEY }}
discourse-base-url: https://forum.netbird.io
discourse-author-username: NetBird
discourse-category: 17
discourse-tags:
releases

View File

@@ -307,7 +307,7 @@ func getStatusOutput(cmd *cobra.Command, anon bool) string {
cmd.PrintErrf("Failed to get status: %v\n", err) cmd.PrintErrf("Failed to get status: %v\n", err)
} else { } else {
statusOutputString = nbstatus.ParseToFullDetailSummary( statusOutputString = nbstatus.ParseToFullDetailSummary(
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil), nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, ""),
) )
} }
return statusOutputString return statusOutputString

View File

@@ -26,6 +26,7 @@ var (
statusFilter string statusFilter string
ipsFilterMap map[string]struct{} ipsFilterMap map[string]struct{}
prefixNamesFilterMap map[string]struct{} prefixNamesFilterMap map[string]struct{}
connectionTypeFilter string
) )
var statusCmd = &cobra.Command{ var statusCmd = &cobra.Command{
@@ -45,6 +46,7 @@ func init() {
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200") statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200")
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud") statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected") statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
statusCmd.PersistentFlags().StringVar(&connectionTypeFilter, "filter-by-connection-type", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P")
} }
func statusFunc(cmd *cobra.Command, args []string) error { func statusFunc(cmd *cobra.Command, args []string) error {
@@ -89,7 +91,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return nil return nil
} }
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap) var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter)
var statusOutputString string var statusOutputString string
switch { switch {
case detailFlag: case detailFlag:
@@ -156,6 +158,15 @@ func parseFilters() error {
enableDetailFlagWhenFilterFlag() enableDetailFlagWhenFilterFlag()
} }
switch strings.ToLower(connectionTypeFilter) {
case "", "p2p", "relayed":
if strings.ToLower(connectionTypeFilter) != "" {
enableDetailFlagWhenFilterFlag()
}
default:
return fmt.Errorf("wrong connection-type filter, should be one of P2P|Relayed, got: %s", connectionTypeFilter)
}
return nil return nil
} }

View File

@@ -109,7 +109,7 @@ func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc
} }
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager) secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil) mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &mgmt.MockIntegratedValidator{})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -34,14 +34,14 @@ func NewActivityRecorder() *ActivityRecorder {
} }
// GetLastActivities returns a snapshot of peer last activity // GetLastActivities returns a snapshot of peer last activity
func (r *ActivityRecorder) GetLastActivities() map[string]time.Time { func (r *ActivityRecorder) GetLastActivities() map[string]monotime.Time {
r.mu.RLock() r.mu.RLock()
defer r.mu.RUnlock() defer r.mu.RUnlock()
activities := make(map[string]time.Time, len(r.peers)) activities := make(map[string]monotime.Time, len(r.peers))
for key, record := range r.peers { for key, record := range r.peers {
unixNano := record.LastActivity.Load() monoTime := record.LastActivity.Load()
activities[key] = time.Unix(0, unixNano) activities[key] = monotime.Time(monoTime)
} }
return activities return activities
} }
@@ -51,18 +51,20 @@ func (r *ActivityRecorder) UpsertAddress(publicKey string, address netip.AddrPor
r.mu.Lock() r.mu.Lock()
defer r.mu.Unlock() defer r.mu.Unlock()
if pr, exists := r.peers[publicKey]; exists { var record *PeerRecord
delete(r.addrToPeer, pr.Address) record, exists := r.peers[publicKey]
pr.Address = address if exists {
delete(r.addrToPeer, record.Address)
record.Address = address
} else { } else {
record := &PeerRecord{ record = &PeerRecord{
Address: address, Address: address,
} }
record.LastActivity.Store(monotime.Now()) record.LastActivity.Store(int64(monotime.Now()))
r.peers[publicKey] = record r.peers[publicKey] = record
} }
r.addrToPeer[address] = r.peers[publicKey] r.addrToPeer[address] = record
} }
func (r *ActivityRecorder) Remove(publicKey string) { func (r *ActivityRecorder) Remove(publicKey string) {
@@ -84,7 +86,7 @@ func (r *ActivityRecorder) record(address netip.AddrPort) {
return return
} }
now := monotime.Now() now := int64(monotime.Now())
last := record.LastActivity.Load() last := record.LastActivity.Load()
if now-last < saveFrequency { if now-last < saveFrequency {
return return

View File

@@ -4,6 +4,8 @@ import (
"net/netip" "net/netip"
"testing" "testing"
"time" "time"
"github.com/netbirdio/netbird/monotime"
) )
func TestActivityRecorder_GetLastActivities(t *testing.T) { func TestActivityRecorder_GetLastActivities(t *testing.T) {
@@ -17,11 +19,7 @@ func TestActivityRecorder_GetLastActivities(t *testing.T) {
t.Fatalf("Expected activity for peer %s, but got none", peer) t.Fatalf("Expected activity for peer %s, but got none", peer)
} }
if p.IsZero() { if monotime.Since(p) > 5*time.Second {
t.Fatalf("Expected activity for peer %s, but got zero", peer)
}
if p.Before(time.Now().Add(-2 * time.Minute)) {
t.Fatalf("Expected activity for peer %s to be recent, but got %v", peer, p) t.Fatalf("Expected activity for peer %s to be recent, but got %v", peer, p)
} }
} }

View File

@@ -0,0 +1,15 @@
package bind
import (
wireguard "golang.zx2c4.com/wireguard/conn"
nbnet "github.com/netbirdio/netbird/util/net"
)
// TODO: This is most likely obsolete since the control fns should be called by the wrapped udpconn (ice_bind.go)
func init() {
listener := nbnet.NewListener()
if listener.ListenConfig.Control != nil {
*wireguard.ControlFns = append(*wireguard.ControlFns, listener.ListenConfig.Control)
}
}

View File

@@ -1,12 +0,0 @@
package bind
import (
wireguard "golang.zx2c4.com/wireguard/conn"
nbnet "github.com/netbirdio/netbird/util/net"
)
func init() {
// ControlFns is not thread safe and should only be modified during init.
*wireguard.ControlFns = append(*wireguard.ControlFns, nbnet.ControlProtectSocket)
}

View File

@@ -16,6 +16,7 @@ import (
wgConn "golang.zx2c4.com/wireguard/conn" wgConn "golang.zx2c4.com/wireguard/conn"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
nbnet "github.com/netbirdio/netbird/util/net"
) )
type RecvMessage struct { type RecvMessage struct {
@@ -153,7 +154,7 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
s.udpMux = NewUniversalUDPMuxDefault( s.udpMux = NewUniversalUDPMuxDefault(
UniversalUDPMuxParams{ UniversalUDPMuxParams{
UDPConn: conn, UDPConn: nbnet.WrapUDPConn(conn),
Net: s.transportNet, Net: s.transportNet,
FilterFn: s.filterFn, FilterFn: s.filterFn,
WGAddress: s.address, WGAddress: s.address,

View File

@@ -296,14 +296,20 @@ func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
return return
} }
m.addressMapMu.Lock() var allAddresses []string
defer m.addressMapMu.Unlock()
for _, c := range removedConns { for _, c := range removedConns {
addresses := c.getAddresses() addresses := c.getAddresses()
for _, addr := range addresses { allAddresses = append(allAddresses, addresses...)
delete(m.addressMap, addr) }
}
m.addressMapMu.Lock()
for _, addr := range allAddresses {
delete(m.addressMap, addr)
}
m.addressMapMu.Unlock()
for _, addr := range allAddresses {
m.notifyAddressRemoval(addr)
} }
} }
@@ -351,14 +357,13 @@ func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string)
} }
m.addressMapMu.Lock() m.addressMapMu.Lock()
defer m.addressMapMu.Unlock()
existing, ok := m.addressMap[addr] existing, ok := m.addressMap[addr]
if !ok { if !ok {
existing = []*udpMuxedConn{} existing = []*udpMuxedConn{}
} }
existing = append(existing, conn) existing = append(existing, conn)
m.addressMap[addr] = existing m.addressMap[addr] = existing
m.addressMapMu.Unlock()
log.Debugf("ICE: registered %s for %s", addr, conn.params.Key) log.Debugf("ICE: registered %s for %s", addr, conn.params.Key)
} }
@@ -386,12 +391,12 @@ func (m *UDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.Addr) erro
// If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one // If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one
// muxed connection - one for the SRFLX candidate and the other one for the HOST one. // muxed connection - one for the SRFLX candidate and the other one for the HOST one.
// We will then forward STUN packets to each of these connections. // We will then forward STUN packets to each of these connections.
m.addressMapMu.Lock() m.addressMapMu.RLock()
var destinationConnList []*udpMuxedConn var destinationConnList []*udpMuxedConn
if storedConns, ok := m.addressMap[addr.String()]; ok { if storedConns, ok := m.addressMap[addr.String()]; ok {
destinationConnList = append(destinationConnList, storedConns...) destinationConnList = append(destinationConnList, storedConns...)
} }
m.addressMapMu.Unlock() m.addressMapMu.RUnlock()
var isIPv6 bool var isIPv6 bool
if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil { if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil {

View File

@@ -0,0 +1,21 @@
//go:build !ios
package bind
import (
nbnet "github.com/netbirdio/netbird/util/net"
)
func (m *UDPMuxDefault) notifyAddressRemoval(addr string) {
wrapped, ok := m.params.UDPConn.(*UDPConn)
if !ok {
return
}
nbnetConn, ok := wrapped.GetPacketConn().(*nbnet.UDPConn)
if !ok {
return
}
nbnetConn.RemoveAddress(addr)
}

View File

@@ -0,0 +1,7 @@
//go:build ios
package bind
func (m *UDPMuxDefault) notifyAddressRemoval(addr string) {
// iOS doesn't support nbnet hooks, so this is a no-op
}

View File

@@ -62,7 +62,7 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
// wrap UDP connection, process server reflexive messages // wrap UDP connection, process server reflexive messages
// before they are passed to the UDPMux connection handler (connWorker) // before they are passed to the UDPMux connection handler (connWorker)
m.params.UDPConn = &udpConn{ m.params.UDPConn = &UDPConn{
PacketConn: params.UDPConn, PacketConn: params.UDPConn,
mux: m, mux: m,
logger: params.Logger, logger: params.Logger,
@@ -70,7 +70,6 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
address: params.WGAddress, address: params.WGAddress,
} }
// embed UDPMux
udpMuxParams := UDPMuxParams{ udpMuxParams := UDPMuxParams{
Logger: params.Logger, Logger: params.Logger,
UDPConn: m.params.UDPConn, UDPConn: m.params.UDPConn,
@@ -114,8 +113,8 @@ func (m *UniversalUDPMuxDefault) ReadFromConn(ctx context.Context) {
} }
} }
// udpConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets // UDPConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets
type udpConn struct { type UDPConn struct {
net.PacketConn net.PacketConn
mux *UniversalUDPMuxDefault mux *UniversalUDPMuxDefault
logger logging.LeveledLogger logger logging.LeveledLogger
@@ -125,7 +124,12 @@ type udpConn struct {
address wgaddr.Address address wgaddr.Address
} }
func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) { // GetPacketConn returns the underlying PacketConn
func (u *UDPConn) GetPacketConn() net.PacketConn {
return u.PacketConn
}
func (u *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
if u.filterFn == nil { if u.filterFn == nil {
return u.PacketConn.WriteTo(b, addr) return u.PacketConn.WriteTo(b, addr)
} }
@@ -137,21 +141,21 @@ func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) {
return u.handleUncachedAddress(b, addr) return u.handleUncachedAddress(b, addr)
} }
func (u *udpConn) handleCachedAddress(isRouted bool, b []byte, addr net.Addr) (int, error) { func (u *UDPConn) handleCachedAddress(isRouted bool, b []byte, addr net.Addr) (int, error) {
if isRouted { if isRouted {
return 0, fmt.Errorf("address %s is part of a routed network, refusing to write", addr) return 0, fmt.Errorf("address %s is part of a routed network, refusing to write", addr)
} }
return u.PacketConn.WriteTo(b, addr) return u.PacketConn.WriteTo(b, addr)
} }
func (u *udpConn) handleUncachedAddress(b []byte, addr net.Addr) (int, error) { func (u *UDPConn) handleUncachedAddress(b []byte, addr net.Addr) (int, error) {
if err := u.performFilterCheck(addr); err != nil { if err := u.performFilterCheck(addr); err != nil {
return 0, err return 0, err
} }
return u.PacketConn.WriteTo(b, addr) return u.PacketConn.WriteTo(b, addr)
} }
func (u *udpConn) performFilterCheck(addr net.Addr) error { func (u *UDPConn) performFilterCheck(addr net.Addr) error {
host, err := getHostFromAddr(addr) host, err := getHostFromAddr(addr)
if err != nil { if err != nil {
log.Errorf("Failed to get host from address %s: %v", addr, err) log.Errorf("Failed to get host from address %s: %v", addr, err)

View File

@@ -11,6 +11,8 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/monotime"
) )
var zeroKey wgtypes.Key var zeroKey wgtypes.Key
@@ -277,6 +279,6 @@ func (c *KernelConfigurer) GetStats() (map[string]WGStats, error) {
return stats, nil return stats, nil
} }
func (c *KernelConfigurer) LastActivities() map[string]time.Time { func (c *KernelConfigurer) LastActivities() map[string]monotime.Time {
return nil return nil
} }

View File

@@ -3,4 +3,4 @@
package configurer package configurer
// WgInterfaceDefault is a default interface name of Netbird // WgInterfaceDefault is a default interface name of Netbird
const WgInterfaceDefault = "wt0" const WgInterfaceDefault = "nb0"

View File

@@ -17,6 +17,7 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/monotime"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
@@ -223,7 +224,7 @@ func (c *WGUSPConfigurer) FullStats() (*Stats, error) {
return parseStatus(c.deviceName, ipcStr) return parseStatus(c.deviceName, ipcStr)
} }
func (c *WGUSPConfigurer) LastActivities() map[string]time.Time { func (c *WGUSPConfigurer) LastActivities() map[string]monotime.Time {
return c.activityRecorder.GetLastActivities() return c.activityRecorder.GetLastActivities()
} }

View File

@@ -8,6 +8,7 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/monotime"
) )
type WGConfigurer interface { type WGConfigurer interface {
@@ -19,5 +20,5 @@ type WGConfigurer interface {
Close() Close()
GetStats() (map[string]configurer.WGStats, error) GetStats() (map[string]configurer.WGStats, error)
FullStats() (*configurer.Stats, error) FullStats() (*configurer.Stats, error)
LastActivities() map[string]time.Time LastActivities() map[string]monotime.Time
} }

View File

@@ -21,6 +21,7 @@ import (
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/monotime"
) )
const ( const (
@@ -237,7 +238,7 @@ func (w *WGIface) GetStats() (map[string]configurer.WGStats, error) {
return w.configurer.GetStats() return w.configurer.GetStats()
} }
func (w *WGIface) LastActivities() map[string]time.Time { func (w *WGIface) LastActivities() map[string]monotime.Time {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()

View File

@@ -41,9 +41,12 @@ func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
} }
t.tundev = nsTunDev t.tundev = nsTunDev
skipProxy, err := strconv.ParseBool(os.Getenv(EnvSkipProxy)) var skipProxy bool
if err != nil { if val := os.Getenv(EnvSkipProxy); val != "" {
log.Errorf("failed to parse %s: %s", EnvSkipProxy, err) skipProxy, err = strconv.ParseBool(val)
if err != nil {
log.Errorf("failed to parse %s: %s", EnvSkipProxy, err)
}
} }
if skipProxy { if skipProxy {
return nsTunDev, tunNet, nil return nsTunDev, tunNet, nil

View File

@@ -12,6 +12,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
) )
type ProxyBind struct { type ProxyBind struct {
@@ -28,6 +29,17 @@ type ProxyBind struct {
pausedMu sync.Mutex pausedMu sync.Mutex
paused bool paused bool
isStarted bool isStarted bool
closeListener *listener.CloseListener
}
func NewProxyBind(bind *bind.ICEBind) *ProxyBind {
p := &ProxyBind{
Bind: bind,
closeListener: listener.NewCloseListener(),
}
return p
} }
// AddTurnConn adds a new connection to the bind. // AddTurnConn adds a new connection to the bind.
@@ -54,6 +66,10 @@ func (p *ProxyBind) EndpointAddr() *net.UDPAddr {
} }
} }
func (p *ProxyBind) SetDisconnectListener(disconnected func()) {
p.closeListener.SetCloseListener(disconnected)
}
func (p *ProxyBind) Work() { func (p *ProxyBind) Work() {
if p.remoteConn == nil { if p.remoteConn == nil {
return return
@@ -96,6 +112,9 @@ func (p *ProxyBind) close() error {
if p.closed { if p.closed {
return nil return nil
} }
p.closeListener.SetCloseListener(nil)
p.closed = true p.closed = true
p.cancel() p.cancel()
@@ -122,6 +141,7 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
if ctx.Err() != nil { if ctx.Err() != nil {
return return
} }
p.closeListener.Notify()
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err) log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
return return
} }

View File

@@ -11,6 +11,8 @@ import (
"sync" "sync"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
) )
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call // ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
@@ -26,6 +28,15 @@ type ProxyWrapper struct {
pausedMu sync.Mutex pausedMu sync.Mutex
paused bool paused bool
isStarted bool isStarted bool
closeListener *listener.CloseListener
}
func NewProxyWrapper(WgeBPFProxy *WGEBPFProxy) *ProxyWrapper {
return &ProxyWrapper{
WgeBPFProxy: WgeBPFProxy,
closeListener: listener.NewCloseListener(),
}
} }
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error { func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
@@ -43,6 +54,10 @@ func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
return p.wgEndpointAddr return p.wgEndpointAddr
} }
func (p *ProxyWrapper) SetDisconnectListener(disconnected func()) {
p.closeListener.SetCloseListener(disconnected)
}
func (p *ProxyWrapper) Work() { func (p *ProxyWrapper) Work() {
if p.remoteConn == nil { if p.remoteConn == nil {
return return
@@ -77,6 +92,8 @@ func (e *ProxyWrapper) CloseConn() error {
e.cancel() e.cancel()
e.closeListener.SetCloseListener(nil)
if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
return fmt.Errorf("failed to close remote conn: %w", err) return fmt.Errorf("failed to close remote conn: %w", err)
} }
@@ -117,6 +134,7 @@ func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, err
if ctx.Err() != nil { if ctx.Err() != nil {
return 0, ctx.Err() return 0, ctx.Err()
} }
p.closeListener.Notify()
if !errors.Is(err, io.EOF) { if !errors.Is(err, io.EOF) {
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err) log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err)
} }

View File

@@ -36,9 +36,8 @@ func (w *KernelFactory) GetProxy() Proxy {
return udpProxy.NewWGUDPProxy(w.wgPort) return udpProxy.NewWGUDPProxy(w.wgPort)
} }
return &ebpf.ProxyWrapper{ return ebpf.NewProxyWrapper(w.ebpfProxy)
WgeBPFProxy: w.ebpfProxy,
}
} }
func (w *KernelFactory) Free() error { func (w *KernelFactory) Free() error {

View File

@@ -20,9 +20,7 @@ func NewUSPFactory(iceBind *bind.ICEBind) *USPFactory {
} }
func (w *USPFactory) GetProxy() Proxy { func (w *USPFactory) GetProxy() Proxy {
return &proxyBind.ProxyBind{ return proxyBind.NewProxyBind(w.bind)
Bind: w.bind,
}
} }
func (w *USPFactory) Free() error { func (w *USPFactory) Free() error {

View File

@@ -0,0 +1,19 @@
package listener
type CloseListener struct {
listener func()
}
func NewCloseListener() *CloseListener {
return &CloseListener{}
}
func (c *CloseListener) SetCloseListener(listener func()) {
c.listener = listener
}
func (c *CloseListener) Notify() {
if c.listener != nil {
c.listener()
}
}

View File

@@ -12,4 +12,5 @@ type Proxy interface {
Work() // Work start or resume the proxy Work() // Work start or resume the proxy
Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works. Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works.
CloseConn() error CloseConn() error
SetDisconnectListener(disconnected func())
} }

View File

@@ -98,9 +98,7 @@ func TestProxyCloseByRemoteConn(t *testing.T) {
t.Errorf("failed to free ebpf proxy: %s", err) t.Errorf("failed to free ebpf proxy: %s", err)
} }
}() }()
proxyWrapper := &ebpf.ProxyWrapper{ proxyWrapper := ebpf.NewProxyWrapper(ebpfProxy)
WgeBPFProxy: ebpfProxy,
}
tests = append(tests, struct { tests = append(tests, struct {
name string name string

View File

@@ -12,6 +12,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
cerrors "github.com/netbirdio/netbird/client/errors" cerrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
) )
// WGUDPProxy proxies // WGUDPProxy proxies
@@ -28,6 +29,8 @@ type WGUDPProxy struct {
pausedMu sync.Mutex pausedMu sync.Mutex
paused bool paused bool
isStarted bool isStarted bool
closeListener *listener.CloseListener
} }
// NewWGUDPProxy instantiate a UDP based WireGuard proxy. This is not a thread safe implementation // NewWGUDPProxy instantiate a UDP based WireGuard proxy. This is not a thread safe implementation
@@ -35,6 +38,7 @@ func NewWGUDPProxy(wgPort int) *WGUDPProxy {
log.Debugf("Initializing new user space proxy with port %d", wgPort) log.Debugf("Initializing new user space proxy with port %d", wgPort)
p := &WGUDPProxy{ p := &WGUDPProxy{
localWGListenPort: wgPort, localWGListenPort: wgPort,
closeListener: listener.NewCloseListener(),
} }
return p return p
} }
@@ -67,6 +71,10 @@ func (p *WGUDPProxy) EndpointAddr() *net.UDPAddr {
return endpointUdpAddr return endpointUdpAddr
} }
func (p *WGUDPProxy) SetDisconnectListener(disconnected func()) {
p.closeListener.SetCloseListener(disconnected)
}
// Work starts the proxy or resumes it if it was paused // Work starts the proxy or resumes it if it was paused
func (p *WGUDPProxy) Work() { func (p *WGUDPProxy) Work() {
if p.remoteConn == nil { if p.remoteConn == nil {
@@ -111,6 +119,8 @@ func (p *WGUDPProxy) close() error {
if p.closed { if p.closed {
return nil return nil
} }
p.closeListener.SetCloseListener(nil)
p.closed = true p.closed = true
p.cancel() p.cancel()
@@ -141,6 +151,7 @@ func (p *WGUDPProxy) proxyToRemote(ctx context.Context) {
if ctx.Err() != nil { if ctx.Err() != nil {
return return
} }
p.closeListener.Notify()
log.Debugf("failed to read from wg interface conn: %s", err) log.Debugf("failed to read from wg interface conn: %s", err)
return return
} }

View File

@@ -39,7 +39,7 @@ const (
) )
var defaultInterfaceBlacklist = []string{ var defaultInterfaceBlacklist = []string{
iface.WgInterfaceDefault, "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts", iface.WgInterfaceDefault, "nb", "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts",
"Tailscale", "tailscale", "docker", "veth", "br-", "lo", "Tailscale", "tailscale", "docker", "veth", "br-", "lo",
} }

View File

@@ -226,7 +226,6 @@ func (e *ConnMgr) ActivatePeer(ctx context.Context, conn *peer.Conn) {
} }
if found := e.lazyConnMgr.ActivatePeer(conn.GetKey()); found { if found := e.lazyConnMgr.ActivatePeer(conn.GetKey()); found {
conn.Log.Infof("activated peer from inactive state")
if err := conn.Open(ctx); err != nil { if err := conn.Open(ctx); err != nil {
conn.Log.Errorf("failed to open connection: %v", err) conn.Log.Errorf("failed to open connection: %v", err)
} }

View File

@@ -61,7 +61,6 @@ import (
signal "github.com/netbirdio/netbird/signal/client" signal "github.com/netbirdio/netbird/signal/client"
sProto "github.com/netbirdio/netbird/signal/proto" sProto "github.com/netbirdio/netbird/signal/proto"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
nbnet "github.com/netbirdio/netbird/util/net"
) )
// PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer. // PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer.
@@ -138,9 +137,6 @@ type Engine struct {
connMgr *ConnMgr connMgr *ConnMgr
beforePeerHook nbnet.AddHookFunc
afterPeerHook nbnet.RemoveHookFunc
// rpManager is a Rosenpass manager // rpManager is a Rosenpass manager
rpManager *rosenpass.Manager rpManager *rosenpass.Manager
@@ -409,12 +405,8 @@ func (e *Engine) Start() error {
DisableClientRoutes: e.config.DisableClientRoutes, DisableClientRoutes: e.config.DisableClientRoutes,
DisableServerRoutes: e.config.DisableServerRoutes, DisableServerRoutes: e.config.DisableServerRoutes,
}) })
beforePeerHook, afterPeerHook, err := e.routeManager.Init() if err := e.routeManager.Init(); err != nil {
if err != nil {
log.Errorf("Failed to initialize route manager: %s", err) log.Errorf("Failed to initialize route manager: %s", err)
} else {
e.beforePeerHook = beforePeerHook
e.afterPeerHook = afterPeerHook
} }
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener) e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
@@ -1261,10 +1253,6 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
return fmt.Errorf("peer already exists: %s", peerKey) return fmt.Errorf("peer already exists: %s", peerKey)
} }
if e.beforePeerHook != nil && e.afterPeerHook != nil {
conn.AddBeforeAddPeerHook(e.beforePeerHook)
conn.AddAfterRemovePeerHook(e.afterPeerHook)
}
return nil return nil
} }

View File

@@ -52,6 +52,7 @@ import (
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/monotime"
relayClient "github.com/netbirdio/netbird/relay/client" relayClient "github.com/netbirdio/netbird/relay/client"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
signal "github.com/netbirdio/netbird/signal/client" signal "github.com/netbirdio/netbird/signal/client"
@@ -96,7 +97,7 @@ type MockWGIface struct {
GetInterfaceGUIDStringFunc func() (string, error) GetInterfaceGUIDStringFunc func() (string, error)
GetProxyFunc func() wgproxy.Proxy GetProxyFunc func() wgproxy.Proxy
GetNetFunc func() *netstack.Net GetNetFunc func() *netstack.Net
LastActivitiesFunc func() map[string]time.Time LastActivitiesFunc func() map[string]monotime.Time
} }
func (m *MockWGIface) FullStats() (*configurer.Stats, error) { func (m *MockWGIface) FullStats() (*configurer.Stats, error) {
@@ -187,7 +188,7 @@ func (m *MockWGIface) GetNet() *netstack.Net {
return m.GetNetFunc() return m.GetNetFunc()
} }
func (m *MockWGIface) LastActivities() map[string]time.Time { func (m *MockWGIface) LastActivities() map[string]monotime.Time {
if m.LastActivitiesFunc != nil { if m.LastActivitiesFunc != nil {
return m.LastActivitiesFunc() return m.LastActivitiesFunc()
} }
@@ -399,7 +400,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
StatusRecorder: engine.statusRecorder, StatusRecorder: engine.statusRecorder,
RelayManager: relayMgr, RelayManager: relayMgr,
}) })
_, _, err = engine.routeManager.Init() err = engine.routeManager.Init()
require.NoError(t, err) require.NoError(t, err)
engine.dnsServer = &dns.MockServer{ engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil }, UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
@@ -1392,7 +1393,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
if runtime.GOOS == "darwin" { if runtime.GOOS == "darwin" {
ifaceName = fmt.Sprintf("utun1%d", i) ifaceName = fmt.Sprintf("utun1%d", i)
} else { } else {
ifaceName = fmt.Sprintf("wt%d", i) ifaceName = fmt.Sprintf("nb%d", i)
} }
wgPort := 33100 + i wgPort := 33100 + i
@@ -1480,6 +1481,10 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()). GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
Return(&types.Settings{}, nil). Return(&types.Settings{}, nil).
AnyTimes() AnyTimes()
settingsMockManager.EXPECT().
GetExtraSettings(gomock.Any(), gomock.Any()).
Return(&types.ExtraSettings{}, nil).
AnyTimes()
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)
@@ -1489,7 +1494,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
} }
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager) secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil) mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{})
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }

View File

@@ -14,6 +14,7 @@ import (
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/monotime"
) )
type wgIfaceBase interface { type wgIfaceBase interface {
@@ -38,5 +39,5 @@ type wgIfaceBase interface {
GetStats() (map[string]configurer.WGStats, error) GetStats() (map[string]configurer.WGStats, error)
GetNet() *netstack.Net GetNet() *netstack.Net
FullStats() (*configurer.Stats, error) FullStats() (*configurer.Stats, error)
LastActivities() map[string]time.Time LastActivities() map[string]monotime.Time
} }

View File

@@ -48,7 +48,7 @@ func (d *Listener) ReadPackets() {
n, remoteAddr, err := d.conn.ReadFromUDP(make([]byte, 1)) n, remoteAddr, err := d.conn.ReadFromUDP(make([]byte, 1))
if err != nil { if err != nil {
if d.isClosed.Load() { if d.isClosed.Load() {
d.peerCfg.Log.Debugf("exit from activity listener") d.peerCfg.Log.Infof("exit from activity listener")
} else { } else {
d.peerCfg.Log.Errorf("failed to read from activity listener: %s", err) d.peerCfg.Log.Errorf("failed to read from activity listener: %s", err)
} }
@@ -59,9 +59,11 @@ func (d *Listener) ReadPackets() {
d.peerCfg.Log.Warnf("received %d bytes from %s, too short", n, remoteAddr) d.peerCfg.Log.Warnf("received %d bytes from %s, too short", n, remoteAddr)
continue continue
} }
d.peerCfg.Log.Infof("activity detected")
break break
} }
d.peerCfg.Log.Debugf("removing lazy endpoint: %s", d.endpoint.String())
if err := d.removeEndpoint(); err != nil { if err := d.removeEndpoint(); err != nil {
d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err) d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err)
} }
@@ -71,7 +73,7 @@ func (d *Listener) ReadPackets() {
} }
func (d *Listener) Close() { func (d *Listener) Close() {
d.peerCfg.Log.Infof("closing listener: %s", d.conn.LocalAddr().String()) d.peerCfg.Log.Infof("closing activity listener: %s", d.conn.LocalAddr().String())
d.isClosed.Store(true) d.isClosed.Store(true)
if err := d.conn.Close(); err != nil { if err := d.conn.Close(); err != nil {
@@ -81,7 +83,6 @@ func (d *Listener) Close() {
} }
func (d *Listener) removeEndpoint() error { func (d *Listener) removeEndpoint() error {
d.peerCfg.Log.Debugf("removing lazy endpoint: %s", d.endpoint.String())
return d.wgIface.RemovePeer(d.peerCfg.PublicKey) return d.wgIface.RemovePeer(d.peerCfg.PublicKey)
} }

View File

@@ -8,6 +8,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/lazyconn" "github.com/netbirdio/netbird/client/internal/lazyconn"
"github.com/netbirdio/netbird/monotime"
) )
const ( const (
@@ -18,7 +19,7 @@ const (
) )
type WgInterface interface { type WgInterface interface {
LastActivities() map[string]time.Time LastActivities() map[string]monotime.Time
} }
type Manager struct { type Manager struct {
@@ -124,6 +125,7 @@ func (m *Manager) checkStats() (map[string]struct{}, error) {
idlePeers := make(map[string]struct{}) idlePeers := make(map[string]struct{})
checkTime := time.Now()
for peerID, peerCfg := range m.interestedPeers { for peerID, peerCfg := range m.interestedPeers {
lastActive, ok := lastActivities[peerID] lastActive, ok := lastActivities[peerID]
if !ok { if !ok {
@@ -132,8 +134,9 @@ func (m *Manager) checkStats() (map[string]struct{}, error) {
continue continue
} }
if time.Since(lastActive) > m.inactivityThreshold { since := monotime.Since(lastActive)
peerCfg.Log.Infof("peer is inactive since: %v", lastActive) if since > m.inactivityThreshold {
peerCfg.Log.Infof("peer is inactive since time: %s", checkTime.Add(-since).String())
idlePeers[peerID] = struct{}{} idlePeers[peerID] = struct{}{}
} }
} }

View File

@@ -9,13 +9,14 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/client/internal/lazyconn" "github.com/netbirdio/netbird/client/internal/lazyconn"
"github.com/netbirdio/netbird/monotime"
) )
type mockWgInterface struct { type mockWgInterface struct {
lastActivities map[string]time.Time lastActivities map[string]monotime.Time
} }
func (m *mockWgInterface) LastActivities() map[string]time.Time { func (m *mockWgInterface) LastActivities() map[string]monotime.Time {
return m.lastActivities return m.lastActivities
} }
@@ -23,8 +24,8 @@ func TestPeerTriggersInactivity(t *testing.T) {
peerID := "peer1" peerID := "peer1"
wgMock := &mockWgInterface{ wgMock := &mockWgInterface{
lastActivities: map[string]time.Time{ lastActivities: map[string]monotime.Time{
peerID: time.Now().Add(-20 * time.Minute), peerID: monotime.Time(int64(monotime.Now()) - int64(20*time.Minute)),
}, },
} }
@@ -64,8 +65,8 @@ func TestPeerTriggersActivity(t *testing.T) {
peerID := "peer1" peerID := "peer1"
wgMock := &mockWgInterface{ wgMock := &mockWgInterface{
lastActivities: map[string]time.Time{ lastActivities: map[string]monotime.Time{
peerID: time.Now().Add(-5 * time.Minute), peerID: monotime.Time(int64(monotime.Now()) - int64(5*time.Minute)),
}, },
} }

View File

@@ -258,12 +258,13 @@ func (m *Manager) ActivatePeer(peerID string) (found bool) {
return false return false
} }
cfg.Log.Infof("activate peer from inactive state by remote signal message")
if !m.activateSinglePeer(cfg, mp) { if !m.activateSinglePeer(cfg, mp) {
return false return false
} }
m.activateHAGroupPeers(cfg) m.activateHAGroupPeers(cfg)
return true return true
} }
@@ -571,12 +572,12 @@ func (m *Manager) onPeerInactivityTimedOut(peerIDs map[string]struct{}) {
// this is blocking operation, potentially can be optimized // this is blocking operation, potentially can be optimized
m.peerStore.PeerConnIdle(mp.peerCfg.PublicKey) m.peerStore.PeerConnIdle(mp.peerCfg.PublicKey)
mp.peerCfg.Log.Infof("start activity monitor")
mp.expectedWatcher = watcherActivity mp.expectedWatcher = watcherActivity
m.inactivityManager.RemovePeer(mp.peerCfg.PublicKey) m.inactivityManager.RemovePeer(mp.peerCfg.PublicKey)
mp.peerCfg.Log.Infof("start activity monitor")
if err := m.activityManager.MonitorPeerActivity(*mp.peerCfg); err != nil { if err := m.activityManager.MonitorPeerActivity(*mp.peerCfg); err != nil {
mp.peerCfg.Log.Errorf("failed to create activity monitor: %v", err) mp.peerCfg.Log.Errorf("failed to create activity monitor: %v", err)
continue continue

View File

@@ -6,11 +6,13 @@ import (
"time" "time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/monotime"
) )
type WGIface interface { type WGIface interface {
RemovePeer(peerKey string) error RemovePeer(peerKey string) error
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
IsUserspaceBind() bool IsUserspaceBind() bool
LastActivities() map[string]time.Time LastActivities() map[string]monotime.Time
} }

View File

@@ -19,7 +19,7 @@ type mockIFaceMapper struct {
} }
func (m *mockIFaceMapper) Name() string { func (m *mockIFaceMapper) Name() string {
return "wt0" return "nb0"
} }
func (m *mockIFaceMapper) Address() wgaddr.Address { func (m *mockIFaceMapper) Address() wgaddr.Address {

View File

@@ -26,7 +26,6 @@ import (
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
relayClient "github.com/netbirdio/netbird/relay/client" relayClient "github.com/netbirdio/netbird/relay/client"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
nbnet "github.com/netbirdio/netbird/util/net"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group" semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
) )
@@ -106,10 +105,6 @@ type Conn struct {
workerRelay *WorkerRelay workerRelay *WorkerRelay
wgWatcherWg sync.WaitGroup wgWatcherWg sync.WaitGroup
connIDRelay nbnet.ConnectionID
connIDICE nbnet.ConnectionID
beforeAddPeerHooks []nbnet.AddHookFunc
afterRemovePeerHooks []nbnet.RemoveHookFunc
// used to store the remote Rosenpass key for Relayed connection in case of connection update from ice // used to store the remote Rosenpass key for Relayed connection in case of connection update from ice
rosenpassRemoteKey []byte rosenpassRemoteKey []byte
@@ -167,7 +162,7 @@ func (conn *Conn) Open(engineCtx context.Context) error {
conn.ctx, conn.ctxCancel = context.WithCancel(engineCtx) conn.ctx, conn.ctxCancel = context.WithCancel(engineCtx)
conn.workerRelay = NewWorkerRelay(conn.Log, isController(conn.config), conn.config, conn, conn.relayManager, conn.dumpState) conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager, conn.dumpState)
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally() relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally) workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
@@ -267,8 +262,6 @@ func (conn *Conn) Close(signalToRemote bool) {
conn.Log.Errorf("failed to remove wg endpoint: %v", err) conn.Log.Errorf("failed to remove wg endpoint: %v", err)
} }
conn.freeUpConnID()
if conn.evalStatus() == StatusConnected && conn.onDisconnected != nil { if conn.evalStatus() == StatusConnected && conn.onDisconnected != nil {
conn.onDisconnected(conn.config.WgConfig.RemoteKey) conn.onDisconnected(conn.config.WgConfig.RemoteKey)
} }
@@ -293,13 +286,6 @@ func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMa
conn.workerICE.OnRemoteCandidate(candidate, haRoutes) conn.workerICE.OnRemoteCandidate(candidate, haRoutes)
} }
func (conn *Conn) AddBeforeAddPeerHook(hook nbnet.AddHookFunc) {
conn.beforeAddPeerHooks = append(conn.beforeAddPeerHooks, hook)
}
func (conn *Conn) AddAfterRemovePeerHook(hook nbnet.RemoveHookFunc) {
conn.afterRemovePeerHooks = append(conn.afterRemovePeerHooks, hook)
}
// SetOnConnected sets a handler function to be triggered by Conn when a new connection to a remote peer established // SetOnConnected sets a handler function to be triggered by Conn when a new connection to a remote peer established
func (conn *Conn) SetOnConnected(handler func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)) { func (conn *Conn) SetOnConnected(handler func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)) {
conn.onConnected = handler conn.onConnected = handler
@@ -387,10 +373,6 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
ep = directEp ep = directEp
} }
if err := conn.runBeforeAddPeerHooks(ep.IP); err != nil {
conn.Log.Errorf("Before add peer hook failed: %v", err)
}
conn.workerRelay.DisableWgWatcher() conn.workerRelay.DisableWgWatcher()
// todo consider to run conn.wgWatcherWg.Wait() here // todo consider to run conn.wgWatcherWg.Wait() here
@@ -489,6 +471,8 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
conn.Log.Errorf("failed to add relayed net.Conn to local proxy: %v", err) conn.Log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
return return
} }
wgProxy.SetDisconnectListener(conn.onRelayDisconnected)
conn.dumpState.NewLocalProxy() conn.dumpState.NewLocalProxy()
conn.Log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String()) conn.Log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String())
@@ -501,10 +485,6 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
return return
} }
if err := conn.runBeforeAddPeerHooks(wgProxy.EndpointAddr().IP); err != nil {
conn.Log.Errorf("Before add peer hook failed: %v", err)
}
wgProxy.Work() wgProxy.Work()
if err := conn.configureWGEndpoint(wgProxy.EndpointAddr(), rci.rosenpassPubKey); err != nil { if err := conn.configureWGEndpoint(wgProxy.EndpointAddr(), rci.rosenpassPubKey); err != nil {
if err := wgProxy.CloseConn(); err != nil { if err := wgProxy.CloseConn(); err != nil {
@@ -705,36 +685,6 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) {
return true return true
} }
func (conn *Conn) runBeforeAddPeerHooks(ip net.IP) error {
conn.connIDICE = nbnet.GenerateConnID()
for _, hook := range conn.beforeAddPeerHooks {
if err := hook(conn.connIDICE, ip); err != nil {
return err
}
}
return nil
}
func (conn *Conn) freeUpConnID() {
if conn.connIDRelay != "" {
for _, hook := range conn.afterRemovePeerHooks {
if err := hook(conn.connIDRelay); err != nil {
conn.Log.Errorf("After remove peer hook failed: %v", err)
}
}
conn.connIDRelay = ""
}
if conn.connIDICE != "" {
for _, hook := range conn.afterRemovePeerHooks {
if err := hook(conn.connIDICE); err != nil {
conn.Log.Errorf("After remove peer hook failed: %v", err)
}
}
conn.connIDICE = ""
}
}
func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) { func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
conn.Log.Debugf("setup proxied WireGuard connection") conn.Log.Debugf("setup proxied WireGuard connection")
udpAddr := &net.UDPAddr{ udpAddr := &net.UDPAddr{

View File

@@ -19,6 +19,7 @@ type RelayConnInfo struct {
} }
type WorkerRelay struct { type WorkerRelay struct {
peerCtx context.Context
log *log.Entry log *log.Entry
isController bool isController bool
config ConnConfig config ConnConfig
@@ -33,8 +34,9 @@ type WorkerRelay struct {
wgWatcher *WGWatcher wgWatcher *WGWatcher
} }
func NewWorkerRelay(log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager relayClient.ManagerService, stateDump *stateDump) *WorkerRelay { func NewWorkerRelay(ctx context.Context, log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager relayClient.ManagerService, stateDump *stateDump) *WorkerRelay {
r := &WorkerRelay{ r := &WorkerRelay{
peerCtx: ctx,
log: log, log: log,
isController: ctrl, isController: ctrl,
config: config, config: config,
@@ -62,7 +64,7 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
srv := w.preferredRelayServer(currentRelayAddress, remoteOfferAnswer.RelaySrvAddress) srv := w.preferredRelayServer(currentRelayAddress, remoteOfferAnswer.RelaySrvAddress)
relayedConn, err := w.relayManager.OpenConn(srv, w.config.Key) relayedConn, err := w.relayManager.OpenConn(w.peerCtx, srv, w.config.Key)
if err != nil { if err != nil {
if errors.Is(err, relayClient.ErrConnAlreadyExists) { if errors.Is(err, relayClient.ErrConnAlreadyExists) {
w.log.Debugf("handled offer by reusing existing relay connection") w.log.Debugf("handled offer by reusing existing relay connection")

View File

@@ -812,7 +812,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
} }
params := common.HandlerParams{ params := common.HandlerParams{
Route: &route.Route{Network: netip.MustParsePrefix("192.168.0.0/24")}, Route: &route.Route{Network: netip.MustParsePrefix("192.168.0.0/24")},
} }
// create new clientNetwork // create new clientNetwork
client := &Watcher{ client := &Watcher{

View File

@@ -44,7 +44,7 @@ import (
// Manager is a route manager interface // Manager is a route manager interface
type Manager interface { type Manager interface {
Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) Init() error
UpdateRoutes(updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error UpdateRoutes(updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error
ClassifyRoutes(newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap) ClassifyRoutes(newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap)
TriggerSelection(route.HAMap) TriggerSelection(route.HAMap)
@@ -201,11 +201,11 @@ func (m *DefaultManager) setupRefCounters(useNoop bool) {
} }
// Init sets up the routing // Init sets up the routing
func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { func (m *DefaultManager) Init() error {
m.routeSelector = m.initSelector() m.routeSelector = m.initSelector()
if nbnet.CustomRoutingDisabled() || m.disableClientRoutes { if nbnet.CustomRoutingDisabled() || m.disableClientRoutes {
return nil, nil, nil return nil
} }
if err := m.sysOps.CleanupRouting(nil); err != nil { if err := m.sysOps.CleanupRouting(nil); err != nil {
@@ -219,13 +219,12 @@ func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
ips := resolveURLsToIPs(initialAddresses) ips := resolveURLsToIPs(initialAddresses)
beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips, m.stateManager) if err := m.sysOps.SetupRouting(ips, m.stateManager); err != nil {
if err != nil { return fmt.Errorf("setup routing: %w", err)
return nil, nil, fmt.Errorf("setup routing: %w", err)
} }
log.Info("Routing setup complete") log.Info("Routing setup complete")
return beforePeerHook, afterPeerHook, nil return nil
} }
func (m *DefaultManager) initSelector() *routeselector.RouteSelector { func (m *DefaultManager) initSelector() *routeselector.RouteSelector {

View File

@@ -430,7 +430,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
StatusRecorder: statusRecorder, StatusRecorder: statusRecorder,
}) })
_, _, err = routeManager.Init() err = routeManager.Init()
require.NoError(t, err, "should init route manager") require.NoError(t, err, "should init route manager")
defer routeManager.Stop(nil) defer routeManager.Stop(nil)

View File

@@ -9,7 +9,6 @@ import (
"github.com/netbirdio/netbird/client/internal/routeselector" "github.com/netbirdio/netbird/client/internal/routeselector"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/util/net"
) )
// MockManager is the mock instance of a route manager // MockManager is the mock instance of a route manager
@@ -23,8 +22,8 @@ type MockManager struct {
StopFunc func(manager *statemanager.Manager) StopFunc func(manager *statemanager.Manager)
} }
func (m *MockManager) Init() (net.AddHookFunc, net.RemoveHookFunc, error) { func (m *MockManager) Init() error {
return nil, nil, nil return nil
} }
// InitialRouteRange mock implementation of InitialRouteRange from Manager interface // InitialRouteRange mock implementation of InitialRouteRange from Manager interface

View File

@@ -6,6 +6,7 @@ import (
"net/netip" "net/netip"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/routemanager/notifier" "github.com/netbirdio/netbird/client/internal/routemanager/notifier"
@@ -56,6 +57,10 @@ type SysOps struct {
// seq is an atomic counter for generating unique sequence numbers for route messages // seq is an atomic counter for generating unique sequence numbers for route messages
//nolint:unused // only used on BSD systems //nolint:unused // only used on BSD systems
seq atomic.Uint32 seq atomic.Uint32
localSubnetsCache []*net.IPNet
localSubnetsCacheMu sync.RWMutex
localSubnetsCacheTime time.Time
} }
func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps { func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps {

View File

@@ -10,11 +10,10 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net"
) )
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) error {
return nil, nil, nil return nil
} }
func (r *SysOps) CleanupRouting(*statemanager.Manager) error { func (r *SysOps) CleanupRouting(*statemanager.Manager) error {

View File

@@ -10,6 +10,7 @@ import (
"net/netip" "net/netip"
"runtime" "runtime"
"strconv" "strconv"
"time"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
"github.com/libp2p/go-netroute" "github.com/libp2p/go-netroute"
@@ -24,6 +25,8 @@ import (
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
const localSubnetsCacheTTL = 15 * time.Minute
var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1) var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1)
var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1) var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1)
var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1) var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1)
@@ -31,7 +34,7 @@ var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)
var ErrRoutingIsSeparate = errors.New("routing is separate") var ErrRoutingIsSeparate = errors.New("routing is separate")
func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemanager.Manager) error {
stateManager.RegisterState(&ShutdownState{}) stateManager.RegisterState(&ShutdownState{})
initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified()) initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified())
@@ -75,7 +78,10 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana
r.refCounter = refCounter r.refCounter = refCounter
return r.setupHooks(initAddresses, stateManager) if err := r.setupHooks(initAddresses, stateManager); err != nil {
return fmt.Errorf("setup hooks: %w", err)
}
return nil
} }
// updateState updates state on every change so it will be persisted regularly // updateState updates state on every change so it will be persisted regularly
@@ -128,18 +134,14 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf wgIface, init
return Nexthop{}, fmt.Errorf("get next hop: %w", err) return Nexthop{}, fmt.Errorf("get next hop: %w", err)
} }
log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop.IP, prefix, nexthop.IP) log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop.IP, prefix, nexthop.Intf)
exitNextHop := Nexthop{ exitNextHop := nexthop
IP: nexthop.IP,
Intf: nexthop.Intf,
}
vpnAddr := vpnIntf.Address().IP vpnAddr := vpnIntf.Address().IP
// if next hop is the VPN address or the interface is the VPN interface, we should use the initial values // if next hop is the VPN address or the interface is the VPN interface, we should use the initial values
if exitNextHop.IP == vpnAddr || exitNextHop.Intf != nil && exitNextHop.Intf.Name == vpnIntf.Name() { if exitNextHop.IP == vpnAddr || exitNextHop.Intf != nil && exitNextHop.Intf.Name == vpnIntf.Name() {
log.Debugf("Route for prefix %s is pointing to the VPN interface, using initial next hop %v", prefix, initialNextHop) log.Debugf("Route for prefix %s is pointing to the VPN interface, using initial next hop %v", prefix, initialNextHop)
exitNextHop = initialNextHop exitNextHop = initialNextHop
} }
@@ -152,12 +154,37 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf wgIface, init
} }
func (r *SysOps) isPrefixInLocalSubnets(prefix netip.Prefix) (bool, *net.IPNet) { func (r *SysOps) isPrefixInLocalSubnets(prefix netip.Prefix) (bool, *net.IPNet) {
r.localSubnetsCacheMu.RLock()
cacheAge := time.Since(r.localSubnetsCacheTime)
subnets := r.localSubnetsCache
r.localSubnetsCacheMu.RUnlock()
if cacheAge > localSubnetsCacheTTL || subnets == nil {
r.localSubnetsCacheMu.Lock()
if time.Since(r.localSubnetsCacheTime) > localSubnetsCacheTTL || r.localSubnetsCache == nil {
r.refreshLocalSubnetsCache()
}
subnets = r.localSubnetsCache
r.localSubnetsCacheMu.Unlock()
}
for _, subnet := range subnets {
if subnet.Contains(prefix.Addr().AsSlice()) {
return true, subnet
}
}
return false, nil
}
func (r *SysOps) refreshLocalSubnetsCache() {
localInterfaces, err := net.Interfaces() localInterfaces, err := net.Interfaces()
if err != nil { if err != nil {
log.Errorf("Failed to get local interfaces: %v", err) log.Errorf("Failed to get local interfaces: %v", err)
return false, nil return
} }
var newSubnets []*net.IPNet
for _, intf := range localInterfaces { for _, intf := range localInterfaces {
addrs, err := intf.Addrs() addrs, err := intf.Addrs()
if err != nil { if err != nil {
@@ -171,14 +198,12 @@ func (r *SysOps) isPrefixInLocalSubnets(prefix netip.Prefix) (bool, *net.IPNet)
log.Errorf("Failed to convert address to IPNet: %v", addr) log.Errorf("Failed to convert address to IPNet: %v", addr)
continue continue
} }
newSubnets = append(newSubnets, ipnet)
if ipnet.Contains(prefix.Addr().AsSlice()) {
return true, ipnet
}
} }
} }
return false, nil r.localSubnetsCache = newSubnets
r.localSubnetsCacheTime = time.Now()
} }
// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix // genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix
@@ -264,7 +289,7 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface)
return r.removeFromRouteTable(prefix, nextHop) return r.removeFromRouteTable(prefix, nextHop)
} }
func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) error {
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error { beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
prefix, err := util.GetPrefixFromIP(ip) prefix, err := util.GetPrefixFromIP(ip)
if err != nil { if err != nil {
@@ -289,9 +314,11 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M
return nil return nil
} }
var merr *multierror.Error
for _, ip := range initAddresses { for _, ip := range initAddresses {
if err := beforeHook("init", ip); err != nil { if err := beforeHook("init", ip); err != nil {
log.Errorf("Failed to add route reference: %v", err) merr = multierror.Append(merr, fmt.Errorf("add initial route for %s: %w", ip, err))
} }
} }
@@ -300,11 +327,11 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M
return ctx.Err() return ctx.Err()
} }
var result *multierror.Error var merr *multierror.Error
for _, ip := range resolvedIPs { for _, ip := range resolvedIPs {
result = multierror.Append(result, beforeHook(connID, ip.IP)) merr = multierror.Append(merr, beforeHook(connID, ip.IP))
} }
return nberrors.FormatErrorOrNil(result) return nberrors.FormatErrorOrNil(merr)
}) })
nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error { nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error {
@@ -319,7 +346,16 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M
return afterHook(connID) return afterHook(connID)
}) })
return beforeHook, afterHook, nil nbnet.AddListenerAddressRemoveHook(func(connID nbnet.ConnectionID, prefix netip.Prefix) error {
if _, err := r.refCounter.Decrement(prefix); err != nil {
return fmt.Errorf("remove route reference: %w", err)
}
r.updateState(stateManager)
return nil
})
return nberrors.FormatErrorOrNil(merr)
} }
func GetNextHop(ip netip.Addr) (Nexthop, error) { func GetNextHop(ip netip.Addr) (Nexthop, error) {

View File

@@ -143,7 +143,7 @@ func TestAddVPNRoute(t *testing.T) {
wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n) wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n)
r := NewSysOps(wgInterface, nil) r := NewSysOps(wgInterface, nil)
_, _, err := r.SetupRouting(nil, nil) err := r.SetupRouting(nil, nil)
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(func() { t.Cleanup(func() {
assert.NoError(t, r.CleanupRouting(nil)) assert.NoError(t, r.CleanupRouting(nil))
@@ -341,7 +341,7 @@ func TestAddRouteToNonVPNIntf(t *testing.T) {
wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n) wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n)
r := NewSysOps(wgInterface, nil) r := NewSysOps(wgInterface, nil)
_, _, err := r.SetupRouting(nil, nil) err := r.SetupRouting(nil, nil)
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(func() { t.Cleanup(func() {
assert.NoError(t, r.CleanupRouting(nil)) assert.NoError(t, r.CleanupRouting(nil))
@@ -484,7 +484,7 @@ func setupTestEnv(t *testing.T) {
}) })
r := NewSysOps(wgInterface, nil) r := NewSysOps(wgInterface, nil)
_, _, err := r.SetupRouting(nil, nil) err := r.SetupRouting(nil, nil)
require.NoError(t, err, "setupRouting should not return err") require.NoError(t, err, "setupRouting should not return err")
t.Cleanup(func() { t.Cleanup(func() {
assert.NoError(t, r.CleanupRouting(nil)) assert.NoError(t, r.CleanupRouting(nil))

View File

@@ -10,14 +10,13 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net"
) )
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) error {
r.mu.Lock() r.mu.Lock()
defer r.mu.Unlock() defer r.mu.Unlock()
r.prefixes = make(map[netip.Prefix]struct{}) r.prefixes = make(map[netip.Prefix]struct{})
return nil, nil, nil return nil
} }
func (r *SysOps) CleanupRouting(*statemanager.Manager) error { func (r *SysOps) CleanupRouting(*statemanager.Manager) error {

View File

@@ -72,7 +72,7 @@ func getSetupRules() []ruleParams {
// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table. // Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table.
// This table is where a default route or other specific routes received from the management server are configured, // This table is where a default route or other specific routes received from the management server are configured,
// enabling VPN connectivity. // enabling VPN connectivity.
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (_ nbnet.AddHookFunc, _ nbnet.RemoveHookFunc, err error) { func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (err error) {
if !nbnet.AdvancedRouting() { if !nbnet.AdvancedRouting() {
log.Infof("Using legacy routing setup") log.Infof("Using legacy routing setup")
return r.setupRefCounter(initAddresses, stateManager) return r.setupRefCounter(initAddresses, stateManager)
@@ -89,7 +89,7 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
rules := getSetupRules() rules := getSetupRules()
for _, rule := range rules { for _, rule := range rules {
if err := addRule(rule); err != nil { if err := addRule(rule); err != nil {
return nil, nil, fmt.Errorf("%s: %w", rule.description, err) return fmt.Errorf("%s: %w", rule.description, err)
} }
} }
@@ -104,7 +104,7 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
} }
originalSysctl = originalValues originalSysctl = originalValues
return nil, nil, nil return nil
} }
// CleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'. // CleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'.

View File

@@ -252,7 +252,7 @@ func TestSysOps_validateRoute_InvalidPrefix(t *testing.T) {
IP: wgNetwork.Addr(), IP: wgNetwork.Addr(),
Network: wgNetwork, Network: wgNetwork,
}, },
name: "wt0", name: "nb0",
} }
sysOps := &SysOps{ sysOps := &SysOps{

View File

@@ -18,10 +18,9 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net"
) )
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) error {
return r.setupRefCounter(initAddresses, stateManager) return r.setupRefCounter(initAddresses, stateManager)
} }

View File

@@ -19,7 +19,6 @@ import (
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net"
) )
const InfiniteLifetime = 0xffffffff const InfiniteLifetime = 0xffffffff
@@ -137,7 +136,7 @@ const (
RouteDeleted RouteDeleted
) )
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) error {
return r.setupRefCounter(initAddresses, stateManager) return r.setupRefCounter(initAddresses, stateManager)
} }

View File

@@ -1330,6 +1330,13 @@ func (x *PeerState) GetRelayAddress() string {
return "" return ""
} }
func (x *PeerState) GetConnectionType() string {
if x.Relayed {
return "Relayed"
}
return "P2P"
}
// LocalPeerState contains the latest state of the local peer // LocalPeerState contains the latest state of the local peer
type LocalPeerState struct { type LocalPeerState struct {
state protoimpl.MessageState `protogen:"open.v1"` state protoimpl.MessageState `protogen:"open.v1"`

View File

@@ -212,7 +212,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
} }
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager) secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil) mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{})
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }

View File

@@ -100,7 +100,7 @@ type OutputOverview struct {
LazyConnectionEnabled bool `json:"lazyConnectionEnabled" yaml:"lazyConnectionEnabled"` LazyConnectionEnabled bool `json:"lazyConnectionEnabled" yaml:"lazyConnectionEnabled"`
} }
func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}) OutputOverview { func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}, connectionTypeFilter string) OutputOverview {
pbFullStatus := resp.GetFullStatus() pbFullStatus := resp.GetFullStatus()
managementState := pbFullStatus.GetManagementState() managementState := pbFullStatus.GetManagementState()
@@ -118,7 +118,7 @@ func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, status
} }
relayOverview := mapRelays(pbFullStatus.GetRelays()) relayOverview := mapRelays(pbFullStatus.GetRelays())
peersOverview := mapPeers(resp.GetFullStatus().GetPeers(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter) peersOverview := mapPeers(resp.GetFullStatus().GetPeers(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter)
overview := OutputOverview{ overview := OutputOverview{
Peers: peersOverview, Peers: peersOverview,
@@ -193,6 +193,7 @@ func mapPeers(
prefixNamesFilter []string, prefixNamesFilter []string,
prefixNamesFilterMap map[string]struct{}, prefixNamesFilterMap map[string]struct{},
ipsFilter map[string]struct{}, ipsFilter map[string]struct{},
connectionTypeFilter string,
) PeersStateOutput { ) PeersStateOutput {
var peersStateDetail []PeerStateDetailOutput var peersStateDetail []PeerStateDetailOutput
peersConnected := 0 peersConnected := 0
@@ -208,7 +209,7 @@ func mapPeers(
transferSent := int64(0) transferSent := int64(0)
isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String() isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String()
if skipDetailByFilters(pbPeerState, pbPeerState.ConnStatus, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter) { if skipDetailByFilters(pbPeerState, pbPeerState.ConnStatus, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter) {
continue continue
} }
if isPeerConnected { if isPeerConnected {
@@ -218,10 +219,7 @@ func mapPeers(
remoteICE = pbPeerState.GetRemoteIceCandidateType() remoteICE = pbPeerState.GetRemoteIceCandidateType()
localICEEndpoint = pbPeerState.GetLocalIceCandidateEndpoint() localICEEndpoint = pbPeerState.GetLocalIceCandidateEndpoint()
remoteICEEndpoint = pbPeerState.GetRemoteIceCandidateEndpoint() remoteICEEndpoint = pbPeerState.GetRemoteIceCandidateEndpoint()
connType = "P2P" connType = pbPeerState.GetConnectionType()
if pbPeerState.Relayed {
connType = "Relayed"
}
relayServerAddress = pbPeerState.GetRelayAddress() relayServerAddress = pbPeerState.GetRelayAddress()
lastHandshake = pbPeerState.GetLastWireguardHandshake().AsTime().Local() lastHandshake = pbPeerState.GetLastWireguardHandshake().AsTime().Local()
transferReceived = pbPeerState.GetBytesRx() transferReceived = pbPeerState.GetBytesRx()
@@ -542,10 +540,11 @@ func parsePeers(peers PeersStateOutput, rosenpassEnabled, rosenpassPermissive bo
return peersString return peersString
} }
func skipDetailByFilters(peerState *proto.PeerState, peerStatus string, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}) bool { func skipDetailByFilters(peerState *proto.PeerState, peerStatus string, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}, connectionTypeFilter string) bool {
statusEval := false statusEval := false
ipEval := false ipEval := false
nameEval := true nameEval := true
connectionTypeEval := false
if statusFilter != "" { if statusFilter != "" {
if !strings.EqualFold(peerStatus, statusFilter) { if !strings.EqualFold(peerStatus, statusFilter) {
@@ -570,8 +569,11 @@ func skipDetailByFilters(peerState *proto.PeerState, peerStatus string, statusFi
} else { } else {
nameEval = false nameEval = false
} }
if connectionTypeFilter != "" && !strings.EqualFold(peerState.GetConnectionType(), connectionTypeFilter) {
connectionTypeEval = true
}
return statusEval || ipEval || nameEval return statusEval || ipEval || nameEval || connectionTypeEval
} }
func toIEC(b int64) string { func toIEC(b int64) string {

View File

@@ -234,7 +234,7 @@ var overview = OutputOverview{
} }
func TestConversionFromFullStatusToOutputOverview(t *testing.T) { func TestConversionFromFullStatusToOutputOverview(t *testing.T) {
convertedResult := ConvertToStatusOutputOverview(resp, false, "", nil, nil, nil) convertedResult := ConvertToStatusOutputOverview(resp, false, "", nil, nil, nil, "")
assert.Equal(t, overview, convertedResult) assert.Equal(t, overview, convertedResult)
} }

View File

@@ -433,7 +433,7 @@ func (s *serviceClient) collectDebugData(
var postUpStatusOutput string var postUpStatusOutput string
if postUpStatus != nil { if postUpStatus != nil {
overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil) overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "")
postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview) postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
} }
headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339)) headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
@@ -450,7 +450,7 @@ func (s *serviceClient) collectDebugData(
var preDownStatusOutput string var preDownStatusOutput string
if preDownStatus != nil { if preDownStatus != nil {
overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil) overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "")
preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview) preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
} }
headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s", headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s",
@@ -581,7 +581,7 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa
var statusOutput string var statusOutput string
if statusResp != nil { if statusResp != nil {
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil) overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "")
statusOutput = nbstatus.ParseToFullDetailSummary(overview) statusOutput = nbstatus.ParseToFullDetailSummary(overview)
} }

2
go.mod
View File

@@ -63,7 +63,7 @@ require (
github.com/miekg/dns v1.1.59 github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0 github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/integrations v0.0.0-20250612164546-6bd7e2338d65 github.com/netbirdio/management-integrations/integrations v0.0.0-20250718071730-f4d133556ff5
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250514131221-a464fd5f30cb github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250514131221-a464fd5f30cb
github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0 github.com/oschwald/maxminddb-golang v1.12.0

4
go.sum
View File

@@ -503,8 +503,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250612164546-6bd7e2338d65 h1:5OfYiLjpr4dbQYJI5ouZaylkVdi2KlErLFOwBeBo5Hw= github.com/netbirdio/management-integrations/integrations v0.0.0-20250718071730-f4d133556ff5 h1:Zfn8d83OVyELCdxgprcyXR3D8uqoxHtXE9PUxVXDx/w=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250612164546-6bd7e2338d65/go.mod h1:Gi9raplYzCCyh07Olw/DVfCJTFgpr1WCXJ/Q+8TSA9Q= github.com/netbirdio/management-integrations/integrations v0.0.0-20250718071730-f4d133556ff5/go.mod h1:Gi9raplYzCCyh07Olw/DVfCJTFgpr1WCXJ/Q+8TSA9Q=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250514131221-a464fd5f30cb h1:Cr6age+ePALqlSvtp7wc6lYY97XN7rkD1K4XEDmY+TU= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250514131221-a464fd5f30cb h1:Cr6age+ePALqlSvtp7wc6lYY97XN7rkD1K4XEDmY+TU=

View File

@@ -112,7 +112,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
} }
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager) secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil) mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, mgmt.MockIntegratedValidator{})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -292,7 +292,7 @@ var (
ephemeralManager.LoadInitialPeers(ctx) ephemeralManager.LoadInitialPeers(ctx)
gRPCAPIHandler := grpc.NewServer(gRPCOpts...) gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
srv, err := server.NewServer(ctx, config, accountManager, settingsManager, peersUpdateManager, secretsManager, appMetrics, ephemeralManager, authManager) srv, err := server.NewServer(ctx, config, accountManager, settingsManager, peersUpdateManager, secretsManager, appMetrics, ephemeralManager, authManager, integratedPeerValidator)
if err != nil { if err != nil {
return fmt.Errorf("failed creating gRPC API handler: %v", err) return fmt.Errorf("failed creating gRPC API handler: %v", err)
} }

View File

@@ -2887,7 +2887,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, error) {
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -219,7 +219,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
// return empty extra settings for expected calls to UpdateAccountPeers // return empty extra settings for expected calls to UpdateAccountPeers
settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), gomock.Any()).Return(&types.ExtraSettings{}, nil).AnyTimes() settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), gomock.Any()).Return(&types.ExtraSettings{}, nil).AnyTimes()
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
} }
func createDNSStore(t *testing.T) (store.Store, error) { func createDNSStore(t *testing.T) (store.Store, error) {

View File

@@ -19,6 +19,7 @@ import (
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config" integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
"github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/proto"
@@ -40,13 +41,14 @@ type GRPCServer struct {
settingsManager settings.Manager settingsManager settings.Manager
wgKey wgtypes.Key wgKey wgtypes.Key
proto.UnimplementedManagementServiceServer proto.UnimplementedManagementServiceServer
peersUpdateManager *PeersUpdateManager peersUpdateManager *PeersUpdateManager
config *types.Config config *types.Config
secretsManager SecretsManager secretsManager SecretsManager
appMetrics telemetry.AppMetrics appMetrics telemetry.AppMetrics
ephemeralManager *EphemeralManager ephemeralManager *EphemeralManager
peerLocks sync.Map peerLocks sync.Map
authManager auth.Manager authManager auth.Manager
integratedPeerValidator integrated_validator.IntegratedValidator
} }
// NewServer creates a new Management server // NewServer creates a new Management server
@@ -60,6 +62,7 @@ func NewServer(
appMetrics telemetry.AppMetrics, appMetrics telemetry.AppMetrics,
ephemeralManager *EphemeralManager, ephemeralManager *EphemeralManager,
authManager auth.Manager, authManager auth.Manager,
integratedPeerValidator integrated_validator.IntegratedValidator,
) (*GRPCServer, error) { ) (*GRPCServer, error) {
key, err := wgtypes.GeneratePrivateKey() key, err := wgtypes.GeneratePrivateKey()
if err != nil { if err != nil {
@@ -79,14 +82,15 @@ func NewServer(
return &GRPCServer{ return &GRPCServer{
wgKey: key, wgKey: key,
// peerKey -> event channel // peerKey -> event channel
peersUpdateManager: peersUpdateManager, peersUpdateManager: peersUpdateManager,
accountManager: accountManager, accountManager: accountManager,
settingsManager: settingsManager, settingsManager: settingsManager,
config: config, config: config,
secretsManager: secretsManager, secretsManager: secretsManager,
authManager: authManager, authManager: authManager,
appMetrics: appMetrics, appMetrics: appMetrics,
ephemeralManager: ephemeralManager, ephemeralManager: ephemeralManager,
integratedPeerValidator: integratedPeerValidator,
}, nil }, nil
} }
@@ -850,7 +854,7 @@ func (s *GRPCServer) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.En
return nil, status.Error(codes.NotFound, "no pkce authorization flow information available") return nil, status.Error(codes.NotFound, "no pkce authorization flow information available")
} }
flowInfoResp := &proto.PKCEAuthorizationFlow{ initInfoFlow := &proto.PKCEAuthorizationFlow{
ProviderConfig: &proto.ProviderConfig{ ProviderConfig: &proto.ProviderConfig{
Audience: s.config.PKCEAuthorizationFlow.ProviderConfig.Audience, Audience: s.config.PKCEAuthorizationFlow.ProviderConfig.Audience,
ClientID: s.config.PKCEAuthorizationFlow.ProviderConfig.ClientID, ClientID: s.config.PKCEAuthorizationFlow.ProviderConfig.ClientID,
@@ -865,6 +869,8 @@ func (s *GRPCServer) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.En
}, },
} }
flowInfoResp := s.integratedPeerValidator.ValidateFlowResponse(ctx, peerKey.String(), initInfoFlow)
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, flowInfoResp) encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, flowInfoResp)
if err != nil { if err != nil {
return nil, status.Error(codes.Internal, "failed to encrypt no pkce authorization flow information") return nil, status.Error(codes.Internal, "failed to encrypt no pkce authorization flow information")

View File

@@ -424,9 +424,10 @@ func toPolicyResponse(groups []*types.Group, policy *types.Policy) *api.Policy {
} }
if group, ok := groupsMap[gid]; ok { if group, ok := groupsMap[gid]; ok {
minimum := api.GroupMinimum{ minimum := api.GroupMinimum{
Id: group.ID, Id: group.ID,
Name: group.Name, Name: group.Name,
PeersCount: len(group.Peers), PeersCount: len(group.Peers),
ResourcesCount: len(group.Resources),
} }
destinations = append(destinations, minimum) destinations = append(destinations, minimum)
cache[gid] = minimum cache[gid] = minimum

View File

@@ -1,4 +1,5 @@
package testing_tools package testing_tools
import ( import (
"bytes" "bytes"
"context" "context"
@@ -132,7 +133,7 @@ func BuildApiBlackBoxWithDBState(t TB, sqlFile string, expectedPeerUpdate *serve
} }
geoMock := &geolocation.Mock{} geoMock := &geolocation.Mock{}
validatorMock := server.MocIntegratedValidator{} validatorMock := server.MockIntegratedValidator{}
proxyController := integrations.NewController(store) proxyController := integrations.NewController(store)
userManager := users.NewManager(store) userManager := users.NewManager(store)
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)

View File

@@ -6,6 +6,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
@@ -101,22 +102,23 @@ func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountI
return am.integratedPeerValidator.GetValidatedPeers(accountID, groups, peers, settings.Extra) return am.integratedPeerValidator.GetValidatedPeers(accountID, groups, peers, settings.Extra)
} }
type MocIntegratedValidator struct { type MockIntegratedValidator struct {
integrated_validator.IntegratedValidator
ValidatePeerFunc func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error) ValidatePeerFunc func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error)
} }
func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error { func (a MockIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error {
return nil return nil
} }
func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error) { func (a MockIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error) {
if a.ValidatePeerFunc != nil { if a.ValidatePeerFunc != nil {
return a.ValidatePeerFunc(context.Background(), update, peer, userID, accountID, dnsDomain, peersGroup, extraSettings) return a.ValidatePeerFunc(context.Background(), update, peer, userID, accountID, dnsDomain, peersGroup, extraSettings)
} }
return update, false, nil return update, false, nil
} }
func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *types.ExtraSettings) (map[string]struct{}, error) { func (a MockIntegratedValidator) GetValidatedPeers(accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *types.ExtraSettings) (map[string]struct{}, error) {
validatedPeers := make(map[string]struct{}) validatedPeers := make(map[string]struct{})
for _, peer := range peers { for _, peer := range peers {
validatedPeers[peer.ID] = struct{}{} validatedPeers[peer.ID] = struct{}{}
@@ -124,22 +126,22 @@ func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups []*ty
return validatedPeers, nil return validatedPeers, nil
} }
func (MocIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) *nbpeer.Peer { func (MockIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) *nbpeer.Peer {
return peer return peer
} }
func (MocIntegratedValidator) IsNotValidPeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error) { func (MockIntegratedValidator) IsNotValidPeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error) {
return false, false, nil return false, false, nil
} }
func (MocIntegratedValidator) PeerDeleted(_ context.Context, _, _ string) error { func (MockIntegratedValidator) PeerDeleted(_ context.Context, _, _ string) error {
return nil return nil
} }
func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string)) { func (MockIntegratedValidator) SetPeerInvalidationListener(func(accountID string)) {
// just a dummy // just a dummy
} }
func (MocIntegratedValidator) Stop(_ context.Context) { func (MockIntegratedValidator) Stop(_ context.Context) {
// just a dummy // just a dummy
} }

View File

@@ -3,6 +3,7 @@ package integrated_validator
import ( import (
"context" "context"
"github.com/netbirdio/netbird/management/proto"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
) )
@@ -17,4 +18,5 @@ type IntegratedValidator interface {
PeerDeleted(ctx context.Context, accountID, peerID string) error PeerDeleted(ctx context.Context, accountID, peerID string) error
SetPeerInvalidationListener(fn func(accountID string)) SetPeerInvalidationListener(fn func(accountID string))
Stop(ctx context.Context) Stop(ctx context.Context)
ValidateFlowResponse(ctx context.Context, peerKey string, flowResponse *proto.PKCEAuthorizationFlow) *proto.PKCEAuthorizationFlow
} }

View File

@@ -448,7 +448,7 @@ func startManagementForTest(t *testing.T, testFile string, config *types.Config)
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)
accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted", accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted",
eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil { if err != nil {
cleanup() cleanup()
@@ -458,7 +458,7 @@ func startManagementForTest(t *testing.T, testFile string, config *types.Config)
secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager) secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
ephemeralMgr := NewEphemeralManager(store, accountManager) ephemeralMgr := NewEphemeralManager(store, accountManager)
mgmtServer, err := NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, ephemeralMgr, nil) mgmtServer, err := NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, ephemeralMgr, nil, MockIntegratedValidator{})
if err != nil { if err != nil {
return nil, nil, "", cleanup, err return nil, nil, "", cleanup, err
} }

View File

@@ -206,7 +206,7 @@ func startServer(
eventStore, eventStore,
nil, nil,
false, false,
server.MocIntegratedValidator{}, server.MockIntegratedValidator{},
metrics, metrics,
port_forwarding.NewControllerMock(), port_forwarding.NewControllerMock(),
settingsMockManager, settingsMockManager,
@@ -227,6 +227,7 @@ func startServer(
nil, nil,
nil, nil,
nil, nil,
server.MockIntegratedValidator{},
) )
if err != nil { if err != nil {
t.Fatalf("failed creating management server: %v", err) t.Fatalf("failed creating management server: %v", err)

View File

@@ -283,7 +283,7 @@ func MigrateSetupKeyToHashedSetupKey[T any](ctx context.Context, db *gorm.DB) er
} }
} }
if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s", "peers", "setup_key")).Error; err != nil { if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s DROP COLUMN IF EXISTS %s", "peers", "setup_key")).Error; err != nil {
log.WithContext(ctx).Errorf("Failed to drop column %s: %v", "setup_key", err) log.WithContext(ctx).Errorf("Failed to drop column %s: %v", "setup_key", err)
} }
@@ -377,6 +377,11 @@ func DropIndex[T any](ctx context.Context, db *gorm.DB, indexName string) error
func CreateIndexIfNotExists[T any](ctx context.Context, db *gorm.DB, indexName string, columns ...string) error { func CreateIndexIfNotExists[T any](ctx context.Context, db *gorm.DB, indexName string, columns ...string) error {
var model T var model T
if !db.Migrator().HasTable(&model) {
log.WithContext(ctx).Debugf("table for %T does not exist, no migration needed", model)
return nil
}
stmt := &gorm.Statement{DB: db} stmt := &gorm.Statement{DB: db}
if err := stmt.Parse(&model); err != nil { if err := stmt.Parse(&model); err != nil {
return fmt.Errorf("failed to parse model schema: %w", err) return fmt.Errorf("failed to parse model schema: %w", err)
@@ -384,6 +389,11 @@ func CreateIndexIfNotExists[T any](ctx context.Context, db *gorm.DB, indexName s
tableName := stmt.Schema.Table tableName := stmt.Schema.Table
dialect := db.Dialector.Name() dialect := db.Dialector.Name()
if db.Migrator().HasIndex(&model, indexName) {
log.WithContext(ctx).Infof("index %s already exists on table %s", indexName, tableName)
return nil
}
var columnClause string var columnClause string
if dialect == "mysql" { if dialect == "mysql" {
var withLength []string var withLength []string

View File

@@ -4,16 +4,21 @@ import (
"context" "context"
"encoding/gob" "encoding/gob"
"net" "net"
"os"
"strings" "strings"
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite" "gorm.io/driver/sqlite"
"gorm.io/gorm" "gorm.io/gorm"
"github.com/netbirdio/netbird/management/server/migration" "github.com/netbirdio/netbird/management/server/migration"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/testutil"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@@ -21,7 +26,41 @@ import (
func setupDatabase(t *testing.T) *gorm.DB { func setupDatabase(t *testing.T) *gorm.DB {
t.Helper() t.Helper()
db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{}) var db *gorm.DB
var err error
var dsn string
var cleanup func()
switch os.Getenv("NETBIRD_STORE_ENGINE") {
case "mysql":
cleanup, dsn, err = testutil.CreateMysqlTestContainer()
if err != nil {
t.Fatalf("Failed to create MySQL test container: %v", err)
}
if dsn == "" {
t.Fatal("MySQL connection string is empty, ensure the test container is running")
}
db, err = gorm.Open(mysql.Open(dsn+"?charset=utf8&parseTime=True&loc=Local"), &gorm.Config{})
case "postgres":
cleanup, dsn, err = testutil.CreatePostgresTestContainer()
if err != nil {
t.Fatalf("Failed to create PostgreSQL test container: %v", err)
}
if dsn == "" {
t.Fatalf("PostgreSQL connection string is empty, ensure the test container is running")
}
db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
case "sqlite":
db, err = gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{})
default:
db, err = gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{})
}
if cleanup != nil {
t.Cleanup(cleanup)
}
require.NoError(t, err, "Failed to open database") require.NoError(t, err, "Failed to open database")
return db return db
@@ -34,6 +73,7 @@ func TestMigrateFieldFromGobToJSON_EmptyDB(t *testing.T) {
} }
func TestMigrateFieldFromGobToJSON_WithGobData(t *testing.T) { func TestMigrateFieldFromGobToJSON_WithGobData(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", "sqlite")
db := setupDatabase(t) db := setupDatabase(t)
err := db.AutoMigrate(&types.Account{}, &route.Route{}) err := db.AutoMigrate(&types.Account{}, &route.Route{})
@@ -97,6 +137,7 @@ func TestMigrateNetIPFieldFromBlobToJSON_EmptyDB(t *testing.T) {
} }
func TestMigrateNetIPFieldFromBlobToJSON_WithBlobData(t *testing.T) { func TestMigrateNetIPFieldFromBlobToJSON_WithBlobData(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", "sqlite")
db := setupDatabase(t) db := setupDatabase(t)
err := db.AutoMigrate(&types.Account{}, &nbpeer.Peer{}) err := db.AutoMigrate(&types.Account{}, &nbpeer.Peer{})
@@ -117,12 +158,18 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithBlobData(t *testing.T) {
Peers []peer `gorm:"foreignKey:AccountID;references:id"` Peers []peer `gorm:"foreignKey:AccountID;references:id"`
} }
err = db.Save(&account{ a := &account{
Account: types.Account{Id: "123"}, Account: types.Account{Id: "123"},
Peers: []peer{ }
{Location: location{ConnectionIP: net.IP{10, 0, 0, 1}}},
}}, err = db.Save(a).Error
).Error require.NoError(t, err, "Failed to insert account")
a.Peers = []peer{
{Location: location{ConnectionIP: net.IP{10, 0, 0, 1}}},
}
err = db.Save(a).Error
require.NoError(t, err, "Failed to insert blob data") require.NoError(t, err, "Failed to insert blob data")
var blobValue string var blobValue string
@@ -143,12 +190,18 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithJSONData(t *testing.T) {
err := db.AutoMigrate(&types.Account{}, &nbpeer.Peer{}) err := db.AutoMigrate(&types.Account{}, &nbpeer.Peer{})
require.NoError(t, err, "Failed to auto-migrate tables") require.NoError(t, err, "Failed to auto-migrate tables")
err = db.Save(&types.Account{ account := &types.Account{
Id: "1234", Id: "1234",
PeersG: []nbpeer.Peer{ }
{Location: nbpeer.Location{ConnectionIP: net.IP{10, 0, 0, 1}}},
}}, err = db.Save(account).Error
).Error require.NoError(t, err, "Failed to insert account")
account.PeersG = []nbpeer.Peer{
{AccountID: "1234", Location: nbpeer.Location{ConnectionIP: net.IP{10, 0, 0, 1}}},
}
err = db.Save(account).Error
require.NoError(t, err, "Failed to insert JSON data") require.NoError(t, err, "Failed to insert JSON data")
err = migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](context.Background(), db, "location_connection_ip", "") err = migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](context.Background(), db, "location_connection_ip", "")
@@ -162,12 +215,13 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithJSONData(t *testing.T) {
func TestMigrateSetupKeyToHashedSetupKey_ForPlainKey(t *testing.T) { func TestMigrateSetupKeyToHashedSetupKey_ForPlainKey(t *testing.T) {
db := setupDatabase(t) db := setupDatabase(t)
err := db.AutoMigrate(&types.SetupKey{}) err := db.AutoMigrate(&types.SetupKey{}, &nbpeer.Peer{})
require.NoError(t, err, "Failed to auto-migrate tables") require.NoError(t, err, "Failed to auto-migrate tables")
err = db.Save(&types.SetupKey{ err = db.Save(&types.SetupKey{
Id: "1", Id: "1",
Key: "EEFDAB47-C1A5-4472-8C05-71DE9A1E8382", Key: "EEFDAB47-C1A5-4472-8C05-71DE9A1E8382",
UpdatedAt: time.Now(),
}).Error }).Error
require.NoError(t, err, "Failed to insert setup key") require.NoError(t, err, "Failed to insert setup key")
@@ -192,6 +246,7 @@ func TestMigrateSetupKeyToHashedSetupKey_ForAlreadyMigratedKey_Case1(t *testing.
Id: "1", Id: "1",
Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=",
KeySecret: "EEFDA****", KeySecret: "EEFDA****",
UpdatedAt: time.Now(),
}).Error }).Error
require.NoError(t, err, "Failed to insert setup key") require.NoError(t, err, "Failed to insert setup key")
@@ -213,8 +268,9 @@ func TestMigrateSetupKeyToHashedSetupKey_ForAlreadyMigratedKey_Case2(t *testing.
require.NoError(t, err, "Failed to auto-migrate tables") require.NoError(t, err, "Failed to auto-migrate tables")
err = db.Save(&types.SetupKey{ err = db.Save(&types.SetupKey{
Id: "1", Id: "1",
Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=",
UpdatedAt: time.Now(),
}).Error }).Error
require.NoError(t, err, "Failed to insert setup key") require.NoError(t, err, "Failed to insert setup key")
@@ -235,8 +291,9 @@ func TestDropIndex(t *testing.T) {
require.NoError(t, err, "Failed to auto-migrate tables") require.NoError(t, err, "Failed to auto-migrate tables")
err = db.Save(&types.SetupKey{ err = db.Save(&types.SetupKey{
Id: "1", Id: "1",
Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=",
UpdatedAt: time.Now(),
}).Error }).Error
require.NoError(t, err, "Failed to insert setup key") require.NoError(t, err, "Failed to insert setup key")
@@ -249,3 +306,37 @@ func TestDropIndex(t *testing.T) {
exist = db.Migrator().HasIndex(&types.SetupKey{}, "idx_setup_keys_account_id") exist = db.Migrator().HasIndex(&types.SetupKey{}, "idx_setup_keys_account_id")
assert.False(t, exist, "Should not have the index") assert.False(t, exist, "Should not have the index")
} }
func TestCreateIndex(t *testing.T) {
db := setupDatabase(t)
err := db.AutoMigrate(&nbpeer.Peer{})
assert.NoError(t, err, "Failed to auto-migrate tables")
indexName := "idx_account_ip"
err = migration.CreateIndexIfNotExists[nbpeer.Peer](context.Background(), db, indexName, "account_id", "ip")
assert.NoError(t, err, "Migration should not fail to create index")
exist := db.Migrator().HasIndex(&nbpeer.Peer{}, indexName)
assert.True(t, exist, "Should have the index")
}
func TestCreateIndexIfExists(t *testing.T) {
db := setupDatabase(t)
err := db.AutoMigrate(&nbpeer.Peer{})
assert.NoError(t, err, "Failed to auto-migrate tables")
indexName := "idx_account_ip"
err = migration.CreateIndexIfNotExists[nbpeer.Peer](context.Background(), db, indexName, "account_id", "ip")
assert.NoError(t, err, "Migration should not fail to create index")
exist := db.Migrator().HasIndex(&nbpeer.Peer{}, indexName)
assert.True(t, exist, "Should have the index")
err = migration.CreateIndexIfNotExists[nbpeer.Peer](context.Background(), db, indexName, "account_id", "ip")
assert.NoError(t, err, "Create index should not fail if index exists")
exist = db.Migrator().HasIndex(&nbpeer.Peer{}, indexName)
assert.True(t, exist, "Should have the index")
}

View File

@@ -120,14 +120,20 @@ type MockAccountManager struct {
GetAccountOnboardingFunc func(ctx context.Context, accountID, userID string) (*types.AccountOnboarding, error) GetAccountOnboardingFunc func(ctx context.Context, accountID, userID string) (*types.AccountOnboarding, error)
UpdateAccountOnboardingFunc func(ctx context.Context, accountID, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) UpdateAccountOnboardingFunc func(ctx context.Context, accountID, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error)
GetOrCreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error) GetOrCreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error)
UpdateAccountPeersFunc func(ctx context.Context, accountID string)
BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string)
} }
func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) { func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) {
// do nothing if am.UpdateAccountPeersFunc != nil {
am.UpdateAccountPeersFunc(ctx, accountID)
}
} }
func (am *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) { func (am *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) {
// do nothing if am.BufferUpdateAccountPeersFunc != nil {
am.BufferUpdateAccountPeersFunc(ctx, accountID)
}
} }
func (am *MockAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error { func (am *MockAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error {

View File

@@ -785,7 +785,7 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
AnyTimes() AnyTimes()
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
} }
func createNSStore(t *testing.T) (store.Store, error) { func createNSStore(t *testing.T) (store.Store, error) {

View File

@@ -9,6 +9,7 @@ import (
"slices" "slices"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/rs/xid" "github.com/rs/xid"
@@ -236,11 +237,23 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
if peer.Name != update.Name { if peer.Name != update.Name {
var newLabel string var newLabel string
newLabel, err = getPeerIPDNSLabel(ctx, transaction, peer.IP, accountID, update.Name)
newLabel, err = nbdns.GetParsedDomainLabel(update.Name)
if err != nil { if err != nil {
return fmt.Errorf("failed to get free DNS label: %w", err) newLabel = ""
} else {
_, err := transaction.GetPeerIdByLabel(ctx, store.LockingStrengthNone, accountID, update.Name)
if err == nil {
newLabel = ""
}
} }
if newLabel == "" {
newLabel, err = getPeerIPDNSLabel(peer.IP, update.Name)
if err != nil {
return fmt.Errorf("failed to get free DNS label: %w", err)
}
}
peer.Name = update.Name peer.Name = update.Name
peer.DNSLabel = newLabel peer.DNSLabel = newLabel
peerLabelChanged = true peerLabelChanged = true
@@ -472,6 +485,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
var groupsToAdd []string var groupsToAdd []string
var allowExtraDNSLabels bool var allowExtraDNSLabels bool
var accountID string var accountID string
var isEphemeral bool
if addedByUser { if addedByUser {
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID) user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil { if err != nil {
@@ -501,7 +515,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
setupKeyName = sk.Name setupKeyName = sk.Name
allowExtraDNSLabels = sk.AllowExtraDNSLabels allowExtraDNSLabels = sk.AllowExtraDNSLabels
accountID = sk.AccountID accountID = sk.AccountID
isEphemeral = sk.Ephemeral
if !sk.AllowExtraDNSLabels && len(peer.ExtraDNSLabels) > 0 { if !sk.AllowExtraDNSLabels && len(peer.ExtraDNSLabels) > 0 {
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key doesn't allow extra DNS labels") return nil, nil, nil, status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key doesn't allow extra DNS labels")
} }
@@ -573,11 +587,17 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
} }
var freeLabel string var freeLabel string
freeLabel, err = getPeerIPDNSLabel(ctx, am.Store, freeIP, accountID, peer.Meta.Hostname) if isEphemeral || attempt > 1 {
if err != nil { freeLabel, err = getPeerIPDNSLabel(freeIP, peer.Meta.Hostname)
return nil, nil, nil, fmt.Errorf("failed to get free DNS label: %w", err) if err != nil {
return nil, nil, nil, fmt.Errorf("failed to get free DNS label: %w", err)
}
} else {
freeLabel, err = nbdns.GetParsedDomainLabel(peer.Meta.Hostname)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to get free DNS label: %w", err)
}
} }
newPeer.DNSLabel = freeLabel newPeer.DNSLabel = freeLabel
newPeer.IP = freeIP newPeer.IP = freeIP
@@ -647,7 +667,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
if isUniqueConstraintError(err) { if isUniqueConstraintError(err) {
unlock() unlock()
unlock = nil unlock = nil
log.WithContext(ctx).Debugf("Failed to add peer in attempt %d, retrying: %v", attempt, err) log.WithContext(ctx).WithFields(log.Fields{"dns_label": freeLabel, "ip": freeIP}).Tracef("Failed to add peer in attempt %d, retrying: %v", attempt, err)
continue continue
} }
@@ -681,7 +701,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer) return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer)
} }
func getPeerIPDNSLabel(ctx context.Context, tx store.Store, ip net.IP, accountID, peerHostName string) (string, error) { func getPeerIPDNSLabel(ip net.IP, peerHostName string) (string, error) {
ip = ip.To4() ip = ip.To4()
dnsName, err := nbdns.GetParsedDomainLabel(peerHostName) dnsName, err := nbdns.GetParsedDomainLabel(peerHostName)
@@ -689,12 +709,6 @@ func getPeerIPDNSLabel(ctx context.Context, tx store.Store, ip net.IP, accountID
return "", fmt.Errorf("failed to parse peer host name %s: %w", peerHostName, err) return "", fmt.Errorf("failed to parse peer host name %s: %w", peerHostName, err)
} }
_, err = tx.GetPeerIdByLabel(ctx, store.LockingStrengthNone, accountID, dnsName)
if err != nil {
//nolint:nilerr
return dnsName, nil
}
return fmt.Sprintf("%s-%d-%d", dnsName, ip[2], ip[3]), nil return fmt.Sprintf("%s-%d-%d", dnsName, ip[2], ip[3]), nil
} }
@@ -1267,18 +1281,39 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
} }
} }
func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) { type bufferUpdate struct {
mu, _ := am.accountUpdateLocks.LoadOrStore(accountID, &sync.Mutex{}) mu sync.Mutex
lock := mu.(*sync.Mutex) next *time.Timer
update atomic.Bool
}
if !lock.TryLock() { func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) {
bufUpd, _ := am.accountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{})
b := bufUpd.(*bufferUpdate)
if !b.mu.TryLock() {
b.update.Store(true)
return return
} }
if b.next != nil {
b.next.Stop()
}
go func() { go func() {
time.Sleep(time.Duration(am.updateAccountPeersBufferInterval.Load())) defer b.mu.Unlock()
lock.Unlock()
am.UpdateAccountPeers(ctx, accountID) am.UpdateAccountPeers(ctx, accountID)
if !b.update.Load() {
return
}
b.update.Store(false)
if b.next == nil {
b.next = time.AfterFunc(time.Duration(am.updateAccountPeersBufferInterval.Load()), func() {
am.UpdateAccountPeers(ctx, accountID)
})
return
}
b.next.Reset(time.Duration(am.updateAccountPeersBufferInterval.Load()))
}() }()
} }

View File

@@ -13,6 +13,7 @@ import (
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"testing" "testing"
"time" "time"
@@ -25,6 +26,7 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
@@ -1271,7 +1273,7 @@ func Test_RegisterPeerByUser(t *testing.T) {
settingsMockManager := settings.NewMockManager(ctrl) settingsMockManager := settings.NewMockManager(ctrl)
permissionsManager := permissions.NewManager(s) permissionsManager := permissions.NewManager(s)
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err) assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
@@ -1351,7 +1353,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
AnyTimes() AnyTimes()
permissionsManager := permissions.NewManager(s) permissionsManager := permissions.NewManager(s)
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err) assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
@@ -1494,7 +1496,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
permissionsManager := permissions.NewManager(s) permissionsManager := permissions.NewManager(s)
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err) assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
@@ -1568,7 +1570,7 @@ func Test_LoginPeer(t *testing.T) {
AnyTimes() AnyTimes()
permissionsManager := permissions.NewManager(s) permissionsManager := permissions.NewManager(s)
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err) assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
@@ -1846,7 +1848,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
return update, true, nil return update, true, nil
} }
manager.integratedPeerValidator = MocIntegratedValidator{ValidatePeerFunc: requireUpdateFunc} manager.integratedPeerValidator = MockIntegratedValidator{ValidatePeerFunc: requireUpdateFunc}
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
peerShouldReceiveUpdate(t, updMsg) peerShouldReceiveUpdate(t, updMsg)
@@ -1868,7 +1870,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
return update, false, nil return update, false, nil
} }
manager.integratedPeerValidator = MocIntegratedValidator{ValidatePeerFunc: requireNoUpdateFunc} manager.integratedPeerValidator = MockIntegratedValidator{ValidatePeerFunc: requireNoUpdateFunc}
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
peerShouldNotReceiveUpdate(t, updMsg) peerShouldNotReceiveUpdate(t, updMsg)
@@ -2251,3 +2253,131 @@ func Test_AddPeer(t *testing.T) {
assert.Equal(t, totalPeers, maps.Values(account.SetupKeys)[0].UsedTimes) assert.Equal(t, totalPeers, maps.Values(account.SetupKeys)[0].UsedTimes)
assert.Equal(t, uint64(totalPeers), account.Network.Serial) assert.Equal(t, uint64(totalPeers), account.Network.Serial)
} }
func TestBufferUpdateAccountPeers(t *testing.T) {
const (
peersCount = 1000
updateAccountInterval = 50 * time.Millisecond
)
var (
deletedPeers, updatePeersDeleted, updatePeersRuns atomic.Int32
uapLastRun, dpLastRun atomic.Int64
totalNewRuns, totalOldRuns int
)
uap := func(ctx context.Context, accountID string) {
updatePeersDeleted.Store(deletedPeers.Load())
updatePeersRuns.Add(1)
uapLastRun.Store(time.Now().UnixMilli())
time.Sleep(100 * time.Millisecond)
}
t.Run("new approach", func(t *testing.T) {
updatePeersRuns.Store(0)
updatePeersDeleted.Store(0)
deletedPeers.Store(0)
var mustore sync.Map
bufupd := func(ctx context.Context, accountID string) {
mu, _ := mustore.LoadOrStore(accountID, &bufferUpdate{})
b := mu.(*bufferUpdate)
if !b.mu.TryLock() {
b.update.Store(true)
return
}
if b.next != nil {
b.next.Stop()
}
go func() {
defer b.mu.Unlock()
uap(ctx, accountID)
if !b.update.Load() {
return
}
b.update.Store(false)
b.next = time.AfterFunc(updateAccountInterval, func() {
uap(ctx, accountID)
})
}()
}
dp := func(ctx context.Context, accountID, peerID, userID string) error {
deletedPeers.Add(1)
dpLastRun.Store(time.Now().UnixMilli())
time.Sleep(10 * time.Millisecond)
bufupd(ctx, accountID)
return nil
}
am := mock_server.MockAccountManager{
UpdateAccountPeersFunc: uap,
BufferUpdateAccountPeersFunc: bufupd,
DeletePeerFunc: dp,
}
empty := ""
for range peersCount {
//nolint
am.DeletePeer(context.Background(), empty, empty, empty)
}
time.Sleep(100 * time.Millisecond)
assert.Equal(t, peersCount, int(deletedPeers.Load()), "Expected all peers to be deleted")
assert.Equal(t, peersCount, int(updatePeersDeleted.Load()), "Expected all peers to be updated in the buffer")
assert.GreaterOrEqual(t, uapLastRun.Load(), dpLastRun.Load(), "Expected update account peers to run after delete peer")
totalNewRuns = int(updatePeersRuns.Load())
})
t.Run("old approach", func(t *testing.T) {
updatePeersRuns.Store(0)
updatePeersDeleted.Store(0)
deletedPeers.Store(0)
var mustore sync.Map
bufupd := func(ctx context.Context, accountID string) {
mu, _ := mustore.LoadOrStore(accountID, &sync.Mutex{})
b := mu.(*sync.Mutex)
if !b.TryLock() {
return
}
go func() {
time.Sleep(updateAccountInterval)
b.Unlock()
uap(ctx, accountID)
}()
}
dp := func(ctx context.Context, accountID, peerID, userID string) error {
deletedPeers.Add(1)
dpLastRun.Store(time.Now().UnixMilli())
time.Sleep(10 * time.Millisecond)
bufupd(ctx, accountID)
return nil
}
am := mock_server.MockAccountManager{
UpdateAccountPeersFunc: uap,
BufferUpdateAccountPeersFunc: bufupd,
DeletePeerFunc: dp,
}
empty := ""
for range peersCount {
//nolint
am.DeletePeer(context.Background(), empty, empty, empty)
}
time.Sleep(100 * time.Millisecond)
assert.Equal(t, peersCount, int(deletedPeers.Load()), "Expected all peers to be deleted")
assert.Equal(t, peersCount, int(updatePeersDeleted.Load()), "Expected all peers to be updated in the buffer")
assert.GreaterOrEqual(t, uapLastRun.Load(), dpLastRun.Load(), "Expected update account peers to run after delete peer")
totalOldRuns = int(updatePeersRuns.Load())
})
assert.Less(t, totalNewRuns, totalOldRuns, "Expected new approach to run less than old approach. New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns)
t.Logf("New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns)
}

View File

@@ -1284,7 +1284,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) {
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
} }
func createRouterStore(t *testing.T) (store.Store, error) { func createRouterStore(t *testing.T) (store.Store, error) {

View File

@@ -852,7 +852,7 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
eventStore: &activity.InMemoryEventStore{}, eventStore: &activity.InMemoryEventStore{},
integratedPeerValidator: MocIntegratedValidator{}, integratedPeerValidator: MockIntegratedValidator{},
permissionsManager: permissionsManager, permissionsManager: permissionsManager,
} }

View File

@@ -9,6 +9,8 @@ var (
baseWallNano int64 baseWallNano int64
) )
type Time int64
func init() { func init() {
baseWallTime = time.Now() baseWallTime = time.Now()
baseWallNano = baseWallTime.UnixNano() baseWallNano = baseWallTime.UnixNano()
@@ -23,7 +25,11 @@ func init() {
// and using time.Since() for elapsed calculation, this avoids repeated // and using time.Since() for elapsed calculation, this avoids repeated
// time.Now() calls and leverages Go's internal monotonic clock for // time.Now() calls and leverages Go's internal monotonic clock for
// efficient duration measurement. // efficient duration measurement.
func Now() int64 { func Now() Time {
elapsed := time.Since(baseWallTime) elapsed := time.Since(baseWallTime)
return baseWallNano + int64(elapsed) return Time(baseWallNano + int64(elapsed))
}
func Since(t Time) time.Duration {
return time.Duration(Now() - t)
} }

View File

@@ -7,13 +7,6 @@ import (
authv2 "github.com/netbirdio/netbird/relay/auth/hmac/v2" authv2 "github.com/netbirdio/netbird/relay/auth/hmac/v2"
) )
// Validator is an interface that defines the Validate method.
type Validator interface {
Validate(any) error
// Deprecated: Use Validate instead.
ValidateHelloMsgType(any) error
}
type TimedHMACValidator struct { type TimedHMACValidator struct {
authenticatorV2 *authv2.Validator authenticatorV2 *authv2.Validator
authenticator *auth.TimedHMACValidator authenticator *auth.TimedHMACValidator

View File

@@ -124,15 +124,14 @@ func (cc *connContainer) close() {
// While the Connect is in progress, the OpenConn function will block until the connection is established with relay server. // While the Connect is in progress, the OpenConn function will block until the connection is established with relay server.
type Client struct { type Client struct {
log *log.Entry log *log.Entry
parentCtx context.Context
connectionURL string connectionURL string
authTokenStore *auth.TokenStore authTokenStore *auth.TokenStore
hashedID []byte hashedID messages.PeerID
bufPool *sync.Pool bufPool *sync.Pool
relayConn net.Conn relayConn net.Conn
conns map[string]*connContainer conns map[messages.PeerID]*connContainer
serviceIsRunning bool serviceIsRunning bool
mu sync.Mutex // protect serviceIsRunning and conns mu sync.Mutex // protect serviceIsRunning and conns
readLoopMutex sync.Mutex readLoopMutex sync.Mutex
@@ -142,14 +141,17 @@ type Client struct {
onDisconnectListener func(string) onDisconnectListener func(string)
listenerMutex sync.Mutex listenerMutex sync.Mutex
stateSubscription *PeersStateSubscription
} }
// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect // NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client { func NewClient(serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client {
hashedID, hashedStringId := messages.HashID(peerID) hashedID := messages.HashID(peerID)
relayLog := log.WithFields(log.Fields{"relay": serverURL})
c := &Client{ c := &Client{
log: log.WithFields(log.Fields{"relay": serverURL}), log: relayLog,
parentCtx: ctx,
connectionURL: serverURL, connectionURL: serverURL,
authTokenStore: authTokenStore, authTokenStore: authTokenStore,
hashedID: hashedID, hashedID: hashedID,
@@ -159,14 +161,15 @@ func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.Token
return &buf return &buf
}, },
}, },
conns: make(map[string]*connContainer), conns: make(map[messages.PeerID]*connContainer),
} }
c.log.Infof("create new relay connection: local peerID: %s, local peer hashedID: %s", peerID, hashedStringId)
c.log.Infof("create new relay connection: local peerID: %s, local peer hashedID: %s", peerID, hashedID)
return c return c
} }
// Connect establishes a connection to the relay server. It blocks until the connection is established or an error occurs. // Connect establishes a connection to the relay server. It blocks until the connection is established or an error occurs.
func (c *Client) Connect() error { func (c *Client) Connect(ctx context.Context) error {
c.log.Infof("connecting to relay server") c.log.Infof("connecting to relay server")
c.readLoopMutex.Lock() c.readLoopMutex.Lock()
defer c.readLoopMutex.Unlock() defer c.readLoopMutex.Unlock()
@@ -178,17 +181,27 @@ func (c *Client) Connect() error {
return nil return nil
} }
if err := c.connect(); err != nil { instanceURL, err := c.connect(ctx)
if err != nil {
return err return err
} }
c.muInstanceURL.Lock()
c.instanceURL = instanceURL
c.muInstanceURL.Unlock()
c.log = c.log.WithField("relay", c.instanceURL.String()) c.stateSubscription = NewPeersStateSubscription(c.log, c.relayConn, c.closeConnsByPeerID)
c.log = c.log.WithField("relay", instanceURL.String())
c.log.Infof("relay connection established") c.log.Infof("relay connection established")
c.serviceIsRunning = true c.serviceIsRunning = true
internallyStoppedFlag := newInternalStopFlag()
hc := healthcheck.NewReceiver(c.log)
go c.listenForStopEvents(ctx, hc, c.relayConn, internallyStoppedFlag)
c.wgReadLoop.Add(1) c.wgReadLoop.Add(1)
go c.readLoop(c.relayConn) go c.readLoop(hc, c.relayConn, internallyStoppedFlag)
return nil return nil
} }
@@ -196,26 +209,50 @@ func (c *Client) Connect() error {
// OpenConn create a new net.Conn for the destination peer ID. In case if the connection is in progress // OpenConn create a new net.Conn for the destination peer ID. In case if the connection is in progress
// to the relay server, the function will block until the connection is established or timed out. Otherwise, // to the relay server, the function will block until the connection is established or timed out. Otherwise,
// it will return immediately. // it will return immediately.
// It block until the server confirm the peer is online.
// todo: what should happen if call with the same peerID with multiple times? // todo: what should happen if call with the same peerID with multiple times?
func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) { func (c *Client) OpenConn(ctx context.Context, dstPeerID string) (net.Conn, error) {
c.mu.Lock() peerID := messages.HashID(dstPeerID)
defer c.mu.Unlock()
c.mu.Lock()
if !c.serviceIsRunning { if !c.serviceIsRunning {
c.mu.Unlock()
return nil, fmt.Errorf("relay connection is not established")
}
_, ok := c.conns[peerID]
if ok {
c.mu.Unlock()
return nil, ErrConnAlreadyExists
}
c.mu.Unlock()
if err := c.stateSubscription.WaitToBeOnlineAndSubscribe(ctx, peerID); err != nil {
c.log.Errorf("peer not available: %s, %s", peerID, err)
return nil, err
}
c.log.Infof("remote peer is available, prepare the relayed connection: %s", peerID)
msgChannel := make(chan Msg, 100)
c.mu.Lock()
if !c.serviceIsRunning {
c.mu.Unlock()
return nil, fmt.Errorf("relay connection is not established") return nil, fmt.Errorf("relay connection is not established")
} }
hashedID, hashedStringID := messages.HashID(dstPeerID) c.muInstanceURL.Lock()
_, ok := c.conns[hashedStringID] instanceURL := c.instanceURL
c.muInstanceURL.Unlock()
conn := NewConn(c, peerID, msgChannel, instanceURL)
_, ok = c.conns[peerID]
if ok { if ok {
c.mu.Unlock()
_ = conn.Close()
return nil, ErrConnAlreadyExists return nil, ErrConnAlreadyExists
} }
c.conns[peerID] = newConnContainer(c.log, conn, msgChannel)
c.log.Infof("open connection to peer: %s", hashedStringID) c.mu.Unlock()
msgChannel := make(chan Msg, 100)
conn := NewConn(c, hashedID, hashedStringID, msgChannel, c.instanceURL)
c.conns[hashedStringID] = newConnContainer(c.log, conn, msgChannel)
return conn, nil return conn, nil
} }
@@ -254,76 +291,70 @@ func (c *Client) Close() error {
return c.close(true) return c.close(true)
} }
func (c *Client) connect() error { func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
rd := dialer.NewRaceDial(c.log, c.connectionURL, quic.Dialer{}, ws.Dialer{}) rd := dialer.NewRaceDial(c.log, c.connectionURL, quic.Dialer{}, ws.Dialer{})
conn, err := rd.Dial() conn, err := rd.Dial()
if err != nil { if err != nil {
return err return nil, err
} }
c.relayConn = conn c.relayConn = conn
if err = c.handShake(); err != nil { instanceURL, err := c.handShake(ctx)
if err != nil {
cErr := conn.Close() cErr := conn.Close()
if cErr != nil { if cErr != nil {
c.log.Errorf("failed to close connection: %s", cErr) c.log.Errorf("failed to close connection: %s", cErr)
} }
return err return nil, err
} }
return nil return instanceURL, nil
} }
func (c *Client) handShake() error { func (c *Client) handShake(ctx context.Context) (*RelayAddr, error) {
msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary()) msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary())
if err != nil { if err != nil {
c.log.Errorf("failed to marshal auth message: %s", err) c.log.Errorf("failed to marshal auth message: %s", err)
return err return nil, err
} }
_, err = c.relayConn.Write(msg) _, err = c.relayConn.Write(msg)
if err != nil { if err != nil {
c.log.Errorf("failed to send auth message: %s", err) c.log.Errorf("failed to send auth message: %s", err)
return err return nil, err
} }
buf := make([]byte, messages.MaxHandshakeRespSize) buf := make([]byte, messages.MaxHandshakeRespSize)
n, err := c.readWithTimeout(buf) n, err := c.readWithTimeout(ctx, buf)
if err != nil { if err != nil {
c.log.Errorf("failed to read auth response: %s", err) c.log.Errorf("failed to read auth response: %s", err)
return err return nil, err
} }
_, err = messages.ValidateVersion(buf[:n]) _, err = messages.ValidateVersion(buf[:n])
if err != nil { if err != nil {
return fmt.Errorf("validate version: %w", err) return nil, fmt.Errorf("validate version: %w", err)
} }
msgType, err := messages.DetermineServerMessageType(buf[:n]) msgType, err := messages.DetermineServerMessageType(buf[:n])
if err != nil { if err != nil {
c.log.Errorf("failed to determine message type: %s", err) c.log.Errorf("failed to determine message type: %s", err)
return err return nil, err
} }
if msgType != messages.MsgTypeAuthResponse { if msgType != messages.MsgTypeAuthResponse {
c.log.Errorf("unexpected message type: %s", msgType) c.log.Errorf("unexpected message type: %s", msgType)
return fmt.Errorf("unexpected message type") return nil, fmt.Errorf("unexpected message type")
} }
addr, err := messages.UnmarshalAuthResponse(buf[:n]) addr, err := messages.UnmarshalAuthResponse(buf[:n])
if err != nil { if err != nil {
return err return nil, err
} }
c.muInstanceURL.Lock() return &RelayAddr{addr: addr}, nil
c.instanceURL = &RelayAddr{addr: addr}
c.muInstanceURL.Unlock()
return nil
} }
func (c *Client) readLoop(relayConn net.Conn) { func (c *Client) readLoop(hc *healthcheck.Receiver, relayConn net.Conn, internallyStoppedFlag *internalStopFlag) {
internallyStoppedFlag := newInternalStopFlag()
hc := healthcheck.NewReceiver(c.log)
go c.listenForStopEvents(hc, relayConn, internallyStoppedFlag)
var ( var (
errExit error errExit error
n int n int
@@ -366,10 +397,7 @@ func (c *Client) readLoop(relayConn net.Conn) {
hc.Stop() hc.Stop()
c.muInstanceURL.Lock() c.stateSubscription.Cleanup()
c.instanceURL = nil
c.muInstanceURL.Unlock()
c.wgReadLoop.Done() c.wgReadLoop.Done()
_ = c.close(false) _ = c.close(false)
c.notifyDisconnected() c.notifyDisconnected()
@@ -382,6 +410,14 @@ func (c *Client) handleMsg(msgType messages.MsgType, buf []byte, bufPtr *[]byte,
c.bufPool.Put(bufPtr) c.bufPool.Put(bufPtr)
case messages.MsgTypeTransport: case messages.MsgTypeTransport:
return c.handleTransportMsg(buf, bufPtr, internallyStoppedFlag) return c.handleTransportMsg(buf, bufPtr, internallyStoppedFlag)
case messages.MsgTypePeersOnline:
c.handlePeersOnlineMsg(buf)
c.bufPool.Put(bufPtr)
return true
case messages.MsgTypePeersWentOffline:
c.handlePeersWentOfflineMsg(buf)
c.bufPool.Put(bufPtr)
return true
case messages.MsgTypeClose: case messages.MsgTypeClose:
c.log.Debugf("relay connection close by server") c.log.Debugf("relay connection close by server")
c.bufPool.Put(bufPtr) c.bufPool.Put(bufPtr)
@@ -413,18 +449,16 @@ func (c *Client) handleTransportMsg(buf []byte, bufPtr *[]byte, internallyStoppe
return true return true
} }
stringID := messages.HashIDToString(peerID)
c.mu.Lock() c.mu.Lock()
if !c.serviceIsRunning { if !c.serviceIsRunning {
c.mu.Unlock() c.mu.Unlock()
c.bufPool.Put(bufPtr) c.bufPool.Put(bufPtr)
return false return false
} }
container, ok := c.conns[stringID] container, ok := c.conns[*peerID]
c.mu.Unlock() c.mu.Unlock()
if !ok { if !ok {
c.log.Errorf("peer not found: %s", stringID) c.log.Errorf("peer not found: %s", peerID.String())
c.bufPool.Put(bufPtr) c.bufPool.Put(bufPtr)
return true return true
} }
@@ -437,9 +471,9 @@ func (c *Client) handleTransportMsg(buf []byte, bufPtr *[]byte, internallyStoppe
return true return true
} }
func (c *Client) writeTo(connReference *Conn, id string, dstID []byte, payload []byte) (int, error) { func (c *Client) writeTo(connReference *Conn, dstID messages.PeerID, payload []byte) (int, error) {
c.mu.Lock() c.mu.Lock()
conn, ok := c.conns[id] conn, ok := c.conns[dstID]
c.mu.Unlock() c.mu.Unlock()
if !ok { if !ok {
return 0, net.ErrClosed return 0, net.ErrClosed
@@ -464,7 +498,7 @@ func (c *Client) writeTo(connReference *Conn, id string, dstID []byte, payload [
return len(payload), err return len(payload), err
} }
func (c *Client) listenForStopEvents(hc *healthcheck.Receiver, conn net.Conn, internalStopFlag *internalStopFlag) { func (c *Client) listenForStopEvents(ctx context.Context, hc *healthcheck.Receiver, conn net.Conn, internalStopFlag *internalStopFlag) {
for { for {
select { select {
case _, ok := <-hc.OnTimeout: case _, ok := <-hc.OnTimeout:
@@ -478,7 +512,7 @@ func (c *Client) listenForStopEvents(hc *healthcheck.Receiver, conn net.Conn, in
c.log.Warnf("failed to close connection: %s", err) c.log.Warnf("failed to close connection: %s", err)
} }
return return
case <-c.parentCtx.Done(): case <-ctx.Done():
err := c.close(true) err := c.close(true)
if err != nil { if err != nil {
c.log.Errorf("failed to teardown connection: %s", err) c.log.Errorf("failed to teardown connection: %s", err)
@@ -492,10 +526,31 @@ func (c *Client) closeAllConns() {
for _, container := range c.conns { for _, container := range c.conns {
container.close() container.close()
} }
c.conns = make(map[string]*connContainer) c.conns = make(map[messages.PeerID]*connContainer)
} }
func (c *Client) closeConn(connReference *Conn, id string) error { func (c *Client) closeConnsByPeerID(peerIDs []messages.PeerID) {
c.mu.Lock()
defer c.mu.Unlock()
for _, peerID := range peerIDs {
container, ok := c.conns[peerID]
if !ok {
c.log.Warnf("can not close connection, peer not found: %s", peerID)
continue
}
container.log.Infof("remote peer has been disconnected, free up connection: %s", peerID)
container.close()
delete(c.conns, peerID)
}
if err := c.stateSubscription.UnsubscribeStateChange(peerIDs); err != nil {
c.log.Errorf("failed to unsubscribe from peer state change: %s, %s", peerIDs, err)
}
}
func (c *Client) closeConn(connReference *Conn, id messages.PeerID) error {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
@@ -507,6 +562,11 @@ func (c *Client) closeConn(connReference *Conn, id string) error {
if container.conn != connReference { if container.conn != connReference {
return fmt.Errorf("conn reference mismatch") return fmt.Errorf("conn reference mismatch")
} }
if err := c.stateSubscription.UnsubscribeStateChange([]messages.PeerID{id}); err != nil {
container.log.Errorf("failed to unsubscribe from peer state change: %s", err)
}
c.log.Infof("free up connection to peer: %s", id) c.log.Infof("free up connection to peer: %s", id)
delete(c.conns, id) delete(c.conns, id)
container.close() container.close()
@@ -525,8 +585,12 @@ func (c *Client) close(gracefullyExit bool) error {
c.log.Warn("relay connection was already marked as not running") c.log.Warn("relay connection was already marked as not running")
return nil return nil
} }
c.serviceIsRunning = false c.serviceIsRunning = false
c.muInstanceURL.Lock()
c.instanceURL = nil
c.muInstanceURL.Unlock()
c.log.Infof("closing all peer connections") c.log.Infof("closing all peer connections")
c.closeAllConns() c.closeAllConns()
if gracefullyExit { if gracefullyExit {
@@ -559,8 +623,8 @@ func (c *Client) writeCloseMsg() {
} }
} }
func (c *Client) readWithTimeout(buf []byte) (int, error) { func (c *Client) readWithTimeout(ctx context.Context, buf []byte) (int, error) {
ctx, cancel := context.WithTimeout(c.parentCtx, serverResponseTimeout) ctx, cancel := context.WithTimeout(ctx, serverResponseTimeout)
defer cancel() defer cancel()
readDone := make(chan struct{}) readDone := make(chan struct{})
@@ -581,3 +645,21 @@ func (c *Client) readWithTimeout(buf []byte) (int, error) {
return n, err return n, err
} }
} }
func (c *Client) handlePeersOnlineMsg(buf []byte) {
peersID, err := messages.UnmarshalPeersOnlineMsg(buf)
if err != nil {
c.log.Errorf("failed to unmarshal peers online msg: %s", err)
return
}
c.stateSubscription.OnPeersOnline(peersID)
}
func (c *Client) handlePeersWentOfflineMsg(buf []byte) {
peersID, err := messages.UnMarshalPeersWentOffline(buf)
if err != nil {
c.log.Errorf("failed to unmarshal peers went offline msg: %s", err)
return
}
c.stateSubscription.OnPeersWentOffline(peersID)
}

View File

@@ -18,14 +18,19 @@ import (
) )
var ( var (
av = &allow.Auth{}
hmacTokenStore = &hmac.TokenStore{} hmacTokenStore = &hmac.TokenStore{}
serverListenAddr = "127.0.0.1:1234" serverListenAddr = "127.0.0.1:1234"
serverURL = "rel://127.0.0.1:1234" serverURL = "rel://127.0.0.1:1234"
serverCfg = server.Config{
Meter: otel.Meter(""),
ExposedAddress: serverURL,
TLSSupport: false,
AuthValidator: &allow.Auth{},
}
) )
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
_ = util.InitLog("error", "console") _ = util.InitLog("debug", "console")
code := m.Run() code := m.Run()
os.Exit(code) os.Exit(code)
} }
@@ -33,7 +38,7 @@ func TestMain(m *testing.M) {
func TestClient(t *testing.T) { func TestClient(t *testing.T) {
ctx := context.Background() ctx := context.Background()
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) srv, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@@ -58,37 +63,37 @@ func TestClient(t *testing.T) {
t.Fatalf("failed to start server: %s", err) t.Fatalf("failed to start server: %s", err)
} }
t.Log("alice connecting to server") t.Log("alice connecting to server")
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice") clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect() err = clientAlice.Connect(ctx)
if err != nil { if err != nil {
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
} }
defer clientAlice.Close() defer clientAlice.Close()
t.Log("placeholder connecting to server") t.Log("placeholder connecting to server")
clientPlaceHolder := NewClient(ctx, serverURL, hmacTokenStore, "clientPlaceHolder") clientPlaceHolder := NewClient(serverURL, hmacTokenStore, "clientPlaceHolder")
err = clientPlaceHolder.Connect() err = clientPlaceHolder.Connect(ctx)
if err != nil { if err != nil {
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
} }
defer clientPlaceHolder.Close() defer clientPlaceHolder.Close()
t.Log("Bob connecting to server") t.Log("Bob connecting to server")
clientBob := NewClient(ctx, serverURL, hmacTokenStore, "bob") clientBob := NewClient(serverURL, hmacTokenStore, "bob")
err = clientBob.Connect() err = clientBob.Connect(ctx)
if err != nil { if err != nil {
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
} }
defer clientBob.Close() defer clientBob.Close()
t.Log("Alice open connection to Bob") t.Log("Alice open connection to Bob")
connAliceToBob, err := clientAlice.OpenConn("bob") connAliceToBob, err := clientAlice.OpenConn(ctx, "bob")
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }
t.Log("Bob open connection to Alice") t.Log("Bob open connection to Alice")
connBobToAlice, err := clientBob.OpenConn("alice") connBobToAlice, err := clientBob.OpenConn(ctx, "alice")
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }
@@ -115,7 +120,7 @@ func TestClient(t *testing.T) {
func TestRegistration(t *testing.T) { func TestRegistration(t *testing.T) {
ctx := context.Background() ctx := context.Background()
srvCfg := server.ListenerConfig{Address: serverListenAddr} srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) srv, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@@ -132,8 +137,8 @@ func TestRegistration(t *testing.T) {
t.Fatalf("failed to start server: %s", err) t.Fatalf("failed to start server: %s", err)
} }
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice") clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect() err = clientAlice.Connect(ctx)
if err != nil { if err != nil {
_ = srv.Shutdown(ctx) _ = srv.Shutdown(ctx)
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
@@ -172,8 +177,8 @@ func TestRegistrationTimeout(t *testing.T) {
_ = fakeTCPListener.Close() _ = fakeTCPListener.Close()
}(fakeTCPListener) }(fakeTCPListener)
clientAlice := NewClient(ctx, "127.0.0.1:1234", hmacTokenStore, "alice") clientAlice := NewClient("127.0.0.1:1234", hmacTokenStore, "alice")
err = clientAlice.Connect() err = clientAlice.Connect(ctx)
if err == nil { if err == nil {
t.Errorf("failed to connect to server: %s", err) t.Errorf("failed to connect to server: %s", err)
} }
@@ -189,7 +194,7 @@ func TestEcho(t *testing.T) {
idAlice := "alice" idAlice := "alice"
idBob := "bob" idBob := "bob"
srvCfg := server.ListenerConfig{Address: serverListenAddr} srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) srv, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@@ -213,8 +218,8 @@ func TestEcho(t *testing.T) {
t.Fatalf("failed to start server: %s", err) t.Fatalf("failed to start server: %s", err)
} }
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, idAlice) clientAlice := NewClient(serverURL, hmacTokenStore, idAlice)
err = clientAlice.Connect() err = clientAlice.Connect(ctx)
if err != nil { if err != nil {
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
} }
@@ -225,8 +230,8 @@ func TestEcho(t *testing.T) {
} }
}() }()
clientBob := NewClient(ctx, serverURL, hmacTokenStore, idBob) clientBob := NewClient(serverURL, hmacTokenStore, idBob)
err = clientBob.Connect() err = clientBob.Connect(ctx)
if err != nil { if err != nil {
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
} }
@@ -237,12 +242,12 @@ func TestEcho(t *testing.T) {
} }
}() }()
connAliceToBob, err := clientAlice.OpenConn(idBob) connAliceToBob, err := clientAlice.OpenConn(ctx, idBob)
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }
connBobToAlice, err := clientBob.OpenConn(idAlice) connBobToAlice, err := clientBob.OpenConn(ctx, idAlice)
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }
@@ -278,7 +283,7 @@ func TestBindToUnavailabePeer(t *testing.T) {
ctx := context.Background() ctx := context.Background()
srvCfg := server.ListenerConfig{Address: serverListenAddr} srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) srv, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@@ -303,14 +308,14 @@ func TestBindToUnavailabePeer(t *testing.T) {
t.Fatalf("failed to start server: %s", err) t.Fatalf("failed to start server: %s", err)
} }
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice") clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect() err = clientAlice.Connect(ctx)
if err != nil { if err != nil {
t.Errorf("failed to connect to server: %s", err) t.Errorf("failed to connect to server: %s", err)
} }
_, err = clientAlice.OpenConn("bob") _, err = clientAlice.OpenConn(ctx, "bob")
if err != nil { if err == nil {
t.Errorf("failed to bind channel: %s", err) t.Errorf("expected error when binding to unavailable peer, got nil")
} }
log.Infof("closing client") log.Infof("closing client")
@@ -324,7 +329,7 @@ func TestBindReconnect(t *testing.T) {
ctx := context.Background() ctx := context.Background()
srvCfg := server.ListenerConfig{Address: serverListenAddr} srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) srv, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@@ -349,24 +354,24 @@ func TestBindReconnect(t *testing.T) {
t.Fatalf("failed to start server: %s", err) t.Fatalf("failed to start server: %s", err)
} }
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice") clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect() err = clientAlice.Connect(ctx)
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
clientBob := NewClient(serverURL, hmacTokenStore, "bob")
err = clientBob.Connect(ctx)
if err != nil { if err != nil {
t.Errorf("failed to connect to server: %s", err) t.Errorf("failed to connect to server: %s", err)
} }
_, err = clientAlice.OpenConn("bob") _, err = clientAlice.OpenConn(ctx, "bob")
if err != nil { if err != nil {
t.Errorf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }
clientBob := NewClient(ctx, serverURL, hmacTokenStore, "bob") chBob, err := clientBob.OpenConn(ctx, "alice")
err = clientBob.Connect()
if err != nil {
t.Errorf("failed to connect to server: %s", err)
}
chBob, err := clientBob.OpenConn("alice")
if err != nil { if err != nil {
t.Errorf("failed to bind channel: %s", err) t.Errorf("failed to bind channel: %s", err)
} }
@@ -377,18 +382,28 @@ func TestBindReconnect(t *testing.T) {
t.Errorf("failed to close client: %s", err) t.Errorf("failed to close client: %s", err)
} }
clientAlice = NewClient(ctx, serverURL, hmacTokenStore, "alice") clientAlice = NewClient(serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect() err = clientAlice.Connect(ctx)
if err != nil { if err != nil {
t.Errorf("failed to connect to server: %s", err) t.Errorf("failed to connect to server: %s", err)
} }
chAlice, err := clientAlice.OpenConn("bob") chAlice, err := clientAlice.OpenConn(ctx, "bob")
if err != nil { if err != nil {
t.Errorf("failed to bind channel: %s", err) t.Errorf("failed to bind channel: %s", err)
} }
testString := "hello alice, I am bob" testString := "hello alice, I am bob"
_, err = chBob.Write([]byte(testString))
if err == nil {
t.Errorf("expected error when writing to channel, got nil")
}
chBob, err = clientBob.OpenConn(ctx, "alice")
if err != nil {
t.Errorf("failed to bind channel: %s", err)
}
_, err = chBob.Write([]byte(testString)) _, err = chBob.Write([]byte(testString))
if err != nil { if err != nil {
t.Errorf("failed to write to channel: %s", err) t.Errorf("failed to write to channel: %s", err)
@@ -415,7 +430,7 @@ func TestCloseConn(t *testing.T) {
ctx := context.Background() ctx := context.Background()
srvCfg := server.ListenerConfig{Address: serverListenAddr} srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) srv, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@@ -440,13 +455,19 @@ func TestCloseConn(t *testing.T) {
t.Fatalf("failed to start server: %s", err) t.Fatalf("failed to start server: %s", err)
} }
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice") bob := NewClient(serverURL, hmacTokenStore, "bob")
err = clientAlice.Connect() err = bob.Connect(ctx)
if err != nil { if err != nil {
t.Errorf("failed to connect to server: %s", err) t.Errorf("failed to connect to server: %s", err)
} }
conn, err := clientAlice.OpenConn("bob") clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect(ctx)
if err != nil {
t.Errorf("failed to connect to server: %s", err)
}
conn, err := clientAlice.OpenConn(ctx, "bob")
if err != nil { if err != nil {
t.Errorf("failed to bind channel: %s", err) t.Errorf("failed to bind channel: %s", err)
} }
@@ -472,7 +493,7 @@ func TestCloseRelayConn(t *testing.T) {
ctx := context.Background() ctx := context.Background()
srvCfg := server.ListenerConfig{Address: serverListenAddr} srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) srv, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@@ -496,13 +517,19 @@ func TestCloseRelayConn(t *testing.T) {
t.Fatalf("failed to start server: %s", err) t.Fatalf("failed to start server: %s", err)
} }
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice") bob := NewClient(serverURL, hmacTokenStore, "bob")
err = clientAlice.Connect() err = bob.Connect(ctx)
if err != nil { if err != nil {
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
} }
conn, err := clientAlice.OpenConn("bob") clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect(ctx)
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
conn, err := clientAlice.OpenConn(ctx, "bob")
if err != nil { if err != nil {
t.Errorf("failed to bind channel: %s", err) t.Errorf("failed to bind channel: %s", err)
} }
@@ -514,7 +541,7 @@ func TestCloseRelayConn(t *testing.T) {
t.Errorf("unexpected reading from closed connection") t.Errorf("unexpected reading from closed connection")
} }
_, err = clientAlice.OpenConn("bob") _, err = clientAlice.OpenConn(ctx, "bob")
if err == nil { if err == nil {
t.Errorf("unexpected opening connection to closed server") t.Errorf("unexpected opening connection to closed server")
} }
@@ -524,7 +551,7 @@ func TestCloseByServer(t *testing.T) {
ctx := context.Background() ctx := context.Background()
srvCfg := server.ListenerConfig{Address: serverListenAddr} srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv1, err := server.NewServer(otel.Meter(""), serverURL, false, av) srv1, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@@ -544,8 +571,8 @@ func TestCloseByServer(t *testing.T) {
idAlice := "alice" idAlice := "alice"
log.Debugf("connect by alice") log.Debugf("connect by alice")
relayClient := NewClient(ctx, serverURL, hmacTokenStore, idAlice) relayClient := NewClient(serverURL, hmacTokenStore, idAlice)
err = relayClient.Connect() err = relayClient.Connect(ctx)
if err != nil { if err != nil {
log.Fatalf("failed to connect to server: %s", err) log.Fatalf("failed to connect to server: %s", err)
} }
@@ -567,7 +594,7 @@ func TestCloseByServer(t *testing.T) {
log.Fatalf("timeout waiting for client to disconnect") log.Fatalf("timeout waiting for client to disconnect")
} }
_, err = relayClient.OpenConn("bob") _, err = relayClient.OpenConn(ctx, "bob")
if err == nil { if err == nil {
t.Errorf("unexpected opening connection to closed server") t.Errorf("unexpected opening connection to closed server")
} }
@@ -577,7 +604,7 @@ func TestCloseByClient(t *testing.T) {
ctx := context.Background() ctx := context.Background()
srvCfg := server.ListenerConfig{Address: serverListenAddr} srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) srv, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@@ -596,8 +623,8 @@ func TestCloseByClient(t *testing.T) {
idAlice := "alice" idAlice := "alice"
log.Debugf("connect by alice") log.Debugf("connect by alice")
relayClient := NewClient(ctx, serverURL, hmacTokenStore, idAlice) relayClient := NewClient(serverURL, hmacTokenStore, idAlice)
err = relayClient.Connect() err = relayClient.Connect(ctx)
if err != nil { if err != nil {
log.Fatalf("failed to connect to server: %s", err) log.Fatalf("failed to connect to server: %s", err)
} }
@@ -607,7 +634,7 @@ func TestCloseByClient(t *testing.T) {
t.Errorf("failed to close client: %s", err) t.Errorf("failed to close client: %s", err)
} }
_, err = relayClient.OpenConn("bob") _, err = relayClient.OpenConn(ctx, "bob")
if err == nil { if err == nil {
t.Errorf("unexpected opening connection to closed server") t.Errorf("unexpected opening connection to closed server")
} }
@@ -623,7 +650,7 @@ func TestCloseNotDrainedChannel(t *testing.T) {
idAlice := "alice" idAlice := "alice"
idBob := "bob" idBob := "bob"
srvCfg := server.ListenerConfig{Address: serverListenAddr} srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) srv, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@@ -647,8 +674,8 @@ func TestCloseNotDrainedChannel(t *testing.T) {
t.Fatalf("failed to start server: %s", err) t.Fatalf("failed to start server: %s", err)
} }
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, idAlice) clientAlice := NewClient(serverURL, hmacTokenStore, idAlice)
err = clientAlice.Connect() err = clientAlice.Connect(ctx)
if err != nil { if err != nil {
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
} }
@@ -659,8 +686,8 @@ func TestCloseNotDrainedChannel(t *testing.T) {
} }
}() }()
clientBob := NewClient(ctx, serverURL, hmacTokenStore, idBob) clientBob := NewClient(serverURL, hmacTokenStore, idBob)
err = clientBob.Connect() err = clientBob.Connect(ctx)
if err != nil { if err != nil {
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
} }
@@ -671,12 +698,12 @@ func TestCloseNotDrainedChannel(t *testing.T) {
} }
}() }()
connAliceToBob, err := clientAlice.OpenConn(idBob) connAliceToBob, err := clientAlice.OpenConn(ctx, idBob)
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }
connBobToAlice, err := clientBob.OpenConn(idAlice) connBobToAlice, err := clientBob.OpenConn(ctx, idAlice)
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }

View File

@@ -3,13 +3,14 @@ package client
import ( import (
"net" "net"
"time" "time"
"github.com/netbirdio/netbird/relay/messages"
) )
// Conn represent a connection to a relayed remote peer. // Conn represent a connection to a relayed remote peer.
type Conn struct { type Conn struct {
client *Client client *Client
dstID []byte dstID messages.PeerID
dstStringID string
messageChan chan Msg messageChan chan Msg
instanceURL *RelayAddr instanceURL *RelayAddr
} }
@@ -17,14 +18,12 @@ type Conn struct {
// NewConn creates a new connection to a relayed remote peer. // NewConn creates a new connection to a relayed remote peer.
// client: the client instance, it used to send messages to the destination peer // client: the client instance, it used to send messages to the destination peer
// dstID: the destination peer ID // dstID: the destination peer ID
// dstStringID: the destination peer ID in string format
// messageChan: the channel where the messages will be received // messageChan: the channel where the messages will be received
// instanceURL: the relay instance URL, it used to get the proper server instance address for the remote peer // instanceURL: the relay instance URL, it used to get the proper server instance address for the remote peer
func NewConn(client *Client, dstID []byte, dstStringID string, messageChan chan Msg, instanceURL *RelayAddr) *Conn { func NewConn(client *Client, dstID messages.PeerID, messageChan chan Msg, instanceURL *RelayAddr) *Conn {
c := &Conn{ c := &Conn{
client: client, client: client,
dstID: dstID, dstID: dstID,
dstStringID: dstStringID,
messageChan: messageChan, messageChan: messageChan,
instanceURL: instanceURL, instanceURL: instanceURL,
} }
@@ -33,7 +32,7 @@ func NewConn(client *Client, dstID []byte, dstStringID string, messageChan chan
} }
func (c *Conn) Write(p []byte) (n int, err error) { func (c *Conn) Write(p []byte) (n int, err error) {
return c.client.writeTo(c, c.dstStringID, c.dstID, p) return c.client.writeTo(c, c.dstID, p)
} }
func (c *Conn) Read(b []byte) (n int, err error) { func (c *Conn) Read(b []byte) (n int, err error) {
@@ -48,7 +47,7 @@ func (c *Conn) Read(b []byte) (n int, err error) {
} }
func (c *Conn) Close() error { func (c *Conn) Close() error {
return c.client.closeConn(c, c.dstStringID) return c.client.closeConn(c, c.dstID)
} }
func (c *Conn) LocalAddr() net.Addr { func (c *Conn) LocalAddr() net.Addr {

View File

@@ -80,7 +80,7 @@ func (g *Guard) tryToQuickReconnect(parentCtx context.Context, rc *Client) bool
log.Infof("try to reconnect to Relay server: %s", rc.connectionURL) log.Infof("try to reconnect to Relay server: %s", rc.connectionURL)
if err := rc.Connect(); err != nil { if err := rc.Connect(parentCtx); err != nil {
log.Errorf("failed to reconnect to relay server: %s", err) log.Errorf("failed to reconnect to relay server: %s", err)
return false return false
} }

View File

@@ -42,7 +42,7 @@ type OnServerCloseListener func()
// ManagerService is the interface for the relay manager. // ManagerService is the interface for the relay manager.
type ManagerService interface { type ManagerService interface {
Serve() error Serve() error
OpenConn(serverAddress, peerKey string) (net.Conn, error) OpenConn(ctx context.Context, serverAddress, peerKey string) (net.Conn, error)
AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error
RelayInstanceAddress() (string, error) RelayInstanceAddress() (string, error)
ServerURLs() []string ServerURLs() []string
@@ -65,7 +65,7 @@ type Manager struct {
relayClient *Client relayClient *Client
// the guard logic can overwrite the relayClient variable, this mutex protect the usage of the variable // the guard logic can overwrite the relayClient variable, this mutex protect the usage of the variable
relayClientMu sync.Mutex relayClientMu sync.RWMutex
reconnectGuard *Guard reconnectGuard *Guard
relayClients map[string]*RelayTrack relayClients map[string]*RelayTrack
@@ -123,9 +123,9 @@ func (m *Manager) Serve() error {
// OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be // OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be
// established via the relay server. If the peer is on a different relay server, the manager will establish a new // established via the relay server. If the peer is on a different relay server, the manager will establish a new
// connection to the relay server. It returns back with a net.Conn what represent the remote peer connection. // connection to the relay server. It returns back with a net.Conn what represent the remote peer connection.
func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) { func (m *Manager) OpenConn(ctx context.Context, serverAddress, peerKey string) (net.Conn, error) {
m.relayClientMu.Lock() m.relayClientMu.RLock()
defer m.relayClientMu.Unlock() defer m.relayClientMu.RUnlock()
if m.relayClient == nil { if m.relayClient == nil {
return nil, ErrRelayClientNotConnected return nil, ErrRelayClientNotConnected
@@ -141,10 +141,10 @@ func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) {
) )
if !foreign { if !foreign {
log.Debugf("open peer connection via permanent server: %s", peerKey) log.Debugf("open peer connection via permanent server: %s", peerKey)
netConn, err = m.relayClient.OpenConn(peerKey) netConn, err = m.relayClient.OpenConn(ctx, peerKey)
} else { } else {
log.Debugf("open peer connection via foreign server: %s", serverAddress) log.Debugf("open peer connection via foreign server: %s", serverAddress)
netConn, err = m.openConnVia(serverAddress, peerKey) netConn, err = m.openConnVia(ctx, serverAddress, peerKey)
} }
if err != nil { if err != nil {
return nil, err return nil, err
@@ -155,8 +155,8 @@ func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) {
// Ready returns true if the home Relay client is connected to the relay server. // Ready returns true if the home Relay client is connected to the relay server.
func (m *Manager) Ready() bool { func (m *Manager) Ready() bool {
m.relayClientMu.Lock() m.relayClientMu.RLock()
defer m.relayClientMu.Unlock() defer m.relayClientMu.RUnlock()
if m.relayClient == nil { if m.relayClient == nil {
return false return false
@@ -174,8 +174,8 @@ func (m *Manager) SetOnReconnectedListener(f func()) {
// AddCloseListener adds a listener to the given server instance address. The listener will be called if the connection // AddCloseListener adds a listener to the given server instance address. The listener will be called if the connection
// closed. // closed.
func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error { func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error {
m.relayClientMu.Lock() m.relayClientMu.RLock()
defer m.relayClientMu.Unlock() defer m.relayClientMu.RUnlock()
if m.relayClient == nil { if m.relayClient == nil {
return ErrRelayClientNotConnected return ErrRelayClientNotConnected
@@ -199,8 +199,8 @@ func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServ
// RelayInstanceAddress returns the address of the permanent relay server. It could change if the network connection is // RelayInstanceAddress returns the address of the permanent relay server. It could change if the network connection is
// lost. This address will be sent to the target peer to choose the common relay server for the communication. // lost. This address will be sent to the target peer to choose the common relay server for the communication.
func (m *Manager) RelayInstanceAddress() (string, error) { func (m *Manager) RelayInstanceAddress() (string, error) {
m.relayClientMu.Lock() m.relayClientMu.RLock()
defer m.relayClientMu.Unlock() defer m.relayClientMu.RUnlock()
if m.relayClient == nil { if m.relayClient == nil {
return "", ErrRelayClientNotConnected return "", ErrRelayClientNotConnected
@@ -229,7 +229,7 @@ func (m *Manager) UpdateToken(token *relayAuth.Token) error {
return m.tokenStore.UpdateToken(token) return m.tokenStore.UpdateToken(token)
} }
func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) { func (m *Manager) openConnVia(ctx context.Context, serverAddress, peerKey string) (net.Conn, error) {
// check if already has a connection to the desired relay server // check if already has a connection to the desired relay server
m.relayClientsMutex.RLock() m.relayClientsMutex.RLock()
rt, ok := m.relayClients[serverAddress] rt, ok := m.relayClients[serverAddress]
@@ -240,7 +240,7 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) {
if rt.err != nil { if rt.err != nil {
return nil, rt.err return nil, rt.err
} }
return rt.relayClient.OpenConn(peerKey) return rt.relayClient.OpenConn(ctx, peerKey)
} }
m.relayClientsMutex.RUnlock() m.relayClientsMutex.RUnlock()
@@ -255,7 +255,7 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) {
if rt.err != nil { if rt.err != nil {
return nil, rt.err return nil, rt.err
} }
return rt.relayClient.OpenConn(peerKey) return rt.relayClient.OpenConn(ctx, peerKey)
} }
// create a new relay client and store it in the relayClients map // create a new relay client and store it in the relayClients map
@@ -264,8 +264,8 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) {
m.relayClients[serverAddress] = rt m.relayClients[serverAddress] = rt
m.relayClientsMutex.Unlock() m.relayClientsMutex.Unlock()
relayClient := NewClient(m.ctx, serverAddress, m.tokenStore, m.peerID) relayClient := NewClient(serverAddress, m.tokenStore, m.peerID)
err := relayClient.Connect() err := relayClient.Connect(m.ctx)
if err != nil { if err != nil {
rt.err = err rt.err = err
rt.Unlock() rt.Unlock()
@@ -279,7 +279,7 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) {
rt.relayClient = relayClient rt.relayClient = relayClient
rt.Unlock() rt.Unlock()
conn, err := relayClient.OpenConn(peerKey) conn, err := relayClient.OpenConn(ctx, peerKey)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -300,7 +300,9 @@ func (m *Manager) onServerConnected() {
func (m *Manager) onServerDisconnected(serverAddress string) { func (m *Manager) onServerDisconnected(serverAddress string) {
m.relayClientMu.Lock() m.relayClientMu.Lock()
if serverAddress == m.relayClient.connectionURL { if serverAddress == m.relayClient.connectionURL {
go m.reconnectGuard.StartReconnectTrys(m.ctx, m.relayClient) go func(client *Client) {
m.reconnectGuard.StartReconnectTrys(m.ctx, client)
}(m.relayClient)
} }
m.relayClientMu.Unlock() m.relayClientMu.Unlock()

View File

@@ -8,6 +8,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel" "go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/relay/auth/allow"
"github.com/netbirdio/netbird/relay/server" "github.com/netbirdio/netbird/relay/server"
) )
@@ -22,16 +23,22 @@ func TestEmptyURL(t *testing.T) {
func TestForeignConn(t *testing.T) { func TestForeignConn(t *testing.T) {
ctx := context.Background() ctx := context.Background()
srvCfg1 := server.ListenerConfig{ lstCfg1 := server.ListenerConfig{
Address: "localhost:1234", Address: "localhost:1234",
} }
srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av)
srv1, err := server.NewServer(server.Config{
Meter: otel.Meter(""),
ExposedAddress: lstCfg1.Address,
TLSSupport: false,
AuthValidator: &allow.Auth{},
})
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
errChan := make(chan error, 1) errChan := make(chan error, 1)
go func() { go func() {
err := srv1.Listen(srvCfg1) err := srv1.Listen(lstCfg1)
if err != nil { if err != nil {
errChan <- err errChan <- err
} }
@@ -51,7 +58,12 @@ func TestForeignConn(t *testing.T) {
srvCfg2 := server.ListenerConfig{ srvCfg2 := server.ListenerConfig{
Address: "localhost:2234", Address: "localhost:2234",
} }
srv2, err := server.NewServer(otel.Meter(""), srvCfg2.Address, false, av) srv2, err := server.NewServer(server.Config{
Meter: otel.Meter(""),
ExposedAddress: srvCfg2.Address,
TLSSupport: false,
AuthValidator: &allow.Auth{},
})
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@@ -74,32 +86,26 @@ func TestForeignConn(t *testing.T) {
t.Fatalf("failed to start server: %s", err) t.Fatalf("failed to start server: %s", err)
} }
idAlice := "alice"
log.Debugf("connect by alice")
mCtx, cancel := context.WithCancel(ctx) mCtx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
clientAlice := NewManager(mCtx, toURL(srvCfg1), idAlice) clientAlice := NewManager(mCtx, toURL(lstCfg1), "alice")
err = clientAlice.Serve() if err := clientAlice.Serve(); err != nil {
if err != nil {
t.Fatalf("failed to serve manager: %s", err) t.Fatalf("failed to serve manager: %s", err)
} }
idBob := "bob" clientBob := NewManager(mCtx, toURL(srvCfg2), "bob")
log.Debugf("connect by bob") if err := clientBob.Serve(); err != nil {
clientBob := NewManager(mCtx, toURL(srvCfg2), idBob)
err = clientBob.Serve()
if err != nil {
t.Fatalf("failed to serve manager: %s", err) t.Fatalf("failed to serve manager: %s", err)
} }
bobsSrvAddr, err := clientBob.RelayInstanceAddress() bobsSrvAddr, err := clientBob.RelayInstanceAddress()
if err != nil { if err != nil {
t.Fatalf("failed to get relay address: %s", err) t.Fatalf("failed to get relay address: %s", err)
} }
connAliceToBob, err := clientAlice.OpenConn(bobsSrvAddr, idBob) connAliceToBob, err := clientAlice.OpenConn(ctx, bobsSrvAddr, "bob")
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }
connBobToAlice, err := clientBob.OpenConn(bobsSrvAddr, idAlice) connBobToAlice, err := clientBob.OpenConn(ctx, bobsSrvAddr, "alice")
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }
@@ -137,7 +143,7 @@ func TestForeginConnClose(t *testing.T) {
srvCfg1 := server.ListenerConfig{ srvCfg1 := server.ListenerConfig{
Address: "localhost:1234", Address: "localhost:1234",
} }
srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av) srv1, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@@ -163,7 +169,7 @@ func TestForeginConnClose(t *testing.T) {
srvCfg2 := server.ListenerConfig{ srvCfg2 := server.ListenerConfig{
Address: "localhost:2234", Address: "localhost:2234",
} }
srv2, err := server.NewServer(otel.Meter(""), srvCfg2.Address, false, av) srv2, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@@ -186,16 +192,20 @@ func TestForeginConnClose(t *testing.T) {
t.Fatalf("failed to start server: %s", err) t.Fatalf("failed to start server: %s", err)
} }
idAlice := "alice"
log.Debugf("connect by alice")
mCtx, cancel := context.WithCancel(ctx) mCtx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
mgr := NewManager(mCtx, toURL(srvCfg1), idAlice)
mgrBob := NewManager(mCtx, toURL(srvCfg2), "bob")
if err := mgrBob.Serve(); err != nil {
t.Fatalf("failed to serve manager: %s", err)
}
mgr := NewManager(mCtx, toURL(srvCfg1), "alice")
err = mgr.Serve() err = mgr.Serve()
if err != nil { if err != nil {
t.Fatalf("failed to serve manager: %s", err) t.Fatalf("failed to serve manager: %s", err)
} }
conn, err := mgr.OpenConn(toURL(srvCfg2)[0], "anotherpeer") conn, err := mgr.OpenConn(ctx, toURL(srvCfg2)[0], "bob")
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }
@@ -212,23 +222,21 @@ func TestForeginAutoClose(t *testing.T) {
srvCfg1 := server.ListenerConfig{ srvCfg1 := server.ListenerConfig{
Address: "localhost:1234", Address: "localhost:1234",
} }
srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av) srv1, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
errChan := make(chan error, 1) errChan := make(chan error, 1)
go func() { go func() {
t.Log("binding server 1.") t.Log("binding server 1.")
err := srv1.Listen(srvCfg1) if err := srv1.Listen(srvCfg1); err != nil {
if err != nil {
errChan <- err errChan <- err
} }
}() }()
defer func() { defer func() {
t.Logf("closing server 1.") t.Logf("closing server 1.")
err := srv1.Shutdown(ctx) if err := srv1.Shutdown(ctx); err != nil {
if err != nil {
t.Errorf("failed to close server: %s", err) t.Errorf("failed to close server: %s", err)
} }
t.Logf("server 1. closed") t.Logf("server 1. closed")
@@ -241,7 +249,7 @@ func TestForeginAutoClose(t *testing.T) {
srvCfg2 := server.ListenerConfig{ srvCfg2 := server.ListenerConfig{
Address: "localhost:2234", Address: "localhost:2234",
} }
srv2, err := server.NewServer(otel.Meter(""), srvCfg2.Address, false, av) srv2, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@@ -277,15 +285,8 @@ func TestForeginAutoClose(t *testing.T) {
} }
t.Log("open connection to another peer") t.Log("open connection to another peer")
conn, err := mgr.OpenConn(toURL(srvCfg2)[0], "anotherpeer") if _, err = mgr.OpenConn(ctx, toURL(srvCfg2)[0], "anotherpeer"); err == nil {
if err != nil { t.Fatalf("should have failed to open connection to another peer")
t.Fatalf("failed to bind channel: %s", err)
}
t.Log("close conn")
err = conn.Close()
if err != nil {
t.Fatalf("failed to close connection: %s", err)
} }
timeout := relayCleanupInterval + keepUnusedServerTime + 1*time.Second timeout := relayCleanupInterval + keepUnusedServerTime + 1*time.Second
@@ -305,7 +306,7 @@ func TestAutoReconnect(t *testing.T) {
srvCfg := server.ListenerConfig{ srvCfg := server.ListenerConfig{
Address: "localhost:1234", Address: "localhost:1234",
} }
srv, err := server.NewServer(otel.Meter(""), srvCfg.Address, false, av) srv, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@@ -330,6 +331,13 @@ func TestAutoReconnect(t *testing.T) {
mCtx, cancel := context.WithCancel(ctx) mCtx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
clientBob := NewManager(mCtx, toURL(srvCfg), "bob")
err = clientBob.Serve()
if err != nil {
t.Fatalf("failed to serve manager: %s", err)
}
clientAlice := NewManager(mCtx, toURL(srvCfg), "alice") clientAlice := NewManager(mCtx, toURL(srvCfg), "alice")
err = clientAlice.Serve() err = clientAlice.Serve()
if err != nil { if err != nil {
@@ -339,7 +347,7 @@ func TestAutoReconnect(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("failed to get relay address: %s", err) t.Errorf("failed to get relay address: %s", err)
} }
conn, err := clientAlice.OpenConn(ra, "bob") conn, err := clientAlice.OpenConn(ctx, ra, "bob")
if err != nil { if err != nil {
t.Errorf("failed to bind channel: %s", err) t.Errorf("failed to bind channel: %s", err)
} }
@@ -357,7 +365,7 @@ func TestAutoReconnect(t *testing.T) {
time.Sleep(reconnectingTimeout + 1*time.Second) time.Sleep(reconnectingTimeout + 1*time.Second)
log.Infof("reopent the connection") log.Infof("reopent the connection")
_, err = clientAlice.OpenConn(ra, "bob") _, err = clientAlice.OpenConn(ctx, ra, "bob")
if err != nil { if err != nil {
t.Errorf("failed to open channel: %s", err) t.Errorf("failed to open channel: %s", err)
} }
@@ -366,24 +374,27 @@ func TestAutoReconnect(t *testing.T) {
func TestNotifierDoubleAdd(t *testing.T) { func TestNotifierDoubleAdd(t *testing.T) {
ctx := context.Background() ctx := context.Background()
srvCfg1 := server.ListenerConfig{ listenerCfg1 := server.ListenerConfig{
Address: "localhost:1234", Address: "localhost:1234",
} }
srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av) srv, err := server.NewServer(server.Config{
Meter: otel.Meter(""),
ExposedAddress: listenerCfg1.Address,
TLSSupport: false,
AuthValidator: &allow.Auth{},
})
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
errChan := make(chan error, 1) errChan := make(chan error, 1)
go func() { go func() {
err := srv1.Listen(srvCfg1) if err := srv.Listen(listenerCfg1); err != nil {
if err != nil {
errChan <- err errChan <- err
} }
}() }()
defer func() { defer func() {
err := srv1.Shutdown(ctx) if err := srv.Shutdown(ctx); err != nil {
if err != nil {
t.Errorf("failed to close server: %s", err) t.Errorf("failed to close server: %s", err)
} }
}() }()
@@ -392,17 +403,21 @@ func TestNotifierDoubleAdd(t *testing.T) {
t.Fatalf("failed to start server: %s", err) t.Fatalf("failed to start server: %s", err)
} }
idAlice := "alice"
log.Debugf("connect by alice") log.Debugf("connect by alice")
mCtx, cancel := context.WithCancel(ctx) mCtx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
clientAlice := NewManager(mCtx, toURL(srvCfg1), idAlice)
err = clientAlice.Serve() clientBob := NewManager(mCtx, toURL(listenerCfg1), "bob")
if err != nil { if err = clientBob.Serve(); err != nil {
t.Fatalf("failed to serve manager: %s", err) t.Fatalf("failed to serve manager: %s", err)
} }
conn1, err := clientAlice.OpenConn(clientAlice.ServerURLs()[0], "idBob") clientAlice := NewManager(mCtx, toURL(listenerCfg1), "alice")
if err = clientAlice.Serve(); err != nil {
t.Fatalf("failed to serve manager: %s", err)
}
conn1, err := clientAlice.OpenConn(ctx, clientAlice.ServerURLs()[0], "bob")
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }

View File

@@ -0,0 +1,191 @@
package client
import (
"context"
"errors"
"fmt"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/messages"
)
const (
OpenConnectionTimeout = 30 * time.Second
)
type relayedConnWriter interface {
Write(p []byte) (n int, err error)
}
// PeersStateSubscription manages subscriptions to peer state changes (online/offline)
// over a relay connection. It allows tracking peers' availability and handling offline
// events via a callback. We get online notification from the server only once.
type PeersStateSubscription struct {
log *log.Entry
relayConn relayedConnWriter
offlineCallback func(peerIDs []messages.PeerID)
listenForOfflinePeers map[messages.PeerID]struct{}
waitingPeers map[messages.PeerID]chan struct{}
mu sync.Mutex // Mutex to protect access to waitingPeers and listenForOfflinePeers
}
func NewPeersStateSubscription(log *log.Entry, relayConn relayedConnWriter, offlineCallback func(peerIDs []messages.PeerID)) *PeersStateSubscription {
return &PeersStateSubscription{
log: log,
relayConn: relayConn,
offlineCallback: offlineCallback,
listenForOfflinePeers: make(map[messages.PeerID]struct{}),
waitingPeers: make(map[messages.PeerID]chan struct{}),
}
}
// OnPeersOnline should be called when a notification is received that certain peers have come online.
// It checks if any of the peers are being waited on and signals their availability.
func (s *PeersStateSubscription) OnPeersOnline(peersID []messages.PeerID) {
s.mu.Lock()
defer s.mu.Unlock()
for _, peerID := range peersID {
waitCh, ok := s.waitingPeers[peerID]
if !ok {
// If meanwhile the peer was unsubscribed, we don't need to signal it
continue
}
waitCh <- struct{}{}
delete(s.waitingPeers, peerID)
close(waitCh)
}
}
func (s *PeersStateSubscription) OnPeersWentOffline(peersID []messages.PeerID) {
s.mu.Lock()
relevantPeers := make([]messages.PeerID, 0, len(peersID))
for _, peerID := range peersID {
if _, ok := s.listenForOfflinePeers[peerID]; ok {
relevantPeers = append(relevantPeers, peerID)
}
}
s.mu.Unlock()
if len(relevantPeers) > 0 {
s.offlineCallback(relevantPeers)
}
}
// WaitToBeOnlineAndSubscribe waits for a specific peer to come online and subscribes to its state changes.
func (s *PeersStateSubscription) WaitToBeOnlineAndSubscribe(ctx context.Context, peerID messages.PeerID) error {
// Check if already waiting for this peer
s.mu.Lock()
if _, exists := s.waitingPeers[peerID]; exists {
s.mu.Unlock()
return errors.New("already waiting for peer to come online")
}
// Create a channel to wait for the peer to come online
waitCh := make(chan struct{}, 1)
s.waitingPeers[peerID] = waitCh
s.listenForOfflinePeers[peerID] = struct{}{}
s.mu.Unlock()
if err := s.subscribeStateChange(peerID); err != nil {
s.log.Errorf("failed to subscribe to peer state: %s", err)
s.mu.Lock()
if ch, exists := s.waitingPeers[peerID]; exists && ch == waitCh {
close(waitCh)
delete(s.waitingPeers, peerID)
delete(s.listenForOfflinePeers, peerID)
}
s.mu.Unlock()
return err
}
// Wait for peer to come online or context to be cancelled
timeoutCtx, cancel := context.WithTimeout(ctx, OpenConnectionTimeout)
defer cancel()
select {
case _, ok := <-waitCh:
if !ok {
return fmt.Errorf("wait for peer to come online has been cancelled")
}
s.log.Debugf("peer %s is now online", peerID)
return nil
case <-timeoutCtx.Done():
s.log.Debugf("context timed out while waiting for peer %s to come online", peerID)
if err := s.unsubscribeStateChange([]messages.PeerID{peerID}); err != nil {
s.log.Errorf("failed to unsubscribe from peer state: %s", err)
}
s.mu.Lock()
if ch, exists := s.waitingPeers[peerID]; exists && ch == waitCh {
close(waitCh)
delete(s.waitingPeers, peerID)
delete(s.listenForOfflinePeers, peerID)
}
s.mu.Unlock()
return timeoutCtx.Err()
}
}
func (s *PeersStateSubscription) UnsubscribeStateChange(peerIDs []messages.PeerID) error {
msgErr := s.unsubscribeStateChange(peerIDs)
s.mu.Lock()
for _, peerID := range peerIDs {
if wch, ok := s.waitingPeers[peerID]; ok {
close(wch)
delete(s.waitingPeers, peerID)
}
delete(s.listenForOfflinePeers, peerID)
}
s.mu.Unlock()
return msgErr
}
func (s *PeersStateSubscription) Cleanup() {
s.mu.Lock()
defer s.mu.Unlock()
for _, waitCh := range s.waitingPeers {
close(waitCh)
}
s.waitingPeers = make(map[messages.PeerID]chan struct{})
s.listenForOfflinePeers = make(map[messages.PeerID]struct{})
}
func (s *PeersStateSubscription) subscribeStateChange(peerID messages.PeerID) error {
msgs, err := messages.MarshalSubPeerStateMsg([]messages.PeerID{peerID})
if err != nil {
return err
}
for _, msg := range msgs {
if _, err := s.relayConn.Write(msg); err != nil {
return err
}
}
return nil
}
func (s *PeersStateSubscription) unsubscribeStateChange(peerIDs []messages.PeerID) error {
msgs, err := messages.MarshalUnsubPeerStateMsg(peerIDs)
if err != nil {
return err
}
var connWriteErr error
for _, msg := range msgs {
if _, err := s.relayConn.Write(msg); err != nil {
connWriteErr = err
}
}
return connWriteErr
}

View File

@@ -0,0 +1,99 @@
package client
import (
"bytes"
"context"
"testing"
"time"
"github.com/netbirdio/netbird/relay/messages"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type mockRelayedConn struct {
}
func (m *mockRelayedConn) Write(p []byte) (n int, err error) {
return len(p), nil
}
func TestWaitToBeOnlineAndSubscribe_Success(t *testing.T) {
peerID := messages.HashID("peer1")
mockConn := &mockRelayedConn{}
logger := logrus.New()
logger.SetOutput(&bytes.Buffer{}) // discard log output
sub := NewPeersStateSubscription(logrus.NewEntry(logger), mockConn, nil)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Launch wait in background
go func() {
time.Sleep(100 * time.Millisecond)
sub.OnPeersOnline([]messages.PeerID{peerID})
}()
err := sub.WaitToBeOnlineAndSubscribe(ctx, peerID)
assert.NoError(t, err)
}
func TestWaitToBeOnlineAndSubscribe_Timeout(t *testing.T) {
peerID := messages.HashID("peer2")
mockConn := &mockRelayedConn{}
logger := logrus.New()
logger.SetOutput(&bytes.Buffer{})
sub := NewPeersStateSubscription(logrus.NewEntry(logger), mockConn, nil)
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
err := sub.WaitToBeOnlineAndSubscribe(ctx, peerID)
assert.Error(t, err)
assert.Equal(t, context.DeadlineExceeded, err)
}
func TestWaitToBeOnlineAndSubscribe_Duplicate(t *testing.T) {
peerID := messages.HashID("peer3")
mockConn := &mockRelayedConn{}
logger := logrus.New()
logger.SetOutput(&bytes.Buffer{})
sub := NewPeersStateSubscription(logrus.NewEntry(logger), mockConn, nil)
ctx := context.Background()
go func() {
_ = sub.WaitToBeOnlineAndSubscribe(ctx, peerID)
}()
time.Sleep(100 * time.Millisecond)
err := sub.WaitToBeOnlineAndSubscribe(ctx, peerID)
require.Error(t, err)
assert.Contains(t, err.Error(), "already waiting")
}
func TestUnsubscribeStateChange(t *testing.T) {
peerID := messages.HashID("peer4")
mockConn := &mockRelayedConn{}
logger := logrus.New()
logger.SetOutput(&bytes.Buffer{})
sub := NewPeersStateSubscription(logrus.NewEntry(logger), mockConn, nil)
doneChan := make(chan struct{})
go func() {
_ = sub.WaitToBeOnlineAndSubscribe(context.Background(), peerID)
close(doneChan)
}()
time.Sleep(100 * time.Millisecond)
err := sub.UnsubscribeStateChange([]messages.PeerID{peerID})
assert.NoError(t, err)
select {
case <-doneChan:
case <-time.After(200 * time.Millisecond):
// Expected timeout, meaning the subscription was successfully unsubscribed
t.Errorf("timeout")
}
}

View File

@@ -70,8 +70,8 @@ func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) {
func (sp *ServerPicker) startConnection(ctx context.Context, resultChan chan connResult, url string) { func (sp *ServerPicker) startConnection(ctx context.Context, resultChan chan connResult, url string) {
log.Infof("try to connecting to relay server: %s", url) log.Infof("try to connecting to relay server: %s", url)
relayClient := NewClient(ctx, url, sp.TokenStore, sp.PeerID) relayClient := NewClient(url, sp.TokenStore, sp.PeerID)
err := relayClient.Connect() err := relayClient.Connect(ctx)
resultChan <- connResult{ resultChan <- connResult{
RelayClient: relayClient, RelayClient: relayClient,
Url: url, Url: url,

View File

@@ -141,7 +141,14 @@ func execute(cmd *cobra.Command, args []string) error {
hashedSecret := sha256.Sum256([]byte(cobraConfig.AuthSecret)) hashedSecret := sha256.Sum256([]byte(cobraConfig.AuthSecret))
authenticator := auth.NewTimedHMACValidator(hashedSecret[:], 24*time.Hour) authenticator := auth.NewTimedHMACValidator(hashedSecret[:], 24*time.Hour)
srv, err := server.NewServer(metricsServer.Meter, cobraConfig.ExposedAddress, tlsSupport, authenticator) cfg := server.Config{
Meter: metricsServer.Meter,
ExposedAddress: cobraConfig.ExposedAddress,
AuthValidator: authenticator,
TLSSupport: tlsSupport,
}
srv, err := server.NewServer(cfg)
if err != nil { if err != nil {
log.Debugf("failed to create relay server: %v", err) log.Debugf("failed to create relay server: %v", err)
return fmt.Errorf("failed to create relay server: %v", err) return fmt.Errorf("failed to create relay server: %v", err)

View File

@@ -8,24 +8,24 @@ import (
const ( const (
prefixLength = 4 prefixLength = 4
IDSize = prefixLength + sha256.Size peerIDSize = prefixLength + sha256.Size
) )
var ( var (
prefix = []byte("sha-") // 4 bytes prefix = []byte("sha-") // 4 bytes
) )
// HashID generates a sha256 hash from the peerID and returns the hash and the human-readable string type PeerID [peerIDSize]byte
func HashID(peerID string) ([]byte, string) {
idHash := sha256.Sum256([]byte(peerID)) func (p PeerID) String() string {
idHashString := string(prefix) + base64.StdEncoding.EncodeToString(idHash[:]) return fmt.Sprintf("%s%s", p[:prefixLength], base64.StdEncoding.EncodeToString(p[prefixLength:]))
var prefixedHash []byte
prefixedHash = append(prefixedHash, prefix...)
prefixedHash = append(prefixedHash, idHash[:]...)
return prefixedHash, idHashString
} }
// HashIDToString converts a hash to a human-readable string // HashID generates a sha256 hash from the peerID and returns the hash and the human-readable string
func HashIDToString(idHash []byte) string { func HashID(peerID string) PeerID {
return fmt.Sprintf("%s%s", idHash[:prefixLength], base64.StdEncoding.EncodeToString(idHash[prefixLength:])) idHash := sha256.Sum256([]byte(peerID))
var prefixedHash [peerIDSize]byte
copy(prefixedHash[:prefixLength], prefix)
copy(prefixedHash[prefixLength:], idHash[:])
return prefixedHash
} }

View File

@@ -1,13 +0,0 @@
package messages
import (
"testing"
)
func TestHashID(t *testing.T) {
hashedID, hashedStringId := HashID("alice")
enc := HashIDToString(hashedID)
if enc != hashedStringId {
t.Errorf("expected %s, got %s", hashedStringId, enc)
}
}

View File

@@ -9,19 +9,26 @@ import (
const ( const (
MaxHandshakeSize = 212 MaxHandshakeSize = 212
MaxHandshakeRespSize = 8192 MaxHandshakeRespSize = 8192
MaxMessageSize = 8820
CurrentProtocolVersion = 1 CurrentProtocolVersion = 1
MsgTypeUnknown MsgType = 0 MsgTypeUnknown MsgType = 0
// Deprecated: Use MsgTypeAuth instead. // Deprecated: Use MsgTypeAuth instead.
MsgTypeHello MsgType = 1 MsgTypeHello = 1
// Deprecated: Use MsgTypeAuthResponse instead. // Deprecated: Use MsgTypeAuthResponse instead.
MsgTypeHelloResponse MsgType = 2 MsgTypeHelloResponse = 2
MsgTypeTransport MsgType = 3 MsgTypeTransport = 3
MsgTypeClose MsgType = 4 MsgTypeClose = 4
MsgTypeHealthCheck MsgType = 5 MsgTypeHealthCheck = 5
MsgTypeAuth = 6 MsgTypeAuth = 6
MsgTypeAuthResponse = 7 MsgTypeAuthResponse = 7
// Peers state messages
MsgTypeSubscribePeerState = 8
MsgTypeUnsubscribePeerState = 9
MsgTypePeersOnline = 10
MsgTypePeersWentOffline = 11
// base size of the message // base size of the message
sizeOfVersionByte = 1 sizeOfVersionByte = 1
@@ -30,17 +37,17 @@ const (
// auth message // auth message
sizeOfMagicByte = 4 sizeOfMagicByte = 4
headerSizeAuth = sizeOfMagicByte + IDSize headerSizeAuth = sizeOfMagicByte + peerIDSize
offsetMagicByte = sizeOfProtoHeader offsetMagicByte = sizeOfProtoHeader
offsetAuthPeerID = sizeOfProtoHeader + sizeOfMagicByte offsetAuthPeerID = sizeOfProtoHeader + sizeOfMagicByte
headerTotalSizeAuth = sizeOfProtoHeader + headerSizeAuth headerTotalSizeAuth = sizeOfProtoHeader + headerSizeAuth
// hello message // hello message
headerSizeHello = sizeOfMagicByte + IDSize headerSizeHello = sizeOfMagicByte + peerIDSize
headerSizeHelloResp = 0 headerSizeHelloResp = 0
// transport // transport
headerSizeTransport = IDSize headerSizeTransport = peerIDSize
offsetTransportID = sizeOfProtoHeader offsetTransportID = sizeOfProtoHeader
headerTotalSizeTransport = sizeOfProtoHeader + headerSizeTransport headerTotalSizeTransport = sizeOfProtoHeader + headerSizeTransport
) )
@@ -72,6 +79,14 @@ func (m MsgType) String() string {
return "close" return "close"
case MsgTypeHealthCheck: case MsgTypeHealthCheck:
return "health check" return "health check"
case MsgTypeSubscribePeerState:
return "subscribe peer state"
case MsgTypeUnsubscribePeerState:
return "unsubscribe peer state"
case MsgTypePeersOnline:
return "peers online"
case MsgTypePeersWentOffline:
return "peers went offline"
default: default:
return "unknown" return "unknown"
} }
@@ -102,7 +117,9 @@ func DetermineClientMessageType(msg []byte) (MsgType, error) {
MsgTypeAuth, MsgTypeAuth,
MsgTypeTransport, MsgTypeTransport,
MsgTypeClose, MsgTypeClose,
MsgTypeHealthCheck: MsgTypeHealthCheck,
MsgTypeSubscribePeerState,
MsgTypeUnsubscribePeerState:
return msgType, nil return msgType, nil
default: default:
return MsgTypeUnknown, fmt.Errorf("invalid msg type %d", msgType) return MsgTypeUnknown, fmt.Errorf("invalid msg type %d", msgType)
@@ -122,7 +139,9 @@ func DetermineServerMessageType(msg []byte) (MsgType, error) {
MsgTypeAuthResponse, MsgTypeAuthResponse,
MsgTypeTransport, MsgTypeTransport,
MsgTypeClose, MsgTypeClose,
MsgTypeHealthCheck: MsgTypeHealthCheck,
MsgTypePeersOnline,
MsgTypePeersWentOffline:
return msgType, nil return msgType, nil
default: default:
return MsgTypeUnknown, fmt.Errorf("invalid msg type %d", msgType) return MsgTypeUnknown, fmt.Errorf("invalid msg type %d", msgType)
@@ -135,11 +154,7 @@ func DetermineServerMessageType(msg []byte) (MsgType, error) {
// message is used to authenticate the client with the server. The authentication is done using an HMAC method. // message is used to authenticate the client with the server. The authentication is done using an HMAC method.
// The protocol does not limit to use HMAC, it can be any other method. If the authentication failed the server will // The protocol does not limit to use HMAC, it can be any other method. If the authentication failed the server will
// close the network connection without any response. // close the network connection without any response.
func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) { func MarshalHelloMsg(peerID PeerID, additions []byte) ([]byte, error) {
if len(peerID) != IDSize {
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
}
msg := make([]byte, sizeOfProtoHeader+sizeOfMagicByte, sizeOfProtoHeader+headerSizeHello+len(additions)) msg := make([]byte, sizeOfProtoHeader+sizeOfMagicByte, sizeOfProtoHeader+headerSizeHello+len(additions))
msg[0] = byte(CurrentProtocolVersion) msg[0] = byte(CurrentProtocolVersion)
@@ -147,7 +162,7 @@ func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) {
copy(msg[sizeOfProtoHeader:sizeOfProtoHeader+sizeOfMagicByte], magicHeader) copy(msg[sizeOfProtoHeader:sizeOfProtoHeader+sizeOfMagicByte], magicHeader)
msg = append(msg, peerID...) msg = append(msg, peerID[:]...)
msg = append(msg, additions...) msg = append(msg, additions...)
return msg, nil return msg, nil
@@ -156,7 +171,7 @@ func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) {
// Deprecated: Use UnmarshalAuthMsg instead. // Deprecated: Use UnmarshalAuthMsg instead.
// UnmarshalHelloMsg extracts peerID and the additional data from the hello message. The Additional data is used to // UnmarshalHelloMsg extracts peerID and the additional data from the hello message. The Additional data is used to
// authenticate the client with the server. // authenticate the client with the server.
func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) { func UnmarshalHelloMsg(msg []byte) (*PeerID, []byte, error) {
if len(msg) < sizeOfProtoHeader+headerSizeHello { if len(msg) < sizeOfProtoHeader+headerSizeHello {
return nil, nil, ErrInvalidMessageLength return nil, nil, ErrInvalidMessageLength
} }
@@ -164,7 +179,9 @@ func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) {
return nil, nil, errors.New("invalid magic header") return nil, nil, errors.New("invalid magic header")
} }
return msg[sizeOfProtoHeader+sizeOfMagicByte : sizeOfProtoHeader+headerSizeHello], msg[headerSizeHello:], nil peerID := PeerID(msg[sizeOfProtoHeader+sizeOfMagicByte : sizeOfProtoHeader+headerSizeHello])
return &peerID, msg[headerSizeHello:], nil
} }
// Deprecated: Use MarshalAuthResponse instead. // Deprecated: Use MarshalAuthResponse instead.
@@ -197,34 +214,33 @@ func UnmarshalHelloResponse(msg []byte) ([]byte, error) {
// message is used to authenticate the client with the server. The authentication is done using an HMAC method. // message is used to authenticate the client with the server. The authentication is done using an HMAC method.
// The protocol does not limit to use HMAC, it can be any other method. If the authentication failed the server will // The protocol does not limit to use HMAC, it can be any other method. If the authentication failed the server will
// close the network connection without any response. // close the network connection without any response.
func MarshalAuthMsg(peerID []byte, authPayload []byte) ([]byte, error) { func MarshalAuthMsg(peerID PeerID, authPayload []byte) ([]byte, error) {
if len(peerID) != IDSize { if headerTotalSizeAuth+len(authPayload) > MaxHandshakeSize {
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID)) return nil, fmt.Errorf("too large auth payload")
} }
msg := make([]byte, sizeOfProtoHeader+sizeOfMagicByte, headerTotalSizeAuth+len(authPayload)) msg := make([]byte, headerTotalSizeAuth+len(authPayload))
msg[0] = byte(CurrentProtocolVersion) msg[0] = byte(CurrentProtocolVersion)
msg[1] = byte(MsgTypeAuth) msg[1] = byte(MsgTypeAuth)
copy(msg[sizeOfProtoHeader:], magicHeader) copy(msg[sizeOfProtoHeader:], magicHeader)
copy(msg[offsetAuthPeerID:], peerID[:])
msg = append(msg, peerID...) copy(msg[headerTotalSizeAuth:], authPayload)
msg = append(msg, authPayload...)
return msg, nil return msg, nil
} }
// UnmarshalAuthMsg extracts peerID and the auth payload from the message // UnmarshalAuthMsg extracts peerID and the auth payload from the message
func UnmarshalAuthMsg(msg []byte) ([]byte, []byte, error) { func UnmarshalAuthMsg(msg []byte) (*PeerID, []byte, error) {
if len(msg) < headerTotalSizeAuth { if len(msg) < headerTotalSizeAuth {
return nil, nil, ErrInvalidMessageLength return nil, nil, ErrInvalidMessageLength
} }
// Validate the magic header
if !bytes.Equal(msg[offsetMagicByte:offsetMagicByte+sizeOfMagicByte], magicHeader) { if !bytes.Equal(msg[offsetMagicByte:offsetMagicByte+sizeOfMagicByte], magicHeader) {
return nil, nil, errors.New("invalid magic header") return nil, nil, errors.New("invalid magic header")
} }
return msg[offsetAuthPeerID:headerTotalSizeAuth], msg[headerTotalSizeAuth:], nil peerID := PeerID(msg[offsetAuthPeerID:headerTotalSizeAuth])
return &peerID, msg[headerTotalSizeAuth:], nil
} }
// MarshalAuthResponse creates a response message to the auth. // MarshalAuthResponse creates a response message to the auth.
@@ -268,45 +284,48 @@ func MarshalCloseMsg() []byte {
// MarshalTransportMsg creates a transport message. // MarshalTransportMsg creates a transport message.
// The transport message is used to exchange data between peers. The message contains the data to be exchanged and the // The transport message is used to exchange data between peers. The message contains the data to be exchanged and the
// destination peer hashed ID. // destination peer hashed ID.
func MarshalTransportMsg(peerID, payload []byte) ([]byte, error) { func MarshalTransportMsg(peerID PeerID, payload []byte) ([]byte, error) {
if len(peerID) != IDSize { // todo validate size
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID)) msg := make([]byte, headerTotalSizeTransport+len(payload))
}
msg := make([]byte, headerTotalSizeTransport, headerTotalSizeTransport+len(payload))
msg[0] = byte(CurrentProtocolVersion) msg[0] = byte(CurrentProtocolVersion)
msg[1] = byte(MsgTypeTransport) msg[1] = byte(MsgTypeTransport)
copy(msg[sizeOfProtoHeader:], peerID) copy(msg[sizeOfProtoHeader:], peerID[:])
msg = append(msg, payload...) copy(msg[sizeOfProtoHeader+peerIDSize:], payload)
return msg, nil return msg, nil
} }
// UnmarshalTransportMsg extracts the peerID and the payload from the transport message. // UnmarshalTransportMsg extracts the peerID and the payload from the transport message.
func UnmarshalTransportMsg(buf []byte) ([]byte, []byte, error) { func UnmarshalTransportMsg(buf []byte) (*PeerID, []byte, error) {
if len(buf) < headerTotalSizeTransport { if len(buf) < headerTotalSizeTransport {
return nil, nil, ErrInvalidMessageLength return nil, nil, ErrInvalidMessageLength
} }
return buf[offsetTransportID:headerTotalSizeTransport], buf[headerTotalSizeTransport:], nil const offsetEnd = offsetTransportID + peerIDSize
var peerID PeerID
copy(peerID[:], buf[offsetTransportID:offsetEnd])
return &peerID, buf[headerTotalSizeTransport:], nil
} }
// UnmarshalTransportID extracts the peerID from the transport message. // UnmarshalTransportID extracts the peerID from the transport message.
func UnmarshalTransportID(buf []byte) ([]byte, error) { func UnmarshalTransportID(buf []byte) (*PeerID, error) {
if len(buf) < headerTotalSizeTransport { if len(buf) < headerTotalSizeTransport {
return nil, ErrInvalidMessageLength return nil, ErrInvalidMessageLength
} }
return buf[offsetTransportID:headerTotalSizeTransport], nil
const offsetEnd = offsetTransportID + peerIDSize
var id PeerID
copy(id[:], buf[offsetTransportID:offsetEnd])
return &id, nil
} }
// UpdateTransportMsg updates the peerID in the transport message. // UpdateTransportMsg updates the peerID in the transport message.
// With this function the server can reuse the given byte slice to update the peerID in the transport message. So do // With this function the server can reuse the given byte slice to update the peerID in the transport message. So do
// need to allocate a new byte slice. // need to allocate a new byte slice.
func UpdateTransportMsg(msg []byte, peerID []byte) error { func UpdateTransportMsg(msg []byte, peerID PeerID) error {
if len(msg) < offsetTransportID+len(peerID) { if len(msg) < offsetTransportID+peerIDSize {
return ErrInvalidMessageLength return ErrInvalidMessageLength
} }
copy(msg[offsetTransportID:], peerID) copy(msg[offsetTransportID:], peerID[:])
return nil return nil
} }

View File

@@ -5,7 +5,7 @@ import (
) )
func TestMarshalHelloMsg(t *testing.T) { func TestMarshalHelloMsg(t *testing.T) {
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+") peerID := HashID("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
msg, err := MarshalHelloMsg(peerID, nil) msg, err := MarshalHelloMsg(peerID, nil)
if err != nil { if err != nil {
t.Fatalf("error: %v", err) t.Fatalf("error: %v", err)
@@ -24,13 +24,13 @@ func TestMarshalHelloMsg(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("error: %v", err) t.Fatalf("error: %v", err)
} }
if string(receivedPeerID) != string(peerID) { if receivedPeerID.String() != peerID.String() {
t.Errorf("expected %s, got %s", peerID, receivedPeerID) t.Errorf("expected %s, got %s", peerID, receivedPeerID)
} }
} }
func TestMarshalAuthMsg(t *testing.T) { func TestMarshalAuthMsg(t *testing.T) {
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+") peerID := HashID("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
msg, err := MarshalAuthMsg(peerID, []byte{}) msg, err := MarshalAuthMsg(peerID, []byte{})
if err != nil { if err != nil {
t.Fatalf("error: %v", err) t.Fatalf("error: %v", err)
@@ -49,7 +49,7 @@ func TestMarshalAuthMsg(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("error: %v", err) t.Fatalf("error: %v", err)
} }
if string(receivedPeerID) != string(peerID) { if receivedPeerID.String() != peerID.String() {
t.Errorf("expected %s, got %s", peerID, receivedPeerID) t.Errorf("expected %s, got %s", peerID, receivedPeerID)
} }
} }
@@ -80,7 +80,7 @@ func TestMarshalAuthResponse(t *testing.T) {
} }
func TestMarshalTransportMsg(t *testing.T) { func TestMarshalTransportMsg(t *testing.T) {
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+") peerID := HashID("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
payload := []byte("payload") payload := []byte("payload")
msg, err := MarshalTransportMsg(peerID, payload) msg, err := MarshalTransportMsg(peerID, payload)
if err != nil { if err != nil {
@@ -101,7 +101,7 @@ func TestMarshalTransportMsg(t *testing.T) {
t.Fatalf("failed to unmarshal transport id: %v", err) t.Fatalf("failed to unmarshal transport id: %v", err)
} }
if string(uPeerID) != string(peerID) { if uPeerID.String() != peerID.String() {
t.Errorf("expected %s, got %s", peerID, uPeerID) t.Errorf("expected %s, got %s", peerID, uPeerID)
} }
@@ -110,8 +110,8 @@ func TestMarshalTransportMsg(t *testing.T) {
t.Fatalf("error: %v", err) t.Fatalf("error: %v", err)
} }
if string(id) != string(peerID) { if id.String() != peerID.String() {
t.Errorf("expected %s, got %s", peerID, id) t.Errorf("expected: '%s', got: '%s'", peerID, id)
} }
if string(respPayload) != string(payload) { if string(respPayload) != string(payload) {

View File

@@ -0,0 +1,92 @@
package messages
import (
"fmt"
)
func MarshalSubPeerStateMsg(ids []PeerID) ([][]byte, error) {
return marshalPeerIDs(ids, byte(MsgTypeSubscribePeerState))
}
func UnmarshalSubPeerStateMsg(buf []byte) ([]PeerID, error) {
return unmarshalPeerIDs(buf)
}
func MarshalUnsubPeerStateMsg(ids []PeerID) ([][]byte, error) {
return marshalPeerIDs(ids, byte(MsgTypeUnsubscribePeerState))
}
func UnmarshalUnsubPeerStateMsg(buf []byte) ([]PeerID, error) {
return unmarshalPeerIDs(buf)
}
func MarshalPeersOnline(ids []PeerID) ([][]byte, error) {
return marshalPeerIDs(ids, byte(MsgTypePeersOnline))
}
func UnmarshalPeersOnlineMsg(buf []byte) ([]PeerID, error) {
return unmarshalPeerIDs(buf)
}
func MarshalPeersWentOffline(ids []PeerID) ([][]byte, error) {
return marshalPeerIDs(ids, byte(MsgTypePeersWentOffline))
}
func UnMarshalPeersWentOffline(buf []byte) ([]PeerID, error) {
return unmarshalPeerIDs(buf)
}
// marshalPeerIDs is a generic function to marshal peer IDs with a specific message type
func marshalPeerIDs(ids []PeerID, msgType byte) ([][]byte, error) {
if len(ids) == 0 {
return nil, fmt.Errorf("no list of peer ids provided")
}
const maxPeersPerMessage = (MaxMessageSize - sizeOfProtoHeader) / peerIDSize
var messages [][]byte
for i := 0; i < len(ids); i += maxPeersPerMessage {
end := i + maxPeersPerMessage
if end > len(ids) {
end = len(ids)
}
chunk := ids[i:end]
totalSize := sizeOfProtoHeader + len(chunk)*peerIDSize
buf := make([]byte, totalSize)
buf[0] = byte(CurrentProtocolVersion)
buf[1] = msgType
offset := sizeOfProtoHeader
for _, id := range chunk {
copy(buf[offset:], id[:])
offset += peerIDSize
}
messages = append(messages, buf)
}
return messages, nil
}
// unmarshalPeerIDs is a generic function to unmarshal peer IDs from a buffer
func unmarshalPeerIDs(buf []byte) ([]PeerID, error) {
if len(buf) < sizeOfProtoHeader {
return nil, fmt.Errorf("invalid message format")
}
if (len(buf)-sizeOfProtoHeader)%peerIDSize != 0 {
return nil, fmt.Errorf("invalid peer list size: %d", len(buf)-sizeOfProtoHeader)
}
numIDs := (len(buf) - sizeOfProtoHeader) / peerIDSize
ids := make([]PeerID, numIDs)
offset := sizeOfProtoHeader
for i := 0; i < numIDs; i++ {
copy(ids[i][:], buf[offset:offset+peerIDSize])
offset += peerIDSize
}
return ids, nil
}

View File

@@ -0,0 +1,144 @@
package messages
import (
"bytes"
"testing"
)
const (
testPeerCount = 10
)
// Helper function to generate test PeerIDs
func generateTestPeerIDs(n int) []PeerID {
ids := make([]PeerID, n)
for i := 0; i < n; i++ {
for j := 0; j < peerIDSize; j++ {
ids[i][j] = byte(i + j)
}
}
return ids
}
// Helper function to compare slices of PeerID
func peerIDEqual(a, b []PeerID) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if !bytes.Equal(a[i][:], b[i][:]) {
return false
}
}
return true
}
func TestMarshalUnmarshalSubPeerState(t *testing.T) {
ids := generateTestPeerIDs(testPeerCount)
msgs, err := MarshalSubPeerStateMsg(ids)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
var allIDs []PeerID
for _, msg := range msgs {
decoded, err := UnmarshalSubPeerStateMsg(msg)
if err != nil {
t.Fatalf("unmarshal failed: %v", err)
}
allIDs = append(allIDs, decoded...)
}
if !peerIDEqual(ids, allIDs) {
t.Errorf("expected %v, got %v", ids, allIDs)
}
}
func TestMarshalSubPeerState_EmptyInput(t *testing.T) {
_, err := MarshalSubPeerStateMsg([]PeerID{})
if err == nil {
t.Errorf("expected error for empty input")
}
}
func TestUnmarshalSubPeerState_Invalid(t *testing.T) {
// Too short
_, err := UnmarshalSubPeerStateMsg([]byte{1})
if err == nil {
t.Errorf("expected error for short input")
}
// Misaligned length
buf := make([]byte, sizeOfProtoHeader+1)
_, err = UnmarshalSubPeerStateMsg(buf)
if err == nil {
t.Errorf("expected error for misaligned input")
}
}
func TestMarshalUnmarshalPeersOnline(t *testing.T) {
ids := generateTestPeerIDs(testPeerCount)
msgs, err := MarshalPeersOnline(ids)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
var allIDs []PeerID
for _, msg := range msgs {
decoded, err := UnmarshalPeersOnlineMsg(msg)
if err != nil {
t.Fatalf("unmarshal failed: %v", err)
}
allIDs = append(allIDs, decoded...)
}
if !peerIDEqual(ids, allIDs) {
t.Errorf("expected %v, got %v", ids, allIDs)
}
}
func TestMarshalPeersOnline_EmptyInput(t *testing.T) {
_, err := MarshalPeersOnline([]PeerID{})
if err == nil {
t.Errorf("expected error for empty input")
}
}
func TestUnmarshalPeersOnline_Invalid(t *testing.T) {
_, err := UnmarshalPeersOnlineMsg([]byte{1})
if err == nil {
t.Errorf("expected error for short input")
}
}
func TestMarshalUnmarshalPeersWentOffline(t *testing.T) {
ids := generateTestPeerIDs(testPeerCount)
msgs, err := MarshalPeersWentOffline(ids)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
var allIDs []PeerID
for _, msg := range msgs {
// MarshalPeersWentOffline shares no unmarshal function, so reuse PeersOnline
decoded, err := UnmarshalPeersOnlineMsg(msg)
if err != nil {
t.Fatalf("unmarshal failed: %v", err)
}
allIDs = append(allIDs, decoded...)
}
if !peerIDEqual(ids, allIDs) {
t.Errorf("expected %v, got %v", ids, allIDs)
}
}
func TestMarshalPeersWentOffline_EmptyInput(t *testing.T) {
_, err := MarshalPeersWentOffline([]PeerID{})
if err == nil {
t.Errorf("expected error for empty input")
}
}

View File

@@ -6,7 +6,6 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/auth"
"github.com/netbirdio/netbird/relay/messages" "github.com/netbirdio/netbird/relay/messages"
//nolint:staticcheck //nolint:staticcheck
"github.com/netbirdio/netbird/relay/messages/address" "github.com/netbirdio/netbird/relay/messages/address"
@@ -14,6 +13,12 @@ import (
authmsg "github.com/netbirdio/netbird/relay/messages/auth" authmsg "github.com/netbirdio/netbird/relay/messages/auth"
) )
type Validator interface {
Validate(any) error
// Deprecated: Use Validate instead.
ValidateHelloMsgType(any) error
}
// preparedMsg contains the marshalled success response messages // preparedMsg contains the marshalled success response messages
type preparedMsg struct { type preparedMsg struct {
responseHelloMsg []byte responseHelloMsg []byte
@@ -54,14 +59,14 @@ func marshalResponseHelloMsg(instanceURL string) ([]byte, error) {
type handshake struct { type handshake struct {
conn net.Conn conn net.Conn
validator auth.Validator validator Validator
preparedMsg *preparedMsg preparedMsg *preparedMsg
handshakeMethodAuth bool handshakeMethodAuth bool
peerID string peerID *messages.PeerID
} }
func (h *handshake) handshakeReceive() ([]byte, error) { func (h *handshake) handshakeReceive() (*messages.PeerID, error) {
buf := make([]byte, messages.MaxHandshakeSize) buf := make([]byte, messages.MaxHandshakeSize)
n, err := h.conn.Read(buf) n, err := h.conn.Read(buf)
if err != nil { if err != nil {
@@ -80,17 +85,14 @@ func (h *handshake) handshakeReceive() ([]byte, error) {
return nil, fmt.Errorf("determine message type from %s: %w", h.conn.RemoteAddr(), err) return nil, fmt.Errorf("determine message type from %s: %w", h.conn.RemoteAddr(), err)
} }
var ( var peerID *messages.PeerID
bytePeerID []byte
peerID string
)
switch msgType { switch msgType {
//nolint:staticcheck //nolint:staticcheck
case messages.MsgTypeHello: case messages.MsgTypeHello:
bytePeerID, peerID, err = h.handleHelloMsg(buf) peerID, err = h.handleHelloMsg(buf)
case messages.MsgTypeAuth: case messages.MsgTypeAuth:
h.handshakeMethodAuth = true h.handshakeMethodAuth = true
bytePeerID, peerID, err = h.handleAuthMsg(buf) peerID, err = h.handleAuthMsg(buf)
default: default:
return nil, fmt.Errorf("invalid message type %d from %s", msgType, h.conn.RemoteAddr()) return nil, fmt.Errorf("invalid message type %d from %s", msgType, h.conn.RemoteAddr())
} }
@@ -98,7 +100,7 @@ func (h *handshake) handshakeReceive() ([]byte, error) {
return nil, err return nil, err
} }
h.peerID = peerID h.peerID = peerID
return bytePeerID, nil return peerID, nil
} }
func (h *handshake) handshakeResponse() error { func (h *handshake) handshakeResponse() error {
@@ -116,40 +118,37 @@ func (h *handshake) handshakeResponse() error {
return nil return nil
} }
func (h *handshake) handleHelloMsg(buf []byte) ([]byte, string, error) { func (h *handshake) handleHelloMsg(buf []byte) (*messages.PeerID, error) {
//nolint:staticcheck //nolint:staticcheck
rawPeerID, authData, err := messages.UnmarshalHelloMsg(buf) peerID, authData, err := messages.UnmarshalHelloMsg(buf)
if err != nil { if err != nil {
return nil, "", fmt.Errorf("unmarshal hello message: %w", err) return nil, fmt.Errorf("unmarshal hello message: %w", err)
} }
peerID := messages.HashIDToString(rawPeerID)
log.Warnf("peer %s (%s) is using deprecated initial message type", peerID, h.conn.RemoteAddr()) log.Warnf("peer %s (%s) is using deprecated initial message type", peerID, h.conn.RemoteAddr())
authMsg, err := authmsg.UnmarshalMsg(authData) authMsg, err := authmsg.UnmarshalMsg(authData)
if err != nil { if err != nil {
return nil, "", fmt.Errorf("unmarshal auth message: %w", err) return nil, fmt.Errorf("unmarshal auth message: %w", err)
} }
//nolint:staticcheck //nolint:staticcheck
if err := h.validator.ValidateHelloMsgType(authMsg.AdditionalData); err != nil { if err := h.validator.ValidateHelloMsgType(authMsg.AdditionalData); err != nil {
return nil, "", fmt.Errorf("validate %s (%s): %w", peerID, h.conn.RemoteAddr(), err) return nil, fmt.Errorf("validate %s (%s): %w", peerID, h.conn.RemoteAddr(), err)
} }
return rawPeerID, peerID, nil return peerID, nil
} }
func (h *handshake) handleAuthMsg(buf []byte) ([]byte, string, error) { func (h *handshake) handleAuthMsg(buf []byte) (*messages.PeerID, error) {
rawPeerID, authPayload, err := messages.UnmarshalAuthMsg(buf) rawPeerID, authPayload, err := messages.UnmarshalAuthMsg(buf)
if err != nil { if err != nil {
return nil, "", fmt.Errorf("unmarshal hello message: %w", err) return nil, fmt.Errorf("unmarshal hello message: %w", err)
} }
peerID := messages.HashIDToString(rawPeerID)
if err := h.validator.Validate(authPayload); err != nil { if err := h.validator.Validate(authPayload); err != nil {
return nil, "", fmt.Errorf("validate %s (%s): %w", peerID, h.conn.RemoteAddr(), err) return nil, fmt.Errorf("validate %s (%s): %w", rawPeerID.String(), h.conn.RemoteAddr(), err)
} }
return rawPeerID, peerID, nil return rawPeerID, nil
} }

View File

@@ -12,43 +12,50 @@ import (
"github.com/netbirdio/netbird/relay/healthcheck" "github.com/netbirdio/netbird/relay/healthcheck"
"github.com/netbirdio/netbird/relay/messages" "github.com/netbirdio/netbird/relay/messages"
"github.com/netbirdio/netbird/relay/metrics" "github.com/netbirdio/netbird/relay/metrics"
"github.com/netbirdio/netbird/relay/server/store"
) )
const ( const (
bufferSize = 8820 bufferSize = messages.MaxMessageSize
errCloseConn = "failed to close connection to peer: %s" errCloseConn = "failed to close connection to peer: %s"
) )
// Peer represents a peer connection // Peer represents a peer connection
type Peer struct { type Peer struct {
metrics *metrics.Metrics metrics *metrics.Metrics
log *log.Entry log *log.Entry
idS string id messages.PeerID
idB []byte conn net.Conn
conn net.Conn connMu sync.RWMutex
connMu sync.RWMutex store *store.Store
store *Store notifier *store.PeerNotifier
peersListener *store.Listener
} }
// NewPeer creates a new Peer instance and prepare custom logging // NewPeer creates a new Peer instance and prepare custom logging
func NewPeer(metrics *metrics.Metrics, id []byte, conn net.Conn, store *Store) *Peer { func NewPeer(metrics *metrics.Metrics, id messages.PeerID, conn net.Conn, store *store.Store, notifier *store.PeerNotifier) *Peer {
stringID := messages.HashIDToString(id) p := &Peer{
return &Peer{ metrics: metrics,
metrics: metrics, log: log.WithField("peer_id", id.String()),
log: log.WithField("peer_id", stringID), id: id,
idS: stringID, conn: conn,
idB: id, store: store,
conn: conn, notifier: notifier,
store: store,
} }
return p
} }
// Work reads data from the connection // Work reads data from the connection
// It manages the protocol (healthcheck, transport, close). Read the message and determine the message type and handle // It manages the protocol (healthcheck, transport, close). Read the message and determine the message type and handle
// the message accordingly. // the message accordingly.
func (p *Peer) Work() { func (p *Peer) Work() {
p.peersListener = p.notifier.NewListener(p.sendPeersOnline, p.sendPeersWentOffline)
defer func() { defer func() {
p.notifier.RemoveListener(p.peersListener)
if err := p.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { if err := p.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
p.log.Errorf(errCloseConn, err) p.log.Errorf(errCloseConn, err)
} }
@@ -94,6 +101,10 @@ func (p *Peer) Work() {
} }
} }
func (p *Peer) ID() messages.PeerID {
return p.id
}
func (p *Peer) handleMsgType(ctx context.Context, msgType messages.MsgType, hc *healthcheck.Sender, n int, msg []byte) { func (p *Peer) handleMsgType(ctx context.Context, msgType messages.MsgType, hc *healthcheck.Sender, n int, msg []byte) {
switch msgType { switch msgType {
case messages.MsgTypeHealthCheck: case messages.MsgTypeHealthCheck:
@@ -107,6 +118,10 @@ func (p *Peer) handleMsgType(ctx context.Context, msgType messages.MsgType, hc *
if err := p.conn.Close(); err != nil { if err := p.conn.Close(); err != nil {
log.Errorf(errCloseConn, err) log.Errorf(errCloseConn, err)
} }
case messages.MsgTypeSubscribePeerState:
p.handleSubscribePeerState(msg)
case messages.MsgTypeUnsubscribePeerState:
p.handleUnsubscribePeerState(msg)
default: default:
p.log.Warnf("received unexpected message type: %s", msgType) p.log.Warnf("received unexpected message type: %s", msgType)
} }
@@ -145,7 +160,7 @@ func (p *Peer) Close() {
// String returns the peer ID // String returns the peer ID
func (p *Peer) String() string { func (p *Peer) String() string {
return p.idS return p.id.String()
} }
func (p *Peer) writeWithTimeout(ctx context.Context, buf []byte) error { func (p *Peer) writeWithTimeout(ctx context.Context, buf []byte) error {
@@ -197,14 +212,14 @@ func (p *Peer) handleTransportMsg(msg []byte) {
return return
} }
stringPeerID := messages.HashIDToString(peerID) item, ok := p.store.Peer(*peerID)
dp, ok := p.store.Peer(stringPeerID)
if !ok { if !ok {
p.log.Debugf("peer not found: %s", stringPeerID) p.log.Debugf("peer not found: %s", peerID)
return return
} }
dp := item.(*Peer)
err = messages.UpdateTransportMsg(msg, p.idB) err = messages.UpdateTransportMsg(msg, p.id)
if err != nil { if err != nil {
p.log.Errorf("failed to update transport message: %s", err) p.log.Errorf("failed to update transport message: %s", err)
return return
@@ -217,3 +232,57 @@ func (p *Peer) handleTransportMsg(msg []byte) {
} }
p.metrics.TransferBytesSent.Add(context.Background(), int64(n)) p.metrics.TransferBytesSent.Add(context.Background(), int64(n))
} }
func (p *Peer) handleSubscribePeerState(msg []byte) {
peerIDs, err := messages.UnmarshalSubPeerStateMsg(msg)
if err != nil {
p.log.Errorf("failed to unmarshal open connection message: %s", err)
return
}
p.log.Debugf("received subscription message for %d peers", len(peerIDs))
onlinePeers := p.peersListener.AddInterestedPeers(peerIDs)
if len(onlinePeers) == 0 {
return
}
p.log.Debugf("response with %d online peers", len(onlinePeers))
p.sendPeersOnline(onlinePeers)
}
func (p *Peer) handleUnsubscribePeerState(msg []byte) {
peerIDs, err := messages.UnmarshalUnsubPeerStateMsg(msg)
if err != nil {
p.log.Errorf("failed to unmarshal open connection message: %s", err)
return
}
p.peersListener.RemoveInterestedPeer(peerIDs)
}
func (p *Peer) sendPeersOnline(peers []messages.PeerID) {
msgs, err := messages.MarshalPeersOnline(peers)
if err != nil {
p.log.Errorf("failed to marshal peer location message: %s", err)
return
}
for n, msg := range msgs {
if _, err := p.Write(msg); err != nil {
p.log.Errorf("failed to write %d. peers offline message: %s", n, err)
}
}
}
func (p *Peer) sendPeersWentOffline(peers []messages.PeerID) {
msgs, err := messages.MarshalPeersWentOffline(peers)
if err != nil {
p.log.Errorf("failed to marshal peer location message: %s", err)
return
}
for n, msg := range msgs {
if _, err := p.Write(msg); err != nil {
p.log.Errorf("failed to write %d. peers offline message: %s", n, err)
}
}
}

Some files were not shown because too many files have changed in this diff Show More