mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-03 07:36:39 +00:00
Compare commits
15 Commits
v0.50.2
...
nb-interfa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
417fa6e833 | ||
|
|
a7af15c4fc | ||
|
|
d6ed9c037e | ||
|
|
40fdeda838 | ||
|
|
f6e9d755e4 | ||
|
|
08fd460867 | ||
|
|
4f74509d55 | ||
|
|
58185ced16 | ||
|
|
e67f44f47c | ||
|
|
b524f486e2 | ||
|
|
0dab03252c | ||
|
|
e49bcc343d | ||
|
|
3e6eede152 | ||
|
|
a76c8eafb4 | ||
|
|
2b9f331980 |
2
.github/workflows/git-town.yml
vendored
2
.github/workflows/git-town.yml
vendored
@@ -16,6 +16,6 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- uses: git-town/action@v1
|
- uses: git-town/action@v1.2.1
|
||||||
with:
|
with:
|
||||||
skip-single-stacks: true
|
skip-single-stacks: true
|
||||||
16
.github/workflows/release.yml
vendored
16
.github/workflows/release.yml
vendored
@@ -9,7 +9,7 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
|
|
||||||
env:
|
env:
|
||||||
SIGN_PIPE_VER: "v0.0.20"
|
SIGN_PIPE_VER: "v0.0.21"
|
||||||
GORELEASER_VER: "v2.3.2"
|
GORELEASER_VER: "v2.3.2"
|
||||||
PRODUCT_NAME: "NetBird"
|
PRODUCT_NAME: "NetBird"
|
||||||
COPYRIGHT: "NetBird GmbH"
|
COPYRIGHT: "NetBird GmbH"
|
||||||
@@ -231,3 +231,17 @@ jobs:
|
|||||||
ref: ${{ env.SIGN_PIPE_VER }}
|
ref: ${{ env.SIGN_PIPE_VER }}
|
||||||
token: ${{ secrets.SIGN_GITHUB_TOKEN }}
|
token: ${{ secrets.SIGN_GITHUB_TOKEN }}
|
||||||
inputs: '{ "tag": "${{ github.ref }}", "skipRelease": false }'
|
inputs: '{ "tag": "${{ github.ref }}", "skipRelease": false }'
|
||||||
|
|
||||||
|
post_on_forum:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
continue-on-error: true
|
||||||
|
needs: [trigger_signer]
|
||||||
|
steps:
|
||||||
|
- uses: Codixer/discourse-topic-github-release-action@v2.0.1
|
||||||
|
with:
|
||||||
|
discourse-api-key: ${{ secrets.DISCOURSE_RELEASES_API_KEY }}
|
||||||
|
discourse-base-url: https://forum.netbird.io
|
||||||
|
discourse-author-username: NetBird
|
||||||
|
discourse-category: 17
|
||||||
|
discourse-tags:
|
||||||
|
releases
|
||||||
|
|||||||
@@ -307,7 +307,7 @@ func getStatusOutput(cmd *cobra.Command, anon bool) string {
|
|||||||
cmd.PrintErrf("Failed to get status: %v\n", err)
|
cmd.PrintErrf("Failed to get status: %v\n", err)
|
||||||
} else {
|
} else {
|
||||||
statusOutputString = nbstatus.ParseToFullDetailSummary(
|
statusOutputString = nbstatus.ParseToFullDetailSummary(
|
||||||
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil),
|
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, ""),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
return statusOutputString
|
return statusOutputString
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ var (
|
|||||||
statusFilter string
|
statusFilter string
|
||||||
ipsFilterMap map[string]struct{}
|
ipsFilterMap map[string]struct{}
|
||||||
prefixNamesFilterMap map[string]struct{}
|
prefixNamesFilterMap map[string]struct{}
|
||||||
|
connectionTypeFilter string
|
||||||
)
|
)
|
||||||
|
|
||||||
var statusCmd = &cobra.Command{
|
var statusCmd = &cobra.Command{
|
||||||
@@ -45,6 +46,7 @@ func init() {
|
|||||||
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200")
|
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200")
|
||||||
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
|
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
|
||||||
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
|
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
|
||||||
|
statusCmd.PersistentFlags().StringVar(&connectionTypeFilter, "filter-by-connection-type", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P")
|
||||||
}
|
}
|
||||||
|
|
||||||
func statusFunc(cmd *cobra.Command, args []string) error {
|
func statusFunc(cmd *cobra.Command, args []string) error {
|
||||||
@@ -89,7 +91,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap)
|
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter)
|
||||||
var statusOutputString string
|
var statusOutputString string
|
||||||
switch {
|
switch {
|
||||||
case detailFlag:
|
case detailFlag:
|
||||||
@@ -156,6 +158,15 @@ func parseFilters() error {
|
|||||||
enableDetailFlagWhenFilterFlag()
|
enableDetailFlagWhenFilterFlag()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
switch strings.ToLower(connectionTypeFilter) {
|
||||||
|
case "", "p2p", "relayed":
|
||||||
|
if strings.ToLower(connectionTypeFilter) != "" {
|
||||||
|
enableDetailFlagWhenFilterFlag()
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("wrong connection-type filter, should be one of P2P|Relayed, got: %s", connectionTypeFilter)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -109,7 +109,7 @@ func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc
|
|||||||
}
|
}
|
||||||
|
|
||||||
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
|
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
|
||||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil)
|
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &mgmt.MockIntegratedValidator{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -34,14 +34,14 @@ func NewActivityRecorder() *ActivityRecorder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetLastActivities returns a snapshot of peer last activity
|
// GetLastActivities returns a snapshot of peer last activity
|
||||||
func (r *ActivityRecorder) GetLastActivities() map[string]time.Time {
|
func (r *ActivityRecorder) GetLastActivities() map[string]monotime.Time {
|
||||||
r.mu.RLock()
|
r.mu.RLock()
|
||||||
defer r.mu.RUnlock()
|
defer r.mu.RUnlock()
|
||||||
|
|
||||||
activities := make(map[string]time.Time, len(r.peers))
|
activities := make(map[string]monotime.Time, len(r.peers))
|
||||||
for key, record := range r.peers {
|
for key, record := range r.peers {
|
||||||
unixNano := record.LastActivity.Load()
|
monoTime := record.LastActivity.Load()
|
||||||
activities[key] = time.Unix(0, unixNano)
|
activities[key] = monotime.Time(monoTime)
|
||||||
}
|
}
|
||||||
return activities
|
return activities
|
||||||
}
|
}
|
||||||
@@ -51,18 +51,20 @@ func (r *ActivityRecorder) UpsertAddress(publicKey string, address netip.AddrPor
|
|||||||
r.mu.Lock()
|
r.mu.Lock()
|
||||||
defer r.mu.Unlock()
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
if pr, exists := r.peers[publicKey]; exists {
|
var record *PeerRecord
|
||||||
delete(r.addrToPeer, pr.Address)
|
record, exists := r.peers[publicKey]
|
||||||
pr.Address = address
|
if exists {
|
||||||
|
delete(r.addrToPeer, record.Address)
|
||||||
|
record.Address = address
|
||||||
} else {
|
} else {
|
||||||
record := &PeerRecord{
|
record = &PeerRecord{
|
||||||
Address: address,
|
Address: address,
|
||||||
}
|
}
|
||||||
record.LastActivity.Store(monotime.Now())
|
record.LastActivity.Store(int64(monotime.Now()))
|
||||||
r.peers[publicKey] = record
|
r.peers[publicKey] = record
|
||||||
}
|
}
|
||||||
|
|
||||||
r.addrToPeer[address] = r.peers[publicKey]
|
r.addrToPeer[address] = record
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *ActivityRecorder) Remove(publicKey string) {
|
func (r *ActivityRecorder) Remove(publicKey string) {
|
||||||
@@ -84,7 +86,7 @@ func (r *ActivityRecorder) record(address netip.AddrPort) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
now := monotime.Now()
|
now := int64(monotime.Now())
|
||||||
last := record.LastActivity.Load()
|
last := record.LastActivity.Load()
|
||||||
if now-last < saveFrequency {
|
if now-last < saveFrequency {
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/monotime"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestActivityRecorder_GetLastActivities(t *testing.T) {
|
func TestActivityRecorder_GetLastActivities(t *testing.T) {
|
||||||
@@ -17,11 +19,7 @@ func TestActivityRecorder_GetLastActivities(t *testing.T) {
|
|||||||
t.Fatalf("Expected activity for peer %s, but got none", peer)
|
t.Fatalf("Expected activity for peer %s, but got none", peer)
|
||||||
}
|
}
|
||||||
|
|
||||||
if p.IsZero() {
|
if monotime.Since(p) > 5*time.Second {
|
||||||
t.Fatalf("Expected activity for peer %s, but got zero", peer)
|
|
||||||
}
|
|
||||||
|
|
||||||
if p.Before(time.Now().Add(-2 * time.Minute)) {
|
|
||||||
t.Fatalf("Expected activity for peer %s to be recent, but got %v", peer, p)
|
t.Fatalf("Expected activity for peer %s to be recent, but got %v", peer, p)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
15
client/iface/bind/control.go
Normal file
15
client/iface/bind/control.go
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
package bind
|
||||||
|
|
||||||
|
import (
|
||||||
|
wireguard "golang.zx2c4.com/wireguard/conn"
|
||||||
|
|
||||||
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TODO: This is most likely obsolete since the control fns should be called by the wrapped udpconn (ice_bind.go)
|
||||||
|
func init() {
|
||||||
|
listener := nbnet.NewListener()
|
||||||
|
if listener.ListenConfig.Control != nil {
|
||||||
|
*wireguard.ControlFns = append(*wireguard.ControlFns, listener.ListenConfig.Control)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,12 +0,0 @@
|
|||||||
package bind
|
|
||||||
|
|
||||||
import (
|
|
||||||
wireguard "golang.zx2c4.com/wireguard/conn"
|
|
||||||
|
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
|
||||||
)
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
// ControlFns is not thread safe and should only be modified during init.
|
|
||||||
*wireguard.ControlFns = append(*wireguard.ControlFns, nbnet.ControlProtectSocket)
|
|
||||||
}
|
|
||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
wgConn "golang.zx2c4.com/wireguard/conn"
|
wgConn "golang.zx2c4.com/wireguard/conn"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
type RecvMessage struct {
|
type RecvMessage struct {
|
||||||
@@ -153,7 +154,7 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
|
|||||||
|
|
||||||
s.udpMux = NewUniversalUDPMuxDefault(
|
s.udpMux = NewUniversalUDPMuxDefault(
|
||||||
UniversalUDPMuxParams{
|
UniversalUDPMuxParams{
|
||||||
UDPConn: conn,
|
UDPConn: nbnet.WrapUDPConn(conn),
|
||||||
Net: s.transportNet,
|
Net: s.transportNet,
|
||||||
FilterFn: s.filterFn,
|
FilterFn: s.filterFn,
|
||||||
WGAddress: s.address,
|
WGAddress: s.address,
|
||||||
|
|||||||
@@ -296,14 +296,20 @@ func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
m.addressMapMu.Lock()
|
var allAddresses []string
|
||||||
defer m.addressMapMu.Unlock()
|
|
||||||
|
|
||||||
for _, c := range removedConns {
|
for _, c := range removedConns {
|
||||||
addresses := c.getAddresses()
|
addresses := c.getAddresses()
|
||||||
for _, addr := range addresses {
|
allAddresses = append(allAddresses, addresses...)
|
||||||
delete(m.addressMap, addr)
|
}
|
||||||
}
|
|
||||||
|
m.addressMapMu.Lock()
|
||||||
|
for _, addr := range allAddresses {
|
||||||
|
delete(m.addressMap, addr)
|
||||||
|
}
|
||||||
|
m.addressMapMu.Unlock()
|
||||||
|
|
||||||
|
for _, addr := range allAddresses {
|
||||||
|
m.notifyAddressRemoval(addr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -351,14 +357,13 @@ func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string)
|
|||||||
}
|
}
|
||||||
|
|
||||||
m.addressMapMu.Lock()
|
m.addressMapMu.Lock()
|
||||||
defer m.addressMapMu.Unlock()
|
|
||||||
|
|
||||||
existing, ok := m.addressMap[addr]
|
existing, ok := m.addressMap[addr]
|
||||||
if !ok {
|
if !ok {
|
||||||
existing = []*udpMuxedConn{}
|
existing = []*udpMuxedConn{}
|
||||||
}
|
}
|
||||||
existing = append(existing, conn)
|
existing = append(existing, conn)
|
||||||
m.addressMap[addr] = existing
|
m.addressMap[addr] = existing
|
||||||
|
m.addressMapMu.Unlock()
|
||||||
|
|
||||||
log.Debugf("ICE: registered %s for %s", addr, conn.params.Key)
|
log.Debugf("ICE: registered %s for %s", addr, conn.params.Key)
|
||||||
}
|
}
|
||||||
@@ -386,12 +391,12 @@ func (m *UDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.Addr) erro
|
|||||||
// If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one
|
// If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one
|
||||||
// muxed connection - one for the SRFLX candidate and the other one for the HOST one.
|
// muxed connection - one for the SRFLX candidate and the other one for the HOST one.
|
||||||
// We will then forward STUN packets to each of these connections.
|
// We will then forward STUN packets to each of these connections.
|
||||||
m.addressMapMu.Lock()
|
m.addressMapMu.RLock()
|
||||||
var destinationConnList []*udpMuxedConn
|
var destinationConnList []*udpMuxedConn
|
||||||
if storedConns, ok := m.addressMap[addr.String()]; ok {
|
if storedConns, ok := m.addressMap[addr.String()]; ok {
|
||||||
destinationConnList = append(destinationConnList, storedConns...)
|
destinationConnList = append(destinationConnList, storedConns...)
|
||||||
}
|
}
|
||||||
m.addressMapMu.Unlock()
|
m.addressMapMu.RUnlock()
|
||||||
|
|
||||||
var isIPv6 bool
|
var isIPv6 bool
|
||||||
if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil {
|
if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil {
|
||||||
|
|||||||
21
client/iface/bind/udp_mux_generic.go
Normal file
21
client/iface/bind/udp_mux_generic.go
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
//go:build !ios
|
||||||
|
|
||||||
|
package bind
|
||||||
|
|
||||||
|
import (
|
||||||
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (m *UDPMuxDefault) notifyAddressRemoval(addr string) {
|
||||||
|
wrapped, ok := m.params.UDPConn.(*UDPConn)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
nbnetConn, ok := wrapped.GetPacketConn().(*nbnet.UDPConn)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
nbnetConn.RemoveAddress(addr)
|
||||||
|
}
|
||||||
7
client/iface/bind/udp_mux_ios.go
Normal file
7
client/iface/bind/udp_mux_ios.go
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
//go:build ios
|
||||||
|
|
||||||
|
package bind
|
||||||
|
|
||||||
|
func (m *UDPMuxDefault) notifyAddressRemoval(addr string) {
|
||||||
|
// iOS doesn't support nbnet hooks, so this is a no-op
|
||||||
|
}
|
||||||
@@ -62,7 +62,7 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
|
|||||||
|
|
||||||
// wrap UDP connection, process server reflexive messages
|
// wrap UDP connection, process server reflexive messages
|
||||||
// before they are passed to the UDPMux connection handler (connWorker)
|
// before they are passed to the UDPMux connection handler (connWorker)
|
||||||
m.params.UDPConn = &udpConn{
|
m.params.UDPConn = &UDPConn{
|
||||||
PacketConn: params.UDPConn,
|
PacketConn: params.UDPConn,
|
||||||
mux: m,
|
mux: m,
|
||||||
logger: params.Logger,
|
logger: params.Logger,
|
||||||
@@ -70,7 +70,6 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
|
|||||||
address: params.WGAddress,
|
address: params.WGAddress,
|
||||||
}
|
}
|
||||||
|
|
||||||
// embed UDPMux
|
|
||||||
udpMuxParams := UDPMuxParams{
|
udpMuxParams := UDPMuxParams{
|
||||||
Logger: params.Logger,
|
Logger: params.Logger,
|
||||||
UDPConn: m.params.UDPConn,
|
UDPConn: m.params.UDPConn,
|
||||||
@@ -114,8 +113,8 @@ func (m *UniversalUDPMuxDefault) ReadFromConn(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// udpConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets
|
// UDPConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets
|
||||||
type udpConn struct {
|
type UDPConn struct {
|
||||||
net.PacketConn
|
net.PacketConn
|
||||||
mux *UniversalUDPMuxDefault
|
mux *UniversalUDPMuxDefault
|
||||||
logger logging.LeveledLogger
|
logger logging.LeveledLogger
|
||||||
@@ -125,7 +124,12 @@ type udpConn struct {
|
|||||||
address wgaddr.Address
|
address wgaddr.Address
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) {
|
// GetPacketConn returns the underlying PacketConn
|
||||||
|
func (u *UDPConn) GetPacketConn() net.PacketConn {
|
||||||
|
return u.PacketConn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
|
||||||
if u.filterFn == nil {
|
if u.filterFn == nil {
|
||||||
return u.PacketConn.WriteTo(b, addr)
|
return u.PacketConn.WriteTo(b, addr)
|
||||||
}
|
}
|
||||||
@@ -137,21 +141,21 @@ func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) {
|
|||||||
return u.handleUncachedAddress(b, addr)
|
return u.handleUncachedAddress(b, addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *udpConn) handleCachedAddress(isRouted bool, b []byte, addr net.Addr) (int, error) {
|
func (u *UDPConn) handleCachedAddress(isRouted bool, b []byte, addr net.Addr) (int, error) {
|
||||||
if isRouted {
|
if isRouted {
|
||||||
return 0, fmt.Errorf("address %s is part of a routed network, refusing to write", addr)
|
return 0, fmt.Errorf("address %s is part of a routed network, refusing to write", addr)
|
||||||
}
|
}
|
||||||
return u.PacketConn.WriteTo(b, addr)
|
return u.PacketConn.WriteTo(b, addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *udpConn) handleUncachedAddress(b []byte, addr net.Addr) (int, error) {
|
func (u *UDPConn) handleUncachedAddress(b []byte, addr net.Addr) (int, error) {
|
||||||
if err := u.performFilterCheck(addr); err != nil {
|
if err := u.performFilterCheck(addr); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
return u.PacketConn.WriteTo(b, addr)
|
return u.PacketConn.WriteTo(b, addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *udpConn) performFilterCheck(addr net.Addr) error {
|
func (u *UDPConn) performFilterCheck(addr net.Addr) error {
|
||||||
host, err := getHostFromAddr(addr)
|
host, err := getHostFromAddr(addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed to get host from address %s: %v", addr, err)
|
log.Errorf("Failed to get host from address %s: %v", addr, err)
|
||||||
|
|||||||
@@ -11,6 +11,8 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl"
|
"golang.zx2c4.com/wireguard/wgctrl"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/monotime"
|
||||||
)
|
)
|
||||||
|
|
||||||
var zeroKey wgtypes.Key
|
var zeroKey wgtypes.Key
|
||||||
@@ -277,6 +279,6 @@ func (c *KernelConfigurer) GetStats() (map[string]WGStats, error) {
|
|||||||
return stats, nil
|
return stats, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *KernelConfigurer) LastActivities() map[string]time.Time {
|
func (c *KernelConfigurer) LastActivities() map[string]monotime.Time {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,4 +3,4 @@
|
|||||||
package configurer
|
package configurer
|
||||||
|
|
||||||
// WgInterfaceDefault is a default interface name of Netbird
|
// WgInterfaceDefault is a default interface name of Netbird
|
||||||
const WgInterfaceDefault = "wt0"
|
const WgInterfaceDefault = "nb0"
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import (
|
|||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
|
"github.com/netbirdio/netbird/monotime"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -223,7 +224,7 @@ func (c *WGUSPConfigurer) FullStats() (*Stats, error) {
|
|||||||
return parseStatus(c.deviceName, ipcStr)
|
return parseStatus(c.deviceName, ipcStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *WGUSPConfigurer) LastActivities() map[string]time.Time {
|
func (c *WGUSPConfigurer) LastActivities() map[string]monotime.Time {
|
||||||
return c.activityRecorder.GetLastActivities()
|
return c.activityRecorder.GetLastActivities()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
|
"github.com/netbirdio/netbird/monotime"
|
||||||
)
|
)
|
||||||
|
|
||||||
type WGConfigurer interface {
|
type WGConfigurer interface {
|
||||||
@@ -19,5 +20,5 @@ type WGConfigurer interface {
|
|||||||
Close()
|
Close()
|
||||||
GetStats() (map[string]configurer.WGStats, error)
|
GetStats() (map[string]configurer.WGStats, error)
|
||||||
FullStats() (*configurer.Stats, error)
|
FullStats() (*configurer.Stats, error)
|
||||||
LastActivities() map[string]time.Time
|
LastActivities() map[string]monotime.Time
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||||
|
"github.com/netbirdio/netbird/monotime"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -237,7 +238,7 @@ func (w *WGIface) GetStats() (map[string]configurer.WGStats, error) {
|
|||||||
return w.configurer.GetStats()
|
return w.configurer.GetStats()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *WGIface) LastActivities() map[string]time.Time {
|
func (w *WGIface) LastActivities() map[string]monotime.Time {
|
||||||
w.mu.Lock()
|
w.mu.Lock()
|
||||||
defer w.mu.Unlock()
|
defer w.mu.Unlock()
|
||||||
|
|
||||||
|
|||||||
@@ -41,9 +41,12 @@ func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
|
|||||||
}
|
}
|
||||||
t.tundev = nsTunDev
|
t.tundev = nsTunDev
|
||||||
|
|
||||||
skipProxy, err := strconv.ParseBool(os.Getenv(EnvSkipProxy))
|
var skipProxy bool
|
||||||
if err != nil {
|
if val := os.Getenv(EnvSkipProxy); val != "" {
|
||||||
log.Errorf("failed to parse %s: %s", EnvSkipProxy, err)
|
skipProxy, err = strconv.ParseBool(val)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to parse %s: %s", EnvSkipProxy, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if skipProxy {
|
if skipProxy {
|
||||||
return nsTunDev, tunNet, nil
|
return nsTunDev, tunNet, nil
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ProxyBind struct {
|
type ProxyBind struct {
|
||||||
@@ -28,6 +29,17 @@ type ProxyBind struct {
|
|||||||
pausedMu sync.Mutex
|
pausedMu sync.Mutex
|
||||||
paused bool
|
paused bool
|
||||||
isStarted bool
|
isStarted bool
|
||||||
|
|
||||||
|
closeListener *listener.CloseListener
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewProxyBind(bind *bind.ICEBind) *ProxyBind {
|
||||||
|
p := &ProxyBind{
|
||||||
|
Bind: bind,
|
||||||
|
closeListener: listener.NewCloseListener(),
|
||||||
|
}
|
||||||
|
|
||||||
|
return p
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddTurnConn adds a new connection to the bind.
|
// AddTurnConn adds a new connection to the bind.
|
||||||
@@ -54,6 +66,10 @@ func (p *ProxyBind) EndpointAddr() *net.UDPAddr {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *ProxyBind) SetDisconnectListener(disconnected func()) {
|
||||||
|
p.closeListener.SetCloseListener(disconnected)
|
||||||
|
}
|
||||||
|
|
||||||
func (p *ProxyBind) Work() {
|
func (p *ProxyBind) Work() {
|
||||||
if p.remoteConn == nil {
|
if p.remoteConn == nil {
|
||||||
return
|
return
|
||||||
@@ -96,6 +112,9 @@ func (p *ProxyBind) close() error {
|
|||||||
if p.closed {
|
if p.closed {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
p.closeListener.SetCloseListener(nil)
|
||||||
|
|
||||||
p.closed = true
|
p.closed = true
|
||||||
|
|
||||||
p.cancel()
|
p.cancel()
|
||||||
@@ -122,6 +141,7 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
|
|||||||
if ctx.Err() != nil {
|
if ctx.Err() != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
p.closeListener.Notify()
|
||||||
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
|
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,6 +11,8 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
|
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
|
||||||
@@ -26,6 +28,15 @@ type ProxyWrapper struct {
|
|||||||
pausedMu sync.Mutex
|
pausedMu sync.Mutex
|
||||||
paused bool
|
paused bool
|
||||||
isStarted bool
|
isStarted bool
|
||||||
|
|
||||||
|
closeListener *listener.CloseListener
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewProxyWrapper(WgeBPFProxy *WGEBPFProxy) *ProxyWrapper {
|
||||||
|
return &ProxyWrapper{
|
||||||
|
WgeBPFProxy: WgeBPFProxy,
|
||||||
|
closeListener: listener.NewCloseListener(),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
|
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
|
||||||
@@ -43,6 +54,10 @@ func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
|
|||||||
return p.wgEndpointAddr
|
return p.wgEndpointAddr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *ProxyWrapper) SetDisconnectListener(disconnected func()) {
|
||||||
|
p.closeListener.SetCloseListener(disconnected)
|
||||||
|
}
|
||||||
|
|
||||||
func (p *ProxyWrapper) Work() {
|
func (p *ProxyWrapper) Work() {
|
||||||
if p.remoteConn == nil {
|
if p.remoteConn == nil {
|
||||||
return
|
return
|
||||||
@@ -77,6 +92,8 @@ func (e *ProxyWrapper) CloseConn() error {
|
|||||||
|
|
||||||
e.cancel()
|
e.cancel()
|
||||||
|
|
||||||
|
e.closeListener.SetCloseListener(nil)
|
||||||
|
|
||||||
if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||||
return fmt.Errorf("failed to close remote conn: %w", err)
|
return fmt.Errorf("failed to close remote conn: %w", err)
|
||||||
}
|
}
|
||||||
@@ -117,6 +134,7 @@ func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, err
|
|||||||
if ctx.Err() != nil {
|
if ctx.Err() != nil {
|
||||||
return 0, ctx.Err()
|
return 0, ctx.Err()
|
||||||
}
|
}
|
||||||
|
p.closeListener.Notify()
|
||||||
if !errors.Is(err, io.EOF) {
|
if !errors.Is(err, io.EOF) {
|
||||||
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err)
|
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -36,9 +36,8 @@ func (w *KernelFactory) GetProxy() Proxy {
|
|||||||
return udpProxy.NewWGUDPProxy(w.wgPort)
|
return udpProxy.NewWGUDPProxy(w.wgPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &ebpf.ProxyWrapper{
|
return ebpf.NewProxyWrapper(w.ebpfProxy)
|
||||||
WgeBPFProxy: w.ebpfProxy,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *KernelFactory) Free() error {
|
func (w *KernelFactory) Free() error {
|
||||||
|
|||||||
@@ -20,9 +20,7 @@ func NewUSPFactory(iceBind *bind.ICEBind) *USPFactory {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (w *USPFactory) GetProxy() Proxy {
|
func (w *USPFactory) GetProxy() Proxy {
|
||||||
return &proxyBind.ProxyBind{
|
return proxyBind.NewProxyBind(w.bind)
|
||||||
Bind: w.bind,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *USPFactory) Free() error {
|
func (w *USPFactory) Free() error {
|
||||||
|
|||||||
19
client/iface/wgproxy/listener/listener.go
Normal file
19
client/iface/wgproxy/listener/listener.go
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
package listener
|
||||||
|
|
||||||
|
type CloseListener struct {
|
||||||
|
listener func()
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewCloseListener() *CloseListener {
|
||||||
|
return &CloseListener{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CloseListener) SetCloseListener(listener func()) {
|
||||||
|
c.listener = listener
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CloseListener) Notify() {
|
||||||
|
if c.listener != nil {
|
||||||
|
c.listener()
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -12,4 +12,5 @@ type Proxy interface {
|
|||||||
Work() // Work start or resume the proxy
|
Work() // Work start or resume the proxy
|
||||||
Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works.
|
Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works.
|
||||||
CloseConn() error
|
CloseConn() error
|
||||||
|
SetDisconnectListener(disconnected func())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -98,9 +98,7 @@ func TestProxyCloseByRemoteConn(t *testing.T) {
|
|||||||
t.Errorf("failed to free ebpf proxy: %s", err)
|
t.Errorf("failed to free ebpf proxy: %s", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
proxyWrapper := &ebpf.ProxyWrapper{
|
proxyWrapper := ebpf.NewProxyWrapper(ebpfProxy)
|
||||||
WgeBPFProxy: ebpfProxy,
|
|
||||||
}
|
|
||||||
|
|
||||||
tests = append(tests, struct {
|
tests = append(tests, struct {
|
||||||
name string
|
name string
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
cerrors "github.com/netbirdio/netbird/client/errors"
|
cerrors "github.com/netbirdio/netbird/client/errors"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
|
||||||
)
|
)
|
||||||
|
|
||||||
// WGUDPProxy proxies
|
// WGUDPProxy proxies
|
||||||
@@ -28,6 +29,8 @@ type WGUDPProxy struct {
|
|||||||
pausedMu sync.Mutex
|
pausedMu sync.Mutex
|
||||||
paused bool
|
paused bool
|
||||||
isStarted bool
|
isStarted bool
|
||||||
|
|
||||||
|
closeListener *listener.CloseListener
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewWGUDPProxy instantiate a UDP based WireGuard proxy. This is not a thread safe implementation
|
// NewWGUDPProxy instantiate a UDP based WireGuard proxy. This is not a thread safe implementation
|
||||||
@@ -35,6 +38,7 @@ func NewWGUDPProxy(wgPort int) *WGUDPProxy {
|
|||||||
log.Debugf("Initializing new user space proxy with port %d", wgPort)
|
log.Debugf("Initializing new user space proxy with port %d", wgPort)
|
||||||
p := &WGUDPProxy{
|
p := &WGUDPProxy{
|
||||||
localWGListenPort: wgPort,
|
localWGListenPort: wgPort,
|
||||||
|
closeListener: listener.NewCloseListener(),
|
||||||
}
|
}
|
||||||
return p
|
return p
|
||||||
}
|
}
|
||||||
@@ -67,6 +71,10 @@ func (p *WGUDPProxy) EndpointAddr() *net.UDPAddr {
|
|||||||
return endpointUdpAddr
|
return endpointUdpAddr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *WGUDPProxy) SetDisconnectListener(disconnected func()) {
|
||||||
|
p.closeListener.SetCloseListener(disconnected)
|
||||||
|
}
|
||||||
|
|
||||||
// Work starts the proxy or resumes it if it was paused
|
// Work starts the proxy or resumes it if it was paused
|
||||||
func (p *WGUDPProxy) Work() {
|
func (p *WGUDPProxy) Work() {
|
||||||
if p.remoteConn == nil {
|
if p.remoteConn == nil {
|
||||||
@@ -111,6 +119,8 @@ func (p *WGUDPProxy) close() error {
|
|||||||
if p.closed {
|
if p.closed {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
p.closeListener.SetCloseListener(nil)
|
||||||
p.closed = true
|
p.closed = true
|
||||||
|
|
||||||
p.cancel()
|
p.cancel()
|
||||||
@@ -141,6 +151,7 @@ func (p *WGUDPProxy) proxyToRemote(ctx context.Context) {
|
|||||||
if ctx.Err() != nil {
|
if ctx.Err() != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
p.closeListener.Notify()
|
||||||
log.Debugf("failed to read from wg interface conn: %s", err)
|
log.Debugf("failed to read from wg interface conn: %s", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var defaultInterfaceBlacklist = []string{
|
var defaultInterfaceBlacklist = []string{
|
||||||
iface.WgInterfaceDefault, "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts",
|
iface.WgInterfaceDefault, "nb", "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts",
|
||||||
"Tailscale", "tailscale", "docker", "veth", "br-", "lo",
|
"Tailscale", "tailscale", "docker", "veth", "br-", "lo",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -226,7 +226,6 @@ func (e *ConnMgr) ActivatePeer(ctx context.Context, conn *peer.Conn) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if found := e.lazyConnMgr.ActivatePeer(conn.GetKey()); found {
|
if found := e.lazyConnMgr.ActivatePeer(conn.GetKey()); found {
|
||||||
conn.Log.Infof("activated peer from inactive state")
|
|
||||||
if err := conn.Open(ctx); err != nil {
|
if err := conn.Open(ctx); err != nil {
|
||||||
conn.Log.Errorf("failed to open connection: %v", err)
|
conn.Log.Errorf("failed to open connection: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -61,7 +61,6 @@ import (
|
|||||||
signal "github.com/netbirdio/netbird/signal/client"
|
signal "github.com/netbirdio/netbird/signal/client"
|
||||||
sProto "github.com/netbirdio/netbird/signal/proto"
|
sProto "github.com/netbirdio/netbird/signal/proto"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer.
|
// PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer.
|
||||||
@@ -138,9 +137,6 @@ type Engine struct {
|
|||||||
|
|
||||||
connMgr *ConnMgr
|
connMgr *ConnMgr
|
||||||
|
|
||||||
beforePeerHook nbnet.AddHookFunc
|
|
||||||
afterPeerHook nbnet.RemoveHookFunc
|
|
||||||
|
|
||||||
// rpManager is a Rosenpass manager
|
// rpManager is a Rosenpass manager
|
||||||
rpManager *rosenpass.Manager
|
rpManager *rosenpass.Manager
|
||||||
|
|
||||||
@@ -409,12 +405,8 @@ func (e *Engine) Start() error {
|
|||||||
DisableClientRoutes: e.config.DisableClientRoutes,
|
DisableClientRoutes: e.config.DisableClientRoutes,
|
||||||
DisableServerRoutes: e.config.DisableServerRoutes,
|
DisableServerRoutes: e.config.DisableServerRoutes,
|
||||||
})
|
})
|
||||||
beforePeerHook, afterPeerHook, err := e.routeManager.Init()
|
if err := e.routeManager.Init(); err != nil {
|
||||||
if err != nil {
|
|
||||||
log.Errorf("Failed to initialize route manager: %s", err)
|
log.Errorf("Failed to initialize route manager: %s", err)
|
||||||
} else {
|
|
||||||
e.beforePeerHook = beforePeerHook
|
|
||||||
e.afterPeerHook = afterPeerHook
|
|
||||||
}
|
}
|
||||||
|
|
||||||
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
|
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
|
||||||
@@ -1261,10 +1253,6 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
|
|||||||
return fmt.Errorf("peer already exists: %s", peerKey)
|
return fmt.Errorf("peer already exists: %s", peerKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
if e.beforePeerHook != nil && e.afterPeerHook != nil {
|
|
||||||
conn.AddBeforeAddPeerHook(e.beforePeerHook)
|
|
||||||
conn.AddAfterRemovePeerHook(e.afterPeerHook)
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -52,6 +52,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
"github.com/netbirdio/netbird/monotime"
|
||||||
relayClient "github.com/netbirdio/netbird/relay/client"
|
relayClient "github.com/netbirdio/netbird/relay/client"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
signal "github.com/netbirdio/netbird/signal/client"
|
signal "github.com/netbirdio/netbird/signal/client"
|
||||||
@@ -96,7 +97,7 @@ type MockWGIface struct {
|
|||||||
GetInterfaceGUIDStringFunc func() (string, error)
|
GetInterfaceGUIDStringFunc func() (string, error)
|
||||||
GetProxyFunc func() wgproxy.Proxy
|
GetProxyFunc func() wgproxy.Proxy
|
||||||
GetNetFunc func() *netstack.Net
|
GetNetFunc func() *netstack.Net
|
||||||
LastActivitiesFunc func() map[string]time.Time
|
LastActivitiesFunc func() map[string]monotime.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockWGIface) FullStats() (*configurer.Stats, error) {
|
func (m *MockWGIface) FullStats() (*configurer.Stats, error) {
|
||||||
@@ -187,7 +188,7 @@ func (m *MockWGIface) GetNet() *netstack.Net {
|
|||||||
return m.GetNetFunc()
|
return m.GetNetFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockWGIface) LastActivities() map[string]time.Time {
|
func (m *MockWGIface) LastActivities() map[string]monotime.Time {
|
||||||
if m.LastActivitiesFunc != nil {
|
if m.LastActivitiesFunc != nil {
|
||||||
return m.LastActivitiesFunc()
|
return m.LastActivitiesFunc()
|
||||||
}
|
}
|
||||||
@@ -399,7 +400,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
|||||||
StatusRecorder: engine.statusRecorder,
|
StatusRecorder: engine.statusRecorder,
|
||||||
RelayManager: relayMgr,
|
RelayManager: relayMgr,
|
||||||
})
|
})
|
||||||
_, _, err = engine.routeManager.Init()
|
err = engine.routeManager.Init()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
engine.dnsServer = &dns.MockServer{
|
engine.dnsServer = &dns.MockServer{
|
||||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||||
@@ -1392,7 +1393,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
|
|||||||
if runtime.GOOS == "darwin" {
|
if runtime.GOOS == "darwin" {
|
||||||
ifaceName = fmt.Sprintf("utun1%d", i)
|
ifaceName = fmt.Sprintf("utun1%d", i)
|
||||||
} else {
|
} else {
|
||||||
ifaceName = fmt.Sprintf("wt%d", i)
|
ifaceName = fmt.Sprintf("nb%d", i)
|
||||||
}
|
}
|
||||||
|
|
||||||
wgPort := 33100 + i
|
wgPort := 33100 + i
|
||||||
@@ -1480,6 +1481,10 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
|||||||
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
|
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||||
Return(&types.Settings{}, nil).
|
Return(&types.Settings{}, nil).
|
||||||
AnyTimes()
|
AnyTimes()
|
||||||
|
settingsMockManager.EXPECT().
|
||||||
|
GetExtraSettings(gomock.Any(), gomock.Any()).
|
||||||
|
Return(&types.ExtraSettings{}, nil).
|
||||||
|
AnyTimes()
|
||||||
|
|
||||||
permissionsManager := permissions.NewManager(store)
|
permissionsManager := permissions.NewManager(store)
|
||||||
|
|
||||||
@@ -1489,7 +1494,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
|||||||
}
|
}
|
||||||
|
|
||||||
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
|
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
|
||||||
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil)
|
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||||
|
"github.com/netbirdio/netbird/monotime"
|
||||||
)
|
)
|
||||||
|
|
||||||
type wgIfaceBase interface {
|
type wgIfaceBase interface {
|
||||||
@@ -38,5 +39,5 @@ type wgIfaceBase interface {
|
|||||||
GetStats() (map[string]configurer.WGStats, error)
|
GetStats() (map[string]configurer.WGStats, error)
|
||||||
GetNet() *netstack.Net
|
GetNet() *netstack.Net
|
||||||
FullStats() (*configurer.Stats, error)
|
FullStats() (*configurer.Stats, error)
|
||||||
LastActivities() map[string]time.Time
|
LastActivities() map[string]monotime.Time
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ func (d *Listener) ReadPackets() {
|
|||||||
n, remoteAddr, err := d.conn.ReadFromUDP(make([]byte, 1))
|
n, remoteAddr, err := d.conn.ReadFromUDP(make([]byte, 1))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if d.isClosed.Load() {
|
if d.isClosed.Load() {
|
||||||
d.peerCfg.Log.Debugf("exit from activity listener")
|
d.peerCfg.Log.Infof("exit from activity listener")
|
||||||
} else {
|
} else {
|
||||||
d.peerCfg.Log.Errorf("failed to read from activity listener: %s", err)
|
d.peerCfg.Log.Errorf("failed to read from activity listener: %s", err)
|
||||||
}
|
}
|
||||||
@@ -59,9 +59,11 @@ func (d *Listener) ReadPackets() {
|
|||||||
d.peerCfg.Log.Warnf("received %d bytes from %s, too short", n, remoteAddr)
|
d.peerCfg.Log.Warnf("received %d bytes from %s, too short", n, remoteAddr)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
d.peerCfg.Log.Infof("activity detected")
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
d.peerCfg.Log.Debugf("removing lazy endpoint: %s", d.endpoint.String())
|
||||||
if err := d.removeEndpoint(); err != nil {
|
if err := d.removeEndpoint(); err != nil {
|
||||||
d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err)
|
d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err)
|
||||||
}
|
}
|
||||||
@@ -71,7 +73,7 @@ func (d *Listener) ReadPackets() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *Listener) Close() {
|
func (d *Listener) Close() {
|
||||||
d.peerCfg.Log.Infof("closing listener: %s", d.conn.LocalAddr().String())
|
d.peerCfg.Log.Infof("closing activity listener: %s", d.conn.LocalAddr().String())
|
||||||
d.isClosed.Store(true)
|
d.isClosed.Store(true)
|
||||||
|
|
||||||
if err := d.conn.Close(); err != nil {
|
if err := d.conn.Close(); err != nil {
|
||||||
@@ -81,7 +83,6 @@ func (d *Listener) Close() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *Listener) removeEndpoint() error {
|
func (d *Listener) removeEndpoint() error {
|
||||||
d.peerCfg.Log.Debugf("removing lazy endpoint: %s", d.endpoint.String())
|
|
||||||
return d.wgIface.RemovePeer(d.peerCfg.PublicKey)
|
return d.wgIface.RemovePeer(d.peerCfg.PublicKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||||
|
"github.com/netbirdio/netbird/monotime"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -18,7 +19,7 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type WgInterface interface {
|
type WgInterface interface {
|
||||||
LastActivities() map[string]time.Time
|
LastActivities() map[string]monotime.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
@@ -124,6 +125,7 @@ func (m *Manager) checkStats() (map[string]struct{}, error) {
|
|||||||
|
|
||||||
idlePeers := make(map[string]struct{})
|
idlePeers := make(map[string]struct{})
|
||||||
|
|
||||||
|
checkTime := time.Now()
|
||||||
for peerID, peerCfg := range m.interestedPeers {
|
for peerID, peerCfg := range m.interestedPeers {
|
||||||
lastActive, ok := lastActivities[peerID]
|
lastActive, ok := lastActivities[peerID]
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -132,8 +134,9 @@ func (m *Manager) checkStats() (map[string]struct{}, error) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if time.Since(lastActive) > m.inactivityThreshold {
|
since := monotime.Since(lastActive)
|
||||||
peerCfg.Log.Infof("peer is inactive since: %v", lastActive)
|
if since > m.inactivityThreshold {
|
||||||
|
peerCfg.Log.Infof("peer is inactive since time: %s", checkTime.Add(-since).String())
|
||||||
idlePeers[peerID] = struct{}{}
|
idlePeers[peerID] = struct{}{}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,13 +9,14 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||||
|
"github.com/netbirdio/netbird/monotime"
|
||||||
)
|
)
|
||||||
|
|
||||||
type mockWgInterface struct {
|
type mockWgInterface struct {
|
||||||
lastActivities map[string]time.Time
|
lastActivities map[string]monotime.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockWgInterface) LastActivities() map[string]time.Time {
|
func (m *mockWgInterface) LastActivities() map[string]monotime.Time {
|
||||||
return m.lastActivities
|
return m.lastActivities
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -23,8 +24,8 @@ func TestPeerTriggersInactivity(t *testing.T) {
|
|||||||
peerID := "peer1"
|
peerID := "peer1"
|
||||||
|
|
||||||
wgMock := &mockWgInterface{
|
wgMock := &mockWgInterface{
|
||||||
lastActivities: map[string]time.Time{
|
lastActivities: map[string]monotime.Time{
|
||||||
peerID: time.Now().Add(-20 * time.Minute),
|
peerID: monotime.Time(int64(monotime.Now()) - int64(20*time.Minute)),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -64,8 +65,8 @@ func TestPeerTriggersActivity(t *testing.T) {
|
|||||||
peerID := "peer1"
|
peerID := "peer1"
|
||||||
|
|
||||||
wgMock := &mockWgInterface{
|
wgMock := &mockWgInterface{
|
||||||
lastActivities: map[string]time.Time{
|
lastActivities: map[string]monotime.Time{
|
||||||
peerID: time.Now().Add(-5 * time.Minute),
|
peerID: monotime.Time(int64(monotime.Now()) - int64(5*time.Minute)),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -258,12 +258,13 @@ func (m *Manager) ActivatePeer(peerID string) (found bool) {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cfg.Log.Infof("activate peer from inactive state by remote signal message")
|
||||||
|
|
||||||
if !m.activateSinglePeer(cfg, mp) {
|
if !m.activateSinglePeer(cfg, mp) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
m.activateHAGroupPeers(cfg)
|
m.activateHAGroupPeers(cfg)
|
||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -571,12 +572,12 @@ func (m *Manager) onPeerInactivityTimedOut(peerIDs map[string]struct{}) {
|
|||||||
// this is blocking operation, potentially can be optimized
|
// this is blocking operation, potentially can be optimized
|
||||||
m.peerStore.PeerConnIdle(mp.peerCfg.PublicKey)
|
m.peerStore.PeerConnIdle(mp.peerCfg.PublicKey)
|
||||||
|
|
||||||
mp.peerCfg.Log.Infof("start activity monitor")
|
|
||||||
|
|
||||||
mp.expectedWatcher = watcherActivity
|
mp.expectedWatcher = watcherActivity
|
||||||
|
|
||||||
m.inactivityManager.RemovePeer(mp.peerCfg.PublicKey)
|
m.inactivityManager.RemovePeer(mp.peerCfg.PublicKey)
|
||||||
|
|
||||||
|
mp.peerCfg.Log.Infof("start activity monitor")
|
||||||
|
|
||||||
if err := m.activityManager.MonitorPeerActivity(*mp.peerCfg); err != nil {
|
if err := m.activityManager.MonitorPeerActivity(*mp.peerCfg); err != nil {
|
||||||
mp.peerCfg.Log.Errorf("failed to create activity monitor: %v", err)
|
mp.peerCfg.Log.Errorf("failed to create activity monitor: %v", err)
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -6,11 +6,13 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/monotime"
|
||||||
)
|
)
|
||||||
|
|
||||||
type WGIface interface {
|
type WGIface interface {
|
||||||
RemovePeer(peerKey string) error
|
RemovePeer(peerKey string) error
|
||||||
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||||
IsUserspaceBind() bool
|
IsUserspaceBind() bool
|
||||||
LastActivities() map[string]time.Time
|
LastActivities() map[string]monotime.Time
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ type mockIFaceMapper struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockIFaceMapper) Name() string {
|
func (m *mockIFaceMapper) Name() string {
|
||||||
return "wt0"
|
return "nb0"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockIFaceMapper) Address() wgaddr.Address {
|
func (m *mockIFaceMapper) Address() wgaddr.Address {
|
||||||
|
|||||||
@@ -26,7 +26,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
relayClient "github.com/netbirdio/netbird/relay/client"
|
relayClient "github.com/netbirdio/netbird/relay/client"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
|
||||||
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -106,10 +105,6 @@ type Conn struct {
|
|||||||
workerRelay *WorkerRelay
|
workerRelay *WorkerRelay
|
||||||
wgWatcherWg sync.WaitGroup
|
wgWatcherWg sync.WaitGroup
|
||||||
|
|
||||||
connIDRelay nbnet.ConnectionID
|
|
||||||
connIDICE nbnet.ConnectionID
|
|
||||||
beforeAddPeerHooks []nbnet.AddHookFunc
|
|
||||||
afterRemovePeerHooks []nbnet.RemoveHookFunc
|
|
||||||
// used to store the remote Rosenpass key for Relayed connection in case of connection update from ice
|
// used to store the remote Rosenpass key for Relayed connection in case of connection update from ice
|
||||||
rosenpassRemoteKey []byte
|
rosenpassRemoteKey []byte
|
||||||
|
|
||||||
@@ -167,7 +162,7 @@ func (conn *Conn) Open(engineCtx context.Context) error {
|
|||||||
|
|
||||||
conn.ctx, conn.ctxCancel = context.WithCancel(engineCtx)
|
conn.ctx, conn.ctxCancel = context.WithCancel(engineCtx)
|
||||||
|
|
||||||
conn.workerRelay = NewWorkerRelay(conn.Log, isController(conn.config), conn.config, conn, conn.relayManager, conn.dumpState)
|
conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager, conn.dumpState)
|
||||||
|
|
||||||
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
|
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
|
||||||
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
|
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
|
||||||
@@ -267,8 +262,6 @@ func (conn *Conn) Close(signalToRemote bool) {
|
|||||||
conn.Log.Errorf("failed to remove wg endpoint: %v", err)
|
conn.Log.Errorf("failed to remove wg endpoint: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.freeUpConnID()
|
|
||||||
|
|
||||||
if conn.evalStatus() == StatusConnected && conn.onDisconnected != nil {
|
if conn.evalStatus() == StatusConnected && conn.onDisconnected != nil {
|
||||||
conn.onDisconnected(conn.config.WgConfig.RemoteKey)
|
conn.onDisconnected(conn.config.WgConfig.RemoteKey)
|
||||||
}
|
}
|
||||||
@@ -293,13 +286,6 @@ func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMa
|
|||||||
conn.workerICE.OnRemoteCandidate(candidate, haRoutes)
|
conn.workerICE.OnRemoteCandidate(candidate, haRoutes)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) AddBeforeAddPeerHook(hook nbnet.AddHookFunc) {
|
|
||||||
conn.beforeAddPeerHooks = append(conn.beforeAddPeerHooks, hook)
|
|
||||||
}
|
|
||||||
func (conn *Conn) AddAfterRemovePeerHook(hook nbnet.RemoveHookFunc) {
|
|
||||||
conn.afterRemovePeerHooks = append(conn.afterRemovePeerHooks, hook)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetOnConnected sets a handler function to be triggered by Conn when a new connection to a remote peer established
|
// SetOnConnected sets a handler function to be triggered by Conn when a new connection to a remote peer established
|
||||||
func (conn *Conn) SetOnConnected(handler func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)) {
|
func (conn *Conn) SetOnConnected(handler func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)) {
|
||||||
conn.onConnected = handler
|
conn.onConnected = handler
|
||||||
@@ -387,10 +373,6 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
|
|||||||
ep = directEp
|
ep = directEp
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := conn.runBeforeAddPeerHooks(ep.IP); err != nil {
|
|
||||||
conn.Log.Errorf("Before add peer hook failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
conn.workerRelay.DisableWgWatcher()
|
conn.workerRelay.DisableWgWatcher()
|
||||||
// todo consider to run conn.wgWatcherWg.Wait() here
|
// todo consider to run conn.wgWatcherWg.Wait() here
|
||||||
|
|
||||||
@@ -489,6 +471,8 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
|||||||
conn.Log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
|
conn.Log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
wgProxy.SetDisconnectListener(conn.onRelayDisconnected)
|
||||||
|
|
||||||
conn.dumpState.NewLocalProxy()
|
conn.dumpState.NewLocalProxy()
|
||||||
|
|
||||||
conn.Log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String())
|
conn.Log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String())
|
||||||
@@ -501,10 +485,6 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := conn.runBeforeAddPeerHooks(wgProxy.EndpointAddr().IP); err != nil {
|
|
||||||
conn.Log.Errorf("Before add peer hook failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
wgProxy.Work()
|
wgProxy.Work()
|
||||||
if err := conn.configureWGEndpoint(wgProxy.EndpointAddr(), rci.rosenpassPubKey); err != nil {
|
if err := conn.configureWGEndpoint(wgProxy.EndpointAddr(), rci.rosenpassPubKey); err != nil {
|
||||||
if err := wgProxy.CloseConn(); err != nil {
|
if err := wgProxy.CloseConn(); err != nil {
|
||||||
@@ -705,36 +685,6 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) runBeforeAddPeerHooks(ip net.IP) error {
|
|
||||||
conn.connIDICE = nbnet.GenerateConnID()
|
|
||||||
for _, hook := range conn.beforeAddPeerHooks {
|
|
||||||
if err := hook(conn.connIDICE, ip); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (conn *Conn) freeUpConnID() {
|
|
||||||
if conn.connIDRelay != "" {
|
|
||||||
for _, hook := range conn.afterRemovePeerHooks {
|
|
||||||
if err := hook(conn.connIDRelay); err != nil {
|
|
||||||
conn.Log.Errorf("After remove peer hook failed: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
conn.connIDRelay = ""
|
|
||||||
}
|
|
||||||
|
|
||||||
if conn.connIDICE != "" {
|
|
||||||
for _, hook := range conn.afterRemovePeerHooks {
|
|
||||||
if err := hook(conn.connIDICE); err != nil {
|
|
||||||
conn.Log.Errorf("After remove peer hook failed: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
conn.connIDICE = ""
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
|
func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
|
||||||
conn.Log.Debugf("setup proxied WireGuard connection")
|
conn.Log.Debugf("setup proxied WireGuard connection")
|
||||||
udpAddr := &net.UDPAddr{
|
udpAddr := &net.UDPAddr{
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ type RelayConnInfo struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type WorkerRelay struct {
|
type WorkerRelay struct {
|
||||||
|
peerCtx context.Context
|
||||||
log *log.Entry
|
log *log.Entry
|
||||||
isController bool
|
isController bool
|
||||||
config ConnConfig
|
config ConnConfig
|
||||||
@@ -33,8 +34,9 @@ type WorkerRelay struct {
|
|||||||
wgWatcher *WGWatcher
|
wgWatcher *WGWatcher
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewWorkerRelay(log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager relayClient.ManagerService, stateDump *stateDump) *WorkerRelay {
|
func NewWorkerRelay(ctx context.Context, log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager relayClient.ManagerService, stateDump *stateDump) *WorkerRelay {
|
||||||
r := &WorkerRelay{
|
r := &WorkerRelay{
|
||||||
|
peerCtx: ctx,
|
||||||
log: log,
|
log: log,
|
||||||
isController: ctrl,
|
isController: ctrl,
|
||||||
config: config,
|
config: config,
|
||||||
@@ -62,7 +64,7 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
|||||||
|
|
||||||
srv := w.preferredRelayServer(currentRelayAddress, remoteOfferAnswer.RelaySrvAddress)
|
srv := w.preferredRelayServer(currentRelayAddress, remoteOfferAnswer.RelaySrvAddress)
|
||||||
|
|
||||||
relayedConn, err := w.relayManager.OpenConn(srv, w.config.Key)
|
relayedConn, err := w.relayManager.OpenConn(w.peerCtx, srv, w.config.Key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, relayClient.ErrConnAlreadyExists) {
|
if errors.Is(err, relayClient.ErrConnAlreadyExists) {
|
||||||
w.log.Debugf("handled offer by reusing existing relay connection")
|
w.log.Debugf("handled offer by reusing existing relay connection")
|
||||||
|
|||||||
@@ -812,7 +812,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
params := common.HandlerParams{
|
params := common.HandlerParams{
|
||||||
Route: &route.Route{Network: netip.MustParsePrefix("192.168.0.0/24")},
|
Route: &route.Route{Network: netip.MustParsePrefix("192.168.0.0/24")},
|
||||||
}
|
}
|
||||||
// create new clientNetwork
|
// create new clientNetwork
|
||||||
client := &Watcher{
|
client := &Watcher{
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ import (
|
|||||||
|
|
||||||
// Manager is a route manager interface
|
// Manager is a route manager interface
|
||||||
type Manager interface {
|
type Manager interface {
|
||||||
Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
|
Init() error
|
||||||
UpdateRoutes(updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error
|
UpdateRoutes(updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error
|
||||||
ClassifyRoutes(newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap)
|
ClassifyRoutes(newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap)
|
||||||
TriggerSelection(route.HAMap)
|
TriggerSelection(route.HAMap)
|
||||||
@@ -201,11 +201,11 @@ func (m *DefaultManager) setupRefCounters(useNoop bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Init sets up the routing
|
// Init sets up the routing
|
||||||
func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
func (m *DefaultManager) Init() error {
|
||||||
m.routeSelector = m.initSelector()
|
m.routeSelector = m.initSelector()
|
||||||
|
|
||||||
if nbnet.CustomRoutingDisabled() || m.disableClientRoutes {
|
if nbnet.CustomRoutingDisabled() || m.disableClientRoutes {
|
||||||
return nil, nil, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.sysOps.CleanupRouting(nil); err != nil {
|
if err := m.sysOps.CleanupRouting(nil); err != nil {
|
||||||
@@ -219,13 +219,12 @@ func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
|
|||||||
|
|
||||||
ips := resolveURLsToIPs(initialAddresses)
|
ips := resolveURLsToIPs(initialAddresses)
|
||||||
|
|
||||||
beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips, m.stateManager)
|
if err := m.sysOps.SetupRouting(ips, m.stateManager); err != nil {
|
||||||
if err != nil {
|
return fmt.Errorf("setup routing: %w", err)
|
||||||
return nil, nil, fmt.Errorf("setup routing: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info("Routing setup complete")
|
log.Info("Routing setup complete")
|
||||||
return beforePeerHook, afterPeerHook, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *DefaultManager) initSelector() *routeselector.RouteSelector {
|
func (m *DefaultManager) initSelector() *routeselector.RouteSelector {
|
||||||
|
|||||||
@@ -430,7 +430,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
|||||||
StatusRecorder: statusRecorder,
|
StatusRecorder: statusRecorder,
|
||||||
})
|
})
|
||||||
|
|
||||||
_, _, err = routeManager.Init()
|
err = routeManager.Init()
|
||||||
|
|
||||||
require.NoError(t, err, "should init route manager")
|
require.NoError(t, err, "should init route manager")
|
||||||
defer routeManager.Stop(nil)
|
defer routeManager.Stop(nil)
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/routeselector"
|
"github.com/netbirdio/netbird/client/internal/routeselector"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
"github.com/netbirdio/netbird/util/net"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// MockManager is the mock instance of a route manager
|
// MockManager is the mock instance of a route manager
|
||||||
@@ -23,8 +22,8 @@ type MockManager struct {
|
|||||||
StopFunc func(manager *statemanager.Manager)
|
StopFunc func(manager *statemanager.Manager)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockManager) Init() (net.AddHookFunc, net.RemoveHookFunc, error) {
|
func (m *MockManager) Init() error {
|
||||||
return nil, nil, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// InitialRouteRange mock implementation of InitialRouteRange from Manager interface
|
// InitialRouteRange mock implementation of InitialRouteRange from Manager interface
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
|
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
|
||||||
@@ -56,6 +57,10 @@ type SysOps struct {
|
|||||||
// seq is an atomic counter for generating unique sequence numbers for route messages
|
// seq is an atomic counter for generating unique sequence numbers for route messages
|
||||||
//nolint:unused // only used on BSD systems
|
//nolint:unused // only used on BSD systems
|
||||||
seq atomic.Uint32
|
seq atomic.Uint32
|
||||||
|
|
||||||
|
localSubnetsCache []*net.IPNet
|
||||||
|
localSubnetsCacheMu sync.RWMutex
|
||||||
|
localSubnetsCacheTime time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps {
|
func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps {
|
||||||
|
|||||||
@@ -10,11 +10,10 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) error {
|
||||||
return nil, nil, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) CleanupRouting(*statemanager.Manager) error {
|
func (r *SysOps) CleanupRouting(*statemanager.Manager) error {
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/hashicorp/go-multierror"
|
"github.com/hashicorp/go-multierror"
|
||||||
"github.com/libp2p/go-netroute"
|
"github.com/libp2p/go-netroute"
|
||||||
@@ -24,6 +25,8 @@ import (
|
|||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const localSubnetsCacheTTL = 15 * time.Minute
|
||||||
|
|
||||||
var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1)
|
var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1)
|
||||||
var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1)
|
var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1)
|
||||||
var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1)
|
var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1)
|
||||||
@@ -31,7 +34,7 @@ var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)
|
|||||||
|
|
||||||
var ErrRoutingIsSeparate = errors.New("routing is separate")
|
var ErrRoutingIsSeparate = errors.New("routing is separate")
|
||||||
|
|
||||||
func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemanager.Manager) error {
|
||||||
stateManager.RegisterState(&ShutdownState{})
|
stateManager.RegisterState(&ShutdownState{})
|
||||||
|
|
||||||
initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified())
|
initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified())
|
||||||
@@ -75,7 +78,10 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana
|
|||||||
|
|
||||||
r.refCounter = refCounter
|
r.refCounter = refCounter
|
||||||
|
|
||||||
return r.setupHooks(initAddresses, stateManager)
|
if err := r.setupHooks(initAddresses, stateManager); err != nil {
|
||||||
|
return fmt.Errorf("setup hooks: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// updateState updates state on every change so it will be persisted regularly
|
// updateState updates state on every change so it will be persisted regularly
|
||||||
@@ -128,18 +134,14 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf wgIface, init
|
|||||||
return Nexthop{}, fmt.Errorf("get next hop: %w", err)
|
return Nexthop{}, fmt.Errorf("get next hop: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop.IP, prefix, nexthop.IP)
|
log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop.IP, prefix, nexthop.Intf)
|
||||||
exitNextHop := Nexthop{
|
exitNextHop := nexthop
|
||||||
IP: nexthop.IP,
|
|
||||||
Intf: nexthop.Intf,
|
|
||||||
}
|
|
||||||
|
|
||||||
vpnAddr := vpnIntf.Address().IP
|
vpnAddr := vpnIntf.Address().IP
|
||||||
|
|
||||||
// if next hop is the VPN address or the interface is the VPN interface, we should use the initial values
|
// if next hop is the VPN address or the interface is the VPN interface, we should use the initial values
|
||||||
if exitNextHop.IP == vpnAddr || exitNextHop.Intf != nil && exitNextHop.Intf.Name == vpnIntf.Name() {
|
if exitNextHop.IP == vpnAddr || exitNextHop.Intf != nil && exitNextHop.Intf.Name == vpnIntf.Name() {
|
||||||
log.Debugf("Route for prefix %s is pointing to the VPN interface, using initial next hop %v", prefix, initialNextHop)
|
log.Debugf("Route for prefix %s is pointing to the VPN interface, using initial next hop %v", prefix, initialNextHop)
|
||||||
|
|
||||||
exitNextHop = initialNextHop
|
exitNextHop = initialNextHop
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -152,12 +154,37 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf wgIface, init
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) isPrefixInLocalSubnets(prefix netip.Prefix) (bool, *net.IPNet) {
|
func (r *SysOps) isPrefixInLocalSubnets(prefix netip.Prefix) (bool, *net.IPNet) {
|
||||||
|
r.localSubnetsCacheMu.RLock()
|
||||||
|
cacheAge := time.Since(r.localSubnetsCacheTime)
|
||||||
|
subnets := r.localSubnetsCache
|
||||||
|
r.localSubnetsCacheMu.RUnlock()
|
||||||
|
|
||||||
|
if cacheAge > localSubnetsCacheTTL || subnets == nil {
|
||||||
|
r.localSubnetsCacheMu.Lock()
|
||||||
|
if time.Since(r.localSubnetsCacheTime) > localSubnetsCacheTTL || r.localSubnetsCache == nil {
|
||||||
|
r.refreshLocalSubnetsCache()
|
||||||
|
}
|
||||||
|
subnets = r.localSubnetsCache
|
||||||
|
r.localSubnetsCacheMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, subnet := range subnets {
|
||||||
|
if subnet.Contains(prefix.Addr().AsSlice()) {
|
||||||
|
return true, subnet
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *SysOps) refreshLocalSubnetsCache() {
|
||||||
localInterfaces, err := net.Interfaces()
|
localInterfaces, err := net.Interfaces()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed to get local interfaces: %v", err)
|
log.Errorf("Failed to get local interfaces: %v", err)
|
||||||
return false, nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var newSubnets []*net.IPNet
|
||||||
for _, intf := range localInterfaces {
|
for _, intf := range localInterfaces {
|
||||||
addrs, err := intf.Addrs()
|
addrs, err := intf.Addrs()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -171,14 +198,12 @@ func (r *SysOps) isPrefixInLocalSubnets(prefix netip.Prefix) (bool, *net.IPNet)
|
|||||||
log.Errorf("Failed to convert address to IPNet: %v", addr)
|
log.Errorf("Failed to convert address to IPNet: %v", addr)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
newSubnets = append(newSubnets, ipnet)
|
||||||
if ipnet.Contains(prefix.Addr().AsSlice()) {
|
|
||||||
return true, ipnet
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return false, nil
|
r.localSubnetsCache = newSubnets
|
||||||
|
r.localSubnetsCacheTime = time.Now()
|
||||||
}
|
}
|
||||||
|
|
||||||
// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix
|
// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix
|
||||||
@@ -264,7 +289,7 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface)
|
|||||||
return r.removeFromRouteTable(prefix, nextHop)
|
return r.removeFromRouteTable(prefix, nextHop)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) error {
|
||||||
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
|
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
|
||||||
prefix, err := util.GetPrefixFromIP(ip)
|
prefix, err := util.GetPrefixFromIP(ip)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -289,9 +314,11 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
|
||||||
for _, ip := range initAddresses {
|
for _, ip := range initAddresses {
|
||||||
if err := beforeHook("init", ip); err != nil {
|
if err := beforeHook("init", ip); err != nil {
|
||||||
log.Errorf("Failed to add route reference: %v", err)
|
merr = multierror.Append(merr, fmt.Errorf("add initial route for %s: %w", ip, err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -300,11 +327,11 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M
|
|||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
var result *multierror.Error
|
var merr *multierror.Error
|
||||||
for _, ip := range resolvedIPs {
|
for _, ip := range resolvedIPs {
|
||||||
result = multierror.Append(result, beforeHook(connID, ip.IP))
|
merr = multierror.Append(merr, beforeHook(connID, ip.IP))
|
||||||
}
|
}
|
||||||
return nberrors.FormatErrorOrNil(result)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
})
|
})
|
||||||
|
|
||||||
nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error {
|
nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error {
|
||||||
@@ -319,7 +346,16 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M
|
|||||||
return afterHook(connID)
|
return afterHook(connID)
|
||||||
})
|
})
|
||||||
|
|
||||||
return beforeHook, afterHook, nil
|
nbnet.AddListenerAddressRemoveHook(func(connID nbnet.ConnectionID, prefix netip.Prefix) error {
|
||||||
|
if _, err := r.refCounter.Decrement(prefix); err != nil {
|
||||||
|
return fmt.Errorf("remove route reference: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.updateState(stateManager)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetNextHop(ip netip.Addr) (Nexthop, error) {
|
func GetNextHop(ip netip.Addr) (Nexthop, error) {
|
||||||
|
|||||||
@@ -143,7 +143,7 @@ func TestAddVPNRoute(t *testing.T) {
|
|||||||
wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n)
|
wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n)
|
||||||
|
|
||||||
r := NewSysOps(wgInterface, nil)
|
r := NewSysOps(wgInterface, nil)
|
||||||
_, _, err := r.SetupRouting(nil, nil)
|
err := r.SetupRouting(nil, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
assert.NoError(t, r.CleanupRouting(nil))
|
assert.NoError(t, r.CleanupRouting(nil))
|
||||||
@@ -341,7 +341,7 @@ func TestAddRouteToNonVPNIntf(t *testing.T) {
|
|||||||
wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n)
|
wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n)
|
||||||
|
|
||||||
r := NewSysOps(wgInterface, nil)
|
r := NewSysOps(wgInterface, nil)
|
||||||
_, _, err := r.SetupRouting(nil, nil)
|
err := r.SetupRouting(nil, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
assert.NoError(t, r.CleanupRouting(nil))
|
assert.NoError(t, r.CleanupRouting(nil))
|
||||||
@@ -484,7 +484,7 @@ func setupTestEnv(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
r := NewSysOps(wgInterface, nil)
|
r := NewSysOps(wgInterface, nil)
|
||||||
_, _, err := r.SetupRouting(nil, nil)
|
err := r.SetupRouting(nil, nil)
|
||||||
require.NoError(t, err, "setupRouting should not return err")
|
require.NoError(t, err, "setupRouting should not return err")
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
assert.NoError(t, r.CleanupRouting(nil))
|
assert.NoError(t, r.CleanupRouting(nil))
|
||||||
|
|||||||
@@ -10,14 +10,13 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) error {
|
||||||
r.mu.Lock()
|
r.mu.Lock()
|
||||||
defer r.mu.Unlock()
|
defer r.mu.Unlock()
|
||||||
r.prefixes = make(map[netip.Prefix]struct{})
|
r.prefixes = make(map[netip.Prefix]struct{})
|
||||||
return nil, nil, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) CleanupRouting(*statemanager.Manager) error {
|
func (r *SysOps) CleanupRouting(*statemanager.Manager) error {
|
||||||
|
|||||||
@@ -72,7 +72,7 @@ func getSetupRules() []ruleParams {
|
|||||||
// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table.
|
// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table.
|
||||||
// This table is where a default route or other specific routes received from the management server are configured,
|
// This table is where a default route or other specific routes received from the management server are configured,
|
||||||
// enabling VPN connectivity.
|
// enabling VPN connectivity.
|
||||||
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (_ nbnet.AddHookFunc, _ nbnet.RemoveHookFunc, err error) {
|
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (err error) {
|
||||||
if !nbnet.AdvancedRouting() {
|
if !nbnet.AdvancedRouting() {
|
||||||
log.Infof("Using legacy routing setup")
|
log.Infof("Using legacy routing setup")
|
||||||
return r.setupRefCounter(initAddresses, stateManager)
|
return r.setupRefCounter(initAddresses, stateManager)
|
||||||
@@ -89,7 +89,7 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
|
|||||||
rules := getSetupRules()
|
rules := getSetupRules()
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if err := addRule(rule); err != nil {
|
if err := addRule(rule); err != nil {
|
||||||
return nil, nil, fmt.Errorf("%s: %w", rule.description, err)
|
return fmt.Errorf("%s: %w", rule.description, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -104,7 +104,7 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
|
|||||||
}
|
}
|
||||||
originalSysctl = originalValues
|
originalSysctl = originalValues
|
||||||
|
|
||||||
return nil, nil, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'.
|
// CleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'.
|
||||||
|
|||||||
@@ -252,7 +252,7 @@ func TestSysOps_validateRoute_InvalidPrefix(t *testing.T) {
|
|||||||
IP: wgNetwork.Addr(),
|
IP: wgNetwork.Addr(),
|
||||||
Network: wgNetwork,
|
Network: wgNetwork,
|
||||||
},
|
},
|
||||||
name: "wt0",
|
name: "nb0",
|
||||||
}
|
}
|
||||||
|
|
||||||
sysOps := &SysOps{
|
sysOps := &SysOps{
|
||||||
|
|||||||
@@ -18,10 +18,9 @@ import (
|
|||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) error {
|
||||||
return r.setupRefCounter(initAddresses, stateManager)
|
return r.setupRefCounter(initAddresses, stateManager)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ import (
|
|||||||
"golang.org/x/sys/windows"
|
"golang.org/x/sys/windows"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const InfiniteLifetime = 0xffffffff
|
const InfiniteLifetime = 0xffffffff
|
||||||
@@ -137,7 +136,7 @@ const (
|
|||||||
RouteDeleted
|
RouteDeleted
|
||||||
)
|
)
|
||||||
|
|
||||||
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) error {
|
||||||
return r.setupRefCounter(initAddresses, stateManager)
|
return r.setupRefCounter(initAddresses, stateManager)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1330,6 +1330,13 @@ func (x *PeerState) GetRelayAddress() string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (x *PeerState) GetConnectionType() string {
|
||||||
|
if x.Relayed {
|
||||||
|
return "Relayed"
|
||||||
|
}
|
||||||
|
return "P2P"
|
||||||
|
}
|
||||||
|
|
||||||
// LocalPeerState contains the latest state of the local peer
|
// LocalPeerState contains the latest state of the local peer
|
||||||
type LocalPeerState struct {
|
type LocalPeerState struct {
|
||||||
state protoimpl.MessageState `protogen:"open.v1"`
|
state protoimpl.MessageState `protogen:"open.v1"`
|
||||||
|
|||||||
@@ -212,7 +212,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
|||||||
}
|
}
|
||||||
|
|
||||||
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
|
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
|
||||||
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil)
|
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -100,7 +100,7 @@ type OutputOverview struct {
|
|||||||
LazyConnectionEnabled bool `json:"lazyConnectionEnabled" yaml:"lazyConnectionEnabled"`
|
LazyConnectionEnabled bool `json:"lazyConnectionEnabled" yaml:"lazyConnectionEnabled"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}) OutputOverview {
|
func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}, connectionTypeFilter string) OutputOverview {
|
||||||
pbFullStatus := resp.GetFullStatus()
|
pbFullStatus := resp.GetFullStatus()
|
||||||
|
|
||||||
managementState := pbFullStatus.GetManagementState()
|
managementState := pbFullStatus.GetManagementState()
|
||||||
@@ -118,7 +118,7 @@ func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, status
|
|||||||
}
|
}
|
||||||
|
|
||||||
relayOverview := mapRelays(pbFullStatus.GetRelays())
|
relayOverview := mapRelays(pbFullStatus.GetRelays())
|
||||||
peersOverview := mapPeers(resp.GetFullStatus().GetPeers(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter)
|
peersOverview := mapPeers(resp.GetFullStatus().GetPeers(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter)
|
||||||
|
|
||||||
overview := OutputOverview{
|
overview := OutputOverview{
|
||||||
Peers: peersOverview,
|
Peers: peersOverview,
|
||||||
@@ -193,6 +193,7 @@ func mapPeers(
|
|||||||
prefixNamesFilter []string,
|
prefixNamesFilter []string,
|
||||||
prefixNamesFilterMap map[string]struct{},
|
prefixNamesFilterMap map[string]struct{},
|
||||||
ipsFilter map[string]struct{},
|
ipsFilter map[string]struct{},
|
||||||
|
connectionTypeFilter string,
|
||||||
) PeersStateOutput {
|
) PeersStateOutput {
|
||||||
var peersStateDetail []PeerStateDetailOutput
|
var peersStateDetail []PeerStateDetailOutput
|
||||||
peersConnected := 0
|
peersConnected := 0
|
||||||
@@ -208,7 +209,7 @@ func mapPeers(
|
|||||||
transferSent := int64(0)
|
transferSent := int64(0)
|
||||||
|
|
||||||
isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String()
|
isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String()
|
||||||
if skipDetailByFilters(pbPeerState, pbPeerState.ConnStatus, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter) {
|
if skipDetailByFilters(pbPeerState, pbPeerState.ConnStatus, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if isPeerConnected {
|
if isPeerConnected {
|
||||||
@@ -218,10 +219,7 @@ func mapPeers(
|
|||||||
remoteICE = pbPeerState.GetRemoteIceCandidateType()
|
remoteICE = pbPeerState.GetRemoteIceCandidateType()
|
||||||
localICEEndpoint = pbPeerState.GetLocalIceCandidateEndpoint()
|
localICEEndpoint = pbPeerState.GetLocalIceCandidateEndpoint()
|
||||||
remoteICEEndpoint = pbPeerState.GetRemoteIceCandidateEndpoint()
|
remoteICEEndpoint = pbPeerState.GetRemoteIceCandidateEndpoint()
|
||||||
connType = "P2P"
|
connType = pbPeerState.GetConnectionType()
|
||||||
if pbPeerState.Relayed {
|
|
||||||
connType = "Relayed"
|
|
||||||
}
|
|
||||||
relayServerAddress = pbPeerState.GetRelayAddress()
|
relayServerAddress = pbPeerState.GetRelayAddress()
|
||||||
lastHandshake = pbPeerState.GetLastWireguardHandshake().AsTime().Local()
|
lastHandshake = pbPeerState.GetLastWireguardHandshake().AsTime().Local()
|
||||||
transferReceived = pbPeerState.GetBytesRx()
|
transferReceived = pbPeerState.GetBytesRx()
|
||||||
@@ -542,10 +540,11 @@ func parsePeers(peers PeersStateOutput, rosenpassEnabled, rosenpassPermissive bo
|
|||||||
return peersString
|
return peersString
|
||||||
}
|
}
|
||||||
|
|
||||||
func skipDetailByFilters(peerState *proto.PeerState, peerStatus string, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}) bool {
|
func skipDetailByFilters(peerState *proto.PeerState, peerStatus string, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}, connectionTypeFilter string) bool {
|
||||||
statusEval := false
|
statusEval := false
|
||||||
ipEval := false
|
ipEval := false
|
||||||
nameEval := true
|
nameEval := true
|
||||||
|
connectionTypeEval := false
|
||||||
|
|
||||||
if statusFilter != "" {
|
if statusFilter != "" {
|
||||||
if !strings.EqualFold(peerStatus, statusFilter) {
|
if !strings.EqualFold(peerStatus, statusFilter) {
|
||||||
@@ -570,8 +569,11 @@ func skipDetailByFilters(peerState *proto.PeerState, peerStatus string, statusFi
|
|||||||
} else {
|
} else {
|
||||||
nameEval = false
|
nameEval = false
|
||||||
}
|
}
|
||||||
|
if connectionTypeFilter != "" && !strings.EqualFold(peerState.GetConnectionType(), connectionTypeFilter) {
|
||||||
|
connectionTypeEval = true
|
||||||
|
}
|
||||||
|
|
||||||
return statusEval || ipEval || nameEval
|
return statusEval || ipEval || nameEval || connectionTypeEval
|
||||||
}
|
}
|
||||||
|
|
||||||
func toIEC(b int64) string {
|
func toIEC(b int64) string {
|
||||||
|
|||||||
@@ -234,7 +234,7 @@ var overview = OutputOverview{
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestConversionFromFullStatusToOutputOverview(t *testing.T) {
|
func TestConversionFromFullStatusToOutputOverview(t *testing.T) {
|
||||||
convertedResult := ConvertToStatusOutputOverview(resp, false, "", nil, nil, nil)
|
convertedResult := ConvertToStatusOutputOverview(resp, false, "", nil, nil, nil, "")
|
||||||
|
|
||||||
assert.Equal(t, overview, convertedResult)
|
assert.Equal(t, overview, convertedResult)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -433,7 +433,7 @@ func (s *serviceClient) collectDebugData(
|
|||||||
|
|
||||||
var postUpStatusOutput string
|
var postUpStatusOutput string
|
||||||
if postUpStatus != nil {
|
if postUpStatus != nil {
|
||||||
overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil)
|
overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "")
|
||||||
postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
|
postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
|
||||||
}
|
}
|
||||||
headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
|
headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
|
||||||
@@ -450,7 +450,7 @@ func (s *serviceClient) collectDebugData(
|
|||||||
|
|
||||||
var preDownStatusOutput string
|
var preDownStatusOutput string
|
||||||
if preDownStatus != nil {
|
if preDownStatus != nil {
|
||||||
overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil)
|
overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "")
|
||||||
preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
|
preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
|
||||||
}
|
}
|
||||||
headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s",
|
headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s",
|
||||||
@@ -581,7 +581,7 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa
|
|||||||
|
|
||||||
var statusOutput string
|
var statusOutput string
|
||||||
if statusResp != nil {
|
if statusResp != nil {
|
||||||
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil)
|
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "")
|
||||||
statusOutput = nbstatus.ParseToFullDetailSummary(overview)
|
statusOutput = nbstatus.ParseToFullDetailSummary(overview)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
2
go.mod
2
go.mod
@@ -63,7 +63,7 @@ require (
|
|||||||
github.com/miekg/dns v1.1.59
|
github.com/miekg/dns v1.1.59
|
||||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||||
github.com/nadoo/ipset v0.5.0
|
github.com/nadoo/ipset v0.5.0
|
||||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250612164546-6bd7e2338d65
|
github.com/netbirdio/management-integrations/integrations v0.0.0-20250718071730-f4d133556ff5
|
||||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250514131221-a464fd5f30cb
|
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250514131221-a464fd5f30cb
|
||||||
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
||||||
github.com/oschwald/maxminddb-golang v1.12.0
|
github.com/oschwald/maxminddb-golang v1.12.0
|
||||||
|
|||||||
4
go.sum
4
go.sum
@@ -503,8 +503,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
|
|||||||
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
|
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
|
||||||
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
|
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
|
||||||
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
|
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
|
||||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250612164546-6bd7e2338d65 h1:5OfYiLjpr4dbQYJI5ouZaylkVdi2KlErLFOwBeBo5Hw=
|
github.com/netbirdio/management-integrations/integrations v0.0.0-20250718071730-f4d133556ff5 h1:Zfn8d83OVyELCdxgprcyXR3D8uqoxHtXE9PUxVXDx/w=
|
||||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250612164546-6bd7e2338d65/go.mod h1:Gi9raplYzCCyh07Olw/DVfCJTFgpr1WCXJ/Q+8TSA9Q=
|
github.com/netbirdio/management-integrations/integrations v0.0.0-20250718071730-f4d133556ff5/go.mod h1:Gi9raplYzCCyh07Olw/DVfCJTFgpr1WCXJ/Q+8TSA9Q=
|
||||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
|
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
|
||||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250514131221-a464fd5f30cb h1:Cr6age+ePALqlSvtp7wc6lYY97XN7rkD1K4XEDmY+TU=
|
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250514131221-a464fd5f30cb h1:Cr6age+ePALqlSvtp7wc6lYY97XN7rkD1K4XEDmY+TU=
|
||||||
|
|||||||
@@ -112,7 +112,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
|
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
|
||||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil)
|
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, mgmt.MockIntegratedValidator{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -292,7 +292,7 @@ var (
|
|||||||
ephemeralManager.LoadInitialPeers(ctx)
|
ephemeralManager.LoadInitialPeers(ctx)
|
||||||
|
|
||||||
gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
|
gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
|
||||||
srv, err := server.NewServer(ctx, config, accountManager, settingsManager, peersUpdateManager, secretsManager, appMetrics, ephemeralManager, authManager)
|
srv, err := server.NewServer(ctx, config, accountManager, settingsManager, peersUpdateManager, secretsManager, appMetrics, ephemeralManager, authManager, integratedPeerValidator)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed creating gRPC API handler: %v", err)
|
return fmt.Errorf("failed creating gRPC API handler: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2887,7 +2887,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, error) {
|
|||||||
|
|
||||||
permissionsManager := permissions.NewManager(store)
|
permissionsManager := permissions.NewManager(store)
|
||||||
|
|
||||||
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -219,7 +219,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
|||||||
// return empty extra settings for expected calls to UpdateAccountPeers
|
// return empty extra settings for expected calls to UpdateAccountPeers
|
||||||
settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), gomock.Any()).Return(&types.ExtraSettings{}, nil).AnyTimes()
|
settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), gomock.Any()).Return(&types.ExtraSettings{}, nil).AnyTimes()
|
||||||
permissionsManager := permissions.NewManager(store)
|
permissionsManager := permissions.NewManager(store)
|
||||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
func createDNSStore(t *testing.T) (store.Store, error) {
|
func createDNSStore(t *testing.T) (store.Store, error) {
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import (
|
|||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
|
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
|
||||||
|
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/encryption"
|
"github.com/netbirdio/netbird/encryption"
|
||||||
"github.com/netbirdio/netbird/management/proto"
|
"github.com/netbirdio/netbird/management/proto"
|
||||||
@@ -40,13 +41,14 @@ type GRPCServer struct {
|
|||||||
settingsManager settings.Manager
|
settingsManager settings.Manager
|
||||||
wgKey wgtypes.Key
|
wgKey wgtypes.Key
|
||||||
proto.UnimplementedManagementServiceServer
|
proto.UnimplementedManagementServiceServer
|
||||||
peersUpdateManager *PeersUpdateManager
|
peersUpdateManager *PeersUpdateManager
|
||||||
config *types.Config
|
config *types.Config
|
||||||
secretsManager SecretsManager
|
secretsManager SecretsManager
|
||||||
appMetrics telemetry.AppMetrics
|
appMetrics telemetry.AppMetrics
|
||||||
ephemeralManager *EphemeralManager
|
ephemeralManager *EphemeralManager
|
||||||
peerLocks sync.Map
|
peerLocks sync.Map
|
||||||
authManager auth.Manager
|
authManager auth.Manager
|
||||||
|
integratedPeerValidator integrated_validator.IntegratedValidator
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewServer creates a new Management server
|
// NewServer creates a new Management server
|
||||||
@@ -60,6 +62,7 @@ func NewServer(
|
|||||||
appMetrics telemetry.AppMetrics,
|
appMetrics telemetry.AppMetrics,
|
||||||
ephemeralManager *EphemeralManager,
|
ephemeralManager *EphemeralManager,
|
||||||
authManager auth.Manager,
|
authManager auth.Manager,
|
||||||
|
integratedPeerValidator integrated_validator.IntegratedValidator,
|
||||||
) (*GRPCServer, error) {
|
) (*GRPCServer, error) {
|
||||||
key, err := wgtypes.GeneratePrivateKey()
|
key, err := wgtypes.GeneratePrivateKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -79,14 +82,15 @@ func NewServer(
|
|||||||
return &GRPCServer{
|
return &GRPCServer{
|
||||||
wgKey: key,
|
wgKey: key,
|
||||||
// peerKey -> event channel
|
// peerKey -> event channel
|
||||||
peersUpdateManager: peersUpdateManager,
|
peersUpdateManager: peersUpdateManager,
|
||||||
accountManager: accountManager,
|
accountManager: accountManager,
|
||||||
settingsManager: settingsManager,
|
settingsManager: settingsManager,
|
||||||
config: config,
|
config: config,
|
||||||
secretsManager: secretsManager,
|
secretsManager: secretsManager,
|
||||||
authManager: authManager,
|
authManager: authManager,
|
||||||
appMetrics: appMetrics,
|
appMetrics: appMetrics,
|
||||||
ephemeralManager: ephemeralManager,
|
ephemeralManager: ephemeralManager,
|
||||||
|
integratedPeerValidator: integratedPeerValidator,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -850,7 +854,7 @@ func (s *GRPCServer) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.En
|
|||||||
return nil, status.Error(codes.NotFound, "no pkce authorization flow information available")
|
return nil, status.Error(codes.NotFound, "no pkce authorization flow information available")
|
||||||
}
|
}
|
||||||
|
|
||||||
flowInfoResp := &proto.PKCEAuthorizationFlow{
|
initInfoFlow := &proto.PKCEAuthorizationFlow{
|
||||||
ProviderConfig: &proto.ProviderConfig{
|
ProviderConfig: &proto.ProviderConfig{
|
||||||
Audience: s.config.PKCEAuthorizationFlow.ProviderConfig.Audience,
|
Audience: s.config.PKCEAuthorizationFlow.ProviderConfig.Audience,
|
||||||
ClientID: s.config.PKCEAuthorizationFlow.ProviderConfig.ClientID,
|
ClientID: s.config.PKCEAuthorizationFlow.ProviderConfig.ClientID,
|
||||||
@@ -865,6 +869,8 @@ func (s *GRPCServer) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.En
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
flowInfoResp := s.integratedPeerValidator.ValidateFlowResponse(ctx, peerKey.String(), initInfoFlow)
|
||||||
|
|
||||||
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, flowInfoResp)
|
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, flowInfoResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Error(codes.Internal, "failed to encrypt no pkce authorization flow information")
|
return nil, status.Error(codes.Internal, "failed to encrypt no pkce authorization flow information")
|
||||||
|
|||||||
@@ -424,9 +424,10 @@ func toPolicyResponse(groups []*types.Group, policy *types.Policy) *api.Policy {
|
|||||||
}
|
}
|
||||||
if group, ok := groupsMap[gid]; ok {
|
if group, ok := groupsMap[gid]; ok {
|
||||||
minimum := api.GroupMinimum{
|
minimum := api.GroupMinimum{
|
||||||
Id: group.ID,
|
Id: group.ID,
|
||||||
Name: group.Name,
|
Name: group.Name,
|
||||||
PeersCount: len(group.Peers),
|
PeersCount: len(group.Peers),
|
||||||
|
ResourcesCount: len(group.Resources),
|
||||||
}
|
}
|
||||||
destinations = append(destinations, minimum)
|
destinations = append(destinations, minimum)
|
||||||
cache[gid] = minimum
|
cache[gid] = minimum
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
package testing_tools
|
package testing_tools
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
@@ -132,7 +133,7 @@ func BuildApiBlackBoxWithDBState(t TB, sqlFile string, expectedPeerUpdate *serve
|
|||||||
}
|
}
|
||||||
|
|
||||||
geoMock := &geolocation.Mock{}
|
geoMock := &geolocation.Mock{}
|
||||||
validatorMock := server.MocIntegratedValidator{}
|
validatorMock := server.MockIntegratedValidator{}
|
||||||
proxyController := integrations.NewController(store)
|
proxyController := integrations.NewController(store)
|
||||||
userManager := users.NewManager(store)
|
userManager := users.NewManager(store)
|
||||||
permissionsManager := permissions.NewManager(store)
|
permissionsManager := permissions.NewManager(store)
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
@@ -101,22 +102,23 @@ func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountI
|
|||||||
return am.integratedPeerValidator.GetValidatedPeers(accountID, groups, peers, settings.Extra)
|
return am.integratedPeerValidator.GetValidatedPeers(accountID, groups, peers, settings.Extra)
|
||||||
}
|
}
|
||||||
|
|
||||||
type MocIntegratedValidator struct {
|
type MockIntegratedValidator struct {
|
||||||
|
integrated_validator.IntegratedValidator
|
||||||
ValidatePeerFunc func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error)
|
ValidatePeerFunc func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error {
|
func (a MockIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error) {
|
func (a MockIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error) {
|
||||||
if a.ValidatePeerFunc != nil {
|
if a.ValidatePeerFunc != nil {
|
||||||
return a.ValidatePeerFunc(context.Background(), update, peer, userID, accountID, dnsDomain, peersGroup, extraSettings)
|
return a.ValidatePeerFunc(context.Background(), update, peer, userID, accountID, dnsDomain, peersGroup, extraSettings)
|
||||||
}
|
}
|
||||||
return update, false, nil
|
return update, false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *types.ExtraSettings) (map[string]struct{}, error) {
|
func (a MockIntegratedValidator) GetValidatedPeers(accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *types.ExtraSettings) (map[string]struct{}, error) {
|
||||||
validatedPeers := make(map[string]struct{})
|
validatedPeers := make(map[string]struct{})
|
||||||
for _, peer := range peers {
|
for _, peer := range peers {
|
||||||
validatedPeers[peer.ID] = struct{}{}
|
validatedPeers[peer.ID] = struct{}{}
|
||||||
@@ -124,22 +126,22 @@ func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups []*ty
|
|||||||
return validatedPeers, nil
|
return validatedPeers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (MocIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) *nbpeer.Peer {
|
func (MockIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) *nbpeer.Peer {
|
||||||
return peer
|
return peer
|
||||||
}
|
}
|
||||||
|
|
||||||
func (MocIntegratedValidator) IsNotValidPeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error) {
|
func (MockIntegratedValidator) IsNotValidPeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error) {
|
||||||
return false, false, nil
|
return false, false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (MocIntegratedValidator) PeerDeleted(_ context.Context, _, _ string) error {
|
func (MockIntegratedValidator) PeerDeleted(_ context.Context, _, _ string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string)) {
|
func (MockIntegratedValidator) SetPeerInvalidationListener(func(accountID string)) {
|
||||||
// just a dummy
|
// just a dummy
|
||||||
}
|
}
|
||||||
|
|
||||||
func (MocIntegratedValidator) Stop(_ context.Context) {
|
func (MockIntegratedValidator) Stop(_ context.Context) {
|
||||||
// just a dummy
|
// just a dummy
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package integrated_validator
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/proto"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
)
|
)
|
||||||
@@ -17,4 +18,5 @@ type IntegratedValidator interface {
|
|||||||
PeerDeleted(ctx context.Context, accountID, peerID string) error
|
PeerDeleted(ctx context.Context, accountID, peerID string) error
|
||||||
SetPeerInvalidationListener(fn func(accountID string))
|
SetPeerInvalidationListener(fn func(accountID string))
|
||||||
Stop(ctx context.Context)
|
Stop(ctx context.Context)
|
||||||
|
ValidateFlowResponse(ctx context.Context, peerKey string, flowResponse *proto.PKCEAuthorizationFlow) *proto.PKCEAuthorizationFlow
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -448,7 +448,7 @@ func startManagementForTest(t *testing.T, testFile string, config *types.Config)
|
|||||||
permissionsManager := permissions.NewManager(store)
|
permissionsManager := permissions.NewManager(store)
|
||||||
|
|
||||||
accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted",
|
accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted",
|
||||||
eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cleanup()
|
cleanup()
|
||||||
@@ -458,7 +458,7 @@ func startManagementForTest(t *testing.T, testFile string, config *types.Config)
|
|||||||
secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
|
secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
|
||||||
|
|
||||||
ephemeralMgr := NewEphemeralManager(store, accountManager)
|
ephemeralMgr := NewEphemeralManager(store, accountManager)
|
||||||
mgmtServer, err := NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, ephemeralMgr, nil)
|
mgmtServer, err := NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, ephemeralMgr, nil, MockIntegratedValidator{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, "", cleanup, err
|
return nil, nil, "", cleanup, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -206,7 +206,7 @@ func startServer(
|
|||||||
eventStore,
|
eventStore,
|
||||||
nil,
|
nil,
|
||||||
false,
|
false,
|
||||||
server.MocIntegratedValidator{},
|
server.MockIntegratedValidator{},
|
||||||
metrics,
|
metrics,
|
||||||
port_forwarding.NewControllerMock(),
|
port_forwarding.NewControllerMock(),
|
||||||
settingsMockManager,
|
settingsMockManager,
|
||||||
@@ -227,6 +227,7 @@ func startServer(
|
|||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
|
server.MockIntegratedValidator{},
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed creating management server: %v", err)
|
t.Fatalf("failed creating management server: %v", err)
|
||||||
|
|||||||
@@ -283,7 +283,7 @@ func MigrateSetupKeyToHashedSetupKey[T any](ctx context.Context, db *gorm.DB) er
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s", "peers", "setup_key")).Error; err != nil {
|
if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s DROP COLUMN IF EXISTS %s", "peers", "setup_key")).Error; err != nil {
|
||||||
log.WithContext(ctx).Errorf("Failed to drop column %s: %v", "setup_key", err)
|
log.WithContext(ctx).Errorf("Failed to drop column %s: %v", "setup_key", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -377,6 +377,11 @@ func DropIndex[T any](ctx context.Context, db *gorm.DB, indexName string) error
|
|||||||
func CreateIndexIfNotExists[T any](ctx context.Context, db *gorm.DB, indexName string, columns ...string) error {
|
func CreateIndexIfNotExists[T any](ctx context.Context, db *gorm.DB, indexName string, columns ...string) error {
|
||||||
var model T
|
var model T
|
||||||
|
|
||||||
|
if !db.Migrator().HasTable(&model) {
|
||||||
|
log.WithContext(ctx).Debugf("table for %T does not exist, no migration needed", model)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
stmt := &gorm.Statement{DB: db}
|
stmt := &gorm.Statement{DB: db}
|
||||||
if err := stmt.Parse(&model); err != nil {
|
if err := stmt.Parse(&model); err != nil {
|
||||||
return fmt.Errorf("failed to parse model schema: %w", err)
|
return fmt.Errorf("failed to parse model schema: %w", err)
|
||||||
@@ -384,6 +389,11 @@ func CreateIndexIfNotExists[T any](ctx context.Context, db *gorm.DB, indexName s
|
|||||||
tableName := stmt.Schema.Table
|
tableName := stmt.Schema.Table
|
||||||
dialect := db.Dialector.Name()
|
dialect := db.Dialector.Name()
|
||||||
|
|
||||||
|
if db.Migrator().HasIndex(&model, indexName) {
|
||||||
|
log.WithContext(ctx).Infof("index %s already exists on table %s", indexName, tableName)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
var columnClause string
|
var columnClause string
|
||||||
if dialect == "mysql" {
|
if dialect == "mysql" {
|
||||||
var withLength []string
|
var withLength []string
|
||||||
|
|||||||
@@ -4,16 +4,21 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/gob"
|
"encoding/gob"
|
||||||
"net"
|
"net"
|
||||||
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"gorm.io/driver/mysql"
|
||||||
|
"gorm.io/driver/postgres"
|
||||||
"gorm.io/driver/sqlite"
|
"gorm.io/driver/sqlite"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/migration"
|
"github.com/netbirdio/netbird/management/server/migration"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
|
"github.com/netbirdio/netbird/management/server/testutil"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
@@ -21,7 +26,41 @@ import (
|
|||||||
func setupDatabase(t *testing.T) *gorm.DB {
|
func setupDatabase(t *testing.T) *gorm.DB {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{})
|
var db *gorm.DB
|
||||||
|
var err error
|
||||||
|
var dsn string
|
||||||
|
var cleanup func()
|
||||||
|
switch os.Getenv("NETBIRD_STORE_ENGINE") {
|
||||||
|
case "mysql":
|
||||||
|
cleanup, dsn, err = testutil.CreateMysqlTestContainer()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create MySQL test container: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if dsn == "" {
|
||||||
|
t.Fatal("MySQL connection string is empty, ensure the test container is running")
|
||||||
|
}
|
||||||
|
|
||||||
|
db, err = gorm.Open(mysql.Open(dsn+"?charset=utf8&parseTime=True&loc=Local"), &gorm.Config{})
|
||||||
|
case "postgres":
|
||||||
|
cleanup, dsn, err = testutil.CreatePostgresTestContainer()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create PostgreSQL test container: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if dsn == "" {
|
||||||
|
t.Fatalf("PostgreSQL connection string is empty, ensure the test container is running")
|
||||||
|
}
|
||||||
|
|
||||||
|
db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
|
||||||
|
case "sqlite":
|
||||||
|
db, err = gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{})
|
||||||
|
default:
|
||||||
|
db, err = gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{})
|
||||||
|
}
|
||||||
|
if cleanup != nil {
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
}
|
||||||
|
|
||||||
require.NoError(t, err, "Failed to open database")
|
require.NoError(t, err, "Failed to open database")
|
||||||
return db
|
return db
|
||||||
@@ -34,6 +73,7 @@ func TestMigrateFieldFromGobToJSON_EmptyDB(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestMigrateFieldFromGobToJSON_WithGobData(t *testing.T) {
|
func TestMigrateFieldFromGobToJSON_WithGobData(t *testing.T) {
|
||||||
|
t.Setenv("NETBIRD_STORE_ENGINE", "sqlite")
|
||||||
db := setupDatabase(t)
|
db := setupDatabase(t)
|
||||||
|
|
||||||
err := db.AutoMigrate(&types.Account{}, &route.Route{})
|
err := db.AutoMigrate(&types.Account{}, &route.Route{})
|
||||||
@@ -97,6 +137,7 @@ func TestMigrateNetIPFieldFromBlobToJSON_EmptyDB(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestMigrateNetIPFieldFromBlobToJSON_WithBlobData(t *testing.T) {
|
func TestMigrateNetIPFieldFromBlobToJSON_WithBlobData(t *testing.T) {
|
||||||
|
t.Setenv("NETBIRD_STORE_ENGINE", "sqlite")
|
||||||
db := setupDatabase(t)
|
db := setupDatabase(t)
|
||||||
|
|
||||||
err := db.AutoMigrate(&types.Account{}, &nbpeer.Peer{})
|
err := db.AutoMigrate(&types.Account{}, &nbpeer.Peer{})
|
||||||
@@ -117,12 +158,18 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithBlobData(t *testing.T) {
|
|||||||
Peers []peer `gorm:"foreignKey:AccountID;references:id"`
|
Peers []peer `gorm:"foreignKey:AccountID;references:id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
err = db.Save(&account{
|
a := &account{
|
||||||
Account: types.Account{Id: "123"},
|
Account: types.Account{Id: "123"},
|
||||||
Peers: []peer{
|
}
|
||||||
{Location: location{ConnectionIP: net.IP{10, 0, 0, 1}}},
|
|
||||||
}},
|
err = db.Save(a).Error
|
||||||
).Error
|
require.NoError(t, err, "Failed to insert account")
|
||||||
|
|
||||||
|
a.Peers = []peer{
|
||||||
|
{Location: location{ConnectionIP: net.IP{10, 0, 0, 1}}},
|
||||||
|
}
|
||||||
|
|
||||||
|
err = db.Save(a).Error
|
||||||
require.NoError(t, err, "Failed to insert blob data")
|
require.NoError(t, err, "Failed to insert blob data")
|
||||||
|
|
||||||
var blobValue string
|
var blobValue string
|
||||||
@@ -143,12 +190,18 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithJSONData(t *testing.T) {
|
|||||||
err := db.AutoMigrate(&types.Account{}, &nbpeer.Peer{})
|
err := db.AutoMigrate(&types.Account{}, &nbpeer.Peer{})
|
||||||
require.NoError(t, err, "Failed to auto-migrate tables")
|
require.NoError(t, err, "Failed to auto-migrate tables")
|
||||||
|
|
||||||
err = db.Save(&types.Account{
|
account := &types.Account{
|
||||||
Id: "1234",
|
Id: "1234",
|
||||||
PeersG: []nbpeer.Peer{
|
}
|
||||||
{Location: nbpeer.Location{ConnectionIP: net.IP{10, 0, 0, 1}}},
|
|
||||||
}},
|
err = db.Save(account).Error
|
||||||
).Error
|
require.NoError(t, err, "Failed to insert account")
|
||||||
|
|
||||||
|
account.PeersG = []nbpeer.Peer{
|
||||||
|
{AccountID: "1234", Location: nbpeer.Location{ConnectionIP: net.IP{10, 0, 0, 1}}},
|
||||||
|
}
|
||||||
|
|
||||||
|
err = db.Save(account).Error
|
||||||
require.NoError(t, err, "Failed to insert JSON data")
|
require.NoError(t, err, "Failed to insert JSON data")
|
||||||
|
|
||||||
err = migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](context.Background(), db, "location_connection_ip", "")
|
err = migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](context.Background(), db, "location_connection_ip", "")
|
||||||
@@ -162,12 +215,13 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithJSONData(t *testing.T) {
|
|||||||
func TestMigrateSetupKeyToHashedSetupKey_ForPlainKey(t *testing.T) {
|
func TestMigrateSetupKeyToHashedSetupKey_ForPlainKey(t *testing.T) {
|
||||||
db := setupDatabase(t)
|
db := setupDatabase(t)
|
||||||
|
|
||||||
err := db.AutoMigrate(&types.SetupKey{})
|
err := db.AutoMigrate(&types.SetupKey{}, &nbpeer.Peer{})
|
||||||
require.NoError(t, err, "Failed to auto-migrate tables")
|
require.NoError(t, err, "Failed to auto-migrate tables")
|
||||||
|
|
||||||
err = db.Save(&types.SetupKey{
|
err = db.Save(&types.SetupKey{
|
||||||
Id: "1",
|
Id: "1",
|
||||||
Key: "EEFDAB47-C1A5-4472-8C05-71DE9A1E8382",
|
Key: "EEFDAB47-C1A5-4472-8C05-71DE9A1E8382",
|
||||||
|
UpdatedAt: time.Now(),
|
||||||
}).Error
|
}).Error
|
||||||
require.NoError(t, err, "Failed to insert setup key")
|
require.NoError(t, err, "Failed to insert setup key")
|
||||||
|
|
||||||
@@ -192,6 +246,7 @@ func TestMigrateSetupKeyToHashedSetupKey_ForAlreadyMigratedKey_Case1(t *testing.
|
|||||||
Id: "1",
|
Id: "1",
|
||||||
Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=",
|
Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=",
|
||||||
KeySecret: "EEFDA****",
|
KeySecret: "EEFDA****",
|
||||||
|
UpdatedAt: time.Now(),
|
||||||
}).Error
|
}).Error
|
||||||
require.NoError(t, err, "Failed to insert setup key")
|
require.NoError(t, err, "Failed to insert setup key")
|
||||||
|
|
||||||
@@ -213,8 +268,9 @@ func TestMigrateSetupKeyToHashedSetupKey_ForAlreadyMigratedKey_Case2(t *testing.
|
|||||||
require.NoError(t, err, "Failed to auto-migrate tables")
|
require.NoError(t, err, "Failed to auto-migrate tables")
|
||||||
|
|
||||||
err = db.Save(&types.SetupKey{
|
err = db.Save(&types.SetupKey{
|
||||||
Id: "1",
|
Id: "1",
|
||||||
Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=",
|
Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=",
|
||||||
|
UpdatedAt: time.Now(),
|
||||||
}).Error
|
}).Error
|
||||||
require.NoError(t, err, "Failed to insert setup key")
|
require.NoError(t, err, "Failed to insert setup key")
|
||||||
|
|
||||||
@@ -235,8 +291,9 @@ func TestDropIndex(t *testing.T) {
|
|||||||
require.NoError(t, err, "Failed to auto-migrate tables")
|
require.NoError(t, err, "Failed to auto-migrate tables")
|
||||||
|
|
||||||
err = db.Save(&types.SetupKey{
|
err = db.Save(&types.SetupKey{
|
||||||
Id: "1",
|
Id: "1",
|
||||||
Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=",
|
Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=",
|
||||||
|
UpdatedAt: time.Now(),
|
||||||
}).Error
|
}).Error
|
||||||
require.NoError(t, err, "Failed to insert setup key")
|
require.NoError(t, err, "Failed to insert setup key")
|
||||||
|
|
||||||
@@ -249,3 +306,37 @@ func TestDropIndex(t *testing.T) {
|
|||||||
exist = db.Migrator().HasIndex(&types.SetupKey{}, "idx_setup_keys_account_id")
|
exist = db.Migrator().HasIndex(&types.SetupKey{}, "idx_setup_keys_account_id")
|
||||||
assert.False(t, exist, "Should not have the index")
|
assert.False(t, exist, "Should not have the index")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCreateIndex(t *testing.T) {
|
||||||
|
db := setupDatabase(t)
|
||||||
|
err := db.AutoMigrate(&nbpeer.Peer{})
|
||||||
|
assert.NoError(t, err, "Failed to auto-migrate tables")
|
||||||
|
|
||||||
|
indexName := "idx_account_ip"
|
||||||
|
|
||||||
|
err = migration.CreateIndexIfNotExists[nbpeer.Peer](context.Background(), db, indexName, "account_id", "ip")
|
||||||
|
assert.NoError(t, err, "Migration should not fail to create index")
|
||||||
|
|
||||||
|
exist := db.Migrator().HasIndex(&nbpeer.Peer{}, indexName)
|
||||||
|
assert.True(t, exist, "Should have the index")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateIndexIfExists(t *testing.T) {
|
||||||
|
db := setupDatabase(t)
|
||||||
|
err := db.AutoMigrate(&nbpeer.Peer{})
|
||||||
|
assert.NoError(t, err, "Failed to auto-migrate tables")
|
||||||
|
|
||||||
|
indexName := "idx_account_ip"
|
||||||
|
|
||||||
|
err = migration.CreateIndexIfNotExists[nbpeer.Peer](context.Background(), db, indexName, "account_id", "ip")
|
||||||
|
assert.NoError(t, err, "Migration should not fail to create index")
|
||||||
|
|
||||||
|
exist := db.Migrator().HasIndex(&nbpeer.Peer{}, indexName)
|
||||||
|
assert.True(t, exist, "Should have the index")
|
||||||
|
|
||||||
|
err = migration.CreateIndexIfNotExists[nbpeer.Peer](context.Background(), db, indexName, "account_id", "ip")
|
||||||
|
assert.NoError(t, err, "Create index should not fail if index exists")
|
||||||
|
|
||||||
|
exist = db.Migrator().HasIndex(&nbpeer.Peer{}, indexName)
|
||||||
|
assert.True(t, exist, "Should have the index")
|
||||||
|
}
|
||||||
|
|||||||
@@ -120,14 +120,20 @@ type MockAccountManager struct {
|
|||||||
GetAccountOnboardingFunc func(ctx context.Context, accountID, userID string) (*types.AccountOnboarding, error)
|
GetAccountOnboardingFunc func(ctx context.Context, accountID, userID string) (*types.AccountOnboarding, error)
|
||||||
UpdateAccountOnboardingFunc func(ctx context.Context, accountID, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error)
|
UpdateAccountOnboardingFunc func(ctx context.Context, accountID, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error)
|
||||||
GetOrCreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error)
|
GetOrCreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error)
|
||||||
|
UpdateAccountPeersFunc func(ctx context.Context, accountID string)
|
||||||
|
BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) {
|
func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) {
|
||||||
// do nothing
|
if am.UpdateAccountPeersFunc != nil {
|
||||||
|
am.UpdateAccountPeersFunc(ctx, accountID)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) {
|
func (am *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) {
|
||||||
// do nothing
|
if am.BufferUpdateAccountPeersFunc != nil {
|
||||||
|
am.BufferUpdateAccountPeersFunc(ctx, accountID)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *MockAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error {
|
func (am *MockAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error {
|
||||||
|
|||||||
@@ -785,7 +785,7 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
|||||||
AnyTimes()
|
AnyTimes()
|
||||||
|
|
||||||
permissionsManager := permissions.NewManager(store)
|
permissionsManager := permissions.NewManager(store)
|
||||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
func createNSStore(t *testing.T) (store.Store, error) {
|
func createNSStore(t *testing.T) (store.Store, error) {
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
@@ -236,11 +237,23 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
|
|||||||
|
|
||||||
if peer.Name != update.Name {
|
if peer.Name != update.Name {
|
||||||
var newLabel string
|
var newLabel string
|
||||||
newLabel, err = getPeerIPDNSLabel(ctx, transaction, peer.IP, accountID, update.Name)
|
|
||||||
|
newLabel, err = nbdns.GetParsedDomainLabel(update.Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get free DNS label: %w", err)
|
newLabel = ""
|
||||||
|
} else {
|
||||||
|
_, err := transaction.GetPeerIdByLabel(ctx, store.LockingStrengthNone, accountID, update.Name)
|
||||||
|
if err == nil {
|
||||||
|
newLabel = ""
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if newLabel == "" {
|
||||||
|
newLabel, err = getPeerIPDNSLabel(peer.IP, update.Name)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get free DNS label: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
peer.Name = update.Name
|
peer.Name = update.Name
|
||||||
peer.DNSLabel = newLabel
|
peer.DNSLabel = newLabel
|
||||||
peerLabelChanged = true
|
peerLabelChanged = true
|
||||||
@@ -472,6 +485,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
|||||||
var groupsToAdd []string
|
var groupsToAdd []string
|
||||||
var allowExtraDNSLabels bool
|
var allowExtraDNSLabels bool
|
||||||
var accountID string
|
var accountID string
|
||||||
|
var isEphemeral bool
|
||||||
if addedByUser {
|
if addedByUser {
|
||||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
|
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -501,7 +515,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
|||||||
setupKeyName = sk.Name
|
setupKeyName = sk.Name
|
||||||
allowExtraDNSLabels = sk.AllowExtraDNSLabels
|
allowExtraDNSLabels = sk.AllowExtraDNSLabels
|
||||||
accountID = sk.AccountID
|
accountID = sk.AccountID
|
||||||
|
isEphemeral = sk.Ephemeral
|
||||||
if !sk.AllowExtraDNSLabels && len(peer.ExtraDNSLabels) > 0 {
|
if !sk.AllowExtraDNSLabels && len(peer.ExtraDNSLabels) > 0 {
|
||||||
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key doesn't allow extra DNS labels")
|
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key doesn't allow extra DNS labels")
|
||||||
}
|
}
|
||||||
@@ -573,11 +587,17 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
|||||||
}
|
}
|
||||||
|
|
||||||
var freeLabel string
|
var freeLabel string
|
||||||
freeLabel, err = getPeerIPDNSLabel(ctx, am.Store, freeIP, accountID, peer.Meta.Hostname)
|
if isEphemeral || attempt > 1 {
|
||||||
if err != nil {
|
freeLabel, err = getPeerIPDNSLabel(freeIP, peer.Meta.Hostname)
|
||||||
return nil, nil, nil, fmt.Errorf("failed to get free DNS label: %w", err)
|
if err != nil {
|
||||||
|
return nil, nil, nil, fmt.Errorf("failed to get free DNS label: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
freeLabel, err = nbdns.GetParsedDomainLabel(peer.Meta.Hostname)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, fmt.Errorf("failed to get free DNS label: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
newPeer.DNSLabel = freeLabel
|
newPeer.DNSLabel = freeLabel
|
||||||
newPeer.IP = freeIP
|
newPeer.IP = freeIP
|
||||||
|
|
||||||
@@ -647,7 +667,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
|||||||
if isUniqueConstraintError(err) {
|
if isUniqueConstraintError(err) {
|
||||||
unlock()
|
unlock()
|
||||||
unlock = nil
|
unlock = nil
|
||||||
log.WithContext(ctx).Debugf("Failed to add peer in attempt %d, retrying: %v", attempt, err)
|
log.WithContext(ctx).WithFields(log.Fields{"dns_label": freeLabel, "ip": freeIP}).Tracef("Failed to add peer in attempt %d, retrying: %v", attempt, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -681,7 +701,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
|||||||
return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer)
|
return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getPeerIPDNSLabel(ctx context.Context, tx store.Store, ip net.IP, accountID, peerHostName string) (string, error) {
|
func getPeerIPDNSLabel(ip net.IP, peerHostName string) (string, error) {
|
||||||
ip = ip.To4()
|
ip = ip.To4()
|
||||||
|
|
||||||
dnsName, err := nbdns.GetParsedDomainLabel(peerHostName)
|
dnsName, err := nbdns.GetParsedDomainLabel(peerHostName)
|
||||||
@@ -689,12 +709,6 @@ func getPeerIPDNSLabel(ctx context.Context, tx store.Store, ip net.IP, accountID
|
|||||||
return "", fmt.Errorf("failed to parse peer host name %s: %w", peerHostName, err)
|
return "", fmt.Errorf("failed to parse peer host name %s: %w", peerHostName, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = tx.GetPeerIdByLabel(ctx, store.LockingStrengthNone, accountID, dnsName)
|
|
||||||
if err != nil {
|
|
||||||
//nolint:nilerr
|
|
||||||
return dnsName, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Sprintf("%s-%d-%d", dnsName, ip[2], ip[3]), nil
|
return fmt.Sprintf("%s-%d-%d", dnsName, ip[2], ip[3]), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1267,18 +1281,39 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) {
|
type bufferUpdate struct {
|
||||||
mu, _ := am.accountUpdateLocks.LoadOrStore(accountID, &sync.Mutex{})
|
mu sync.Mutex
|
||||||
lock := mu.(*sync.Mutex)
|
next *time.Timer
|
||||||
|
update atomic.Bool
|
||||||
|
}
|
||||||
|
|
||||||
if !lock.TryLock() {
|
func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) {
|
||||||
|
bufUpd, _ := am.accountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{})
|
||||||
|
b := bufUpd.(*bufferUpdate)
|
||||||
|
|
||||||
|
if !b.mu.TryLock() {
|
||||||
|
b.update.Store(true)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if b.next != nil {
|
||||||
|
b.next.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
time.Sleep(time.Duration(am.updateAccountPeersBufferInterval.Load()))
|
defer b.mu.Unlock()
|
||||||
lock.Unlock()
|
|
||||||
am.UpdateAccountPeers(ctx, accountID)
|
am.UpdateAccountPeers(ctx, accountID)
|
||||||
|
if !b.update.Load() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
b.update.Store(false)
|
||||||
|
if b.next == nil {
|
||||||
|
b.next = time.AfterFunc(time.Duration(am.updateAccountPeersBufferInterval.Load()), func() {
|
||||||
|
am.UpdateAccountPeers(ctx, accountID)
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
b.next.Reset(time.Duration(am.updateAccountPeersBufferInterval.Load()))
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -25,6 +26,7 @@ import (
|
|||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
|
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
"github.com/netbirdio/netbird/management/server/settings"
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
@@ -1271,7 +1273,7 @@ func Test_RegisterPeerByUser(t *testing.T) {
|
|||||||
settingsMockManager := settings.NewMockManager(ctrl)
|
settingsMockManager := settings.NewMockManager(ctrl)
|
||||||
permissionsManager := permissions.NewManager(s)
|
permissionsManager := permissions.NewManager(s)
|
||||||
|
|
||||||
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
@@ -1351,7 +1353,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
|
|||||||
AnyTimes()
|
AnyTimes()
|
||||||
permissionsManager := permissions.NewManager(s)
|
permissionsManager := permissions.NewManager(s)
|
||||||
|
|
||||||
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
@@ -1494,7 +1496,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
|
|||||||
|
|
||||||
permissionsManager := permissions.NewManager(s)
|
permissionsManager := permissions.NewManager(s)
|
||||||
|
|
||||||
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
@@ -1568,7 +1570,7 @@ func Test_LoginPeer(t *testing.T) {
|
|||||||
AnyTimes()
|
AnyTimes()
|
||||||
permissionsManager := permissions.NewManager(s)
|
permissionsManager := permissions.NewManager(s)
|
||||||
|
|
||||||
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
@@ -1846,7 +1848,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
|||||||
return update, true, nil
|
return update, true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
manager.integratedPeerValidator = MocIntegratedValidator{ValidatePeerFunc: requireUpdateFunc}
|
manager.integratedPeerValidator = MockIntegratedValidator{ValidatePeerFunc: requireUpdateFunc}
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
peerShouldReceiveUpdate(t, updMsg)
|
peerShouldReceiveUpdate(t, updMsg)
|
||||||
@@ -1868,7 +1870,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
|||||||
return update, false, nil
|
return update, false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
manager.integratedPeerValidator = MocIntegratedValidator{ValidatePeerFunc: requireNoUpdateFunc}
|
manager.integratedPeerValidator = MockIntegratedValidator{ValidatePeerFunc: requireNoUpdateFunc}
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
peerShouldNotReceiveUpdate(t, updMsg)
|
peerShouldNotReceiveUpdate(t, updMsg)
|
||||||
@@ -2251,3 +2253,131 @@ func Test_AddPeer(t *testing.T) {
|
|||||||
assert.Equal(t, totalPeers, maps.Values(account.SetupKeys)[0].UsedTimes)
|
assert.Equal(t, totalPeers, maps.Values(account.SetupKeys)[0].UsedTimes)
|
||||||
assert.Equal(t, uint64(totalPeers), account.Network.Serial)
|
assert.Equal(t, uint64(totalPeers), account.Network.Serial)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBufferUpdateAccountPeers(t *testing.T) {
|
||||||
|
const (
|
||||||
|
peersCount = 1000
|
||||||
|
updateAccountInterval = 50 * time.Millisecond
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
deletedPeers, updatePeersDeleted, updatePeersRuns atomic.Int32
|
||||||
|
uapLastRun, dpLastRun atomic.Int64
|
||||||
|
|
||||||
|
totalNewRuns, totalOldRuns int
|
||||||
|
)
|
||||||
|
|
||||||
|
uap := func(ctx context.Context, accountID string) {
|
||||||
|
updatePeersDeleted.Store(deletedPeers.Load())
|
||||||
|
updatePeersRuns.Add(1)
|
||||||
|
uapLastRun.Store(time.Now().UnixMilli())
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("new approach", func(t *testing.T) {
|
||||||
|
updatePeersRuns.Store(0)
|
||||||
|
updatePeersDeleted.Store(0)
|
||||||
|
deletedPeers.Store(0)
|
||||||
|
|
||||||
|
var mustore sync.Map
|
||||||
|
bufupd := func(ctx context.Context, accountID string) {
|
||||||
|
mu, _ := mustore.LoadOrStore(accountID, &bufferUpdate{})
|
||||||
|
b := mu.(*bufferUpdate)
|
||||||
|
|
||||||
|
if !b.mu.TryLock() {
|
||||||
|
b.update.Store(true)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if b.next != nil {
|
||||||
|
b.next.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer b.mu.Unlock()
|
||||||
|
uap(ctx, accountID)
|
||||||
|
if !b.update.Load() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
b.update.Store(false)
|
||||||
|
b.next = time.AfterFunc(updateAccountInterval, func() {
|
||||||
|
uap(ctx, accountID)
|
||||||
|
})
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
dp := func(ctx context.Context, accountID, peerID, userID string) error {
|
||||||
|
deletedPeers.Add(1)
|
||||||
|
dpLastRun.Store(time.Now().UnixMilli())
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
bufupd(ctx, accountID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
am := mock_server.MockAccountManager{
|
||||||
|
UpdateAccountPeersFunc: uap,
|
||||||
|
BufferUpdateAccountPeersFunc: bufupd,
|
||||||
|
DeletePeerFunc: dp,
|
||||||
|
}
|
||||||
|
empty := ""
|
||||||
|
for range peersCount {
|
||||||
|
//nolint
|
||||||
|
am.DeletePeer(context.Background(), empty, empty, empty)
|
||||||
|
}
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
assert.Equal(t, peersCount, int(deletedPeers.Load()), "Expected all peers to be deleted")
|
||||||
|
assert.Equal(t, peersCount, int(updatePeersDeleted.Load()), "Expected all peers to be updated in the buffer")
|
||||||
|
assert.GreaterOrEqual(t, uapLastRun.Load(), dpLastRun.Load(), "Expected update account peers to run after delete peer")
|
||||||
|
|
||||||
|
totalNewRuns = int(updatePeersRuns.Load())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("old approach", func(t *testing.T) {
|
||||||
|
updatePeersRuns.Store(0)
|
||||||
|
updatePeersDeleted.Store(0)
|
||||||
|
deletedPeers.Store(0)
|
||||||
|
|
||||||
|
var mustore sync.Map
|
||||||
|
bufupd := func(ctx context.Context, accountID string) {
|
||||||
|
mu, _ := mustore.LoadOrStore(accountID, &sync.Mutex{})
|
||||||
|
b := mu.(*sync.Mutex)
|
||||||
|
|
||||||
|
if !b.TryLock() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
time.Sleep(updateAccountInterval)
|
||||||
|
b.Unlock()
|
||||||
|
uap(ctx, accountID)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
dp := func(ctx context.Context, accountID, peerID, userID string) error {
|
||||||
|
deletedPeers.Add(1)
|
||||||
|
dpLastRun.Store(time.Now().UnixMilli())
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
bufupd(ctx, accountID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
am := mock_server.MockAccountManager{
|
||||||
|
UpdateAccountPeersFunc: uap,
|
||||||
|
BufferUpdateAccountPeersFunc: bufupd,
|
||||||
|
DeletePeerFunc: dp,
|
||||||
|
}
|
||||||
|
empty := ""
|
||||||
|
for range peersCount {
|
||||||
|
//nolint
|
||||||
|
am.DeletePeer(context.Background(), empty, empty, empty)
|
||||||
|
}
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
assert.Equal(t, peersCount, int(deletedPeers.Load()), "Expected all peers to be deleted")
|
||||||
|
assert.Equal(t, peersCount, int(updatePeersDeleted.Load()), "Expected all peers to be updated in the buffer")
|
||||||
|
assert.GreaterOrEqual(t, uapLastRun.Load(), dpLastRun.Load(), "Expected update account peers to run after delete peer")
|
||||||
|
|
||||||
|
totalOldRuns = int(updatePeersRuns.Load())
|
||||||
|
})
|
||||||
|
assert.Less(t, totalNewRuns, totalOldRuns, "Expected new approach to run less than old approach. New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns)
|
||||||
|
t.Logf("New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns)
|
||||||
|
}
|
||||||
|
|||||||
@@ -1284,7 +1284,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) {
|
|||||||
|
|
||||||
permissionsManager := permissions.NewManager(store)
|
permissionsManager := permissions.NewManager(store)
|
||||||
|
|
||||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
func createRouterStore(t *testing.T) (store.Store, error) {
|
func createRouterStore(t *testing.T) (store.Store, error) {
|
||||||
|
|||||||
@@ -852,7 +852,7 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
|
|||||||
am := DefaultAccountManager{
|
am := DefaultAccountManager{
|
||||||
Store: store,
|
Store: store,
|
||||||
eventStore: &activity.InMemoryEventStore{},
|
eventStore: &activity.InMemoryEventStore{},
|
||||||
integratedPeerValidator: MocIntegratedValidator{},
|
integratedPeerValidator: MockIntegratedValidator{},
|
||||||
permissionsManager: permissionsManager,
|
permissionsManager: permissionsManager,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,8 @@ var (
|
|||||||
baseWallNano int64
|
baseWallNano int64
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type Time int64
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
baseWallTime = time.Now()
|
baseWallTime = time.Now()
|
||||||
baseWallNano = baseWallTime.UnixNano()
|
baseWallNano = baseWallTime.UnixNano()
|
||||||
@@ -23,7 +25,11 @@ func init() {
|
|||||||
// and using time.Since() for elapsed calculation, this avoids repeated
|
// and using time.Since() for elapsed calculation, this avoids repeated
|
||||||
// time.Now() calls and leverages Go's internal monotonic clock for
|
// time.Now() calls and leverages Go's internal monotonic clock for
|
||||||
// efficient duration measurement.
|
// efficient duration measurement.
|
||||||
func Now() int64 {
|
func Now() Time {
|
||||||
elapsed := time.Since(baseWallTime)
|
elapsed := time.Since(baseWallTime)
|
||||||
return baseWallNano + int64(elapsed)
|
return Time(baseWallNano + int64(elapsed))
|
||||||
|
}
|
||||||
|
|
||||||
|
func Since(t Time) time.Duration {
|
||||||
|
return time.Duration(Now() - t)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,13 +7,6 @@ import (
|
|||||||
authv2 "github.com/netbirdio/netbird/relay/auth/hmac/v2"
|
authv2 "github.com/netbirdio/netbird/relay/auth/hmac/v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Validator is an interface that defines the Validate method.
|
|
||||||
type Validator interface {
|
|
||||||
Validate(any) error
|
|
||||||
// Deprecated: Use Validate instead.
|
|
||||||
ValidateHelloMsgType(any) error
|
|
||||||
}
|
|
||||||
|
|
||||||
type TimedHMACValidator struct {
|
type TimedHMACValidator struct {
|
||||||
authenticatorV2 *authv2.Validator
|
authenticatorV2 *authv2.Validator
|
||||||
authenticator *auth.TimedHMACValidator
|
authenticator *auth.TimedHMACValidator
|
||||||
|
|||||||
@@ -124,15 +124,14 @@ func (cc *connContainer) close() {
|
|||||||
// While the Connect is in progress, the OpenConn function will block until the connection is established with relay server.
|
// While the Connect is in progress, the OpenConn function will block until the connection is established with relay server.
|
||||||
type Client struct {
|
type Client struct {
|
||||||
log *log.Entry
|
log *log.Entry
|
||||||
parentCtx context.Context
|
|
||||||
connectionURL string
|
connectionURL string
|
||||||
authTokenStore *auth.TokenStore
|
authTokenStore *auth.TokenStore
|
||||||
hashedID []byte
|
hashedID messages.PeerID
|
||||||
|
|
||||||
bufPool *sync.Pool
|
bufPool *sync.Pool
|
||||||
|
|
||||||
relayConn net.Conn
|
relayConn net.Conn
|
||||||
conns map[string]*connContainer
|
conns map[messages.PeerID]*connContainer
|
||||||
serviceIsRunning bool
|
serviceIsRunning bool
|
||||||
mu sync.Mutex // protect serviceIsRunning and conns
|
mu sync.Mutex // protect serviceIsRunning and conns
|
||||||
readLoopMutex sync.Mutex
|
readLoopMutex sync.Mutex
|
||||||
@@ -142,14 +141,17 @@ type Client struct {
|
|||||||
|
|
||||||
onDisconnectListener func(string)
|
onDisconnectListener func(string)
|
||||||
listenerMutex sync.Mutex
|
listenerMutex sync.Mutex
|
||||||
|
|
||||||
|
stateSubscription *PeersStateSubscription
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
|
// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
|
||||||
func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client {
|
func NewClient(serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client {
|
||||||
hashedID, hashedStringId := messages.HashID(peerID)
|
hashedID := messages.HashID(peerID)
|
||||||
|
relayLog := log.WithFields(log.Fields{"relay": serverURL})
|
||||||
|
|
||||||
c := &Client{
|
c := &Client{
|
||||||
log: log.WithFields(log.Fields{"relay": serverURL}),
|
log: relayLog,
|
||||||
parentCtx: ctx,
|
|
||||||
connectionURL: serverURL,
|
connectionURL: serverURL,
|
||||||
authTokenStore: authTokenStore,
|
authTokenStore: authTokenStore,
|
||||||
hashedID: hashedID,
|
hashedID: hashedID,
|
||||||
@@ -159,14 +161,15 @@ func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.Token
|
|||||||
return &buf
|
return &buf
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
conns: make(map[string]*connContainer),
|
conns: make(map[messages.PeerID]*connContainer),
|
||||||
}
|
}
|
||||||
c.log.Infof("create new relay connection: local peerID: %s, local peer hashedID: %s", peerID, hashedStringId)
|
|
||||||
|
c.log.Infof("create new relay connection: local peerID: %s, local peer hashedID: %s", peerID, hashedID)
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
// Connect establishes a connection to the relay server. It blocks until the connection is established or an error occurs.
|
// Connect establishes a connection to the relay server. It blocks until the connection is established or an error occurs.
|
||||||
func (c *Client) Connect() error {
|
func (c *Client) Connect(ctx context.Context) error {
|
||||||
c.log.Infof("connecting to relay server")
|
c.log.Infof("connecting to relay server")
|
||||||
c.readLoopMutex.Lock()
|
c.readLoopMutex.Lock()
|
||||||
defer c.readLoopMutex.Unlock()
|
defer c.readLoopMutex.Unlock()
|
||||||
@@ -178,17 +181,27 @@ func (c *Client) Connect() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := c.connect(); err != nil {
|
instanceURL, err := c.connect(ctx)
|
||||||
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
c.muInstanceURL.Lock()
|
||||||
|
c.instanceURL = instanceURL
|
||||||
|
c.muInstanceURL.Unlock()
|
||||||
|
|
||||||
c.log = c.log.WithField("relay", c.instanceURL.String())
|
c.stateSubscription = NewPeersStateSubscription(c.log, c.relayConn, c.closeConnsByPeerID)
|
||||||
|
|
||||||
|
c.log = c.log.WithField("relay", instanceURL.String())
|
||||||
c.log.Infof("relay connection established")
|
c.log.Infof("relay connection established")
|
||||||
|
|
||||||
c.serviceIsRunning = true
|
c.serviceIsRunning = true
|
||||||
|
|
||||||
|
internallyStoppedFlag := newInternalStopFlag()
|
||||||
|
hc := healthcheck.NewReceiver(c.log)
|
||||||
|
go c.listenForStopEvents(ctx, hc, c.relayConn, internallyStoppedFlag)
|
||||||
|
|
||||||
c.wgReadLoop.Add(1)
|
c.wgReadLoop.Add(1)
|
||||||
go c.readLoop(c.relayConn)
|
go c.readLoop(hc, c.relayConn, internallyStoppedFlag)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -196,26 +209,50 @@ func (c *Client) Connect() error {
|
|||||||
// OpenConn create a new net.Conn for the destination peer ID. In case if the connection is in progress
|
// OpenConn create a new net.Conn for the destination peer ID. In case if the connection is in progress
|
||||||
// to the relay server, the function will block until the connection is established or timed out. Otherwise,
|
// to the relay server, the function will block until the connection is established or timed out. Otherwise,
|
||||||
// it will return immediately.
|
// it will return immediately.
|
||||||
|
// It block until the server confirm the peer is online.
|
||||||
// todo: what should happen if call with the same peerID with multiple times?
|
// todo: what should happen if call with the same peerID with multiple times?
|
||||||
func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) {
|
func (c *Client) OpenConn(ctx context.Context, dstPeerID string) (net.Conn, error) {
|
||||||
c.mu.Lock()
|
peerID := messages.HashID(dstPeerID)
|
||||||
defer c.mu.Unlock()
|
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
if !c.serviceIsRunning {
|
if !c.serviceIsRunning {
|
||||||
|
c.mu.Unlock()
|
||||||
|
return nil, fmt.Errorf("relay connection is not established")
|
||||||
|
}
|
||||||
|
_, ok := c.conns[peerID]
|
||||||
|
if ok {
|
||||||
|
c.mu.Unlock()
|
||||||
|
return nil, ErrConnAlreadyExists
|
||||||
|
}
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
if err := c.stateSubscription.WaitToBeOnlineAndSubscribe(ctx, peerID); err != nil {
|
||||||
|
c.log.Errorf("peer not available: %s, %s", peerID, err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.log.Infof("remote peer is available, prepare the relayed connection: %s", peerID)
|
||||||
|
msgChannel := make(chan Msg, 100)
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
|
if !c.serviceIsRunning {
|
||||||
|
c.mu.Unlock()
|
||||||
return nil, fmt.Errorf("relay connection is not established")
|
return nil, fmt.Errorf("relay connection is not established")
|
||||||
}
|
}
|
||||||
|
|
||||||
hashedID, hashedStringID := messages.HashID(dstPeerID)
|
c.muInstanceURL.Lock()
|
||||||
_, ok := c.conns[hashedStringID]
|
instanceURL := c.instanceURL
|
||||||
|
c.muInstanceURL.Unlock()
|
||||||
|
conn := NewConn(c, peerID, msgChannel, instanceURL)
|
||||||
|
|
||||||
|
_, ok = c.conns[peerID]
|
||||||
if ok {
|
if ok {
|
||||||
|
c.mu.Unlock()
|
||||||
|
_ = conn.Close()
|
||||||
return nil, ErrConnAlreadyExists
|
return nil, ErrConnAlreadyExists
|
||||||
}
|
}
|
||||||
|
c.conns[peerID] = newConnContainer(c.log, conn, msgChannel)
|
||||||
c.log.Infof("open connection to peer: %s", hashedStringID)
|
c.mu.Unlock()
|
||||||
msgChannel := make(chan Msg, 100)
|
|
||||||
conn := NewConn(c, hashedID, hashedStringID, msgChannel, c.instanceURL)
|
|
||||||
|
|
||||||
c.conns[hashedStringID] = newConnContainer(c.log, conn, msgChannel)
|
|
||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -254,76 +291,70 @@ func (c *Client) Close() error {
|
|||||||
return c.close(true)
|
return c.close(true)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) connect() error {
|
func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
|
||||||
rd := dialer.NewRaceDial(c.log, c.connectionURL, quic.Dialer{}, ws.Dialer{})
|
rd := dialer.NewRaceDial(c.log, c.connectionURL, quic.Dialer{}, ws.Dialer{})
|
||||||
conn, err := rd.Dial()
|
conn, err := rd.Dial()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
c.relayConn = conn
|
c.relayConn = conn
|
||||||
|
|
||||||
if err = c.handShake(); err != nil {
|
instanceURL, err := c.handShake(ctx)
|
||||||
|
if err != nil {
|
||||||
cErr := conn.Close()
|
cErr := conn.Close()
|
||||||
if cErr != nil {
|
if cErr != nil {
|
||||||
c.log.Errorf("failed to close connection: %s", cErr)
|
c.log.Errorf("failed to close connection: %s", cErr)
|
||||||
}
|
}
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return instanceURL, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) handShake() error {
|
func (c *Client) handShake(ctx context.Context) (*RelayAddr, error) {
|
||||||
msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary())
|
msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.log.Errorf("failed to marshal auth message: %s", err)
|
c.log.Errorf("failed to marshal auth message: %s", err)
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = c.relayConn.Write(msg)
|
_, err = c.relayConn.Write(msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.log.Errorf("failed to send auth message: %s", err)
|
c.log.Errorf("failed to send auth message: %s", err)
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
buf := make([]byte, messages.MaxHandshakeRespSize)
|
buf := make([]byte, messages.MaxHandshakeRespSize)
|
||||||
n, err := c.readWithTimeout(buf)
|
n, err := c.readWithTimeout(ctx, buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.log.Errorf("failed to read auth response: %s", err)
|
c.log.Errorf("failed to read auth response: %s", err)
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = messages.ValidateVersion(buf[:n])
|
_, err = messages.ValidateVersion(buf[:n])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("validate version: %w", err)
|
return nil, fmt.Errorf("validate version: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
msgType, err := messages.DetermineServerMessageType(buf[:n])
|
msgType, err := messages.DetermineServerMessageType(buf[:n])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.log.Errorf("failed to determine message type: %s", err)
|
c.log.Errorf("failed to determine message type: %s", err)
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if msgType != messages.MsgTypeAuthResponse {
|
if msgType != messages.MsgTypeAuthResponse {
|
||||||
c.log.Errorf("unexpected message type: %s", msgType)
|
c.log.Errorf("unexpected message type: %s", msgType)
|
||||||
return fmt.Errorf("unexpected message type")
|
return nil, fmt.Errorf("unexpected message type")
|
||||||
}
|
}
|
||||||
|
|
||||||
addr, err := messages.UnmarshalAuthResponse(buf[:n])
|
addr, err := messages.UnmarshalAuthResponse(buf[:n])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
c.muInstanceURL.Lock()
|
return &RelayAddr{addr: addr}, nil
|
||||||
c.instanceURL = &RelayAddr{addr: addr}
|
|
||||||
c.muInstanceURL.Unlock()
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) readLoop(relayConn net.Conn) {
|
func (c *Client) readLoop(hc *healthcheck.Receiver, relayConn net.Conn, internallyStoppedFlag *internalStopFlag) {
|
||||||
internallyStoppedFlag := newInternalStopFlag()
|
|
||||||
hc := healthcheck.NewReceiver(c.log)
|
|
||||||
go c.listenForStopEvents(hc, relayConn, internallyStoppedFlag)
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
errExit error
|
errExit error
|
||||||
n int
|
n int
|
||||||
@@ -366,10 +397,7 @@ func (c *Client) readLoop(relayConn net.Conn) {
|
|||||||
|
|
||||||
hc.Stop()
|
hc.Stop()
|
||||||
|
|
||||||
c.muInstanceURL.Lock()
|
c.stateSubscription.Cleanup()
|
||||||
c.instanceURL = nil
|
|
||||||
c.muInstanceURL.Unlock()
|
|
||||||
|
|
||||||
c.wgReadLoop.Done()
|
c.wgReadLoop.Done()
|
||||||
_ = c.close(false)
|
_ = c.close(false)
|
||||||
c.notifyDisconnected()
|
c.notifyDisconnected()
|
||||||
@@ -382,6 +410,14 @@ func (c *Client) handleMsg(msgType messages.MsgType, buf []byte, bufPtr *[]byte,
|
|||||||
c.bufPool.Put(bufPtr)
|
c.bufPool.Put(bufPtr)
|
||||||
case messages.MsgTypeTransport:
|
case messages.MsgTypeTransport:
|
||||||
return c.handleTransportMsg(buf, bufPtr, internallyStoppedFlag)
|
return c.handleTransportMsg(buf, bufPtr, internallyStoppedFlag)
|
||||||
|
case messages.MsgTypePeersOnline:
|
||||||
|
c.handlePeersOnlineMsg(buf)
|
||||||
|
c.bufPool.Put(bufPtr)
|
||||||
|
return true
|
||||||
|
case messages.MsgTypePeersWentOffline:
|
||||||
|
c.handlePeersWentOfflineMsg(buf)
|
||||||
|
c.bufPool.Put(bufPtr)
|
||||||
|
return true
|
||||||
case messages.MsgTypeClose:
|
case messages.MsgTypeClose:
|
||||||
c.log.Debugf("relay connection close by server")
|
c.log.Debugf("relay connection close by server")
|
||||||
c.bufPool.Put(bufPtr)
|
c.bufPool.Put(bufPtr)
|
||||||
@@ -413,18 +449,16 @@ func (c *Client) handleTransportMsg(buf []byte, bufPtr *[]byte, internallyStoppe
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
stringID := messages.HashIDToString(peerID)
|
|
||||||
|
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
if !c.serviceIsRunning {
|
if !c.serviceIsRunning {
|
||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
c.bufPool.Put(bufPtr)
|
c.bufPool.Put(bufPtr)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
container, ok := c.conns[stringID]
|
container, ok := c.conns[*peerID]
|
||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
if !ok {
|
if !ok {
|
||||||
c.log.Errorf("peer not found: %s", stringID)
|
c.log.Errorf("peer not found: %s", peerID.String())
|
||||||
c.bufPool.Put(bufPtr)
|
c.bufPool.Put(bufPtr)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@@ -437,9 +471,9 @@ func (c *Client) handleTransportMsg(buf []byte, bufPtr *[]byte, internallyStoppe
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) writeTo(connReference *Conn, id string, dstID []byte, payload []byte) (int, error) {
|
func (c *Client) writeTo(connReference *Conn, dstID messages.PeerID, payload []byte) (int, error) {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
conn, ok := c.conns[id]
|
conn, ok := c.conns[dstID]
|
||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
if !ok {
|
if !ok {
|
||||||
return 0, net.ErrClosed
|
return 0, net.ErrClosed
|
||||||
@@ -464,7 +498,7 @@ func (c *Client) writeTo(connReference *Conn, id string, dstID []byte, payload [
|
|||||||
return len(payload), err
|
return len(payload), err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) listenForStopEvents(hc *healthcheck.Receiver, conn net.Conn, internalStopFlag *internalStopFlag) {
|
func (c *Client) listenForStopEvents(ctx context.Context, hc *healthcheck.Receiver, conn net.Conn, internalStopFlag *internalStopFlag) {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case _, ok := <-hc.OnTimeout:
|
case _, ok := <-hc.OnTimeout:
|
||||||
@@ -478,7 +512,7 @@ func (c *Client) listenForStopEvents(hc *healthcheck.Receiver, conn net.Conn, in
|
|||||||
c.log.Warnf("failed to close connection: %s", err)
|
c.log.Warnf("failed to close connection: %s", err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
case <-c.parentCtx.Done():
|
case <-ctx.Done():
|
||||||
err := c.close(true)
|
err := c.close(true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.log.Errorf("failed to teardown connection: %s", err)
|
c.log.Errorf("failed to teardown connection: %s", err)
|
||||||
@@ -492,10 +526,31 @@ func (c *Client) closeAllConns() {
|
|||||||
for _, container := range c.conns {
|
for _, container := range c.conns {
|
||||||
container.close()
|
container.close()
|
||||||
}
|
}
|
||||||
c.conns = make(map[string]*connContainer)
|
c.conns = make(map[messages.PeerID]*connContainer)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) closeConn(connReference *Conn, id string) error {
|
func (c *Client) closeConnsByPeerID(peerIDs []messages.PeerID) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
for _, peerID := range peerIDs {
|
||||||
|
container, ok := c.conns[peerID]
|
||||||
|
if !ok {
|
||||||
|
c.log.Warnf("can not close connection, peer not found: %s", peerID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
container.log.Infof("remote peer has been disconnected, free up connection: %s", peerID)
|
||||||
|
container.close()
|
||||||
|
delete(c.conns, peerID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.stateSubscription.UnsubscribeStateChange(peerIDs); err != nil {
|
||||||
|
c.log.Errorf("failed to unsubscribe from peer state change: %s, %s", peerIDs, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) closeConn(connReference *Conn, id messages.PeerID) error {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
defer c.mu.Unlock()
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
@@ -507,6 +562,11 @@ func (c *Client) closeConn(connReference *Conn, id string) error {
|
|||||||
if container.conn != connReference {
|
if container.conn != connReference {
|
||||||
return fmt.Errorf("conn reference mismatch")
|
return fmt.Errorf("conn reference mismatch")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := c.stateSubscription.UnsubscribeStateChange([]messages.PeerID{id}); err != nil {
|
||||||
|
container.log.Errorf("failed to unsubscribe from peer state change: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
c.log.Infof("free up connection to peer: %s", id)
|
c.log.Infof("free up connection to peer: %s", id)
|
||||||
delete(c.conns, id)
|
delete(c.conns, id)
|
||||||
container.close()
|
container.close()
|
||||||
@@ -525,8 +585,12 @@ func (c *Client) close(gracefullyExit bool) error {
|
|||||||
c.log.Warn("relay connection was already marked as not running")
|
c.log.Warn("relay connection was already marked as not running")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
c.serviceIsRunning = false
|
c.serviceIsRunning = false
|
||||||
|
|
||||||
|
c.muInstanceURL.Lock()
|
||||||
|
c.instanceURL = nil
|
||||||
|
c.muInstanceURL.Unlock()
|
||||||
|
|
||||||
c.log.Infof("closing all peer connections")
|
c.log.Infof("closing all peer connections")
|
||||||
c.closeAllConns()
|
c.closeAllConns()
|
||||||
if gracefullyExit {
|
if gracefullyExit {
|
||||||
@@ -559,8 +623,8 @@ func (c *Client) writeCloseMsg() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) readWithTimeout(buf []byte) (int, error) {
|
func (c *Client) readWithTimeout(ctx context.Context, buf []byte) (int, error) {
|
||||||
ctx, cancel := context.WithTimeout(c.parentCtx, serverResponseTimeout)
|
ctx, cancel := context.WithTimeout(ctx, serverResponseTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
readDone := make(chan struct{})
|
readDone := make(chan struct{})
|
||||||
@@ -581,3 +645,21 @@ func (c *Client) readWithTimeout(buf []byte) (int, error) {
|
|||||||
return n, err
|
return n, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) handlePeersOnlineMsg(buf []byte) {
|
||||||
|
peersID, err := messages.UnmarshalPeersOnlineMsg(buf)
|
||||||
|
if err != nil {
|
||||||
|
c.log.Errorf("failed to unmarshal peers online msg: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.stateSubscription.OnPeersOnline(peersID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) handlePeersWentOfflineMsg(buf []byte) {
|
||||||
|
peersID, err := messages.UnMarshalPeersWentOffline(buf)
|
||||||
|
if err != nil {
|
||||||
|
c.log.Errorf("failed to unmarshal peers went offline msg: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.stateSubscription.OnPeersWentOffline(peersID)
|
||||||
|
}
|
||||||
|
|||||||
@@ -18,14 +18,19 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
av = &allow.Auth{}
|
|
||||||
hmacTokenStore = &hmac.TokenStore{}
|
hmacTokenStore = &hmac.TokenStore{}
|
||||||
serverListenAddr = "127.0.0.1:1234"
|
serverListenAddr = "127.0.0.1:1234"
|
||||||
serverURL = "rel://127.0.0.1:1234"
|
serverURL = "rel://127.0.0.1:1234"
|
||||||
|
serverCfg = server.Config{
|
||||||
|
Meter: otel.Meter(""),
|
||||||
|
ExposedAddress: serverURL,
|
||||||
|
TLSSupport: false,
|
||||||
|
AuthValidator: &allow.Auth{},
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestMain(m *testing.M) {
|
func TestMain(m *testing.M) {
|
||||||
_ = util.InitLog("error", "console")
|
_ = util.InitLog("debug", "console")
|
||||||
code := m.Run()
|
code := m.Run()
|
||||||
os.Exit(code)
|
os.Exit(code)
|
||||||
}
|
}
|
||||||
@@ -33,7 +38,7 @@ func TestMain(m *testing.M) {
|
|||||||
func TestClient(t *testing.T) {
|
func TestClient(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
|
srv, err := server.NewServer(serverCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create server: %s", err)
|
t.Fatalf("failed to create server: %s", err)
|
||||||
}
|
}
|
||||||
@@ -58,37 +63,37 @@ func TestClient(t *testing.T) {
|
|||||||
t.Fatalf("failed to start server: %s", err)
|
t.Fatalf("failed to start server: %s", err)
|
||||||
}
|
}
|
||||||
t.Log("alice connecting to server")
|
t.Log("alice connecting to server")
|
||||||
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
|
clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
|
||||||
err = clientAlice.Connect()
|
err = clientAlice.Connect(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to connect to server: %s", err)
|
t.Fatalf("failed to connect to server: %s", err)
|
||||||
}
|
}
|
||||||
defer clientAlice.Close()
|
defer clientAlice.Close()
|
||||||
|
|
||||||
t.Log("placeholder connecting to server")
|
t.Log("placeholder connecting to server")
|
||||||
clientPlaceHolder := NewClient(ctx, serverURL, hmacTokenStore, "clientPlaceHolder")
|
clientPlaceHolder := NewClient(serverURL, hmacTokenStore, "clientPlaceHolder")
|
||||||
err = clientPlaceHolder.Connect()
|
err = clientPlaceHolder.Connect(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to connect to server: %s", err)
|
t.Fatalf("failed to connect to server: %s", err)
|
||||||
}
|
}
|
||||||
defer clientPlaceHolder.Close()
|
defer clientPlaceHolder.Close()
|
||||||
|
|
||||||
t.Log("Bob connecting to server")
|
t.Log("Bob connecting to server")
|
||||||
clientBob := NewClient(ctx, serverURL, hmacTokenStore, "bob")
|
clientBob := NewClient(serverURL, hmacTokenStore, "bob")
|
||||||
err = clientBob.Connect()
|
err = clientBob.Connect(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to connect to server: %s", err)
|
t.Fatalf("failed to connect to server: %s", err)
|
||||||
}
|
}
|
||||||
defer clientBob.Close()
|
defer clientBob.Close()
|
||||||
|
|
||||||
t.Log("Alice open connection to Bob")
|
t.Log("Alice open connection to Bob")
|
||||||
connAliceToBob, err := clientAlice.OpenConn("bob")
|
connAliceToBob, err := clientAlice.OpenConn(ctx, "bob")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to bind channel: %s", err)
|
t.Fatalf("failed to bind channel: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Log("Bob open connection to Alice")
|
t.Log("Bob open connection to Alice")
|
||||||
connBobToAlice, err := clientBob.OpenConn("alice")
|
connBobToAlice, err := clientBob.OpenConn(ctx, "alice")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to bind channel: %s", err)
|
t.Fatalf("failed to bind channel: %s", err)
|
||||||
}
|
}
|
||||||
@@ -115,7 +120,7 @@ func TestClient(t *testing.T) {
|
|||||||
func TestRegistration(t *testing.T) {
|
func TestRegistration(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
||||||
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
|
srv, err := server.NewServer(serverCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create server: %s", err)
|
t.Fatalf("failed to create server: %s", err)
|
||||||
}
|
}
|
||||||
@@ -132,8 +137,8 @@ func TestRegistration(t *testing.T) {
|
|||||||
t.Fatalf("failed to start server: %s", err)
|
t.Fatalf("failed to start server: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
|
clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
|
||||||
err = clientAlice.Connect()
|
err = clientAlice.Connect(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = srv.Shutdown(ctx)
|
_ = srv.Shutdown(ctx)
|
||||||
t.Fatalf("failed to connect to server: %s", err)
|
t.Fatalf("failed to connect to server: %s", err)
|
||||||
@@ -172,8 +177,8 @@ func TestRegistrationTimeout(t *testing.T) {
|
|||||||
_ = fakeTCPListener.Close()
|
_ = fakeTCPListener.Close()
|
||||||
}(fakeTCPListener)
|
}(fakeTCPListener)
|
||||||
|
|
||||||
clientAlice := NewClient(ctx, "127.0.0.1:1234", hmacTokenStore, "alice")
|
clientAlice := NewClient("127.0.0.1:1234", hmacTokenStore, "alice")
|
||||||
err = clientAlice.Connect()
|
err = clientAlice.Connect(ctx)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("failed to connect to server: %s", err)
|
t.Errorf("failed to connect to server: %s", err)
|
||||||
}
|
}
|
||||||
@@ -189,7 +194,7 @@ func TestEcho(t *testing.T) {
|
|||||||
idAlice := "alice"
|
idAlice := "alice"
|
||||||
idBob := "bob"
|
idBob := "bob"
|
||||||
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
||||||
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
|
srv, err := server.NewServer(serverCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create server: %s", err)
|
t.Fatalf("failed to create server: %s", err)
|
||||||
}
|
}
|
||||||
@@ -213,8 +218,8 @@ func TestEcho(t *testing.T) {
|
|||||||
t.Fatalf("failed to start server: %s", err)
|
t.Fatalf("failed to start server: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, idAlice)
|
clientAlice := NewClient(serverURL, hmacTokenStore, idAlice)
|
||||||
err = clientAlice.Connect()
|
err = clientAlice.Connect(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to connect to server: %s", err)
|
t.Fatalf("failed to connect to server: %s", err)
|
||||||
}
|
}
|
||||||
@@ -225,8 +230,8 @@ func TestEcho(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
clientBob := NewClient(ctx, serverURL, hmacTokenStore, idBob)
|
clientBob := NewClient(serverURL, hmacTokenStore, idBob)
|
||||||
err = clientBob.Connect()
|
err = clientBob.Connect(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to connect to server: %s", err)
|
t.Fatalf("failed to connect to server: %s", err)
|
||||||
}
|
}
|
||||||
@@ -237,12 +242,12 @@ func TestEcho(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
connAliceToBob, err := clientAlice.OpenConn(idBob)
|
connAliceToBob, err := clientAlice.OpenConn(ctx, idBob)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to bind channel: %s", err)
|
t.Fatalf("failed to bind channel: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
connBobToAlice, err := clientBob.OpenConn(idAlice)
|
connBobToAlice, err := clientBob.OpenConn(ctx, idAlice)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to bind channel: %s", err)
|
t.Fatalf("failed to bind channel: %s", err)
|
||||||
}
|
}
|
||||||
@@ -278,7 +283,7 @@ func TestBindToUnavailabePeer(t *testing.T) {
|
|||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
||||||
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
|
srv, err := server.NewServer(serverCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create server: %s", err)
|
t.Fatalf("failed to create server: %s", err)
|
||||||
}
|
}
|
||||||
@@ -303,14 +308,14 @@ func TestBindToUnavailabePeer(t *testing.T) {
|
|||||||
t.Fatalf("failed to start server: %s", err)
|
t.Fatalf("failed to start server: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
|
clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
|
||||||
err = clientAlice.Connect()
|
err = clientAlice.Connect(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to connect to server: %s", err)
|
t.Errorf("failed to connect to server: %s", err)
|
||||||
}
|
}
|
||||||
_, err = clientAlice.OpenConn("bob")
|
_, err = clientAlice.OpenConn(ctx, "bob")
|
||||||
if err != nil {
|
if err == nil {
|
||||||
t.Errorf("failed to bind channel: %s", err)
|
t.Errorf("expected error when binding to unavailable peer, got nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("closing client")
|
log.Infof("closing client")
|
||||||
@@ -324,7 +329,7 @@ func TestBindReconnect(t *testing.T) {
|
|||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
||||||
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
|
srv, err := server.NewServer(serverCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create server: %s", err)
|
t.Fatalf("failed to create server: %s", err)
|
||||||
}
|
}
|
||||||
@@ -349,24 +354,24 @@ func TestBindReconnect(t *testing.T) {
|
|||||||
t.Fatalf("failed to start server: %s", err)
|
t.Fatalf("failed to start server: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
|
clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
|
||||||
err = clientAlice.Connect()
|
err = clientAlice.Connect(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to connect to server: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
clientBob := NewClient(serverURL, hmacTokenStore, "bob")
|
||||||
|
err = clientBob.Connect(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to connect to server: %s", err)
|
t.Errorf("failed to connect to server: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = clientAlice.OpenConn("bob")
|
_, err = clientAlice.OpenConn(ctx, "bob")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to bind channel: %s", err)
|
t.Fatalf("failed to bind channel: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
clientBob := NewClient(ctx, serverURL, hmacTokenStore, "bob")
|
chBob, err := clientBob.OpenConn(ctx, "alice")
|
||||||
err = clientBob.Connect()
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to connect to server: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
chBob, err := clientBob.OpenConn("alice")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to bind channel: %s", err)
|
t.Errorf("failed to bind channel: %s", err)
|
||||||
}
|
}
|
||||||
@@ -377,18 +382,28 @@ func TestBindReconnect(t *testing.T) {
|
|||||||
t.Errorf("failed to close client: %s", err)
|
t.Errorf("failed to close client: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
clientAlice = NewClient(ctx, serverURL, hmacTokenStore, "alice")
|
clientAlice = NewClient(serverURL, hmacTokenStore, "alice")
|
||||||
err = clientAlice.Connect()
|
err = clientAlice.Connect(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to connect to server: %s", err)
|
t.Errorf("failed to connect to server: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
chAlice, err := clientAlice.OpenConn("bob")
|
chAlice, err := clientAlice.OpenConn(ctx, "bob")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to bind channel: %s", err)
|
t.Errorf("failed to bind channel: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
testString := "hello alice, I am bob"
|
testString := "hello alice, I am bob"
|
||||||
|
_, err = chBob.Write([]byte(testString))
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("expected error when writing to channel, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
chBob, err = clientBob.OpenConn(ctx, "alice")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to bind channel: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
_, err = chBob.Write([]byte(testString))
|
_, err = chBob.Write([]byte(testString))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to write to channel: %s", err)
|
t.Errorf("failed to write to channel: %s", err)
|
||||||
@@ -415,7 +430,7 @@ func TestCloseConn(t *testing.T) {
|
|||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
||||||
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
|
srv, err := server.NewServer(serverCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create server: %s", err)
|
t.Fatalf("failed to create server: %s", err)
|
||||||
}
|
}
|
||||||
@@ -440,13 +455,19 @@ func TestCloseConn(t *testing.T) {
|
|||||||
t.Fatalf("failed to start server: %s", err)
|
t.Fatalf("failed to start server: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
|
bob := NewClient(serverURL, hmacTokenStore, "bob")
|
||||||
err = clientAlice.Connect()
|
err = bob.Connect(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to connect to server: %s", err)
|
t.Errorf("failed to connect to server: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := clientAlice.OpenConn("bob")
|
clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
|
||||||
|
err = clientAlice.Connect(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to connect to server: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := clientAlice.OpenConn(ctx, "bob")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to bind channel: %s", err)
|
t.Errorf("failed to bind channel: %s", err)
|
||||||
}
|
}
|
||||||
@@ -472,7 +493,7 @@ func TestCloseRelayConn(t *testing.T) {
|
|||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
||||||
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
|
srv, err := server.NewServer(serverCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create server: %s", err)
|
t.Fatalf("failed to create server: %s", err)
|
||||||
}
|
}
|
||||||
@@ -496,13 +517,19 @@ func TestCloseRelayConn(t *testing.T) {
|
|||||||
t.Fatalf("failed to start server: %s", err)
|
t.Fatalf("failed to start server: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
|
bob := NewClient(serverURL, hmacTokenStore, "bob")
|
||||||
err = clientAlice.Connect()
|
err = bob.Connect(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to connect to server: %s", err)
|
t.Fatalf("failed to connect to server: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := clientAlice.OpenConn("bob")
|
clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
|
||||||
|
err = clientAlice.Connect(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to connect to server: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := clientAlice.OpenConn(ctx, "bob")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to bind channel: %s", err)
|
t.Errorf("failed to bind channel: %s", err)
|
||||||
}
|
}
|
||||||
@@ -514,7 +541,7 @@ func TestCloseRelayConn(t *testing.T) {
|
|||||||
t.Errorf("unexpected reading from closed connection")
|
t.Errorf("unexpected reading from closed connection")
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = clientAlice.OpenConn("bob")
|
_, err = clientAlice.OpenConn(ctx, "bob")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("unexpected opening connection to closed server")
|
t.Errorf("unexpected opening connection to closed server")
|
||||||
}
|
}
|
||||||
@@ -524,7 +551,7 @@ func TestCloseByServer(t *testing.T) {
|
|||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
||||||
srv1, err := server.NewServer(otel.Meter(""), serverURL, false, av)
|
srv1, err := server.NewServer(serverCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create server: %s", err)
|
t.Fatalf("failed to create server: %s", err)
|
||||||
}
|
}
|
||||||
@@ -544,8 +571,8 @@ func TestCloseByServer(t *testing.T) {
|
|||||||
|
|
||||||
idAlice := "alice"
|
idAlice := "alice"
|
||||||
log.Debugf("connect by alice")
|
log.Debugf("connect by alice")
|
||||||
relayClient := NewClient(ctx, serverURL, hmacTokenStore, idAlice)
|
relayClient := NewClient(serverURL, hmacTokenStore, idAlice)
|
||||||
err = relayClient.Connect()
|
err = relayClient.Connect(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to connect to server: %s", err)
|
log.Fatalf("failed to connect to server: %s", err)
|
||||||
}
|
}
|
||||||
@@ -567,7 +594,7 @@ func TestCloseByServer(t *testing.T) {
|
|||||||
log.Fatalf("timeout waiting for client to disconnect")
|
log.Fatalf("timeout waiting for client to disconnect")
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = relayClient.OpenConn("bob")
|
_, err = relayClient.OpenConn(ctx, "bob")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("unexpected opening connection to closed server")
|
t.Errorf("unexpected opening connection to closed server")
|
||||||
}
|
}
|
||||||
@@ -577,7 +604,7 @@ func TestCloseByClient(t *testing.T) {
|
|||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
||||||
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
|
srv, err := server.NewServer(serverCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create server: %s", err)
|
t.Fatalf("failed to create server: %s", err)
|
||||||
}
|
}
|
||||||
@@ -596,8 +623,8 @@ func TestCloseByClient(t *testing.T) {
|
|||||||
|
|
||||||
idAlice := "alice"
|
idAlice := "alice"
|
||||||
log.Debugf("connect by alice")
|
log.Debugf("connect by alice")
|
||||||
relayClient := NewClient(ctx, serverURL, hmacTokenStore, idAlice)
|
relayClient := NewClient(serverURL, hmacTokenStore, idAlice)
|
||||||
err = relayClient.Connect()
|
err = relayClient.Connect(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to connect to server: %s", err)
|
log.Fatalf("failed to connect to server: %s", err)
|
||||||
}
|
}
|
||||||
@@ -607,7 +634,7 @@ func TestCloseByClient(t *testing.T) {
|
|||||||
t.Errorf("failed to close client: %s", err)
|
t.Errorf("failed to close client: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = relayClient.OpenConn("bob")
|
_, err = relayClient.OpenConn(ctx, "bob")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("unexpected opening connection to closed server")
|
t.Errorf("unexpected opening connection to closed server")
|
||||||
}
|
}
|
||||||
@@ -623,7 +650,7 @@ func TestCloseNotDrainedChannel(t *testing.T) {
|
|||||||
idAlice := "alice"
|
idAlice := "alice"
|
||||||
idBob := "bob"
|
idBob := "bob"
|
||||||
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
||||||
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
|
srv, err := server.NewServer(serverCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create server: %s", err)
|
t.Fatalf("failed to create server: %s", err)
|
||||||
}
|
}
|
||||||
@@ -647,8 +674,8 @@ func TestCloseNotDrainedChannel(t *testing.T) {
|
|||||||
t.Fatalf("failed to start server: %s", err)
|
t.Fatalf("failed to start server: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, idAlice)
|
clientAlice := NewClient(serverURL, hmacTokenStore, idAlice)
|
||||||
err = clientAlice.Connect()
|
err = clientAlice.Connect(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to connect to server: %s", err)
|
t.Fatalf("failed to connect to server: %s", err)
|
||||||
}
|
}
|
||||||
@@ -659,8 +686,8 @@ func TestCloseNotDrainedChannel(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
clientBob := NewClient(ctx, serverURL, hmacTokenStore, idBob)
|
clientBob := NewClient(serverURL, hmacTokenStore, idBob)
|
||||||
err = clientBob.Connect()
|
err = clientBob.Connect(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to connect to server: %s", err)
|
t.Fatalf("failed to connect to server: %s", err)
|
||||||
}
|
}
|
||||||
@@ -671,12 +698,12 @@ func TestCloseNotDrainedChannel(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
connAliceToBob, err := clientAlice.OpenConn(idBob)
|
connAliceToBob, err := clientAlice.OpenConn(ctx, idBob)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to bind channel: %s", err)
|
t.Fatalf("failed to bind channel: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
connBobToAlice, err := clientBob.OpenConn(idAlice)
|
connBobToAlice, err := clientBob.OpenConn(ctx, idAlice)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to bind channel: %s", err)
|
t.Fatalf("failed to bind channel: %s", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,13 +3,14 @@ package client
|
|||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/relay/messages"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Conn represent a connection to a relayed remote peer.
|
// Conn represent a connection to a relayed remote peer.
|
||||||
type Conn struct {
|
type Conn struct {
|
||||||
client *Client
|
client *Client
|
||||||
dstID []byte
|
dstID messages.PeerID
|
||||||
dstStringID string
|
|
||||||
messageChan chan Msg
|
messageChan chan Msg
|
||||||
instanceURL *RelayAddr
|
instanceURL *RelayAddr
|
||||||
}
|
}
|
||||||
@@ -17,14 +18,12 @@ type Conn struct {
|
|||||||
// NewConn creates a new connection to a relayed remote peer.
|
// NewConn creates a new connection to a relayed remote peer.
|
||||||
// client: the client instance, it used to send messages to the destination peer
|
// client: the client instance, it used to send messages to the destination peer
|
||||||
// dstID: the destination peer ID
|
// dstID: the destination peer ID
|
||||||
// dstStringID: the destination peer ID in string format
|
|
||||||
// messageChan: the channel where the messages will be received
|
// messageChan: the channel where the messages will be received
|
||||||
// instanceURL: the relay instance URL, it used to get the proper server instance address for the remote peer
|
// instanceURL: the relay instance URL, it used to get the proper server instance address for the remote peer
|
||||||
func NewConn(client *Client, dstID []byte, dstStringID string, messageChan chan Msg, instanceURL *RelayAddr) *Conn {
|
func NewConn(client *Client, dstID messages.PeerID, messageChan chan Msg, instanceURL *RelayAddr) *Conn {
|
||||||
c := &Conn{
|
c := &Conn{
|
||||||
client: client,
|
client: client,
|
||||||
dstID: dstID,
|
dstID: dstID,
|
||||||
dstStringID: dstStringID,
|
|
||||||
messageChan: messageChan,
|
messageChan: messageChan,
|
||||||
instanceURL: instanceURL,
|
instanceURL: instanceURL,
|
||||||
}
|
}
|
||||||
@@ -33,7 +32,7 @@ func NewConn(client *Client, dstID []byte, dstStringID string, messageChan chan
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) Write(p []byte) (n int, err error) {
|
func (c *Conn) Write(p []byte) (n int, err error) {
|
||||||
return c.client.writeTo(c, c.dstStringID, c.dstID, p)
|
return c.client.writeTo(c, c.dstID, p)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) Read(b []byte) (n int, err error) {
|
func (c *Conn) Read(b []byte) (n int, err error) {
|
||||||
@@ -48,7 +47,7 @@ func (c *Conn) Read(b []byte) (n int, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) Close() error {
|
func (c *Conn) Close() error {
|
||||||
return c.client.closeConn(c, c.dstStringID)
|
return c.client.closeConn(c, c.dstID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) LocalAddr() net.Addr {
|
func (c *Conn) LocalAddr() net.Addr {
|
||||||
|
|||||||
@@ -80,7 +80,7 @@ func (g *Guard) tryToQuickReconnect(parentCtx context.Context, rc *Client) bool
|
|||||||
|
|
||||||
log.Infof("try to reconnect to Relay server: %s", rc.connectionURL)
|
log.Infof("try to reconnect to Relay server: %s", rc.connectionURL)
|
||||||
|
|
||||||
if err := rc.Connect(); err != nil {
|
if err := rc.Connect(parentCtx); err != nil {
|
||||||
log.Errorf("failed to reconnect to relay server: %s", err)
|
log.Errorf("failed to reconnect to relay server: %s", err)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ type OnServerCloseListener func()
|
|||||||
// ManagerService is the interface for the relay manager.
|
// ManagerService is the interface for the relay manager.
|
||||||
type ManagerService interface {
|
type ManagerService interface {
|
||||||
Serve() error
|
Serve() error
|
||||||
OpenConn(serverAddress, peerKey string) (net.Conn, error)
|
OpenConn(ctx context.Context, serverAddress, peerKey string) (net.Conn, error)
|
||||||
AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error
|
AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error
|
||||||
RelayInstanceAddress() (string, error)
|
RelayInstanceAddress() (string, error)
|
||||||
ServerURLs() []string
|
ServerURLs() []string
|
||||||
@@ -65,7 +65,7 @@ type Manager struct {
|
|||||||
|
|
||||||
relayClient *Client
|
relayClient *Client
|
||||||
// the guard logic can overwrite the relayClient variable, this mutex protect the usage of the variable
|
// the guard logic can overwrite the relayClient variable, this mutex protect the usage of the variable
|
||||||
relayClientMu sync.Mutex
|
relayClientMu sync.RWMutex
|
||||||
reconnectGuard *Guard
|
reconnectGuard *Guard
|
||||||
|
|
||||||
relayClients map[string]*RelayTrack
|
relayClients map[string]*RelayTrack
|
||||||
@@ -123,9 +123,9 @@ func (m *Manager) Serve() error {
|
|||||||
// OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be
|
// OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be
|
||||||
// established via the relay server. If the peer is on a different relay server, the manager will establish a new
|
// established via the relay server. If the peer is on a different relay server, the manager will establish a new
|
||||||
// connection to the relay server. It returns back with a net.Conn what represent the remote peer connection.
|
// connection to the relay server. It returns back with a net.Conn what represent the remote peer connection.
|
||||||
func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) {
|
func (m *Manager) OpenConn(ctx context.Context, serverAddress, peerKey string) (net.Conn, error) {
|
||||||
m.relayClientMu.Lock()
|
m.relayClientMu.RLock()
|
||||||
defer m.relayClientMu.Unlock()
|
defer m.relayClientMu.RUnlock()
|
||||||
|
|
||||||
if m.relayClient == nil {
|
if m.relayClient == nil {
|
||||||
return nil, ErrRelayClientNotConnected
|
return nil, ErrRelayClientNotConnected
|
||||||
@@ -141,10 +141,10 @@ func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) {
|
|||||||
)
|
)
|
||||||
if !foreign {
|
if !foreign {
|
||||||
log.Debugf("open peer connection via permanent server: %s", peerKey)
|
log.Debugf("open peer connection via permanent server: %s", peerKey)
|
||||||
netConn, err = m.relayClient.OpenConn(peerKey)
|
netConn, err = m.relayClient.OpenConn(ctx, peerKey)
|
||||||
} else {
|
} else {
|
||||||
log.Debugf("open peer connection via foreign server: %s", serverAddress)
|
log.Debugf("open peer connection via foreign server: %s", serverAddress)
|
||||||
netConn, err = m.openConnVia(serverAddress, peerKey)
|
netConn, err = m.openConnVia(ctx, serverAddress, peerKey)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -155,8 +155,8 @@ func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) {
|
|||||||
|
|
||||||
// Ready returns true if the home Relay client is connected to the relay server.
|
// Ready returns true if the home Relay client is connected to the relay server.
|
||||||
func (m *Manager) Ready() bool {
|
func (m *Manager) Ready() bool {
|
||||||
m.relayClientMu.Lock()
|
m.relayClientMu.RLock()
|
||||||
defer m.relayClientMu.Unlock()
|
defer m.relayClientMu.RUnlock()
|
||||||
|
|
||||||
if m.relayClient == nil {
|
if m.relayClient == nil {
|
||||||
return false
|
return false
|
||||||
@@ -174,8 +174,8 @@ func (m *Manager) SetOnReconnectedListener(f func()) {
|
|||||||
// AddCloseListener adds a listener to the given server instance address. The listener will be called if the connection
|
// AddCloseListener adds a listener to the given server instance address. The listener will be called if the connection
|
||||||
// closed.
|
// closed.
|
||||||
func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error {
|
func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error {
|
||||||
m.relayClientMu.Lock()
|
m.relayClientMu.RLock()
|
||||||
defer m.relayClientMu.Unlock()
|
defer m.relayClientMu.RUnlock()
|
||||||
|
|
||||||
if m.relayClient == nil {
|
if m.relayClient == nil {
|
||||||
return ErrRelayClientNotConnected
|
return ErrRelayClientNotConnected
|
||||||
@@ -199,8 +199,8 @@ func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServ
|
|||||||
// RelayInstanceAddress returns the address of the permanent relay server. It could change if the network connection is
|
// RelayInstanceAddress returns the address of the permanent relay server. It could change if the network connection is
|
||||||
// lost. This address will be sent to the target peer to choose the common relay server for the communication.
|
// lost. This address will be sent to the target peer to choose the common relay server for the communication.
|
||||||
func (m *Manager) RelayInstanceAddress() (string, error) {
|
func (m *Manager) RelayInstanceAddress() (string, error) {
|
||||||
m.relayClientMu.Lock()
|
m.relayClientMu.RLock()
|
||||||
defer m.relayClientMu.Unlock()
|
defer m.relayClientMu.RUnlock()
|
||||||
|
|
||||||
if m.relayClient == nil {
|
if m.relayClient == nil {
|
||||||
return "", ErrRelayClientNotConnected
|
return "", ErrRelayClientNotConnected
|
||||||
@@ -229,7 +229,7 @@ func (m *Manager) UpdateToken(token *relayAuth.Token) error {
|
|||||||
return m.tokenStore.UpdateToken(token)
|
return m.tokenStore.UpdateToken(token)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) {
|
func (m *Manager) openConnVia(ctx context.Context, serverAddress, peerKey string) (net.Conn, error) {
|
||||||
// check if already has a connection to the desired relay server
|
// check if already has a connection to the desired relay server
|
||||||
m.relayClientsMutex.RLock()
|
m.relayClientsMutex.RLock()
|
||||||
rt, ok := m.relayClients[serverAddress]
|
rt, ok := m.relayClients[serverAddress]
|
||||||
@@ -240,7 +240,7 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) {
|
|||||||
if rt.err != nil {
|
if rt.err != nil {
|
||||||
return nil, rt.err
|
return nil, rt.err
|
||||||
}
|
}
|
||||||
return rt.relayClient.OpenConn(peerKey)
|
return rt.relayClient.OpenConn(ctx, peerKey)
|
||||||
}
|
}
|
||||||
m.relayClientsMutex.RUnlock()
|
m.relayClientsMutex.RUnlock()
|
||||||
|
|
||||||
@@ -255,7 +255,7 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) {
|
|||||||
if rt.err != nil {
|
if rt.err != nil {
|
||||||
return nil, rt.err
|
return nil, rt.err
|
||||||
}
|
}
|
||||||
return rt.relayClient.OpenConn(peerKey)
|
return rt.relayClient.OpenConn(ctx, peerKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
// create a new relay client and store it in the relayClients map
|
// create a new relay client and store it in the relayClients map
|
||||||
@@ -264,8 +264,8 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) {
|
|||||||
m.relayClients[serverAddress] = rt
|
m.relayClients[serverAddress] = rt
|
||||||
m.relayClientsMutex.Unlock()
|
m.relayClientsMutex.Unlock()
|
||||||
|
|
||||||
relayClient := NewClient(m.ctx, serverAddress, m.tokenStore, m.peerID)
|
relayClient := NewClient(serverAddress, m.tokenStore, m.peerID)
|
||||||
err := relayClient.Connect()
|
err := relayClient.Connect(m.ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
rt.err = err
|
rt.err = err
|
||||||
rt.Unlock()
|
rt.Unlock()
|
||||||
@@ -279,7 +279,7 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) {
|
|||||||
rt.relayClient = relayClient
|
rt.relayClient = relayClient
|
||||||
rt.Unlock()
|
rt.Unlock()
|
||||||
|
|
||||||
conn, err := relayClient.OpenConn(peerKey)
|
conn, err := relayClient.OpenConn(ctx, peerKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -300,7 +300,9 @@ func (m *Manager) onServerConnected() {
|
|||||||
func (m *Manager) onServerDisconnected(serverAddress string) {
|
func (m *Manager) onServerDisconnected(serverAddress string) {
|
||||||
m.relayClientMu.Lock()
|
m.relayClientMu.Lock()
|
||||||
if serverAddress == m.relayClient.connectionURL {
|
if serverAddress == m.relayClient.connectionURL {
|
||||||
go m.reconnectGuard.StartReconnectTrys(m.ctx, m.relayClient)
|
go func(client *Client) {
|
||||||
|
m.reconnectGuard.StartReconnectTrys(m.ctx, client)
|
||||||
|
}(m.relayClient)
|
||||||
}
|
}
|
||||||
m.relayClientMu.Unlock()
|
m.relayClientMu.Unlock()
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"go.opentelemetry.io/otel"
|
"go.opentelemetry.io/otel"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/relay/auth/allow"
|
||||||
"github.com/netbirdio/netbird/relay/server"
|
"github.com/netbirdio/netbird/relay/server"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -22,16 +23,22 @@ func TestEmptyURL(t *testing.T) {
|
|||||||
func TestForeignConn(t *testing.T) {
|
func TestForeignConn(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
srvCfg1 := server.ListenerConfig{
|
lstCfg1 := server.ListenerConfig{
|
||||||
Address: "localhost:1234",
|
Address: "localhost:1234",
|
||||||
}
|
}
|
||||||
srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av)
|
|
||||||
|
srv1, err := server.NewServer(server.Config{
|
||||||
|
Meter: otel.Meter(""),
|
||||||
|
ExposedAddress: lstCfg1.Address,
|
||||||
|
TLSSupport: false,
|
||||||
|
AuthValidator: &allow.Auth{},
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create server: %s", err)
|
t.Fatalf("failed to create server: %s", err)
|
||||||
}
|
}
|
||||||
errChan := make(chan error, 1)
|
errChan := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
err := srv1.Listen(srvCfg1)
|
err := srv1.Listen(lstCfg1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errChan <- err
|
errChan <- err
|
||||||
}
|
}
|
||||||
@@ -51,7 +58,12 @@ func TestForeignConn(t *testing.T) {
|
|||||||
srvCfg2 := server.ListenerConfig{
|
srvCfg2 := server.ListenerConfig{
|
||||||
Address: "localhost:2234",
|
Address: "localhost:2234",
|
||||||
}
|
}
|
||||||
srv2, err := server.NewServer(otel.Meter(""), srvCfg2.Address, false, av)
|
srv2, err := server.NewServer(server.Config{
|
||||||
|
Meter: otel.Meter(""),
|
||||||
|
ExposedAddress: srvCfg2.Address,
|
||||||
|
TLSSupport: false,
|
||||||
|
AuthValidator: &allow.Auth{},
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create server: %s", err)
|
t.Fatalf("failed to create server: %s", err)
|
||||||
}
|
}
|
||||||
@@ -74,32 +86,26 @@ func TestForeignConn(t *testing.T) {
|
|||||||
t.Fatalf("failed to start server: %s", err)
|
t.Fatalf("failed to start server: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
idAlice := "alice"
|
|
||||||
log.Debugf("connect by alice")
|
|
||||||
mCtx, cancel := context.WithCancel(ctx)
|
mCtx, cancel := context.WithCancel(ctx)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
clientAlice := NewManager(mCtx, toURL(srvCfg1), idAlice)
|
clientAlice := NewManager(mCtx, toURL(lstCfg1), "alice")
|
||||||
err = clientAlice.Serve()
|
if err := clientAlice.Serve(); err != nil {
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to serve manager: %s", err)
|
t.Fatalf("failed to serve manager: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
idBob := "bob"
|
clientBob := NewManager(mCtx, toURL(srvCfg2), "bob")
|
||||||
log.Debugf("connect by bob")
|
if err := clientBob.Serve(); err != nil {
|
||||||
clientBob := NewManager(mCtx, toURL(srvCfg2), idBob)
|
|
||||||
err = clientBob.Serve()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to serve manager: %s", err)
|
t.Fatalf("failed to serve manager: %s", err)
|
||||||
}
|
}
|
||||||
bobsSrvAddr, err := clientBob.RelayInstanceAddress()
|
bobsSrvAddr, err := clientBob.RelayInstanceAddress()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to get relay address: %s", err)
|
t.Fatalf("failed to get relay address: %s", err)
|
||||||
}
|
}
|
||||||
connAliceToBob, err := clientAlice.OpenConn(bobsSrvAddr, idBob)
|
connAliceToBob, err := clientAlice.OpenConn(ctx, bobsSrvAddr, "bob")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to bind channel: %s", err)
|
t.Fatalf("failed to bind channel: %s", err)
|
||||||
}
|
}
|
||||||
connBobToAlice, err := clientBob.OpenConn(bobsSrvAddr, idAlice)
|
connBobToAlice, err := clientBob.OpenConn(ctx, bobsSrvAddr, "alice")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to bind channel: %s", err)
|
t.Fatalf("failed to bind channel: %s", err)
|
||||||
}
|
}
|
||||||
@@ -137,7 +143,7 @@ func TestForeginConnClose(t *testing.T) {
|
|||||||
srvCfg1 := server.ListenerConfig{
|
srvCfg1 := server.ListenerConfig{
|
||||||
Address: "localhost:1234",
|
Address: "localhost:1234",
|
||||||
}
|
}
|
||||||
srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av)
|
srv1, err := server.NewServer(serverCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create server: %s", err)
|
t.Fatalf("failed to create server: %s", err)
|
||||||
}
|
}
|
||||||
@@ -163,7 +169,7 @@ func TestForeginConnClose(t *testing.T) {
|
|||||||
srvCfg2 := server.ListenerConfig{
|
srvCfg2 := server.ListenerConfig{
|
||||||
Address: "localhost:2234",
|
Address: "localhost:2234",
|
||||||
}
|
}
|
||||||
srv2, err := server.NewServer(otel.Meter(""), srvCfg2.Address, false, av)
|
srv2, err := server.NewServer(serverCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create server: %s", err)
|
t.Fatalf("failed to create server: %s", err)
|
||||||
}
|
}
|
||||||
@@ -186,16 +192,20 @@ func TestForeginConnClose(t *testing.T) {
|
|||||||
t.Fatalf("failed to start server: %s", err)
|
t.Fatalf("failed to start server: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
idAlice := "alice"
|
|
||||||
log.Debugf("connect by alice")
|
|
||||||
mCtx, cancel := context.WithCancel(ctx)
|
mCtx, cancel := context.WithCancel(ctx)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
mgr := NewManager(mCtx, toURL(srvCfg1), idAlice)
|
|
||||||
|
mgrBob := NewManager(mCtx, toURL(srvCfg2), "bob")
|
||||||
|
if err := mgrBob.Serve(); err != nil {
|
||||||
|
t.Fatalf("failed to serve manager: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr := NewManager(mCtx, toURL(srvCfg1), "alice")
|
||||||
err = mgr.Serve()
|
err = mgr.Serve()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to serve manager: %s", err)
|
t.Fatalf("failed to serve manager: %s", err)
|
||||||
}
|
}
|
||||||
conn, err := mgr.OpenConn(toURL(srvCfg2)[0], "anotherpeer")
|
conn, err := mgr.OpenConn(ctx, toURL(srvCfg2)[0], "bob")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to bind channel: %s", err)
|
t.Fatalf("failed to bind channel: %s", err)
|
||||||
}
|
}
|
||||||
@@ -212,23 +222,21 @@ func TestForeginAutoClose(t *testing.T) {
|
|||||||
srvCfg1 := server.ListenerConfig{
|
srvCfg1 := server.ListenerConfig{
|
||||||
Address: "localhost:1234",
|
Address: "localhost:1234",
|
||||||
}
|
}
|
||||||
srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av)
|
srv1, err := server.NewServer(serverCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create server: %s", err)
|
t.Fatalf("failed to create server: %s", err)
|
||||||
}
|
}
|
||||||
errChan := make(chan error, 1)
|
errChan := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
t.Log("binding server 1.")
|
t.Log("binding server 1.")
|
||||||
err := srv1.Listen(srvCfg1)
|
if err := srv1.Listen(srvCfg1); err != nil {
|
||||||
if err != nil {
|
|
||||||
errChan <- err
|
errChan <- err
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
t.Logf("closing server 1.")
|
t.Logf("closing server 1.")
|
||||||
err := srv1.Shutdown(ctx)
|
if err := srv1.Shutdown(ctx); err != nil {
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to close server: %s", err)
|
t.Errorf("failed to close server: %s", err)
|
||||||
}
|
}
|
||||||
t.Logf("server 1. closed")
|
t.Logf("server 1. closed")
|
||||||
@@ -241,7 +249,7 @@ func TestForeginAutoClose(t *testing.T) {
|
|||||||
srvCfg2 := server.ListenerConfig{
|
srvCfg2 := server.ListenerConfig{
|
||||||
Address: "localhost:2234",
|
Address: "localhost:2234",
|
||||||
}
|
}
|
||||||
srv2, err := server.NewServer(otel.Meter(""), srvCfg2.Address, false, av)
|
srv2, err := server.NewServer(serverCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create server: %s", err)
|
t.Fatalf("failed to create server: %s", err)
|
||||||
}
|
}
|
||||||
@@ -277,15 +285,8 @@ func TestForeginAutoClose(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
t.Log("open connection to another peer")
|
t.Log("open connection to another peer")
|
||||||
conn, err := mgr.OpenConn(toURL(srvCfg2)[0], "anotherpeer")
|
if _, err = mgr.OpenConn(ctx, toURL(srvCfg2)[0], "anotherpeer"); err == nil {
|
||||||
if err != nil {
|
t.Fatalf("should have failed to open connection to another peer")
|
||||||
t.Fatalf("failed to bind channel: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Log("close conn")
|
|
||||||
err = conn.Close()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to close connection: %s", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
timeout := relayCleanupInterval + keepUnusedServerTime + 1*time.Second
|
timeout := relayCleanupInterval + keepUnusedServerTime + 1*time.Second
|
||||||
@@ -305,7 +306,7 @@ func TestAutoReconnect(t *testing.T) {
|
|||||||
srvCfg := server.ListenerConfig{
|
srvCfg := server.ListenerConfig{
|
||||||
Address: "localhost:1234",
|
Address: "localhost:1234",
|
||||||
}
|
}
|
||||||
srv, err := server.NewServer(otel.Meter(""), srvCfg.Address, false, av)
|
srv, err := server.NewServer(serverCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create server: %s", err)
|
t.Fatalf("failed to create server: %s", err)
|
||||||
}
|
}
|
||||||
@@ -330,6 +331,13 @@ func TestAutoReconnect(t *testing.T) {
|
|||||||
|
|
||||||
mCtx, cancel := context.WithCancel(ctx)
|
mCtx, cancel := context.WithCancel(ctx)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
|
clientBob := NewManager(mCtx, toURL(srvCfg), "bob")
|
||||||
|
err = clientBob.Serve()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to serve manager: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
clientAlice := NewManager(mCtx, toURL(srvCfg), "alice")
|
clientAlice := NewManager(mCtx, toURL(srvCfg), "alice")
|
||||||
err = clientAlice.Serve()
|
err = clientAlice.Serve()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -339,7 +347,7 @@ func TestAutoReconnect(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to get relay address: %s", err)
|
t.Errorf("failed to get relay address: %s", err)
|
||||||
}
|
}
|
||||||
conn, err := clientAlice.OpenConn(ra, "bob")
|
conn, err := clientAlice.OpenConn(ctx, ra, "bob")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to bind channel: %s", err)
|
t.Errorf("failed to bind channel: %s", err)
|
||||||
}
|
}
|
||||||
@@ -357,7 +365,7 @@ func TestAutoReconnect(t *testing.T) {
|
|||||||
time.Sleep(reconnectingTimeout + 1*time.Second)
|
time.Sleep(reconnectingTimeout + 1*time.Second)
|
||||||
|
|
||||||
log.Infof("reopent the connection")
|
log.Infof("reopent the connection")
|
||||||
_, err = clientAlice.OpenConn(ra, "bob")
|
_, err = clientAlice.OpenConn(ctx, ra, "bob")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to open channel: %s", err)
|
t.Errorf("failed to open channel: %s", err)
|
||||||
}
|
}
|
||||||
@@ -366,24 +374,27 @@ func TestAutoReconnect(t *testing.T) {
|
|||||||
func TestNotifierDoubleAdd(t *testing.T) {
|
func TestNotifierDoubleAdd(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
srvCfg1 := server.ListenerConfig{
|
listenerCfg1 := server.ListenerConfig{
|
||||||
Address: "localhost:1234",
|
Address: "localhost:1234",
|
||||||
}
|
}
|
||||||
srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av)
|
srv, err := server.NewServer(server.Config{
|
||||||
|
Meter: otel.Meter(""),
|
||||||
|
ExposedAddress: listenerCfg1.Address,
|
||||||
|
TLSSupport: false,
|
||||||
|
AuthValidator: &allow.Auth{},
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create server: %s", err)
|
t.Fatalf("failed to create server: %s", err)
|
||||||
}
|
}
|
||||||
errChan := make(chan error, 1)
|
errChan := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
err := srv1.Listen(srvCfg1)
|
if err := srv.Listen(listenerCfg1); err != nil {
|
||||||
if err != nil {
|
|
||||||
errChan <- err
|
errChan <- err
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
err := srv1.Shutdown(ctx)
|
if err := srv.Shutdown(ctx); err != nil {
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to close server: %s", err)
|
t.Errorf("failed to close server: %s", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@@ -392,17 +403,21 @@ func TestNotifierDoubleAdd(t *testing.T) {
|
|||||||
t.Fatalf("failed to start server: %s", err)
|
t.Fatalf("failed to start server: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
idAlice := "alice"
|
|
||||||
log.Debugf("connect by alice")
|
log.Debugf("connect by alice")
|
||||||
mCtx, cancel := context.WithCancel(ctx)
|
mCtx, cancel := context.WithCancel(ctx)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
clientAlice := NewManager(mCtx, toURL(srvCfg1), idAlice)
|
|
||||||
err = clientAlice.Serve()
|
clientBob := NewManager(mCtx, toURL(listenerCfg1), "bob")
|
||||||
if err != nil {
|
if err = clientBob.Serve(); err != nil {
|
||||||
t.Fatalf("failed to serve manager: %s", err)
|
t.Fatalf("failed to serve manager: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
conn1, err := clientAlice.OpenConn(clientAlice.ServerURLs()[0], "idBob")
|
clientAlice := NewManager(mCtx, toURL(listenerCfg1), "alice")
|
||||||
|
if err = clientAlice.Serve(); err != nil {
|
||||||
|
t.Fatalf("failed to serve manager: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn1, err := clientAlice.OpenConn(ctx, clientAlice.ServerURLs()[0], "bob")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to bind channel: %s", err)
|
t.Fatalf("failed to bind channel: %s", err)
|
||||||
}
|
}
|
||||||
|
|||||||
191
relay/client/peer_subscription.go
Normal file
191
relay/client/peer_subscription.go
Normal file
@@ -0,0 +1,191 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/relay/messages"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
OpenConnectionTimeout = 30 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
type relayedConnWriter interface {
|
||||||
|
Write(p []byte) (n int, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PeersStateSubscription manages subscriptions to peer state changes (online/offline)
|
||||||
|
// over a relay connection. It allows tracking peers' availability and handling offline
|
||||||
|
// events via a callback. We get online notification from the server only once.
|
||||||
|
type PeersStateSubscription struct {
|
||||||
|
log *log.Entry
|
||||||
|
relayConn relayedConnWriter
|
||||||
|
offlineCallback func(peerIDs []messages.PeerID)
|
||||||
|
|
||||||
|
listenForOfflinePeers map[messages.PeerID]struct{}
|
||||||
|
waitingPeers map[messages.PeerID]chan struct{}
|
||||||
|
mu sync.Mutex // Mutex to protect access to waitingPeers and listenForOfflinePeers
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPeersStateSubscription(log *log.Entry, relayConn relayedConnWriter, offlineCallback func(peerIDs []messages.PeerID)) *PeersStateSubscription {
|
||||||
|
return &PeersStateSubscription{
|
||||||
|
log: log,
|
||||||
|
relayConn: relayConn,
|
||||||
|
offlineCallback: offlineCallback,
|
||||||
|
listenForOfflinePeers: make(map[messages.PeerID]struct{}),
|
||||||
|
waitingPeers: make(map[messages.PeerID]chan struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnPeersOnline should be called when a notification is received that certain peers have come online.
|
||||||
|
// It checks if any of the peers are being waited on and signals their availability.
|
||||||
|
func (s *PeersStateSubscription) OnPeersOnline(peersID []messages.PeerID) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
for _, peerID := range peersID {
|
||||||
|
waitCh, ok := s.waitingPeers[peerID]
|
||||||
|
if !ok {
|
||||||
|
// If meanwhile the peer was unsubscribed, we don't need to signal it
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
waitCh <- struct{}{}
|
||||||
|
delete(s.waitingPeers, peerID)
|
||||||
|
close(waitCh)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PeersStateSubscription) OnPeersWentOffline(peersID []messages.PeerID) {
|
||||||
|
s.mu.Lock()
|
||||||
|
relevantPeers := make([]messages.PeerID, 0, len(peersID))
|
||||||
|
for _, peerID := range peersID {
|
||||||
|
if _, ok := s.listenForOfflinePeers[peerID]; ok {
|
||||||
|
relevantPeers = append(relevantPeers, peerID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
if len(relevantPeers) > 0 {
|
||||||
|
s.offlineCallback(relevantPeers)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WaitToBeOnlineAndSubscribe waits for a specific peer to come online and subscribes to its state changes.
|
||||||
|
func (s *PeersStateSubscription) WaitToBeOnlineAndSubscribe(ctx context.Context, peerID messages.PeerID) error {
|
||||||
|
// Check if already waiting for this peer
|
||||||
|
s.mu.Lock()
|
||||||
|
if _, exists := s.waitingPeers[peerID]; exists {
|
||||||
|
s.mu.Unlock()
|
||||||
|
return errors.New("already waiting for peer to come online")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a channel to wait for the peer to come online
|
||||||
|
waitCh := make(chan struct{}, 1)
|
||||||
|
s.waitingPeers[peerID] = waitCh
|
||||||
|
s.listenForOfflinePeers[peerID] = struct{}{}
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
if err := s.subscribeStateChange(peerID); err != nil {
|
||||||
|
s.log.Errorf("failed to subscribe to peer state: %s", err)
|
||||||
|
s.mu.Lock()
|
||||||
|
if ch, exists := s.waitingPeers[peerID]; exists && ch == waitCh {
|
||||||
|
close(waitCh)
|
||||||
|
delete(s.waitingPeers, peerID)
|
||||||
|
delete(s.listenForOfflinePeers, peerID)
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for peer to come online or context to be cancelled
|
||||||
|
timeoutCtx, cancel := context.WithTimeout(ctx, OpenConnectionTimeout)
|
||||||
|
defer cancel()
|
||||||
|
select {
|
||||||
|
case _, ok := <-waitCh:
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("wait for peer to come online has been cancelled")
|
||||||
|
}
|
||||||
|
|
||||||
|
s.log.Debugf("peer %s is now online", peerID)
|
||||||
|
return nil
|
||||||
|
case <-timeoutCtx.Done():
|
||||||
|
s.log.Debugf("context timed out while waiting for peer %s to come online", peerID)
|
||||||
|
if err := s.unsubscribeStateChange([]messages.PeerID{peerID}); err != nil {
|
||||||
|
s.log.Errorf("failed to unsubscribe from peer state: %s", err)
|
||||||
|
}
|
||||||
|
s.mu.Lock()
|
||||||
|
if ch, exists := s.waitingPeers[peerID]; exists && ch == waitCh {
|
||||||
|
close(waitCh)
|
||||||
|
delete(s.waitingPeers, peerID)
|
||||||
|
delete(s.listenForOfflinePeers, peerID)
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
return timeoutCtx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PeersStateSubscription) UnsubscribeStateChange(peerIDs []messages.PeerID) error {
|
||||||
|
msgErr := s.unsubscribeStateChange(peerIDs)
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
for _, peerID := range peerIDs {
|
||||||
|
if wch, ok := s.waitingPeers[peerID]; ok {
|
||||||
|
close(wch)
|
||||||
|
delete(s.waitingPeers, peerID)
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(s.listenForOfflinePeers, peerID)
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
return msgErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PeersStateSubscription) Cleanup() {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
for _, waitCh := range s.waitingPeers {
|
||||||
|
close(waitCh)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.waitingPeers = make(map[messages.PeerID]chan struct{})
|
||||||
|
s.listenForOfflinePeers = make(map[messages.PeerID]struct{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PeersStateSubscription) subscribeStateChange(peerID messages.PeerID) error {
|
||||||
|
msgs, err := messages.MarshalSubPeerStateMsg([]messages.PeerID{peerID})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, msg := range msgs {
|
||||||
|
if _, err := s.relayConn.Write(msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PeersStateSubscription) unsubscribeStateChange(peerIDs []messages.PeerID) error {
|
||||||
|
msgs, err := messages.MarshalUnsubPeerStateMsg(peerIDs)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var connWriteErr error
|
||||||
|
for _, msg := range msgs {
|
||||||
|
if _, err := s.relayConn.Write(msg); err != nil {
|
||||||
|
connWriteErr = err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return connWriteErr
|
||||||
|
}
|
||||||
99
relay/client/peer_subscription_test.go
Normal file
99
relay/client/peer_subscription_test.go
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/relay/messages"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockRelayedConn struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockRelayedConn) Write(p []byte) (n int, err error) {
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWaitToBeOnlineAndSubscribe_Success(t *testing.T) {
|
||||||
|
peerID := messages.HashID("peer1")
|
||||||
|
mockConn := &mockRelayedConn{}
|
||||||
|
logger := logrus.New()
|
||||||
|
logger.SetOutput(&bytes.Buffer{}) // discard log output
|
||||||
|
sub := NewPeersStateSubscription(logrus.NewEntry(logger), mockConn, nil)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Launch wait in background
|
||||||
|
go func() {
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
sub.OnPeersOnline([]messages.PeerID{peerID})
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := sub.WaitToBeOnlineAndSubscribe(ctx, peerID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWaitToBeOnlineAndSubscribe_Timeout(t *testing.T) {
|
||||||
|
peerID := messages.HashID("peer2")
|
||||||
|
mockConn := &mockRelayedConn{}
|
||||||
|
logger := logrus.New()
|
||||||
|
logger.SetOutput(&bytes.Buffer{})
|
||||||
|
sub := NewPeersStateSubscription(logrus.NewEntry(logger), mockConn, nil)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
err := sub.WaitToBeOnlineAndSubscribe(ctx, peerID)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Equal(t, context.DeadlineExceeded, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWaitToBeOnlineAndSubscribe_Duplicate(t *testing.T) {
|
||||||
|
peerID := messages.HashID("peer3")
|
||||||
|
mockConn := &mockRelayedConn{}
|
||||||
|
logger := logrus.New()
|
||||||
|
logger.SetOutput(&bytes.Buffer{})
|
||||||
|
sub := NewPeersStateSubscription(logrus.NewEntry(logger), mockConn, nil)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
go func() {
|
||||||
|
_ = sub.WaitToBeOnlineAndSubscribe(ctx, peerID)
|
||||||
|
|
||||||
|
}()
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
err := sub.WaitToBeOnlineAndSubscribe(ctx, peerID)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "already waiting")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnsubscribeStateChange(t *testing.T) {
|
||||||
|
peerID := messages.HashID("peer4")
|
||||||
|
mockConn := &mockRelayedConn{}
|
||||||
|
logger := logrus.New()
|
||||||
|
logger.SetOutput(&bytes.Buffer{})
|
||||||
|
sub := NewPeersStateSubscription(logrus.NewEntry(logger), mockConn, nil)
|
||||||
|
|
||||||
|
doneChan := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
_ = sub.WaitToBeOnlineAndSubscribe(context.Background(), peerID)
|
||||||
|
close(doneChan)
|
||||||
|
}()
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
err := sub.UnsubscribeStateChange([]messages.PeerID{peerID})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-doneChan:
|
||||||
|
case <-time.After(200 * time.Millisecond):
|
||||||
|
// Expected timeout, meaning the subscription was successfully unsubscribed
|
||||||
|
t.Errorf("timeout")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -70,8 +70,8 @@ func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) {
|
|||||||
|
|
||||||
func (sp *ServerPicker) startConnection(ctx context.Context, resultChan chan connResult, url string) {
|
func (sp *ServerPicker) startConnection(ctx context.Context, resultChan chan connResult, url string) {
|
||||||
log.Infof("try to connecting to relay server: %s", url)
|
log.Infof("try to connecting to relay server: %s", url)
|
||||||
relayClient := NewClient(ctx, url, sp.TokenStore, sp.PeerID)
|
relayClient := NewClient(url, sp.TokenStore, sp.PeerID)
|
||||||
err := relayClient.Connect()
|
err := relayClient.Connect(ctx)
|
||||||
resultChan <- connResult{
|
resultChan <- connResult{
|
||||||
RelayClient: relayClient,
|
RelayClient: relayClient,
|
||||||
Url: url,
|
Url: url,
|
||||||
|
|||||||
@@ -141,7 +141,14 @@ func execute(cmd *cobra.Command, args []string) error {
|
|||||||
hashedSecret := sha256.Sum256([]byte(cobraConfig.AuthSecret))
|
hashedSecret := sha256.Sum256([]byte(cobraConfig.AuthSecret))
|
||||||
authenticator := auth.NewTimedHMACValidator(hashedSecret[:], 24*time.Hour)
|
authenticator := auth.NewTimedHMACValidator(hashedSecret[:], 24*time.Hour)
|
||||||
|
|
||||||
srv, err := server.NewServer(metricsServer.Meter, cobraConfig.ExposedAddress, tlsSupport, authenticator)
|
cfg := server.Config{
|
||||||
|
Meter: metricsServer.Meter,
|
||||||
|
ExposedAddress: cobraConfig.ExposedAddress,
|
||||||
|
AuthValidator: authenticator,
|
||||||
|
TLSSupport: tlsSupport,
|
||||||
|
}
|
||||||
|
|
||||||
|
srv, err := server.NewServer(cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("failed to create relay server: %v", err)
|
log.Debugf("failed to create relay server: %v", err)
|
||||||
return fmt.Errorf("failed to create relay server: %v", err)
|
return fmt.Errorf("failed to create relay server: %v", err)
|
||||||
|
|||||||
@@ -8,24 +8,24 @@ import (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
prefixLength = 4
|
prefixLength = 4
|
||||||
IDSize = prefixLength + sha256.Size
|
peerIDSize = prefixLength + sha256.Size
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
prefix = []byte("sha-") // 4 bytes
|
prefix = []byte("sha-") // 4 bytes
|
||||||
)
|
)
|
||||||
|
|
||||||
// HashID generates a sha256 hash from the peerID and returns the hash and the human-readable string
|
type PeerID [peerIDSize]byte
|
||||||
func HashID(peerID string) ([]byte, string) {
|
|
||||||
idHash := sha256.Sum256([]byte(peerID))
|
func (p PeerID) String() string {
|
||||||
idHashString := string(prefix) + base64.StdEncoding.EncodeToString(idHash[:])
|
return fmt.Sprintf("%s%s", p[:prefixLength], base64.StdEncoding.EncodeToString(p[prefixLength:]))
|
||||||
var prefixedHash []byte
|
|
||||||
prefixedHash = append(prefixedHash, prefix...)
|
|
||||||
prefixedHash = append(prefixedHash, idHash[:]...)
|
|
||||||
return prefixedHash, idHashString
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// HashIDToString converts a hash to a human-readable string
|
// HashID generates a sha256 hash from the peerID and returns the hash and the human-readable string
|
||||||
func HashIDToString(idHash []byte) string {
|
func HashID(peerID string) PeerID {
|
||||||
return fmt.Sprintf("%s%s", idHash[:prefixLength], base64.StdEncoding.EncodeToString(idHash[prefixLength:]))
|
idHash := sha256.Sum256([]byte(peerID))
|
||||||
|
var prefixedHash [peerIDSize]byte
|
||||||
|
copy(prefixedHash[:prefixLength], prefix)
|
||||||
|
copy(prefixedHash[prefixLength:], idHash[:])
|
||||||
|
return prefixedHash
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,13 +0,0 @@
|
|||||||
package messages
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestHashID(t *testing.T) {
|
|
||||||
hashedID, hashedStringId := HashID("alice")
|
|
||||||
enc := HashIDToString(hashedID)
|
|
||||||
if enc != hashedStringId {
|
|
||||||
t.Errorf("expected %s, got %s", hashedStringId, enc)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -9,19 +9,26 @@ import (
|
|||||||
const (
|
const (
|
||||||
MaxHandshakeSize = 212
|
MaxHandshakeSize = 212
|
||||||
MaxHandshakeRespSize = 8192
|
MaxHandshakeRespSize = 8192
|
||||||
|
MaxMessageSize = 8820
|
||||||
|
|
||||||
CurrentProtocolVersion = 1
|
CurrentProtocolVersion = 1
|
||||||
|
|
||||||
MsgTypeUnknown MsgType = 0
|
MsgTypeUnknown MsgType = 0
|
||||||
// Deprecated: Use MsgTypeAuth instead.
|
// Deprecated: Use MsgTypeAuth instead.
|
||||||
MsgTypeHello MsgType = 1
|
MsgTypeHello = 1
|
||||||
// Deprecated: Use MsgTypeAuthResponse instead.
|
// Deprecated: Use MsgTypeAuthResponse instead.
|
||||||
MsgTypeHelloResponse MsgType = 2
|
MsgTypeHelloResponse = 2
|
||||||
MsgTypeTransport MsgType = 3
|
MsgTypeTransport = 3
|
||||||
MsgTypeClose MsgType = 4
|
MsgTypeClose = 4
|
||||||
MsgTypeHealthCheck MsgType = 5
|
MsgTypeHealthCheck = 5
|
||||||
MsgTypeAuth = 6
|
MsgTypeAuth = 6
|
||||||
MsgTypeAuthResponse = 7
|
MsgTypeAuthResponse = 7
|
||||||
|
|
||||||
|
// Peers state messages
|
||||||
|
MsgTypeSubscribePeerState = 8
|
||||||
|
MsgTypeUnsubscribePeerState = 9
|
||||||
|
MsgTypePeersOnline = 10
|
||||||
|
MsgTypePeersWentOffline = 11
|
||||||
|
|
||||||
// base size of the message
|
// base size of the message
|
||||||
sizeOfVersionByte = 1
|
sizeOfVersionByte = 1
|
||||||
@@ -30,17 +37,17 @@ const (
|
|||||||
|
|
||||||
// auth message
|
// auth message
|
||||||
sizeOfMagicByte = 4
|
sizeOfMagicByte = 4
|
||||||
headerSizeAuth = sizeOfMagicByte + IDSize
|
headerSizeAuth = sizeOfMagicByte + peerIDSize
|
||||||
offsetMagicByte = sizeOfProtoHeader
|
offsetMagicByte = sizeOfProtoHeader
|
||||||
offsetAuthPeerID = sizeOfProtoHeader + sizeOfMagicByte
|
offsetAuthPeerID = sizeOfProtoHeader + sizeOfMagicByte
|
||||||
headerTotalSizeAuth = sizeOfProtoHeader + headerSizeAuth
|
headerTotalSizeAuth = sizeOfProtoHeader + headerSizeAuth
|
||||||
|
|
||||||
// hello message
|
// hello message
|
||||||
headerSizeHello = sizeOfMagicByte + IDSize
|
headerSizeHello = sizeOfMagicByte + peerIDSize
|
||||||
headerSizeHelloResp = 0
|
headerSizeHelloResp = 0
|
||||||
|
|
||||||
// transport
|
// transport
|
||||||
headerSizeTransport = IDSize
|
headerSizeTransport = peerIDSize
|
||||||
offsetTransportID = sizeOfProtoHeader
|
offsetTransportID = sizeOfProtoHeader
|
||||||
headerTotalSizeTransport = sizeOfProtoHeader + headerSizeTransport
|
headerTotalSizeTransport = sizeOfProtoHeader + headerSizeTransport
|
||||||
)
|
)
|
||||||
@@ -72,6 +79,14 @@ func (m MsgType) String() string {
|
|||||||
return "close"
|
return "close"
|
||||||
case MsgTypeHealthCheck:
|
case MsgTypeHealthCheck:
|
||||||
return "health check"
|
return "health check"
|
||||||
|
case MsgTypeSubscribePeerState:
|
||||||
|
return "subscribe peer state"
|
||||||
|
case MsgTypeUnsubscribePeerState:
|
||||||
|
return "unsubscribe peer state"
|
||||||
|
case MsgTypePeersOnline:
|
||||||
|
return "peers online"
|
||||||
|
case MsgTypePeersWentOffline:
|
||||||
|
return "peers went offline"
|
||||||
default:
|
default:
|
||||||
return "unknown"
|
return "unknown"
|
||||||
}
|
}
|
||||||
@@ -102,7 +117,9 @@ func DetermineClientMessageType(msg []byte) (MsgType, error) {
|
|||||||
MsgTypeAuth,
|
MsgTypeAuth,
|
||||||
MsgTypeTransport,
|
MsgTypeTransport,
|
||||||
MsgTypeClose,
|
MsgTypeClose,
|
||||||
MsgTypeHealthCheck:
|
MsgTypeHealthCheck,
|
||||||
|
MsgTypeSubscribePeerState,
|
||||||
|
MsgTypeUnsubscribePeerState:
|
||||||
return msgType, nil
|
return msgType, nil
|
||||||
default:
|
default:
|
||||||
return MsgTypeUnknown, fmt.Errorf("invalid msg type %d", msgType)
|
return MsgTypeUnknown, fmt.Errorf("invalid msg type %d", msgType)
|
||||||
@@ -122,7 +139,9 @@ func DetermineServerMessageType(msg []byte) (MsgType, error) {
|
|||||||
MsgTypeAuthResponse,
|
MsgTypeAuthResponse,
|
||||||
MsgTypeTransport,
|
MsgTypeTransport,
|
||||||
MsgTypeClose,
|
MsgTypeClose,
|
||||||
MsgTypeHealthCheck:
|
MsgTypeHealthCheck,
|
||||||
|
MsgTypePeersOnline,
|
||||||
|
MsgTypePeersWentOffline:
|
||||||
return msgType, nil
|
return msgType, nil
|
||||||
default:
|
default:
|
||||||
return MsgTypeUnknown, fmt.Errorf("invalid msg type %d", msgType)
|
return MsgTypeUnknown, fmt.Errorf("invalid msg type %d", msgType)
|
||||||
@@ -135,11 +154,7 @@ func DetermineServerMessageType(msg []byte) (MsgType, error) {
|
|||||||
// message is used to authenticate the client with the server. The authentication is done using an HMAC method.
|
// message is used to authenticate the client with the server. The authentication is done using an HMAC method.
|
||||||
// The protocol does not limit to use HMAC, it can be any other method. If the authentication failed the server will
|
// The protocol does not limit to use HMAC, it can be any other method. If the authentication failed the server will
|
||||||
// close the network connection without any response.
|
// close the network connection without any response.
|
||||||
func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) {
|
func MarshalHelloMsg(peerID PeerID, additions []byte) ([]byte, error) {
|
||||||
if len(peerID) != IDSize {
|
|
||||||
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := make([]byte, sizeOfProtoHeader+sizeOfMagicByte, sizeOfProtoHeader+headerSizeHello+len(additions))
|
msg := make([]byte, sizeOfProtoHeader+sizeOfMagicByte, sizeOfProtoHeader+headerSizeHello+len(additions))
|
||||||
|
|
||||||
msg[0] = byte(CurrentProtocolVersion)
|
msg[0] = byte(CurrentProtocolVersion)
|
||||||
@@ -147,7 +162,7 @@ func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) {
|
|||||||
|
|
||||||
copy(msg[sizeOfProtoHeader:sizeOfProtoHeader+sizeOfMagicByte], magicHeader)
|
copy(msg[sizeOfProtoHeader:sizeOfProtoHeader+sizeOfMagicByte], magicHeader)
|
||||||
|
|
||||||
msg = append(msg, peerID...)
|
msg = append(msg, peerID[:]...)
|
||||||
msg = append(msg, additions...)
|
msg = append(msg, additions...)
|
||||||
|
|
||||||
return msg, nil
|
return msg, nil
|
||||||
@@ -156,7 +171,7 @@ func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) {
|
|||||||
// Deprecated: Use UnmarshalAuthMsg instead.
|
// Deprecated: Use UnmarshalAuthMsg instead.
|
||||||
// UnmarshalHelloMsg extracts peerID and the additional data from the hello message. The Additional data is used to
|
// UnmarshalHelloMsg extracts peerID and the additional data from the hello message. The Additional data is used to
|
||||||
// authenticate the client with the server.
|
// authenticate the client with the server.
|
||||||
func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) {
|
func UnmarshalHelloMsg(msg []byte) (*PeerID, []byte, error) {
|
||||||
if len(msg) < sizeOfProtoHeader+headerSizeHello {
|
if len(msg) < sizeOfProtoHeader+headerSizeHello {
|
||||||
return nil, nil, ErrInvalidMessageLength
|
return nil, nil, ErrInvalidMessageLength
|
||||||
}
|
}
|
||||||
@@ -164,7 +179,9 @@ func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) {
|
|||||||
return nil, nil, errors.New("invalid magic header")
|
return nil, nil, errors.New("invalid magic header")
|
||||||
}
|
}
|
||||||
|
|
||||||
return msg[sizeOfProtoHeader+sizeOfMagicByte : sizeOfProtoHeader+headerSizeHello], msg[headerSizeHello:], nil
|
peerID := PeerID(msg[sizeOfProtoHeader+sizeOfMagicByte : sizeOfProtoHeader+headerSizeHello])
|
||||||
|
|
||||||
|
return &peerID, msg[headerSizeHello:], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deprecated: Use MarshalAuthResponse instead.
|
// Deprecated: Use MarshalAuthResponse instead.
|
||||||
@@ -197,34 +214,33 @@ func UnmarshalHelloResponse(msg []byte) ([]byte, error) {
|
|||||||
// message is used to authenticate the client with the server. The authentication is done using an HMAC method.
|
// message is used to authenticate the client with the server. The authentication is done using an HMAC method.
|
||||||
// The protocol does not limit to use HMAC, it can be any other method. If the authentication failed the server will
|
// The protocol does not limit to use HMAC, it can be any other method. If the authentication failed the server will
|
||||||
// close the network connection without any response.
|
// close the network connection without any response.
|
||||||
func MarshalAuthMsg(peerID []byte, authPayload []byte) ([]byte, error) {
|
func MarshalAuthMsg(peerID PeerID, authPayload []byte) ([]byte, error) {
|
||||||
if len(peerID) != IDSize {
|
if headerTotalSizeAuth+len(authPayload) > MaxHandshakeSize {
|
||||||
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
|
return nil, fmt.Errorf("too large auth payload")
|
||||||
}
|
}
|
||||||
|
|
||||||
msg := make([]byte, sizeOfProtoHeader+sizeOfMagicByte, headerTotalSizeAuth+len(authPayload))
|
msg := make([]byte, headerTotalSizeAuth+len(authPayload))
|
||||||
|
|
||||||
msg[0] = byte(CurrentProtocolVersion)
|
msg[0] = byte(CurrentProtocolVersion)
|
||||||
msg[1] = byte(MsgTypeAuth)
|
msg[1] = byte(MsgTypeAuth)
|
||||||
|
|
||||||
copy(msg[sizeOfProtoHeader:], magicHeader)
|
copy(msg[sizeOfProtoHeader:], magicHeader)
|
||||||
|
copy(msg[offsetAuthPeerID:], peerID[:])
|
||||||
msg = append(msg, peerID...)
|
copy(msg[headerTotalSizeAuth:], authPayload)
|
||||||
msg = append(msg, authPayload...)
|
|
||||||
|
|
||||||
return msg, nil
|
return msg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UnmarshalAuthMsg extracts peerID and the auth payload from the message
|
// UnmarshalAuthMsg extracts peerID and the auth payload from the message
|
||||||
func UnmarshalAuthMsg(msg []byte) ([]byte, []byte, error) {
|
func UnmarshalAuthMsg(msg []byte) (*PeerID, []byte, error) {
|
||||||
if len(msg) < headerTotalSizeAuth {
|
if len(msg) < headerTotalSizeAuth {
|
||||||
return nil, nil, ErrInvalidMessageLength
|
return nil, nil, ErrInvalidMessageLength
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate the magic header
|
||||||
if !bytes.Equal(msg[offsetMagicByte:offsetMagicByte+sizeOfMagicByte], magicHeader) {
|
if !bytes.Equal(msg[offsetMagicByte:offsetMagicByte+sizeOfMagicByte], magicHeader) {
|
||||||
return nil, nil, errors.New("invalid magic header")
|
return nil, nil, errors.New("invalid magic header")
|
||||||
}
|
}
|
||||||
|
|
||||||
return msg[offsetAuthPeerID:headerTotalSizeAuth], msg[headerTotalSizeAuth:], nil
|
peerID := PeerID(msg[offsetAuthPeerID:headerTotalSizeAuth])
|
||||||
|
return &peerID, msg[headerTotalSizeAuth:], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalAuthResponse creates a response message to the auth.
|
// MarshalAuthResponse creates a response message to the auth.
|
||||||
@@ -268,45 +284,48 @@ func MarshalCloseMsg() []byte {
|
|||||||
// MarshalTransportMsg creates a transport message.
|
// MarshalTransportMsg creates a transport message.
|
||||||
// The transport message is used to exchange data between peers. The message contains the data to be exchanged and the
|
// The transport message is used to exchange data between peers. The message contains the data to be exchanged and the
|
||||||
// destination peer hashed ID.
|
// destination peer hashed ID.
|
||||||
func MarshalTransportMsg(peerID, payload []byte) ([]byte, error) {
|
func MarshalTransportMsg(peerID PeerID, payload []byte) ([]byte, error) {
|
||||||
if len(peerID) != IDSize {
|
// todo validate size
|
||||||
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
|
msg := make([]byte, headerTotalSizeTransport+len(payload))
|
||||||
}
|
|
||||||
|
|
||||||
msg := make([]byte, headerTotalSizeTransport, headerTotalSizeTransport+len(payload))
|
|
||||||
msg[0] = byte(CurrentProtocolVersion)
|
msg[0] = byte(CurrentProtocolVersion)
|
||||||
msg[1] = byte(MsgTypeTransport)
|
msg[1] = byte(MsgTypeTransport)
|
||||||
copy(msg[sizeOfProtoHeader:], peerID)
|
copy(msg[sizeOfProtoHeader:], peerID[:])
|
||||||
msg = append(msg, payload...)
|
copy(msg[sizeOfProtoHeader+peerIDSize:], payload)
|
||||||
|
|
||||||
return msg, nil
|
return msg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UnmarshalTransportMsg extracts the peerID and the payload from the transport message.
|
// UnmarshalTransportMsg extracts the peerID and the payload from the transport message.
|
||||||
func UnmarshalTransportMsg(buf []byte) ([]byte, []byte, error) {
|
func UnmarshalTransportMsg(buf []byte) (*PeerID, []byte, error) {
|
||||||
if len(buf) < headerTotalSizeTransport {
|
if len(buf) < headerTotalSizeTransport {
|
||||||
return nil, nil, ErrInvalidMessageLength
|
return nil, nil, ErrInvalidMessageLength
|
||||||
}
|
}
|
||||||
|
|
||||||
return buf[offsetTransportID:headerTotalSizeTransport], buf[headerTotalSizeTransport:], nil
|
const offsetEnd = offsetTransportID + peerIDSize
|
||||||
|
var peerID PeerID
|
||||||
|
copy(peerID[:], buf[offsetTransportID:offsetEnd])
|
||||||
|
return &peerID, buf[headerTotalSizeTransport:], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UnmarshalTransportID extracts the peerID from the transport message.
|
// UnmarshalTransportID extracts the peerID from the transport message.
|
||||||
func UnmarshalTransportID(buf []byte) ([]byte, error) {
|
func UnmarshalTransportID(buf []byte) (*PeerID, error) {
|
||||||
if len(buf) < headerTotalSizeTransport {
|
if len(buf) < headerTotalSizeTransport {
|
||||||
return nil, ErrInvalidMessageLength
|
return nil, ErrInvalidMessageLength
|
||||||
}
|
}
|
||||||
return buf[offsetTransportID:headerTotalSizeTransport], nil
|
|
||||||
|
const offsetEnd = offsetTransportID + peerIDSize
|
||||||
|
var id PeerID
|
||||||
|
copy(id[:], buf[offsetTransportID:offsetEnd])
|
||||||
|
return &id, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateTransportMsg updates the peerID in the transport message.
|
// UpdateTransportMsg updates the peerID in the transport message.
|
||||||
// With this function the server can reuse the given byte slice to update the peerID in the transport message. So do
|
// With this function the server can reuse the given byte slice to update the peerID in the transport message. So do
|
||||||
// need to allocate a new byte slice.
|
// need to allocate a new byte slice.
|
||||||
func UpdateTransportMsg(msg []byte, peerID []byte) error {
|
func UpdateTransportMsg(msg []byte, peerID PeerID) error {
|
||||||
if len(msg) < offsetTransportID+len(peerID) {
|
if len(msg) < offsetTransportID+peerIDSize {
|
||||||
return ErrInvalidMessageLength
|
return ErrInvalidMessageLength
|
||||||
}
|
}
|
||||||
copy(msg[offsetTransportID:], peerID)
|
copy(msg[offsetTransportID:], peerID[:])
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestMarshalHelloMsg(t *testing.T) {
|
func TestMarshalHelloMsg(t *testing.T) {
|
||||||
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
|
peerID := HashID("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
|
||||||
msg, err := MarshalHelloMsg(peerID, nil)
|
msg, err := MarshalHelloMsg(peerID, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("error: %v", err)
|
t.Fatalf("error: %v", err)
|
||||||
@@ -24,13 +24,13 @@ func TestMarshalHelloMsg(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("error: %v", err)
|
t.Fatalf("error: %v", err)
|
||||||
}
|
}
|
||||||
if string(receivedPeerID) != string(peerID) {
|
if receivedPeerID.String() != peerID.String() {
|
||||||
t.Errorf("expected %s, got %s", peerID, receivedPeerID)
|
t.Errorf("expected %s, got %s", peerID, receivedPeerID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMarshalAuthMsg(t *testing.T) {
|
func TestMarshalAuthMsg(t *testing.T) {
|
||||||
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
|
peerID := HashID("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
|
||||||
msg, err := MarshalAuthMsg(peerID, []byte{})
|
msg, err := MarshalAuthMsg(peerID, []byte{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("error: %v", err)
|
t.Fatalf("error: %v", err)
|
||||||
@@ -49,7 +49,7 @@ func TestMarshalAuthMsg(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("error: %v", err)
|
t.Fatalf("error: %v", err)
|
||||||
}
|
}
|
||||||
if string(receivedPeerID) != string(peerID) {
|
if receivedPeerID.String() != peerID.String() {
|
||||||
t.Errorf("expected %s, got %s", peerID, receivedPeerID)
|
t.Errorf("expected %s, got %s", peerID, receivedPeerID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -80,7 +80,7 @@ func TestMarshalAuthResponse(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestMarshalTransportMsg(t *testing.T) {
|
func TestMarshalTransportMsg(t *testing.T) {
|
||||||
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
|
peerID := HashID("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
|
||||||
payload := []byte("payload")
|
payload := []byte("payload")
|
||||||
msg, err := MarshalTransportMsg(peerID, payload)
|
msg, err := MarshalTransportMsg(peerID, payload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -101,7 +101,7 @@ func TestMarshalTransportMsg(t *testing.T) {
|
|||||||
t.Fatalf("failed to unmarshal transport id: %v", err)
|
t.Fatalf("failed to unmarshal transport id: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if string(uPeerID) != string(peerID) {
|
if uPeerID.String() != peerID.String() {
|
||||||
t.Errorf("expected %s, got %s", peerID, uPeerID)
|
t.Errorf("expected %s, got %s", peerID, uPeerID)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -110,8 +110,8 @@ func TestMarshalTransportMsg(t *testing.T) {
|
|||||||
t.Fatalf("error: %v", err)
|
t.Fatalf("error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if string(id) != string(peerID) {
|
if id.String() != peerID.String() {
|
||||||
t.Errorf("expected %s, got %s", peerID, id)
|
t.Errorf("expected: '%s', got: '%s'", peerID, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
if string(respPayload) != string(payload) {
|
if string(respPayload) != string(payload) {
|
||||||
|
|||||||
92
relay/messages/peer_state.go
Normal file
92
relay/messages/peer_state.go
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
package messages
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
func MarshalSubPeerStateMsg(ids []PeerID) ([][]byte, error) {
|
||||||
|
return marshalPeerIDs(ids, byte(MsgTypeSubscribePeerState))
|
||||||
|
}
|
||||||
|
|
||||||
|
func UnmarshalSubPeerStateMsg(buf []byte) ([]PeerID, error) {
|
||||||
|
return unmarshalPeerIDs(buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
func MarshalUnsubPeerStateMsg(ids []PeerID) ([][]byte, error) {
|
||||||
|
return marshalPeerIDs(ids, byte(MsgTypeUnsubscribePeerState))
|
||||||
|
}
|
||||||
|
|
||||||
|
func UnmarshalUnsubPeerStateMsg(buf []byte) ([]PeerID, error) {
|
||||||
|
return unmarshalPeerIDs(buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
func MarshalPeersOnline(ids []PeerID) ([][]byte, error) {
|
||||||
|
return marshalPeerIDs(ids, byte(MsgTypePeersOnline))
|
||||||
|
}
|
||||||
|
|
||||||
|
func UnmarshalPeersOnlineMsg(buf []byte) ([]PeerID, error) {
|
||||||
|
return unmarshalPeerIDs(buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
func MarshalPeersWentOffline(ids []PeerID) ([][]byte, error) {
|
||||||
|
return marshalPeerIDs(ids, byte(MsgTypePeersWentOffline))
|
||||||
|
}
|
||||||
|
|
||||||
|
func UnMarshalPeersWentOffline(buf []byte) ([]PeerID, error) {
|
||||||
|
return unmarshalPeerIDs(buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
// marshalPeerIDs is a generic function to marshal peer IDs with a specific message type
|
||||||
|
func marshalPeerIDs(ids []PeerID, msgType byte) ([][]byte, error) {
|
||||||
|
if len(ids) == 0 {
|
||||||
|
return nil, fmt.Errorf("no list of peer ids provided")
|
||||||
|
}
|
||||||
|
|
||||||
|
const maxPeersPerMessage = (MaxMessageSize - sizeOfProtoHeader) / peerIDSize
|
||||||
|
var messages [][]byte
|
||||||
|
|
||||||
|
for i := 0; i < len(ids); i += maxPeersPerMessage {
|
||||||
|
end := i + maxPeersPerMessage
|
||||||
|
if end > len(ids) {
|
||||||
|
end = len(ids)
|
||||||
|
}
|
||||||
|
chunk := ids[i:end]
|
||||||
|
|
||||||
|
totalSize := sizeOfProtoHeader + len(chunk)*peerIDSize
|
||||||
|
buf := make([]byte, totalSize)
|
||||||
|
buf[0] = byte(CurrentProtocolVersion)
|
||||||
|
buf[1] = msgType
|
||||||
|
|
||||||
|
offset := sizeOfProtoHeader
|
||||||
|
for _, id := range chunk {
|
||||||
|
copy(buf[offset:], id[:])
|
||||||
|
offset += peerIDSize
|
||||||
|
}
|
||||||
|
|
||||||
|
messages = append(messages, buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
return messages, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// unmarshalPeerIDs is a generic function to unmarshal peer IDs from a buffer
|
||||||
|
func unmarshalPeerIDs(buf []byte) ([]PeerID, error) {
|
||||||
|
if len(buf) < sizeOfProtoHeader {
|
||||||
|
return nil, fmt.Errorf("invalid message format")
|
||||||
|
}
|
||||||
|
|
||||||
|
if (len(buf)-sizeOfProtoHeader)%peerIDSize != 0 {
|
||||||
|
return nil, fmt.Errorf("invalid peer list size: %d", len(buf)-sizeOfProtoHeader)
|
||||||
|
}
|
||||||
|
|
||||||
|
numIDs := (len(buf) - sizeOfProtoHeader) / peerIDSize
|
||||||
|
|
||||||
|
ids := make([]PeerID, numIDs)
|
||||||
|
offset := sizeOfProtoHeader
|
||||||
|
for i := 0; i < numIDs; i++ {
|
||||||
|
copy(ids[i][:], buf[offset:offset+peerIDSize])
|
||||||
|
offset += peerIDSize
|
||||||
|
}
|
||||||
|
|
||||||
|
return ids, nil
|
||||||
|
}
|
||||||
144
relay/messages/peer_state_test.go
Normal file
144
relay/messages/peer_state_test.go
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
package messages
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
testPeerCount = 10
|
||||||
|
)
|
||||||
|
|
||||||
|
// Helper function to generate test PeerIDs
|
||||||
|
func generateTestPeerIDs(n int) []PeerID {
|
||||||
|
ids := make([]PeerID, n)
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
for j := 0; j < peerIDSize; j++ {
|
||||||
|
ids[i][j] = byte(i + j)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ids
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to compare slices of PeerID
|
||||||
|
func peerIDEqual(a, b []PeerID) bool {
|
||||||
|
if len(a) != len(b) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for i := range a {
|
||||||
|
if !bytes.Equal(a[i][:], b[i][:]) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMarshalUnmarshalSubPeerState(t *testing.T) {
|
||||||
|
ids := generateTestPeerIDs(testPeerCount)
|
||||||
|
|
||||||
|
msgs, err := MarshalSubPeerStateMsg(ids)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var allIDs []PeerID
|
||||||
|
for _, msg := range msgs {
|
||||||
|
decoded, err := UnmarshalSubPeerStateMsg(msg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unmarshal failed: %v", err)
|
||||||
|
}
|
||||||
|
allIDs = append(allIDs, decoded...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !peerIDEqual(ids, allIDs) {
|
||||||
|
t.Errorf("expected %v, got %v", ids, allIDs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMarshalSubPeerState_EmptyInput(t *testing.T) {
|
||||||
|
_, err := MarshalSubPeerStateMsg([]PeerID{})
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("expected error for empty input")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnmarshalSubPeerState_Invalid(t *testing.T) {
|
||||||
|
// Too short
|
||||||
|
_, err := UnmarshalSubPeerStateMsg([]byte{1})
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("expected error for short input")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Misaligned length
|
||||||
|
buf := make([]byte, sizeOfProtoHeader+1)
|
||||||
|
_, err = UnmarshalSubPeerStateMsg(buf)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("expected error for misaligned input")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMarshalUnmarshalPeersOnline(t *testing.T) {
|
||||||
|
ids := generateTestPeerIDs(testPeerCount)
|
||||||
|
|
||||||
|
msgs, err := MarshalPeersOnline(ids)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var allIDs []PeerID
|
||||||
|
for _, msg := range msgs {
|
||||||
|
decoded, err := UnmarshalPeersOnlineMsg(msg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unmarshal failed: %v", err)
|
||||||
|
}
|
||||||
|
allIDs = append(allIDs, decoded...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !peerIDEqual(ids, allIDs) {
|
||||||
|
t.Errorf("expected %v, got %v", ids, allIDs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMarshalPeersOnline_EmptyInput(t *testing.T) {
|
||||||
|
_, err := MarshalPeersOnline([]PeerID{})
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("expected error for empty input")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnmarshalPeersOnline_Invalid(t *testing.T) {
|
||||||
|
_, err := UnmarshalPeersOnlineMsg([]byte{1})
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("expected error for short input")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMarshalUnmarshalPeersWentOffline(t *testing.T) {
|
||||||
|
ids := generateTestPeerIDs(testPeerCount)
|
||||||
|
|
||||||
|
msgs, err := MarshalPeersWentOffline(ids)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var allIDs []PeerID
|
||||||
|
for _, msg := range msgs {
|
||||||
|
// MarshalPeersWentOffline shares no unmarshal function, so reuse PeersOnline
|
||||||
|
decoded, err := UnmarshalPeersOnlineMsg(msg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unmarshal failed: %v", err)
|
||||||
|
}
|
||||||
|
allIDs = append(allIDs, decoded...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !peerIDEqual(ids, allIDs) {
|
||||||
|
t.Errorf("expected %v, got %v", ids, allIDs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMarshalPeersWentOffline_EmptyInput(t *testing.T) {
|
||||||
|
_, err := MarshalPeersWentOffline([]PeerID{})
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("expected error for empty input")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/relay/auth"
|
|
||||||
"github.com/netbirdio/netbird/relay/messages"
|
"github.com/netbirdio/netbird/relay/messages"
|
||||||
//nolint:staticcheck
|
//nolint:staticcheck
|
||||||
"github.com/netbirdio/netbird/relay/messages/address"
|
"github.com/netbirdio/netbird/relay/messages/address"
|
||||||
@@ -14,6 +13,12 @@ import (
|
|||||||
authmsg "github.com/netbirdio/netbird/relay/messages/auth"
|
authmsg "github.com/netbirdio/netbird/relay/messages/auth"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type Validator interface {
|
||||||
|
Validate(any) error
|
||||||
|
// Deprecated: Use Validate instead.
|
||||||
|
ValidateHelloMsgType(any) error
|
||||||
|
}
|
||||||
|
|
||||||
// preparedMsg contains the marshalled success response messages
|
// preparedMsg contains the marshalled success response messages
|
||||||
type preparedMsg struct {
|
type preparedMsg struct {
|
||||||
responseHelloMsg []byte
|
responseHelloMsg []byte
|
||||||
@@ -54,14 +59,14 @@ func marshalResponseHelloMsg(instanceURL string) ([]byte, error) {
|
|||||||
|
|
||||||
type handshake struct {
|
type handshake struct {
|
||||||
conn net.Conn
|
conn net.Conn
|
||||||
validator auth.Validator
|
validator Validator
|
||||||
preparedMsg *preparedMsg
|
preparedMsg *preparedMsg
|
||||||
|
|
||||||
handshakeMethodAuth bool
|
handshakeMethodAuth bool
|
||||||
peerID string
|
peerID *messages.PeerID
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *handshake) handshakeReceive() ([]byte, error) {
|
func (h *handshake) handshakeReceive() (*messages.PeerID, error) {
|
||||||
buf := make([]byte, messages.MaxHandshakeSize)
|
buf := make([]byte, messages.MaxHandshakeSize)
|
||||||
n, err := h.conn.Read(buf)
|
n, err := h.conn.Read(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -80,17 +85,14 @@ func (h *handshake) handshakeReceive() ([]byte, error) {
|
|||||||
return nil, fmt.Errorf("determine message type from %s: %w", h.conn.RemoteAddr(), err)
|
return nil, fmt.Errorf("determine message type from %s: %w", h.conn.RemoteAddr(), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var peerID *messages.PeerID
|
||||||
bytePeerID []byte
|
|
||||||
peerID string
|
|
||||||
)
|
|
||||||
switch msgType {
|
switch msgType {
|
||||||
//nolint:staticcheck
|
//nolint:staticcheck
|
||||||
case messages.MsgTypeHello:
|
case messages.MsgTypeHello:
|
||||||
bytePeerID, peerID, err = h.handleHelloMsg(buf)
|
peerID, err = h.handleHelloMsg(buf)
|
||||||
case messages.MsgTypeAuth:
|
case messages.MsgTypeAuth:
|
||||||
h.handshakeMethodAuth = true
|
h.handshakeMethodAuth = true
|
||||||
bytePeerID, peerID, err = h.handleAuthMsg(buf)
|
peerID, err = h.handleAuthMsg(buf)
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("invalid message type %d from %s", msgType, h.conn.RemoteAddr())
|
return nil, fmt.Errorf("invalid message type %d from %s", msgType, h.conn.RemoteAddr())
|
||||||
}
|
}
|
||||||
@@ -98,7 +100,7 @@ func (h *handshake) handshakeReceive() ([]byte, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
h.peerID = peerID
|
h.peerID = peerID
|
||||||
return bytePeerID, nil
|
return peerID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *handshake) handshakeResponse() error {
|
func (h *handshake) handshakeResponse() error {
|
||||||
@@ -116,40 +118,37 @@ func (h *handshake) handshakeResponse() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *handshake) handleHelloMsg(buf []byte) ([]byte, string, error) {
|
func (h *handshake) handleHelloMsg(buf []byte) (*messages.PeerID, error) {
|
||||||
//nolint:staticcheck
|
//nolint:staticcheck
|
||||||
rawPeerID, authData, err := messages.UnmarshalHelloMsg(buf)
|
peerID, authData, err := messages.UnmarshalHelloMsg(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", fmt.Errorf("unmarshal hello message: %w", err)
|
return nil, fmt.Errorf("unmarshal hello message: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
peerID := messages.HashIDToString(rawPeerID)
|
|
||||||
log.Warnf("peer %s (%s) is using deprecated initial message type", peerID, h.conn.RemoteAddr())
|
log.Warnf("peer %s (%s) is using deprecated initial message type", peerID, h.conn.RemoteAddr())
|
||||||
|
|
||||||
authMsg, err := authmsg.UnmarshalMsg(authData)
|
authMsg, err := authmsg.UnmarshalMsg(authData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", fmt.Errorf("unmarshal auth message: %w", err)
|
return nil, fmt.Errorf("unmarshal auth message: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
//nolint:staticcheck
|
//nolint:staticcheck
|
||||||
if err := h.validator.ValidateHelloMsgType(authMsg.AdditionalData); err != nil {
|
if err := h.validator.ValidateHelloMsgType(authMsg.AdditionalData); err != nil {
|
||||||
return nil, "", fmt.Errorf("validate %s (%s): %w", peerID, h.conn.RemoteAddr(), err)
|
return nil, fmt.Errorf("validate %s (%s): %w", peerID, h.conn.RemoteAddr(), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return rawPeerID, peerID, nil
|
return peerID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *handshake) handleAuthMsg(buf []byte) ([]byte, string, error) {
|
func (h *handshake) handleAuthMsg(buf []byte) (*messages.PeerID, error) {
|
||||||
rawPeerID, authPayload, err := messages.UnmarshalAuthMsg(buf)
|
rawPeerID, authPayload, err := messages.UnmarshalAuthMsg(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", fmt.Errorf("unmarshal hello message: %w", err)
|
return nil, fmt.Errorf("unmarshal hello message: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
peerID := messages.HashIDToString(rawPeerID)
|
|
||||||
|
|
||||||
if err := h.validator.Validate(authPayload); err != nil {
|
if err := h.validator.Validate(authPayload); err != nil {
|
||||||
return nil, "", fmt.Errorf("validate %s (%s): %w", peerID, h.conn.RemoteAddr(), err)
|
return nil, fmt.Errorf("validate %s (%s): %w", rawPeerID.String(), h.conn.RemoteAddr(), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return rawPeerID, peerID, nil
|
return rawPeerID, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,43 +12,50 @@ import (
|
|||||||
"github.com/netbirdio/netbird/relay/healthcheck"
|
"github.com/netbirdio/netbird/relay/healthcheck"
|
||||||
"github.com/netbirdio/netbird/relay/messages"
|
"github.com/netbirdio/netbird/relay/messages"
|
||||||
"github.com/netbirdio/netbird/relay/metrics"
|
"github.com/netbirdio/netbird/relay/metrics"
|
||||||
|
"github.com/netbirdio/netbird/relay/server/store"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
bufferSize = 8820
|
bufferSize = messages.MaxMessageSize
|
||||||
|
|
||||||
errCloseConn = "failed to close connection to peer: %s"
|
errCloseConn = "failed to close connection to peer: %s"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Peer represents a peer connection
|
// Peer represents a peer connection
|
||||||
type Peer struct {
|
type Peer struct {
|
||||||
metrics *metrics.Metrics
|
metrics *metrics.Metrics
|
||||||
log *log.Entry
|
log *log.Entry
|
||||||
idS string
|
id messages.PeerID
|
||||||
idB []byte
|
conn net.Conn
|
||||||
conn net.Conn
|
connMu sync.RWMutex
|
||||||
connMu sync.RWMutex
|
store *store.Store
|
||||||
store *Store
|
notifier *store.PeerNotifier
|
||||||
|
|
||||||
|
peersListener *store.Listener
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewPeer creates a new Peer instance and prepare custom logging
|
// NewPeer creates a new Peer instance and prepare custom logging
|
||||||
func NewPeer(metrics *metrics.Metrics, id []byte, conn net.Conn, store *Store) *Peer {
|
func NewPeer(metrics *metrics.Metrics, id messages.PeerID, conn net.Conn, store *store.Store, notifier *store.PeerNotifier) *Peer {
|
||||||
stringID := messages.HashIDToString(id)
|
p := &Peer{
|
||||||
return &Peer{
|
metrics: metrics,
|
||||||
metrics: metrics,
|
log: log.WithField("peer_id", id.String()),
|
||||||
log: log.WithField("peer_id", stringID),
|
id: id,
|
||||||
idS: stringID,
|
conn: conn,
|
||||||
idB: id,
|
store: store,
|
||||||
conn: conn,
|
notifier: notifier,
|
||||||
store: store,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return p
|
||||||
}
|
}
|
||||||
|
|
||||||
// Work reads data from the connection
|
// Work reads data from the connection
|
||||||
// It manages the protocol (healthcheck, transport, close). Read the message and determine the message type and handle
|
// It manages the protocol (healthcheck, transport, close). Read the message and determine the message type and handle
|
||||||
// the message accordingly.
|
// the message accordingly.
|
||||||
func (p *Peer) Work() {
|
func (p *Peer) Work() {
|
||||||
|
p.peersListener = p.notifier.NewListener(p.sendPeersOnline, p.sendPeersWentOffline)
|
||||||
defer func() {
|
defer func() {
|
||||||
|
p.notifier.RemoveListener(p.peersListener)
|
||||||
|
|
||||||
if err := p.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
if err := p.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||||
p.log.Errorf(errCloseConn, err)
|
p.log.Errorf(errCloseConn, err)
|
||||||
}
|
}
|
||||||
@@ -94,6 +101,10 @@ func (p *Peer) Work() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *Peer) ID() messages.PeerID {
|
||||||
|
return p.id
|
||||||
|
}
|
||||||
|
|
||||||
func (p *Peer) handleMsgType(ctx context.Context, msgType messages.MsgType, hc *healthcheck.Sender, n int, msg []byte) {
|
func (p *Peer) handleMsgType(ctx context.Context, msgType messages.MsgType, hc *healthcheck.Sender, n int, msg []byte) {
|
||||||
switch msgType {
|
switch msgType {
|
||||||
case messages.MsgTypeHealthCheck:
|
case messages.MsgTypeHealthCheck:
|
||||||
@@ -107,6 +118,10 @@ func (p *Peer) handleMsgType(ctx context.Context, msgType messages.MsgType, hc *
|
|||||||
if err := p.conn.Close(); err != nil {
|
if err := p.conn.Close(); err != nil {
|
||||||
log.Errorf(errCloseConn, err)
|
log.Errorf(errCloseConn, err)
|
||||||
}
|
}
|
||||||
|
case messages.MsgTypeSubscribePeerState:
|
||||||
|
p.handleSubscribePeerState(msg)
|
||||||
|
case messages.MsgTypeUnsubscribePeerState:
|
||||||
|
p.handleUnsubscribePeerState(msg)
|
||||||
default:
|
default:
|
||||||
p.log.Warnf("received unexpected message type: %s", msgType)
|
p.log.Warnf("received unexpected message type: %s", msgType)
|
||||||
}
|
}
|
||||||
@@ -145,7 +160,7 @@ func (p *Peer) Close() {
|
|||||||
|
|
||||||
// String returns the peer ID
|
// String returns the peer ID
|
||||||
func (p *Peer) String() string {
|
func (p *Peer) String() string {
|
||||||
return p.idS
|
return p.id.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Peer) writeWithTimeout(ctx context.Context, buf []byte) error {
|
func (p *Peer) writeWithTimeout(ctx context.Context, buf []byte) error {
|
||||||
@@ -197,14 +212,14 @@ func (p *Peer) handleTransportMsg(msg []byte) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
stringPeerID := messages.HashIDToString(peerID)
|
item, ok := p.store.Peer(*peerID)
|
||||||
dp, ok := p.store.Peer(stringPeerID)
|
|
||||||
if !ok {
|
if !ok {
|
||||||
p.log.Debugf("peer not found: %s", stringPeerID)
|
p.log.Debugf("peer not found: %s", peerID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
dp := item.(*Peer)
|
||||||
|
|
||||||
err = messages.UpdateTransportMsg(msg, p.idB)
|
err = messages.UpdateTransportMsg(msg, p.id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
p.log.Errorf("failed to update transport message: %s", err)
|
p.log.Errorf("failed to update transport message: %s", err)
|
||||||
return
|
return
|
||||||
@@ -217,3 +232,57 @@ func (p *Peer) handleTransportMsg(msg []byte) {
|
|||||||
}
|
}
|
||||||
p.metrics.TransferBytesSent.Add(context.Background(), int64(n))
|
p.metrics.TransferBytesSent.Add(context.Background(), int64(n))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *Peer) handleSubscribePeerState(msg []byte) {
|
||||||
|
peerIDs, err := messages.UnmarshalSubPeerStateMsg(msg)
|
||||||
|
if err != nil {
|
||||||
|
p.log.Errorf("failed to unmarshal open connection message: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p.log.Debugf("received subscription message for %d peers", len(peerIDs))
|
||||||
|
onlinePeers := p.peersListener.AddInterestedPeers(peerIDs)
|
||||||
|
if len(onlinePeers) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p.log.Debugf("response with %d online peers", len(onlinePeers))
|
||||||
|
p.sendPeersOnline(onlinePeers)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Peer) handleUnsubscribePeerState(msg []byte) {
|
||||||
|
peerIDs, err := messages.UnmarshalUnsubPeerStateMsg(msg)
|
||||||
|
if err != nil {
|
||||||
|
p.log.Errorf("failed to unmarshal open connection message: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p.peersListener.RemoveInterestedPeer(peerIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Peer) sendPeersOnline(peers []messages.PeerID) {
|
||||||
|
msgs, err := messages.MarshalPeersOnline(peers)
|
||||||
|
if err != nil {
|
||||||
|
p.log.Errorf("failed to marshal peer location message: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for n, msg := range msgs {
|
||||||
|
if _, err := p.Write(msg); err != nil {
|
||||||
|
p.log.Errorf("failed to write %d. peers offline message: %s", n, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Peer) sendPeersWentOffline(peers []messages.PeerID) {
|
||||||
|
msgs, err := messages.MarshalPeersWentOffline(peers)
|
||||||
|
if err != nil {
|
||||||
|
p.log.Errorf("failed to marshal peer location message: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for n, msg := range msgs {
|
||||||
|
if _, err := p.Write(msg); err != nil {
|
||||||
|
p.log.Errorf("failed to write %d. peers offline message: %s", n, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user