Compare commits

...

12 Commits

Author SHA1 Message Date
Viktor Liu
417fa6e833 Change default interface name from wt0 to nb0 2025-07-21 17:58:56 +02:00
Bethuel Mmbaga
a7af15c4fc [management] Fix group resource count mismatch in policy (#4182) 2025-07-21 15:26:06 +03:00
Viktor Liu
d6ed9c037e [client] Fix bind exclusion routes (#4154) 2025-07-21 12:13:21 +02:00
Ali Amer
40fdeda838 [client] add new filter-by-connection-type flag (#4010)
introduces a new flag --filter-by-connection-type to the status command.
It allows users to filter peers by connection type (P2P or Relayed) in both JSON and detailed views.

Input validation is added in parseFilters() to ensure proper usage, and --detail is auto-enabled if no output format is specified (consistent with other filters).
2025-07-21 11:55:17 +02:00
Zoltan Papp
f6e9d755e4 [client, relay] The openConn function no longer blocks the relayAddress function call (#4180)
The openConn function no longer blocks the relayAddress function call in manager layer
2025-07-21 09:46:53 +02:00
Maycon Santos
08fd460867 [management] Add validate flow response (#4172)
This PR adds a validate flow response feature to the management server by integrating an IntegratedValidator component. The main purpose is to enable validation of PKCE authorization flows through an integrated validator interface.

- Adds a new ValidateFlowResponse method to the IntegratedValidator interface
- Integrates the validator into the management server to validate PKCE authorization flows
- Updates dependency version for management-integrations
2025-07-18 12:18:52 +02:00
Pascal Fischer
4f74509d55 [management] fix index creation if exist on mysql (#4150) 2025-07-16 15:07:31 +02:00
Maycon Santos
58185ced16 [misc] add forum post and update sign pipeline (#4155)
use old git-town version
2025-07-16 14:10:28 +02:00
Pedro Maia Costa
e67f44f47c [client] fix test (#4156) 2025-07-16 12:09:38 +02:00
Zoltan Papp
b524f486e2 [client] Fix/nil relayed address (#4153)
Fix nil pointer in Relay conn address

Meanwhile, we create a relayed net.Conn struct instance, it is possible to set the relayedURL to nil.

panic: value method github.com/netbirdio/netbird/relay/client.RelayAddr.String called using nil *RelayAddr pointer

Fix relayed URL variable protection
Protect the channel closing
2025-07-16 00:00:18 +02:00
Zoltan Papp
0dab03252c [client, relay-server] Feature/relay notification (#4083)
- Clients now subscribe to peer status changes.
- The server manages and maintains these subscriptions.
- Replaced raw string peer IDs with a custom peer ID type for better type safety and clarity.
2025-07-15 10:43:42 +02:00
iisteev
e49bcc343d [client] Avoid parsing NB_NETSTACK_SKIP_PROXY if empty (#4145)
Signed-off-by: iisteev <isteevan.shetoo@is-info.fr>
2025-07-13 15:42:48 +02:00
97 changed files with 2015 additions and 786 deletions

View File

@@ -16,6 +16,6 @@ jobs:
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: git-town/action@v1 - uses: git-town/action@v1.2.1
with: with:
skip-single-stacks: true skip-single-stacks: true

View File

@@ -9,7 +9,7 @@ on:
pull_request: pull_request:
env: env:
SIGN_PIPE_VER: "v0.0.20" SIGN_PIPE_VER: "v0.0.21"
GORELEASER_VER: "v2.3.2" GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird" PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH" COPYRIGHT: "NetBird GmbH"
@@ -231,3 +231,17 @@ jobs:
ref: ${{ env.SIGN_PIPE_VER }} ref: ${{ env.SIGN_PIPE_VER }}
token: ${{ secrets.SIGN_GITHUB_TOKEN }} token: ${{ secrets.SIGN_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref }}", "skipRelease": false }' inputs: '{ "tag": "${{ github.ref }}", "skipRelease": false }'
post_on_forum:
runs-on: ubuntu-latest
continue-on-error: true
needs: [trigger_signer]
steps:
- uses: Codixer/discourse-topic-github-release-action@v2.0.1
with:
discourse-api-key: ${{ secrets.DISCOURSE_RELEASES_API_KEY }}
discourse-base-url: https://forum.netbird.io
discourse-author-username: NetBird
discourse-category: 17
discourse-tags:
releases

View File

@@ -307,7 +307,7 @@ func getStatusOutput(cmd *cobra.Command, anon bool) string {
cmd.PrintErrf("Failed to get status: %v\n", err) cmd.PrintErrf("Failed to get status: %v\n", err)
} else { } else {
statusOutputString = nbstatus.ParseToFullDetailSummary( statusOutputString = nbstatus.ParseToFullDetailSummary(
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil), nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, ""),
) )
} }
return statusOutputString return statusOutputString

View File

@@ -26,6 +26,7 @@ var (
statusFilter string statusFilter string
ipsFilterMap map[string]struct{} ipsFilterMap map[string]struct{}
prefixNamesFilterMap map[string]struct{} prefixNamesFilterMap map[string]struct{}
connectionTypeFilter string
) )
var statusCmd = &cobra.Command{ var statusCmd = &cobra.Command{
@@ -45,6 +46,7 @@ func init() {
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200") statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200")
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud") statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected") statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
statusCmd.PersistentFlags().StringVar(&connectionTypeFilter, "filter-by-connection-type", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P")
} }
func statusFunc(cmd *cobra.Command, args []string) error { func statusFunc(cmd *cobra.Command, args []string) error {
@@ -89,7 +91,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return nil return nil
} }
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap) var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter)
var statusOutputString string var statusOutputString string
switch { switch {
case detailFlag: case detailFlag:
@@ -156,6 +158,15 @@ func parseFilters() error {
enableDetailFlagWhenFilterFlag() enableDetailFlagWhenFilterFlag()
} }
switch strings.ToLower(connectionTypeFilter) {
case "", "p2p", "relayed":
if strings.ToLower(connectionTypeFilter) != "" {
enableDetailFlagWhenFilterFlag()
}
default:
return fmt.Errorf("wrong connection-type filter, should be one of P2P|Relayed, got: %s", connectionTypeFilter)
}
return nil return nil
} }

View File

@@ -109,7 +109,7 @@ func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc
} }
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager) secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil) mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &mgmt.MockIntegratedValidator{})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -0,0 +1,15 @@
package bind
import (
wireguard "golang.zx2c4.com/wireguard/conn"
nbnet "github.com/netbirdio/netbird/util/net"
)
// TODO: This is most likely obsolete since the control fns should be called by the wrapped udpconn (ice_bind.go)
func init() {
listener := nbnet.NewListener()
if listener.ListenConfig.Control != nil {
*wireguard.ControlFns = append(*wireguard.ControlFns, listener.ListenConfig.Control)
}
}

View File

@@ -1,12 +0,0 @@
package bind
import (
wireguard "golang.zx2c4.com/wireguard/conn"
nbnet "github.com/netbirdio/netbird/util/net"
)
func init() {
// ControlFns is not thread safe and should only be modified during init.
*wireguard.ControlFns = append(*wireguard.ControlFns, nbnet.ControlProtectSocket)
}

View File

@@ -16,6 +16,7 @@ import (
wgConn "golang.zx2c4.com/wireguard/conn" wgConn "golang.zx2c4.com/wireguard/conn"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
nbnet "github.com/netbirdio/netbird/util/net"
) )
type RecvMessage struct { type RecvMessage struct {
@@ -153,7 +154,7 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
s.udpMux = NewUniversalUDPMuxDefault( s.udpMux = NewUniversalUDPMuxDefault(
UniversalUDPMuxParams{ UniversalUDPMuxParams{
UDPConn: conn, UDPConn: nbnet.WrapUDPConn(conn),
Net: s.transportNet, Net: s.transportNet,
FilterFn: s.filterFn, FilterFn: s.filterFn,
WGAddress: s.address, WGAddress: s.address,

View File

@@ -296,14 +296,20 @@ func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
return return
} }
m.addressMapMu.Lock() var allAddresses []string
defer m.addressMapMu.Unlock()
for _, c := range removedConns { for _, c := range removedConns {
addresses := c.getAddresses() addresses := c.getAddresses()
for _, addr := range addresses { allAddresses = append(allAddresses, addresses...)
delete(m.addressMap, addr) }
}
m.addressMapMu.Lock()
for _, addr := range allAddresses {
delete(m.addressMap, addr)
}
m.addressMapMu.Unlock()
for _, addr := range allAddresses {
m.notifyAddressRemoval(addr)
} }
} }
@@ -351,14 +357,13 @@ func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string)
} }
m.addressMapMu.Lock() m.addressMapMu.Lock()
defer m.addressMapMu.Unlock()
existing, ok := m.addressMap[addr] existing, ok := m.addressMap[addr]
if !ok { if !ok {
existing = []*udpMuxedConn{} existing = []*udpMuxedConn{}
} }
existing = append(existing, conn) existing = append(existing, conn)
m.addressMap[addr] = existing m.addressMap[addr] = existing
m.addressMapMu.Unlock()
log.Debugf("ICE: registered %s for %s", addr, conn.params.Key) log.Debugf("ICE: registered %s for %s", addr, conn.params.Key)
} }
@@ -386,12 +391,12 @@ func (m *UDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.Addr) erro
// If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one // If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one
// muxed connection - one for the SRFLX candidate and the other one for the HOST one. // muxed connection - one for the SRFLX candidate and the other one for the HOST one.
// We will then forward STUN packets to each of these connections. // We will then forward STUN packets to each of these connections.
m.addressMapMu.Lock() m.addressMapMu.RLock()
var destinationConnList []*udpMuxedConn var destinationConnList []*udpMuxedConn
if storedConns, ok := m.addressMap[addr.String()]; ok { if storedConns, ok := m.addressMap[addr.String()]; ok {
destinationConnList = append(destinationConnList, storedConns...) destinationConnList = append(destinationConnList, storedConns...)
} }
m.addressMapMu.Unlock() m.addressMapMu.RUnlock()
var isIPv6 bool var isIPv6 bool
if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil { if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil {

View File

@@ -0,0 +1,21 @@
//go:build !ios
package bind
import (
nbnet "github.com/netbirdio/netbird/util/net"
)
func (m *UDPMuxDefault) notifyAddressRemoval(addr string) {
wrapped, ok := m.params.UDPConn.(*UDPConn)
if !ok {
return
}
nbnetConn, ok := wrapped.GetPacketConn().(*nbnet.UDPConn)
if !ok {
return
}
nbnetConn.RemoveAddress(addr)
}

View File

@@ -0,0 +1,7 @@
//go:build ios
package bind
func (m *UDPMuxDefault) notifyAddressRemoval(addr string) {
// iOS doesn't support nbnet hooks, so this is a no-op
}

View File

@@ -62,7 +62,7 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
// wrap UDP connection, process server reflexive messages // wrap UDP connection, process server reflexive messages
// before they are passed to the UDPMux connection handler (connWorker) // before they are passed to the UDPMux connection handler (connWorker)
m.params.UDPConn = &udpConn{ m.params.UDPConn = &UDPConn{
PacketConn: params.UDPConn, PacketConn: params.UDPConn,
mux: m, mux: m,
logger: params.Logger, logger: params.Logger,
@@ -70,7 +70,6 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
address: params.WGAddress, address: params.WGAddress,
} }
// embed UDPMux
udpMuxParams := UDPMuxParams{ udpMuxParams := UDPMuxParams{
Logger: params.Logger, Logger: params.Logger,
UDPConn: m.params.UDPConn, UDPConn: m.params.UDPConn,
@@ -114,8 +113,8 @@ func (m *UniversalUDPMuxDefault) ReadFromConn(ctx context.Context) {
} }
} }
// udpConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets // UDPConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets
type udpConn struct { type UDPConn struct {
net.PacketConn net.PacketConn
mux *UniversalUDPMuxDefault mux *UniversalUDPMuxDefault
logger logging.LeveledLogger logger logging.LeveledLogger
@@ -125,7 +124,12 @@ type udpConn struct {
address wgaddr.Address address wgaddr.Address
} }
func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) { // GetPacketConn returns the underlying PacketConn
func (u *UDPConn) GetPacketConn() net.PacketConn {
return u.PacketConn
}
func (u *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
if u.filterFn == nil { if u.filterFn == nil {
return u.PacketConn.WriteTo(b, addr) return u.PacketConn.WriteTo(b, addr)
} }
@@ -137,21 +141,21 @@ func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) {
return u.handleUncachedAddress(b, addr) return u.handleUncachedAddress(b, addr)
} }
func (u *udpConn) handleCachedAddress(isRouted bool, b []byte, addr net.Addr) (int, error) { func (u *UDPConn) handleCachedAddress(isRouted bool, b []byte, addr net.Addr) (int, error) {
if isRouted { if isRouted {
return 0, fmt.Errorf("address %s is part of a routed network, refusing to write", addr) return 0, fmt.Errorf("address %s is part of a routed network, refusing to write", addr)
} }
return u.PacketConn.WriteTo(b, addr) return u.PacketConn.WriteTo(b, addr)
} }
func (u *udpConn) handleUncachedAddress(b []byte, addr net.Addr) (int, error) { func (u *UDPConn) handleUncachedAddress(b []byte, addr net.Addr) (int, error) {
if err := u.performFilterCheck(addr); err != nil { if err := u.performFilterCheck(addr); err != nil {
return 0, err return 0, err
} }
return u.PacketConn.WriteTo(b, addr) return u.PacketConn.WriteTo(b, addr)
} }
func (u *udpConn) performFilterCheck(addr net.Addr) error { func (u *UDPConn) performFilterCheck(addr net.Addr) error {
host, err := getHostFromAddr(addr) host, err := getHostFromAddr(addr)
if err != nil { if err != nil {
log.Errorf("Failed to get host from address %s: %v", addr, err) log.Errorf("Failed to get host from address %s: %v", addr, err)

View File

@@ -3,4 +3,4 @@
package configurer package configurer
// WgInterfaceDefault is a default interface name of Netbird // WgInterfaceDefault is a default interface name of Netbird
const WgInterfaceDefault = "wt0" const WgInterfaceDefault = "nb0"

View File

@@ -41,9 +41,12 @@ func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
} }
t.tundev = nsTunDev t.tundev = nsTunDev
skipProxy, err := strconv.ParseBool(os.Getenv(EnvSkipProxy)) var skipProxy bool
if err != nil { if val := os.Getenv(EnvSkipProxy); val != "" {
log.Errorf("failed to parse %s: %s", EnvSkipProxy, err) skipProxy, err = strconv.ParseBool(val)
if err != nil {
log.Errorf("failed to parse %s: %s", EnvSkipProxy, err)
}
} }
if skipProxy { if skipProxy {
return nsTunDev, tunNet, nil return nsTunDev, tunNet, nil

View File

@@ -12,6 +12,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
) )
type ProxyBind struct { type ProxyBind struct {
@@ -28,6 +29,17 @@ type ProxyBind struct {
pausedMu sync.Mutex pausedMu sync.Mutex
paused bool paused bool
isStarted bool isStarted bool
closeListener *listener.CloseListener
}
func NewProxyBind(bind *bind.ICEBind) *ProxyBind {
p := &ProxyBind{
Bind: bind,
closeListener: listener.NewCloseListener(),
}
return p
} }
// AddTurnConn adds a new connection to the bind. // AddTurnConn adds a new connection to the bind.
@@ -54,6 +66,10 @@ func (p *ProxyBind) EndpointAddr() *net.UDPAddr {
} }
} }
func (p *ProxyBind) SetDisconnectListener(disconnected func()) {
p.closeListener.SetCloseListener(disconnected)
}
func (p *ProxyBind) Work() { func (p *ProxyBind) Work() {
if p.remoteConn == nil { if p.remoteConn == nil {
return return
@@ -96,6 +112,9 @@ func (p *ProxyBind) close() error {
if p.closed { if p.closed {
return nil return nil
} }
p.closeListener.SetCloseListener(nil)
p.closed = true p.closed = true
p.cancel() p.cancel()
@@ -122,6 +141,7 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
if ctx.Err() != nil { if ctx.Err() != nil {
return return
} }
p.closeListener.Notify()
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err) log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
return return
} }

View File

@@ -11,6 +11,8 @@ import (
"sync" "sync"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
) )
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call // ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
@@ -26,6 +28,15 @@ type ProxyWrapper struct {
pausedMu sync.Mutex pausedMu sync.Mutex
paused bool paused bool
isStarted bool isStarted bool
closeListener *listener.CloseListener
}
func NewProxyWrapper(WgeBPFProxy *WGEBPFProxy) *ProxyWrapper {
return &ProxyWrapper{
WgeBPFProxy: WgeBPFProxy,
closeListener: listener.NewCloseListener(),
}
} }
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error { func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
@@ -43,6 +54,10 @@ func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
return p.wgEndpointAddr return p.wgEndpointAddr
} }
func (p *ProxyWrapper) SetDisconnectListener(disconnected func()) {
p.closeListener.SetCloseListener(disconnected)
}
func (p *ProxyWrapper) Work() { func (p *ProxyWrapper) Work() {
if p.remoteConn == nil { if p.remoteConn == nil {
return return
@@ -77,6 +92,8 @@ func (e *ProxyWrapper) CloseConn() error {
e.cancel() e.cancel()
e.closeListener.SetCloseListener(nil)
if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
return fmt.Errorf("failed to close remote conn: %w", err) return fmt.Errorf("failed to close remote conn: %w", err)
} }
@@ -117,6 +134,7 @@ func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, err
if ctx.Err() != nil { if ctx.Err() != nil {
return 0, ctx.Err() return 0, ctx.Err()
} }
p.closeListener.Notify()
if !errors.Is(err, io.EOF) { if !errors.Is(err, io.EOF) {
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err) log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err)
} }

View File

@@ -36,9 +36,8 @@ func (w *KernelFactory) GetProxy() Proxy {
return udpProxy.NewWGUDPProxy(w.wgPort) return udpProxy.NewWGUDPProxy(w.wgPort)
} }
return &ebpf.ProxyWrapper{ return ebpf.NewProxyWrapper(w.ebpfProxy)
WgeBPFProxy: w.ebpfProxy,
}
} }
func (w *KernelFactory) Free() error { func (w *KernelFactory) Free() error {

View File

@@ -20,9 +20,7 @@ func NewUSPFactory(iceBind *bind.ICEBind) *USPFactory {
} }
func (w *USPFactory) GetProxy() Proxy { func (w *USPFactory) GetProxy() Proxy {
return &proxyBind.ProxyBind{ return proxyBind.NewProxyBind(w.bind)
Bind: w.bind,
}
} }
func (w *USPFactory) Free() error { func (w *USPFactory) Free() error {

View File

@@ -0,0 +1,19 @@
package listener
type CloseListener struct {
listener func()
}
func NewCloseListener() *CloseListener {
return &CloseListener{}
}
func (c *CloseListener) SetCloseListener(listener func()) {
c.listener = listener
}
func (c *CloseListener) Notify() {
if c.listener != nil {
c.listener()
}
}

View File

@@ -12,4 +12,5 @@ type Proxy interface {
Work() // Work start or resume the proxy Work() // Work start or resume the proxy
Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works. Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works.
CloseConn() error CloseConn() error
SetDisconnectListener(disconnected func())
} }

View File

@@ -98,9 +98,7 @@ func TestProxyCloseByRemoteConn(t *testing.T) {
t.Errorf("failed to free ebpf proxy: %s", err) t.Errorf("failed to free ebpf proxy: %s", err)
} }
}() }()
proxyWrapper := &ebpf.ProxyWrapper{ proxyWrapper := ebpf.NewProxyWrapper(ebpfProxy)
WgeBPFProxy: ebpfProxy,
}
tests = append(tests, struct { tests = append(tests, struct {
name string name string

View File

@@ -12,6 +12,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
cerrors "github.com/netbirdio/netbird/client/errors" cerrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
) )
// WGUDPProxy proxies // WGUDPProxy proxies
@@ -28,6 +29,8 @@ type WGUDPProxy struct {
pausedMu sync.Mutex pausedMu sync.Mutex
paused bool paused bool
isStarted bool isStarted bool
closeListener *listener.CloseListener
} }
// NewWGUDPProxy instantiate a UDP based WireGuard proxy. This is not a thread safe implementation // NewWGUDPProxy instantiate a UDP based WireGuard proxy. This is not a thread safe implementation
@@ -35,6 +38,7 @@ func NewWGUDPProxy(wgPort int) *WGUDPProxy {
log.Debugf("Initializing new user space proxy with port %d", wgPort) log.Debugf("Initializing new user space proxy with port %d", wgPort)
p := &WGUDPProxy{ p := &WGUDPProxy{
localWGListenPort: wgPort, localWGListenPort: wgPort,
closeListener: listener.NewCloseListener(),
} }
return p return p
} }
@@ -67,6 +71,10 @@ func (p *WGUDPProxy) EndpointAddr() *net.UDPAddr {
return endpointUdpAddr return endpointUdpAddr
} }
func (p *WGUDPProxy) SetDisconnectListener(disconnected func()) {
p.closeListener.SetCloseListener(disconnected)
}
// Work starts the proxy or resumes it if it was paused // Work starts the proxy or resumes it if it was paused
func (p *WGUDPProxy) Work() { func (p *WGUDPProxy) Work() {
if p.remoteConn == nil { if p.remoteConn == nil {
@@ -111,6 +119,8 @@ func (p *WGUDPProxy) close() error {
if p.closed { if p.closed {
return nil return nil
} }
p.closeListener.SetCloseListener(nil)
p.closed = true p.closed = true
p.cancel() p.cancel()
@@ -141,6 +151,7 @@ func (p *WGUDPProxy) proxyToRemote(ctx context.Context) {
if ctx.Err() != nil { if ctx.Err() != nil {
return return
} }
p.closeListener.Notify()
log.Debugf("failed to read from wg interface conn: %s", err) log.Debugf("failed to read from wg interface conn: %s", err)
return return
} }

View File

@@ -39,7 +39,7 @@ const (
) )
var defaultInterfaceBlacklist = []string{ var defaultInterfaceBlacklist = []string{
iface.WgInterfaceDefault, "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts", iface.WgInterfaceDefault, "nb", "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts",
"Tailscale", "tailscale", "docker", "veth", "br-", "lo", "Tailscale", "tailscale", "docker", "veth", "br-", "lo",
} }

View File

@@ -61,7 +61,6 @@ import (
signal "github.com/netbirdio/netbird/signal/client" signal "github.com/netbirdio/netbird/signal/client"
sProto "github.com/netbirdio/netbird/signal/proto" sProto "github.com/netbirdio/netbird/signal/proto"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
nbnet "github.com/netbirdio/netbird/util/net"
) )
// PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer. // PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer.
@@ -138,9 +137,6 @@ type Engine struct {
connMgr *ConnMgr connMgr *ConnMgr
beforePeerHook nbnet.AddHookFunc
afterPeerHook nbnet.RemoveHookFunc
// rpManager is a Rosenpass manager // rpManager is a Rosenpass manager
rpManager *rosenpass.Manager rpManager *rosenpass.Manager
@@ -409,12 +405,8 @@ func (e *Engine) Start() error {
DisableClientRoutes: e.config.DisableClientRoutes, DisableClientRoutes: e.config.DisableClientRoutes,
DisableServerRoutes: e.config.DisableServerRoutes, DisableServerRoutes: e.config.DisableServerRoutes,
}) })
beforePeerHook, afterPeerHook, err := e.routeManager.Init() if err := e.routeManager.Init(); err != nil {
if err != nil {
log.Errorf("Failed to initialize route manager: %s", err) log.Errorf("Failed to initialize route manager: %s", err)
} else {
e.beforePeerHook = beforePeerHook
e.afterPeerHook = afterPeerHook
} }
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener) e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
@@ -1261,10 +1253,6 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
return fmt.Errorf("peer already exists: %s", peerKey) return fmt.Errorf("peer already exists: %s", peerKey)
} }
if e.beforePeerHook != nil && e.afterPeerHook != nil {
conn.AddBeforeAddPeerHook(e.beforePeerHook)
conn.AddAfterRemovePeerHook(e.afterPeerHook)
}
return nil return nil
} }

View File

@@ -400,7 +400,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
StatusRecorder: engine.statusRecorder, StatusRecorder: engine.statusRecorder,
RelayManager: relayMgr, RelayManager: relayMgr,
}) })
_, _, err = engine.routeManager.Init() err = engine.routeManager.Init()
require.NoError(t, err) require.NoError(t, err)
engine.dnsServer = &dns.MockServer{ engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil }, UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
@@ -1393,7 +1393,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
if runtime.GOOS == "darwin" { if runtime.GOOS == "darwin" {
ifaceName = fmt.Sprintf("utun1%d", i) ifaceName = fmt.Sprintf("utun1%d", i)
} else { } else {
ifaceName = fmt.Sprintf("wt%d", i) ifaceName = fmt.Sprintf("nb%d", i)
} }
wgPort := 33100 + i wgPort := 33100 + i
@@ -1481,6 +1481,10 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()). GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
Return(&types.Settings{}, nil). Return(&types.Settings{}, nil).
AnyTimes() AnyTimes()
settingsMockManager.EXPECT().
GetExtraSettings(gomock.Any(), gomock.Any()).
Return(&types.ExtraSettings{}, nil).
AnyTimes()
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)
@@ -1490,7 +1494,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
} }
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager) secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil) mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{})
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }

View File

@@ -19,7 +19,7 @@ type mockIFaceMapper struct {
} }
func (m *mockIFaceMapper) Name() string { func (m *mockIFaceMapper) Name() string {
return "wt0" return "nb0"
} }
func (m *mockIFaceMapper) Address() wgaddr.Address { func (m *mockIFaceMapper) Address() wgaddr.Address {

View File

@@ -26,7 +26,6 @@ import (
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
relayClient "github.com/netbirdio/netbird/relay/client" relayClient "github.com/netbirdio/netbird/relay/client"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
nbnet "github.com/netbirdio/netbird/util/net"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group" semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
) )
@@ -106,10 +105,6 @@ type Conn struct {
workerRelay *WorkerRelay workerRelay *WorkerRelay
wgWatcherWg sync.WaitGroup wgWatcherWg sync.WaitGroup
connIDRelay nbnet.ConnectionID
connIDICE nbnet.ConnectionID
beforeAddPeerHooks []nbnet.AddHookFunc
afterRemovePeerHooks []nbnet.RemoveHookFunc
// used to store the remote Rosenpass key for Relayed connection in case of connection update from ice // used to store the remote Rosenpass key for Relayed connection in case of connection update from ice
rosenpassRemoteKey []byte rosenpassRemoteKey []byte
@@ -167,7 +162,7 @@ func (conn *Conn) Open(engineCtx context.Context) error {
conn.ctx, conn.ctxCancel = context.WithCancel(engineCtx) conn.ctx, conn.ctxCancel = context.WithCancel(engineCtx)
conn.workerRelay = NewWorkerRelay(conn.Log, isController(conn.config), conn.config, conn, conn.relayManager, conn.dumpState) conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager, conn.dumpState)
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally() relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally) workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
@@ -267,8 +262,6 @@ func (conn *Conn) Close(signalToRemote bool) {
conn.Log.Errorf("failed to remove wg endpoint: %v", err) conn.Log.Errorf("failed to remove wg endpoint: %v", err)
} }
conn.freeUpConnID()
if conn.evalStatus() == StatusConnected && conn.onDisconnected != nil { if conn.evalStatus() == StatusConnected && conn.onDisconnected != nil {
conn.onDisconnected(conn.config.WgConfig.RemoteKey) conn.onDisconnected(conn.config.WgConfig.RemoteKey)
} }
@@ -293,13 +286,6 @@ func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMa
conn.workerICE.OnRemoteCandidate(candidate, haRoutes) conn.workerICE.OnRemoteCandidate(candidate, haRoutes)
} }
func (conn *Conn) AddBeforeAddPeerHook(hook nbnet.AddHookFunc) {
conn.beforeAddPeerHooks = append(conn.beforeAddPeerHooks, hook)
}
func (conn *Conn) AddAfterRemovePeerHook(hook nbnet.RemoveHookFunc) {
conn.afterRemovePeerHooks = append(conn.afterRemovePeerHooks, hook)
}
// SetOnConnected sets a handler function to be triggered by Conn when a new connection to a remote peer established // SetOnConnected sets a handler function to be triggered by Conn when a new connection to a remote peer established
func (conn *Conn) SetOnConnected(handler func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)) { func (conn *Conn) SetOnConnected(handler func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)) {
conn.onConnected = handler conn.onConnected = handler
@@ -387,10 +373,6 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
ep = directEp ep = directEp
} }
if err := conn.runBeforeAddPeerHooks(ep.IP); err != nil {
conn.Log.Errorf("Before add peer hook failed: %v", err)
}
conn.workerRelay.DisableWgWatcher() conn.workerRelay.DisableWgWatcher()
// todo consider to run conn.wgWatcherWg.Wait() here // todo consider to run conn.wgWatcherWg.Wait() here
@@ -489,6 +471,8 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
conn.Log.Errorf("failed to add relayed net.Conn to local proxy: %v", err) conn.Log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
return return
} }
wgProxy.SetDisconnectListener(conn.onRelayDisconnected)
conn.dumpState.NewLocalProxy() conn.dumpState.NewLocalProxy()
conn.Log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String()) conn.Log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String())
@@ -501,10 +485,6 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
return return
} }
if err := conn.runBeforeAddPeerHooks(wgProxy.EndpointAddr().IP); err != nil {
conn.Log.Errorf("Before add peer hook failed: %v", err)
}
wgProxy.Work() wgProxy.Work()
if err := conn.configureWGEndpoint(wgProxy.EndpointAddr(), rci.rosenpassPubKey); err != nil { if err := conn.configureWGEndpoint(wgProxy.EndpointAddr(), rci.rosenpassPubKey); err != nil {
if err := wgProxy.CloseConn(); err != nil { if err := wgProxy.CloseConn(); err != nil {
@@ -705,36 +685,6 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) {
return true return true
} }
func (conn *Conn) runBeforeAddPeerHooks(ip net.IP) error {
conn.connIDICE = nbnet.GenerateConnID()
for _, hook := range conn.beforeAddPeerHooks {
if err := hook(conn.connIDICE, ip); err != nil {
return err
}
}
return nil
}
func (conn *Conn) freeUpConnID() {
if conn.connIDRelay != "" {
for _, hook := range conn.afterRemovePeerHooks {
if err := hook(conn.connIDRelay); err != nil {
conn.Log.Errorf("After remove peer hook failed: %v", err)
}
}
conn.connIDRelay = ""
}
if conn.connIDICE != "" {
for _, hook := range conn.afterRemovePeerHooks {
if err := hook(conn.connIDICE); err != nil {
conn.Log.Errorf("After remove peer hook failed: %v", err)
}
}
conn.connIDICE = ""
}
}
func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) { func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
conn.Log.Debugf("setup proxied WireGuard connection") conn.Log.Debugf("setup proxied WireGuard connection")
udpAddr := &net.UDPAddr{ udpAddr := &net.UDPAddr{

View File

@@ -19,6 +19,7 @@ type RelayConnInfo struct {
} }
type WorkerRelay struct { type WorkerRelay struct {
peerCtx context.Context
log *log.Entry log *log.Entry
isController bool isController bool
config ConnConfig config ConnConfig
@@ -33,8 +34,9 @@ type WorkerRelay struct {
wgWatcher *WGWatcher wgWatcher *WGWatcher
} }
func NewWorkerRelay(log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager relayClient.ManagerService, stateDump *stateDump) *WorkerRelay { func NewWorkerRelay(ctx context.Context, log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager relayClient.ManagerService, stateDump *stateDump) *WorkerRelay {
r := &WorkerRelay{ r := &WorkerRelay{
peerCtx: ctx,
log: log, log: log,
isController: ctrl, isController: ctrl,
config: config, config: config,
@@ -62,7 +64,7 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
srv := w.preferredRelayServer(currentRelayAddress, remoteOfferAnswer.RelaySrvAddress) srv := w.preferredRelayServer(currentRelayAddress, remoteOfferAnswer.RelaySrvAddress)
relayedConn, err := w.relayManager.OpenConn(srv, w.config.Key) relayedConn, err := w.relayManager.OpenConn(w.peerCtx, srv, w.config.Key)
if err != nil { if err != nil {
if errors.Is(err, relayClient.ErrConnAlreadyExists) { if errors.Is(err, relayClient.ErrConnAlreadyExists) {
w.log.Debugf("handled offer by reusing existing relay connection") w.log.Debugf("handled offer by reusing existing relay connection")

View File

@@ -812,7 +812,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
} }
params := common.HandlerParams{ params := common.HandlerParams{
Route: &route.Route{Network: netip.MustParsePrefix("192.168.0.0/24")}, Route: &route.Route{Network: netip.MustParsePrefix("192.168.0.0/24")},
} }
// create new clientNetwork // create new clientNetwork
client := &Watcher{ client := &Watcher{

View File

@@ -44,7 +44,7 @@ import (
// Manager is a route manager interface // Manager is a route manager interface
type Manager interface { type Manager interface {
Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) Init() error
UpdateRoutes(updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error UpdateRoutes(updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error
ClassifyRoutes(newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap) ClassifyRoutes(newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap)
TriggerSelection(route.HAMap) TriggerSelection(route.HAMap)
@@ -201,11 +201,11 @@ func (m *DefaultManager) setupRefCounters(useNoop bool) {
} }
// Init sets up the routing // Init sets up the routing
func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { func (m *DefaultManager) Init() error {
m.routeSelector = m.initSelector() m.routeSelector = m.initSelector()
if nbnet.CustomRoutingDisabled() || m.disableClientRoutes { if nbnet.CustomRoutingDisabled() || m.disableClientRoutes {
return nil, nil, nil return nil
} }
if err := m.sysOps.CleanupRouting(nil); err != nil { if err := m.sysOps.CleanupRouting(nil); err != nil {
@@ -219,13 +219,12 @@ func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
ips := resolveURLsToIPs(initialAddresses) ips := resolveURLsToIPs(initialAddresses)
beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips, m.stateManager) if err := m.sysOps.SetupRouting(ips, m.stateManager); err != nil {
if err != nil { return fmt.Errorf("setup routing: %w", err)
return nil, nil, fmt.Errorf("setup routing: %w", err)
} }
log.Info("Routing setup complete") log.Info("Routing setup complete")
return beforePeerHook, afterPeerHook, nil return nil
} }
func (m *DefaultManager) initSelector() *routeselector.RouteSelector { func (m *DefaultManager) initSelector() *routeselector.RouteSelector {

View File

@@ -430,7 +430,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
StatusRecorder: statusRecorder, StatusRecorder: statusRecorder,
}) })
_, _, err = routeManager.Init() err = routeManager.Init()
require.NoError(t, err, "should init route manager") require.NoError(t, err, "should init route manager")
defer routeManager.Stop(nil) defer routeManager.Stop(nil)

View File

@@ -9,7 +9,6 @@ import (
"github.com/netbirdio/netbird/client/internal/routeselector" "github.com/netbirdio/netbird/client/internal/routeselector"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/util/net"
) )
// MockManager is the mock instance of a route manager // MockManager is the mock instance of a route manager
@@ -23,8 +22,8 @@ type MockManager struct {
StopFunc func(manager *statemanager.Manager) StopFunc func(manager *statemanager.Manager)
} }
func (m *MockManager) Init() (net.AddHookFunc, net.RemoveHookFunc, error) { func (m *MockManager) Init() error {
return nil, nil, nil return nil
} }
// InitialRouteRange mock implementation of InitialRouteRange from Manager interface // InitialRouteRange mock implementation of InitialRouteRange from Manager interface

View File

@@ -6,6 +6,7 @@ import (
"net/netip" "net/netip"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/routemanager/notifier" "github.com/netbirdio/netbird/client/internal/routemanager/notifier"
@@ -56,6 +57,10 @@ type SysOps struct {
// seq is an atomic counter for generating unique sequence numbers for route messages // seq is an atomic counter for generating unique sequence numbers for route messages
//nolint:unused // only used on BSD systems //nolint:unused // only used on BSD systems
seq atomic.Uint32 seq atomic.Uint32
localSubnetsCache []*net.IPNet
localSubnetsCacheMu sync.RWMutex
localSubnetsCacheTime time.Time
} }
func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps { func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps {

View File

@@ -10,11 +10,10 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net"
) )
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) error {
return nil, nil, nil return nil
} }
func (r *SysOps) CleanupRouting(*statemanager.Manager) error { func (r *SysOps) CleanupRouting(*statemanager.Manager) error {

View File

@@ -10,6 +10,7 @@ import (
"net/netip" "net/netip"
"runtime" "runtime"
"strconv" "strconv"
"time"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
"github.com/libp2p/go-netroute" "github.com/libp2p/go-netroute"
@@ -24,6 +25,8 @@ import (
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
const localSubnetsCacheTTL = 15 * time.Minute
var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1) var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1)
var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1) var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1)
var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1) var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1)
@@ -31,7 +34,7 @@ var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)
var ErrRoutingIsSeparate = errors.New("routing is separate") var ErrRoutingIsSeparate = errors.New("routing is separate")
func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemanager.Manager) error {
stateManager.RegisterState(&ShutdownState{}) stateManager.RegisterState(&ShutdownState{})
initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified()) initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified())
@@ -75,7 +78,10 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana
r.refCounter = refCounter r.refCounter = refCounter
return r.setupHooks(initAddresses, stateManager) if err := r.setupHooks(initAddresses, stateManager); err != nil {
return fmt.Errorf("setup hooks: %w", err)
}
return nil
} }
// updateState updates state on every change so it will be persisted regularly // updateState updates state on every change so it will be persisted regularly
@@ -128,18 +134,14 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf wgIface, init
return Nexthop{}, fmt.Errorf("get next hop: %w", err) return Nexthop{}, fmt.Errorf("get next hop: %w", err)
} }
log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop.IP, prefix, nexthop.IP) log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop.IP, prefix, nexthop.Intf)
exitNextHop := Nexthop{ exitNextHop := nexthop
IP: nexthop.IP,
Intf: nexthop.Intf,
}
vpnAddr := vpnIntf.Address().IP vpnAddr := vpnIntf.Address().IP
// if next hop is the VPN address or the interface is the VPN interface, we should use the initial values // if next hop is the VPN address or the interface is the VPN interface, we should use the initial values
if exitNextHop.IP == vpnAddr || exitNextHop.Intf != nil && exitNextHop.Intf.Name == vpnIntf.Name() { if exitNextHop.IP == vpnAddr || exitNextHop.Intf != nil && exitNextHop.Intf.Name == vpnIntf.Name() {
log.Debugf("Route for prefix %s is pointing to the VPN interface, using initial next hop %v", prefix, initialNextHop) log.Debugf("Route for prefix %s is pointing to the VPN interface, using initial next hop %v", prefix, initialNextHop)
exitNextHop = initialNextHop exitNextHop = initialNextHop
} }
@@ -152,12 +154,37 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf wgIface, init
} }
func (r *SysOps) isPrefixInLocalSubnets(prefix netip.Prefix) (bool, *net.IPNet) { func (r *SysOps) isPrefixInLocalSubnets(prefix netip.Prefix) (bool, *net.IPNet) {
r.localSubnetsCacheMu.RLock()
cacheAge := time.Since(r.localSubnetsCacheTime)
subnets := r.localSubnetsCache
r.localSubnetsCacheMu.RUnlock()
if cacheAge > localSubnetsCacheTTL || subnets == nil {
r.localSubnetsCacheMu.Lock()
if time.Since(r.localSubnetsCacheTime) > localSubnetsCacheTTL || r.localSubnetsCache == nil {
r.refreshLocalSubnetsCache()
}
subnets = r.localSubnetsCache
r.localSubnetsCacheMu.Unlock()
}
for _, subnet := range subnets {
if subnet.Contains(prefix.Addr().AsSlice()) {
return true, subnet
}
}
return false, nil
}
func (r *SysOps) refreshLocalSubnetsCache() {
localInterfaces, err := net.Interfaces() localInterfaces, err := net.Interfaces()
if err != nil { if err != nil {
log.Errorf("Failed to get local interfaces: %v", err) log.Errorf("Failed to get local interfaces: %v", err)
return false, nil return
} }
var newSubnets []*net.IPNet
for _, intf := range localInterfaces { for _, intf := range localInterfaces {
addrs, err := intf.Addrs() addrs, err := intf.Addrs()
if err != nil { if err != nil {
@@ -171,14 +198,12 @@ func (r *SysOps) isPrefixInLocalSubnets(prefix netip.Prefix) (bool, *net.IPNet)
log.Errorf("Failed to convert address to IPNet: %v", addr) log.Errorf("Failed to convert address to IPNet: %v", addr)
continue continue
} }
newSubnets = append(newSubnets, ipnet)
if ipnet.Contains(prefix.Addr().AsSlice()) {
return true, ipnet
}
} }
} }
return false, nil r.localSubnetsCache = newSubnets
r.localSubnetsCacheTime = time.Now()
} }
// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix // genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix
@@ -264,7 +289,7 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface)
return r.removeFromRouteTable(prefix, nextHop) return r.removeFromRouteTable(prefix, nextHop)
} }
func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) error {
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error { beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
prefix, err := util.GetPrefixFromIP(ip) prefix, err := util.GetPrefixFromIP(ip)
if err != nil { if err != nil {
@@ -289,9 +314,11 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M
return nil return nil
} }
var merr *multierror.Error
for _, ip := range initAddresses { for _, ip := range initAddresses {
if err := beforeHook("init", ip); err != nil { if err := beforeHook("init", ip); err != nil {
log.Errorf("Failed to add route reference: %v", err) merr = multierror.Append(merr, fmt.Errorf("add initial route for %s: %w", ip, err))
} }
} }
@@ -300,11 +327,11 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M
return ctx.Err() return ctx.Err()
} }
var result *multierror.Error var merr *multierror.Error
for _, ip := range resolvedIPs { for _, ip := range resolvedIPs {
result = multierror.Append(result, beforeHook(connID, ip.IP)) merr = multierror.Append(merr, beforeHook(connID, ip.IP))
} }
return nberrors.FormatErrorOrNil(result) return nberrors.FormatErrorOrNil(merr)
}) })
nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error { nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error {
@@ -319,7 +346,16 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M
return afterHook(connID) return afterHook(connID)
}) })
return beforeHook, afterHook, nil nbnet.AddListenerAddressRemoveHook(func(connID nbnet.ConnectionID, prefix netip.Prefix) error {
if _, err := r.refCounter.Decrement(prefix); err != nil {
return fmt.Errorf("remove route reference: %w", err)
}
r.updateState(stateManager)
return nil
})
return nberrors.FormatErrorOrNil(merr)
} }
func GetNextHop(ip netip.Addr) (Nexthop, error) { func GetNextHop(ip netip.Addr) (Nexthop, error) {

View File

@@ -143,7 +143,7 @@ func TestAddVPNRoute(t *testing.T) {
wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n) wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n)
r := NewSysOps(wgInterface, nil) r := NewSysOps(wgInterface, nil)
_, _, err := r.SetupRouting(nil, nil) err := r.SetupRouting(nil, nil)
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(func() { t.Cleanup(func() {
assert.NoError(t, r.CleanupRouting(nil)) assert.NoError(t, r.CleanupRouting(nil))
@@ -341,7 +341,7 @@ func TestAddRouteToNonVPNIntf(t *testing.T) {
wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n) wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n)
r := NewSysOps(wgInterface, nil) r := NewSysOps(wgInterface, nil)
_, _, err := r.SetupRouting(nil, nil) err := r.SetupRouting(nil, nil)
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(func() { t.Cleanup(func() {
assert.NoError(t, r.CleanupRouting(nil)) assert.NoError(t, r.CleanupRouting(nil))
@@ -484,7 +484,7 @@ func setupTestEnv(t *testing.T) {
}) })
r := NewSysOps(wgInterface, nil) r := NewSysOps(wgInterface, nil)
_, _, err := r.SetupRouting(nil, nil) err := r.SetupRouting(nil, nil)
require.NoError(t, err, "setupRouting should not return err") require.NoError(t, err, "setupRouting should not return err")
t.Cleanup(func() { t.Cleanup(func() {
assert.NoError(t, r.CleanupRouting(nil)) assert.NoError(t, r.CleanupRouting(nil))

View File

@@ -10,14 +10,13 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net"
) )
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) error {
r.mu.Lock() r.mu.Lock()
defer r.mu.Unlock() defer r.mu.Unlock()
r.prefixes = make(map[netip.Prefix]struct{}) r.prefixes = make(map[netip.Prefix]struct{})
return nil, nil, nil return nil
} }
func (r *SysOps) CleanupRouting(*statemanager.Manager) error { func (r *SysOps) CleanupRouting(*statemanager.Manager) error {

View File

@@ -72,7 +72,7 @@ func getSetupRules() []ruleParams {
// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table. // Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table.
// This table is where a default route or other specific routes received from the management server are configured, // This table is where a default route or other specific routes received from the management server are configured,
// enabling VPN connectivity. // enabling VPN connectivity.
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (_ nbnet.AddHookFunc, _ nbnet.RemoveHookFunc, err error) { func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (err error) {
if !nbnet.AdvancedRouting() { if !nbnet.AdvancedRouting() {
log.Infof("Using legacy routing setup") log.Infof("Using legacy routing setup")
return r.setupRefCounter(initAddresses, stateManager) return r.setupRefCounter(initAddresses, stateManager)
@@ -89,7 +89,7 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
rules := getSetupRules() rules := getSetupRules()
for _, rule := range rules { for _, rule := range rules {
if err := addRule(rule); err != nil { if err := addRule(rule); err != nil {
return nil, nil, fmt.Errorf("%s: %w", rule.description, err) return fmt.Errorf("%s: %w", rule.description, err)
} }
} }
@@ -104,7 +104,7 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
} }
originalSysctl = originalValues originalSysctl = originalValues
return nil, nil, nil return nil
} }
// CleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'. // CleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'.

View File

@@ -252,7 +252,7 @@ func TestSysOps_validateRoute_InvalidPrefix(t *testing.T) {
IP: wgNetwork.Addr(), IP: wgNetwork.Addr(),
Network: wgNetwork, Network: wgNetwork,
}, },
name: "wt0", name: "nb0",
} }
sysOps := &SysOps{ sysOps := &SysOps{

View File

@@ -18,10 +18,9 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net"
) )
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) error {
return r.setupRefCounter(initAddresses, stateManager) return r.setupRefCounter(initAddresses, stateManager)
} }

View File

@@ -19,7 +19,6 @@ import (
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net"
) )
const InfiniteLifetime = 0xffffffff const InfiniteLifetime = 0xffffffff
@@ -137,7 +136,7 @@ const (
RouteDeleted RouteDeleted
) )
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) error {
return r.setupRefCounter(initAddresses, stateManager) return r.setupRefCounter(initAddresses, stateManager)
} }

View File

@@ -1330,6 +1330,13 @@ func (x *PeerState) GetRelayAddress() string {
return "" return ""
} }
func (x *PeerState) GetConnectionType() string {
if x.Relayed {
return "Relayed"
}
return "P2P"
}
// LocalPeerState contains the latest state of the local peer // LocalPeerState contains the latest state of the local peer
type LocalPeerState struct { type LocalPeerState struct {
state protoimpl.MessageState `protogen:"open.v1"` state protoimpl.MessageState `protogen:"open.v1"`

View File

@@ -212,7 +212,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
} }
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager) secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil) mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{})
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }

View File

@@ -100,7 +100,7 @@ type OutputOverview struct {
LazyConnectionEnabled bool `json:"lazyConnectionEnabled" yaml:"lazyConnectionEnabled"` LazyConnectionEnabled bool `json:"lazyConnectionEnabled" yaml:"lazyConnectionEnabled"`
} }
func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}) OutputOverview { func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}, connectionTypeFilter string) OutputOverview {
pbFullStatus := resp.GetFullStatus() pbFullStatus := resp.GetFullStatus()
managementState := pbFullStatus.GetManagementState() managementState := pbFullStatus.GetManagementState()
@@ -118,7 +118,7 @@ func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, status
} }
relayOverview := mapRelays(pbFullStatus.GetRelays()) relayOverview := mapRelays(pbFullStatus.GetRelays())
peersOverview := mapPeers(resp.GetFullStatus().GetPeers(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter) peersOverview := mapPeers(resp.GetFullStatus().GetPeers(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter)
overview := OutputOverview{ overview := OutputOverview{
Peers: peersOverview, Peers: peersOverview,
@@ -193,6 +193,7 @@ func mapPeers(
prefixNamesFilter []string, prefixNamesFilter []string,
prefixNamesFilterMap map[string]struct{}, prefixNamesFilterMap map[string]struct{},
ipsFilter map[string]struct{}, ipsFilter map[string]struct{},
connectionTypeFilter string,
) PeersStateOutput { ) PeersStateOutput {
var peersStateDetail []PeerStateDetailOutput var peersStateDetail []PeerStateDetailOutput
peersConnected := 0 peersConnected := 0
@@ -208,7 +209,7 @@ func mapPeers(
transferSent := int64(0) transferSent := int64(0)
isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String() isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String()
if skipDetailByFilters(pbPeerState, pbPeerState.ConnStatus, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter) { if skipDetailByFilters(pbPeerState, pbPeerState.ConnStatus, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter) {
continue continue
} }
if isPeerConnected { if isPeerConnected {
@@ -218,10 +219,7 @@ func mapPeers(
remoteICE = pbPeerState.GetRemoteIceCandidateType() remoteICE = pbPeerState.GetRemoteIceCandidateType()
localICEEndpoint = pbPeerState.GetLocalIceCandidateEndpoint() localICEEndpoint = pbPeerState.GetLocalIceCandidateEndpoint()
remoteICEEndpoint = pbPeerState.GetRemoteIceCandidateEndpoint() remoteICEEndpoint = pbPeerState.GetRemoteIceCandidateEndpoint()
connType = "P2P" connType = pbPeerState.GetConnectionType()
if pbPeerState.Relayed {
connType = "Relayed"
}
relayServerAddress = pbPeerState.GetRelayAddress() relayServerAddress = pbPeerState.GetRelayAddress()
lastHandshake = pbPeerState.GetLastWireguardHandshake().AsTime().Local() lastHandshake = pbPeerState.GetLastWireguardHandshake().AsTime().Local()
transferReceived = pbPeerState.GetBytesRx() transferReceived = pbPeerState.GetBytesRx()
@@ -542,10 +540,11 @@ func parsePeers(peers PeersStateOutput, rosenpassEnabled, rosenpassPermissive bo
return peersString return peersString
} }
func skipDetailByFilters(peerState *proto.PeerState, peerStatus string, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}) bool { func skipDetailByFilters(peerState *proto.PeerState, peerStatus string, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}, connectionTypeFilter string) bool {
statusEval := false statusEval := false
ipEval := false ipEval := false
nameEval := true nameEval := true
connectionTypeEval := false
if statusFilter != "" { if statusFilter != "" {
if !strings.EqualFold(peerStatus, statusFilter) { if !strings.EqualFold(peerStatus, statusFilter) {
@@ -570,8 +569,11 @@ func skipDetailByFilters(peerState *proto.PeerState, peerStatus string, statusFi
} else { } else {
nameEval = false nameEval = false
} }
if connectionTypeFilter != "" && !strings.EqualFold(peerState.GetConnectionType(), connectionTypeFilter) {
connectionTypeEval = true
}
return statusEval || ipEval || nameEval return statusEval || ipEval || nameEval || connectionTypeEval
} }
func toIEC(b int64) string { func toIEC(b int64) string {

View File

@@ -234,7 +234,7 @@ var overview = OutputOverview{
} }
func TestConversionFromFullStatusToOutputOverview(t *testing.T) { func TestConversionFromFullStatusToOutputOverview(t *testing.T) {
convertedResult := ConvertToStatusOutputOverview(resp, false, "", nil, nil, nil) convertedResult := ConvertToStatusOutputOverview(resp, false, "", nil, nil, nil, "")
assert.Equal(t, overview, convertedResult) assert.Equal(t, overview, convertedResult)
} }

View File

@@ -433,7 +433,7 @@ func (s *serviceClient) collectDebugData(
var postUpStatusOutput string var postUpStatusOutput string
if postUpStatus != nil { if postUpStatus != nil {
overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil) overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "")
postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview) postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
} }
headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339)) headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
@@ -450,7 +450,7 @@ func (s *serviceClient) collectDebugData(
var preDownStatusOutput string var preDownStatusOutput string
if preDownStatus != nil { if preDownStatus != nil {
overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil) overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "")
preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview) preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
} }
headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s", headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s",
@@ -581,7 +581,7 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa
var statusOutput string var statusOutput string
if statusResp != nil { if statusResp != nil {
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil) overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "")
statusOutput = nbstatus.ParseToFullDetailSummary(overview) statusOutput = nbstatus.ParseToFullDetailSummary(overview)
} }

2
go.mod
View File

@@ -63,7 +63,7 @@ require (
github.com/miekg/dns v1.1.59 github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0 github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/integrations v0.0.0-20250612164546-6bd7e2338d65 github.com/netbirdio/management-integrations/integrations v0.0.0-20250718071730-f4d133556ff5
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250514131221-a464fd5f30cb github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250514131221-a464fd5f30cb
github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0 github.com/oschwald/maxminddb-golang v1.12.0

4
go.sum
View File

@@ -503,8 +503,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250612164546-6bd7e2338d65 h1:5OfYiLjpr4dbQYJI5ouZaylkVdi2KlErLFOwBeBo5Hw= github.com/netbirdio/management-integrations/integrations v0.0.0-20250718071730-f4d133556ff5 h1:Zfn8d83OVyELCdxgprcyXR3D8uqoxHtXE9PUxVXDx/w=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250612164546-6bd7e2338d65/go.mod h1:Gi9raplYzCCyh07Olw/DVfCJTFgpr1WCXJ/Q+8TSA9Q= github.com/netbirdio/management-integrations/integrations v0.0.0-20250718071730-f4d133556ff5/go.mod h1:Gi9raplYzCCyh07Olw/DVfCJTFgpr1WCXJ/Q+8TSA9Q=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250514131221-a464fd5f30cb h1:Cr6age+ePALqlSvtp7wc6lYY97XN7rkD1K4XEDmY+TU= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250514131221-a464fd5f30cb h1:Cr6age+ePALqlSvtp7wc6lYY97XN7rkD1K4XEDmY+TU=

View File

@@ -112,7 +112,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
} }
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager) secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil) mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, mgmt.MockIntegratedValidator{})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -292,7 +292,7 @@ var (
ephemeralManager.LoadInitialPeers(ctx) ephemeralManager.LoadInitialPeers(ctx)
gRPCAPIHandler := grpc.NewServer(gRPCOpts...) gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
srv, err := server.NewServer(ctx, config, accountManager, settingsManager, peersUpdateManager, secretsManager, appMetrics, ephemeralManager, authManager) srv, err := server.NewServer(ctx, config, accountManager, settingsManager, peersUpdateManager, secretsManager, appMetrics, ephemeralManager, authManager, integratedPeerValidator)
if err != nil { if err != nil {
return fmt.Errorf("failed creating gRPC API handler: %v", err) return fmt.Errorf("failed creating gRPC API handler: %v", err)
} }

View File

@@ -2887,7 +2887,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, error) {
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -219,7 +219,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
// return empty extra settings for expected calls to UpdateAccountPeers // return empty extra settings for expected calls to UpdateAccountPeers
settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), gomock.Any()).Return(&types.ExtraSettings{}, nil).AnyTimes() settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), gomock.Any()).Return(&types.ExtraSettings{}, nil).AnyTimes()
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
} }
func createDNSStore(t *testing.T) (store.Store, error) { func createDNSStore(t *testing.T) (store.Store, error) {

View File

@@ -19,6 +19,7 @@ import (
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config" integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
"github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/proto"
@@ -40,13 +41,14 @@ type GRPCServer struct {
settingsManager settings.Manager settingsManager settings.Manager
wgKey wgtypes.Key wgKey wgtypes.Key
proto.UnimplementedManagementServiceServer proto.UnimplementedManagementServiceServer
peersUpdateManager *PeersUpdateManager peersUpdateManager *PeersUpdateManager
config *types.Config config *types.Config
secretsManager SecretsManager secretsManager SecretsManager
appMetrics telemetry.AppMetrics appMetrics telemetry.AppMetrics
ephemeralManager *EphemeralManager ephemeralManager *EphemeralManager
peerLocks sync.Map peerLocks sync.Map
authManager auth.Manager authManager auth.Manager
integratedPeerValidator integrated_validator.IntegratedValidator
} }
// NewServer creates a new Management server // NewServer creates a new Management server
@@ -60,6 +62,7 @@ func NewServer(
appMetrics telemetry.AppMetrics, appMetrics telemetry.AppMetrics,
ephemeralManager *EphemeralManager, ephemeralManager *EphemeralManager,
authManager auth.Manager, authManager auth.Manager,
integratedPeerValidator integrated_validator.IntegratedValidator,
) (*GRPCServer, error) { ) (*GRPCServer, error) {
key, err := wgtypes.GeneratePrivateKey() key, err := wgtypes.GeneratePrivateKey()
if err != nil { if err != nil {
@@ -79,14 +82,15 @@ func NewServer(
return &GRPCServer{ return &GRPCServer{
wgKey: key, wgKey: key,
// peerKey -> event channel // peerKey -> event channel
peersUpdateManager: peersUpdateManager, peersUpdateManager: peersUpdateManager,
accountManager: accountManager, accountManager: accountManager,
settingsManager: settingsManager, settingsManager: settingsManager,
config: config, config: config,
secretsManager: secretsManager, secretsManager: secretsManager,
authManager: authManager, authManager: authManager,
appMetrics: appMetrics, appMetrics: appMetrics,
ephemeralManager: ephemeralManager, ephemeralManager: ephemeralManager,
integratedPeerValidator: integratedPeerValidator,
}, nil }, nil
} }
@@ -850,7 +854,7 @@ func (s *GRPCServer) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.En
return nil, status.Error(codes.NotFound, "no pkce authorization flow information available") return nil, status.Error(codes.NotFound, "no pkce authorization flow information available")
} }
flowInfoResp := &proto.PKCEAuthorizationFlow{ initInfoFlow := &proto.PKCEAuthorizationFlow{
ProviderConfig: &proto.ProviderConfig{ ProviderConfig: &proto.ProviderConfig{
Audience: s.config.PKCEAuthorizationFlow.ProviderConfig.Audience, Audience: s.config.PKCEAuthorizationFlow.ProviderConfig.Audience,
ClientID: s.config.PKCEAuthorizationFlow.ProviderConfig.ClientID, ClientID: s.config.PKCEAuthorizationFlow.ProviderConfig.ClientID,
@@ -865,6 +869,8 @@ func (s *GRPCServer) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.En
}, },
} }
flowInfoResp := s.integratedPeerValidator.ValidateFlowResponse(ctx, peerKey.String(), initInfoFlow)
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, flowInfoResp) encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, flowInfoResp)
if err != nil { if err != nil {
return nil, status.Error(codes.Internal, "failed to encrypt no pkce authorization flow information") return nil, status.Error(codes.Internal, "failed to encrypt no pkce authorization flow information")

View File

@@ -424,9 +424,10 @@ func toPolicyResponse(groups []*types.Group, policy *types.Policy) *api.Policy {
} }
if group, ok := groupsMap[gid]; ok { if group, ok := groupsMap[gid]; ok {
minimum := api.GroupMinimum{ minimum := api.GroupMinimum{
Id: group.ID, Id: group.ID,
Name: group.Name, Name: group.Name,
PeersCount: len(group.Peers), PeersCount: len(group.Peers),
ResourcesCount: len(group.Resources),
} }
destinations = append(destinations, minimum) destinations = append(destinations, minimum)
cache[gid] = minimum cache[gid] = minimum

View File

@@ -1,4 +1,5 @@
package testing_tools package testing_tools
import ( import (
"bytes" "bytes"
"context" "context"
@@ -132,7 +133,7 @@ func BuildApiBlackBoxWithDBState(t TB, sqlFile string, expectedPeerUpdate *serve
} }
geoMock := &geolocation.Mock{} geoMock := &geolocation.Mock{}
validatorMock := server.MocIntegratedValidator{} validatorMock := server.MockIntegratedValidator{}
proxyController := integrations.NewController(store) proxyController := integrations.NewController(store)
userManager := users.NewManager(store) userManager := users.NewManager(store)
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)

View File

@@ -6,6 +6,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
@@ -101,22 +102,23 @@ func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountI
return am.integratedPeerValidator.GetValidatedPeers(accountID, groups, peers, settings.Extra) return am.integratedPeerValidator.GetValidatedPeers(accountID, groups, peers, settings.Extra)
} }
type MocIntegratedValidator struct { type MockIntegratedValidator struct {
integrated_validator.IntegratedValidator
ValidatePeerFunc func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error) ValidatePeerFunc func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error)
} }
func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error { func (a MockIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error {
return nil return nil
} }
func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error) { func (a MockIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error) {
if a.ValidatePeerFunc != nil { if a.ValidatePeerFunc != nil {
return a.ValidatePeerFunc(context.Background(), update, peer, userID, accountID, dnsDomain, peersGroup, extraSettings) return a.ValidatePeerFunc(context.Background(), update, peer, userID, accountID, dnsDomain, peersGroup, extraSettings)
} }
return update, false, nil return update, false, nil
} }
func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *types.ExtraSettings) (map[string]struct{}, error) { func (a MockIntegratedValidator) GetValidatedPeers(accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *types.ExtraSettings) (map[string]struct{}, error) {
validatedPeers := make(map[string]struct{}) validatedPeers := make(map[string]struct{})
for _, peer := range peers { for _, peer := range peers {
validatedPeers[peer.ID] = struct{}{} validatedPeers[peer.ID] = struct{}{}
@@ -124,22 +126,22 @@ func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups []*ty
return validatedPeers, nil return validatedPeers, nil
} }
func (MocIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) *nbpeer.Peer { func (MockIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) *nbpeer.Peer {
return peer return peer
} }
func (MocIntegratedValidator) IsNotValidPeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error) { func (MockIntegratedValidator) IsNotValidPeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error) {
return false, false, nil return false, false, nil
} }
func (MocIntegratedValidator) PeerDeleted(_ context.Context, _, _ string) error { func (MockIntegratedValidator) PeerDeleted(_ context.Context, _, _ string) error {
return nil return nil
} }
func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string)) { func (MockIntegratedValidator) SetPeerInvalidationListener(func(accountID string)) {
// just a dummy // just a dummy
} }
func (MocIntegratedValidator) Stop(_ context.Context) { func (MockIntegratedValidator) Stop(_ context.Context) {
// just a dummy // just a dummy
} }

View File

@@ -3,6 +3,7 @@ package integrated_validator
import ( import (
"context" "context"
"github.com/netbirdio/netbird/management/proto"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
) )
@@ -17,4 +18,5 @@ type IntegratedValidator interface {
PeerDeleted(ctx context.Context, accountID, peerID string) error PeerDeleted(ctx context.Context, accountID, peerID string) error
SetPeerInvalidationListener(fn func(accountID string)) SetPeerInvalidationListener(fn func(accountID string))
Stop(ctx context.Context) Stop(ctx context.Context)
ValidateFlowResponse(ctx context.Context, peerKey string, flowResponse *proto.PKCEAuthorizationFlow) *proto.PKCEAuthorizationFlow
} }

View File

@@ -448,7 +448,7 @@ func startManagementForTest(t *testing.T, testFile string, config *types.Config)
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)
accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted", accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted",
eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil { if err != nil {
cleanup() cleanup()
@@ -458,7 +458,7 @@ func startManagementForTest(t *testing.T, testFile string, config *types.Config)
secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager) secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
ephemeralMgr := NewEphemeralManager(store, accountManager) ephemeralMgr := NewEphemeralManager(store, accountManager)
mgmtServer, err := NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, ephemeralMgr, nil) mgmtServer, err := NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, ephemeralMgr, nil, MockIntegratedValidator{})
if err != nil { if err != nil {
return nil, nil, "", cleanup, err return nil, nil, "", cleanup, err
} }

View File

@@ -206,7 +206,7 @@ func startServer(
eventStore, eventStore,
nil, nil,
false, false,
server.MocIntegratedValidator{}, server.MockIntegratedValidator{},
metrics, metrics,
port_forwarding.NewControllerMock(), port_forwarding.NewControllerMock(),
settingsMockManager, settingsMockManager,
@@ -227,6 +227,7 @@ func startServer(
nil, nil,
nil, nil,
nil, nil,
server.MockIntegratedValidator{},
) )
if err != nil { if err != nil {
t.Fatalf("failed creating management server: %v", err) t.Fatalf("failed creating management server: %v", err)

View File

@@ -283,7 +283,7 @@ func MigrateSetupKeyToHashedSetupKey[T any](ctx context.Context, db *gorm.DB) er
} }
} }
if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s", "peers", "setup_key")).Error; err != nil { if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s DROP COLUMN IF EXISTS %s", "peers", "setup_key")).Error; err != nil {
log.WithContext(ctx).Errorf("Failed to drop column %s: %v", "setup_key", err) log.WithContext(ctx).Errorf("Failed to drop column %s: %v", "setup_key", err)
} }
@@ -377,6 +377,11 @@ func DropIndex[T any](ctx context.Context, db *gorm.DB, indexName string) error
func CreateIndexIfNotExists[T any](ctx context.Context, db *gorm.DB, indexName string, columns ...string) error { func CreateIndexIfNotExists[T any](ctx context.Context, db *gorm.DB, indexName string, columns ...string) error {
var model T var model T
if !db.Migrator().HasTable(&model) {
log.WithContext(ctx).Debugf("table for %T does not exist, no migration needed", model)
return nil
}
stmt := &gorm.Statement{DB: db} stmt := &gorm.Statement{DB: db}
if err := stmt.Parse(&model); err != nil { if err := stmt.Parse(&model); err != nil {
return fmt.Errorf("failed to parse model schema: %w", err) return fmt.Errorf("failed to parse model schema: %w", err)
@@ -384,6 +389,11 @@ func CreateIndexIfNotExists[T any](ctx context.Context, db *gorm.DB, indexName s
tableName := stmt.Schema.Table tableName := stmt.Schema.Table
dialect := db.Dialector.Name() dialect := db.Dialector.Name()
if db.Migrator().HasIndex(&model, indexName) {
log.WithContext(ctx).Infof("index %s already exists on table %s", indexName, tableName)
return nil
}
var columnClause string var columnClause string
if dialect == "mysql" { if dialect == "mysql" {
var withLength []string var withLength []string

View File

@@ -4,16 +4,21 @@ import (
"context" "context"
"encoding/gob" "encoding/gob"
"net" "net"
"os"
"strings" "strings"
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite" "gorm.io/driver/sqlite"
"gorm.io/gorm" "gorm.io/gorm"
"github.com/netbirdio/netbird/management/server/migration" "github.com/netbirdio/netbird/management/server/migration"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/testutil"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@@ -21,7 +26,41 @@ import (
func setupDatabase(t *testing.T) *gorm.DB { func setupDatabase(t *testing.T) *gorm.DB {
t.Helper() t.Helper()
db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{}) var db *gorm.DB
var err error
var dsn string
var cleanup func()
switch os.Getenv("NETBIRD_STORE_ENGINE") {
case "mysql":
cleanup, dsn, err = testutil.CreateMysqlTestContainer()
if err != nil {
t.Fatalf("Failed to create MySQL test container: %v", err)
}
if dsn == "" {
t.Fatal("MySQL connection string is empty, ensure the test container is running")
}
db, err = gorm.Open(mysql.Open(dsn+"?charset=utf8&parseTime=True&loc=Local"), &gorm.Config{})
case "postgres":
cleanup, dsn, err = testutil.CreatePostgresTestContainer()
if err != nil {
t.Fatalf("Failed to create PostgreSQL test container: %v", err)
}
if dsn == "" {
t.Fatalf("PostgreSQL connection string is empty, ensure the test container is running")
}
db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
case "sqlite":
db, err = gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{})
default:
db, err = gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{})
}
if cleanup != nil {
t.Cleanup(cleanup)
}
require.NoError(t, err, "Failed to open database") require.NoError(t, err, "Failed to open database")
return db return db
@@ -34,6 +73,7 @@ func TestMigrateFieldFromGobToJSON_EmptyDB(t *testing.T) {
} }
func TestMigrateFieldFromGobToJSON_WithGobData(t *testing.T) { func TestMigrateFieldFromGobToJSON_WithGobData(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", "sqlite")
db := setupDatabase(t) db := setupDatabase(t)
err := db.AutoMigrate(&types.Account{}, &route.Route{}) err := db.AutoMigrate(&types.Account{}, &route.Route{})
@@ -97,6 +137,7 @@ func TestMigrateNetIPFieldFromBlobToJSON_EmptyDB(t *testing.T) {
} }
func TestMigrateNetIPFieldFromBlobToJSON_WithBlobData(t *testing.T) { func TestMigrateNetIPFieldFromBlobToJSON_WithBlobData(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", "sqlite")
db := setupDatabase(t) db := setupDatabase(t)
err := db.AutoMigrate(&types.Account{}, &nbpeer.Peer{}) err := db.AutoMigrate(&types.Account{}, &nbpeer.Peer{})
@@ -117,12 +158,18 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithBlobData(t *testing.T) {
Peers []peer `gorm:"foreignKey:AccountID;references:id"` Peers []peer `gorm:"foreignKey:AccountID;references:id"`
} }
err = db.Save(&account{ a := &account{
Account: types.Account{Id: "123"}, Account: types.Account{Id: "123"},
Peers: []peer{ }
{Location: location{ConnectionIP: net.IP{10, 0, 0, 1}}},
}}, err = db.Save(a).Error
).Error require.NoError(t, err, "Failed to insert account")
a.Peers = []peer{
{Location: location{ConnectionIP: net.IP{10, 0, 0, 1}}},
}
err = db.Save(a).Error
require.NoError(t, err, "Failed to insert blob data") require.NoError(t, err, "Failed to insert blob data")
var blobValue string var blobValue string
@@ -143,12 +190,18 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithJSONData(t *testing.T) {
err := db.AutoMigrate(&types.Account{}, &nbpeer.Peer{}) err := db.AutoMigrate(&types.Account{}, &nbpeer.Peer{})
require.NoError(t, err, "Failed to auto-migrate tables") require.NoError(t, err, "Failed to auto-migrate tables")
err = db.Save(&types.Account{ account := &types.Account{
Id: "1234", Id: "1234",
PeersG: []nbpeer.Peer{ }
{Location: nbpeer.Location{ConnectionIP: net.IP{10, 0, 0, 1}}},
}}, err = db.Save(account).Error
).Error require.NoError(t, err, "Failed to insert account")
account.PeersG = []nbpeer.Peer{
{AccountID: "1234", Location: nbpeer.Location{ConnectionIP: net.IP{10, 0, 0, 1}}},
}
err = db.Save(account).Error
require.NoError(t, err, "Failed to insert JSON data") require.NoError(t, err, "Failed to insert JSON data")
err = migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](context.Background(), db, "location_connection_ip", "") err = migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](context.Background(), db, "location_connection_ip", "")
@@ -162,12 +215,13 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithJSONData(t *testing.T) {
func TestMigrateSetupKeyToHashedSetupKey_ForPlainKey(t *testing.T) { func TestMigrateSetupKeyToHashedSetupKey_ForPlainKey(t *testing.T) {
db := setupDatabase(t) db := setupDatabase(t)
err := db.AutoMigrate(&types.SetupKey{}) err := db.AutoMigrate(&types.SetupKey{}, &nbpeer.Peer{})
require.NoError(t, err, "Failed to auto-migrate tables") require.NoError(t, err, "Failed to auto-migrate tables")
err = db.Save(&types.SetupKey{ err = db.Save(&types.SetupKey{
Id: "1", Id: "1",
Key: "EEFDAB47-C1A5-4472-8C05-71DE9A1E8382", Key: "EEFDAB47-C1A5-4472-8C05-71DE9A1E8382",
UpdatedAt: time.Now(),
}).Error }).Error
require.NoError(t, err, "Failed to insert setup key") require.NoError(t, err, "Failed to insert setup key")
@@ -192,6 +246,7 @@ func TestMigrateSetupKeyToHashedSetupKey_ForAlreadyMigratedKey_Case1(t *testing.
Id: "1", Id: "1",
Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=",
KeySecret: "EEFDA****", KeySecret: "EEFDA****",
UpdatedAt: time.Now(),
}).Error }).Error
require.NoError(t, err, "Failed to insert setup key") require.NoError(t, err, "Failed to insert setup key")
@@ -213,8 +268,9 @@ func TestMigrateSetupKeyToHashedSetupKey_ForAlreadyMigratedKey_Case2(t *testing.
require.NoError(t, err, "Failed to auto-migrate tables") require.NoError(t, err, "Failed to auto-migrate tables")
err = db.Save(&types.SetupKey{ err = db.Save(&types.SetupKey{
Id: "1", Id: "1",
Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=",
UpdatedAt: time.Now(),
}).Error }).Error
require.NoError(t, err, "Failed to insert setup key") require.NoError(t, err, "Failed to insert setup key")
@@ -235,8 +291,9 @@ func TestDropIndex(t *testing.T) {
require.NoError(t, err, "Failed to auto-migrate tables") require.NoError(t, err, "Failed to auto-migrate tables")
err = db.Save(&types.SetupKey{ err = db.Save(&types.SetupKey{
Id: "1", Id: "1",
Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=",
UpdatedAt: time.Now(),
}).Error }).Error
require.NoError(t, err, "Failed to insert setup key") require.NoError(t, err, "Failed to insert setup key")
@@ -249,3 +306,37 @@ func TestDropIndex(t *testing.T) {
exist = db.Migrator().HasIndex(&types.SetupKey{}, "idx_setup_keys_account_id") exist = db.Migrator().HasIndex(&types.SetupKey{}, "idx_setup_keys_account_id")
assert.False(t, exist, "Should not have the index") assert.False(t, exist, "Should not have the index")
} }
func TestCreateIndex(t *testing.T) {
db := setupDatabase(t)
err := db.AutoMigrate(&nbpeer.Peer{})
assert.NoError(t, err, "Failed to auto-migrate tables")
indexName := "idx_account_ip"
err = migration.CreateIndexIfNotExists[nbpeer.Peer](context.Background(), db, indexName, "account_id", "ip")
assert.NoError(t, err, "Migration should not fail to create index")
exist := db.Migrator().HasIndex(&nbpeer.Peer{}, indexName)
assert.True(t, exist, "Should have the index")
}
func TestCreateIndexIfExists(t *testing.T) {
db := setupDatabase(t)
err := db.AutoMigrate(&nbpeer.Peer{})
assert.NoError(t, err, "Failed to auto-migrate tables")
indexName := "idx_account_ip"
err = migration.CreateIndexIfNotExists[nbpeer.Peer](context.Background(), db, indexName, "account_id", "ip")
assert.NoError(t, err, "Migration should not fail to create index")
exist := db.Migrator().HasIndex(&nbpeer.Peer{}, indexName)
assert.True(t, exist, "Should have the index")
err = migration.CreateIndexIfNotExists[nbpeer.Peer](context.Background(), db, indexName, "account_id", "ip")
assert.NoError(t, err, "Create index should not fail if index exists")
exist = db.Migrator().HasIndex(&nbpeer.Peer{}, indexName)
assert.True(t, exist, "Should have the index")
}

View File

@@ -785,7 +785,7 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
AnyTimes() AnyTimes()
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
} }
func createNSStore(t *testing.T) (store.Store, error) { func createNSStore(t *testing.T) (store.Store, error) {

View File

@@ -1273,7 +1273,7 @@ func Test_RegisterPeerByUser(t *testing.T) {
settingsMockManager := settings.NewMockManager(ctrl) settingsMockManager := settings.NewMockManager(ctrl)
permissionsManager := permissions.NewManager(s) permissionsManager := permissions.NewManager(s)
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err) assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
@@ -1353,7 +1353,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
AnyTimes() AnyTimes()
permissionsManager := permissions.NewManager(s) permissionsManager := permissions.NewManager(s)
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err) assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
@@ -1496,7 +1496,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
permissionsManager := permissions.NewManager(s) permissionsManager := permissions.NewManager(s)
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err) assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
@@ -1570,7 +1570,7 @@ func Test_LoginPeer(t *testing.T) {
AnyTimes() AnyTimes()
permissionsManager := permissions.NewManager(s) permissionsManager := permissions.NewManager(s)
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err) assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
@@ -1848,7 +1848,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
return update, true, nil return update, true, nil
} }
manager.integratedPeerValidator = MocIntegratedValidator{ValidatePeerFunc: requireUpdateFunc} manager.integratedPeerValidator = MockIntegratedValidator{ValidatePeerFunc: requireUpdateFunc}
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
peerShouldReceiveUpdate(t, updMsg) peerShouldReceiveUpdate(t, updMsg)
@@ -1870,7 +1870,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
return update, false, nil return update, false, nil
} }
manager.integratedPeerValidator = MocIntegratedValidator{ValidatePeerFunc: requireNoUpdateFunc} manager.integratedPeerValidator = MockIntegratedValidator{ValidatePeerFunc: requireNoUpdateFunc}
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
peerShouldNotReceiveUpdate(t, updMsg) peerShouldNotReceiveUpdate(t, updMsg)

View File

@@ -1284,7 +1284,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) {
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
} }
func createRouterStore(t *testing.T) (store.Store, error) { func createRouterStore(t *testing.T) (store.Store, error) {

View File

@@ -852,7 +852,7 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
eventStore: &activity.InMemoryEventStore{}, eventStore: &activity.InMemoryEventStore{},
integratedPeerValidator: MocIntegratedValidator{}, integratedPeerValidator: MockIntegratedValidator{},
permissionsManager: permissionsManager, permissionsManager: permissionsManager,
} }

View File

@@ -7,13 +7,6 @@ import (
authv2 "github.com/netbirdio/netbird/relay/auth/hmac/v2" authv2 "github.com/netbirdio/netbird/relay/auth/hmac/v2"
) )
// Validator is an interface that defines the Validate method.
type Validator interface {
Validate(any) error
// Deprecated: Use Validate instead.
ValidateHelloMsgType(any) error
}
type TimedHMACValidator struct { type TimedHMACValidator struct {
authenticatorV2 *authv2.Validator authenticatorV2 *authv2.Validator
authenticator *auth.TimedHMACValidator authenticator *auth.TimedHMACValidator

View File

@@ -124,15 +124,14 @@ func (cc *connContainer) close() {
// While the Connect is in progress, the OpenConn function will block until the connection is established with relay server. // While the Connect is in progress, the OpenConn function will block until the connection is established with relay server.
type Client struct { type Client struct {
log *log.Entry log *log.Entry
parentCtx context.Context
connectionURL string connectionURL string
authTokenStore *auth.TokenStore authTokenStore *auth.TokenStore
hashedID []byte hashedID messages.PeerID
bufPool *sync.Pool bufPool *sync.Pool
relayConn net.Conn relayConn net.Conn
conns map[string]*connContainer conns map[messages.PeerID]*connContainer
serviceIsRunning bool serviceIsRunning bool
mu sync.Mutex // protect serviceIsRunning and conns mu sync.Mutex // protect serviceIsRunning and conns
readLoopMutex sync.Mutex readLoopMutex sync.Mutex
@@ -142,14 +141,17 @@ type Client struct {
onDisconnectListener func(string) onDisconnectListener func(string)
listenerMutex sync.Mutex listenerMutex sync.Mutex
stateSubscription *PeersStateSubscription
} }
// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect // NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client { func NewClient(serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client {
hashedID, hashedStringId := messages.HashID(peerID) hashedID := messages.HashID(peerID)
relayLog := log.WithFields(log.Fields{"relay": serverURL})
c := &Client{ c := &Client{
log: log.WithFields(log.Fields{"relay": serverURL}), log: relayLog,
parentCtx: ctx,
connectionURL: serverURL, connectionURL: serverURL,
authTokenStore: authTokenStore, authTokenStore: authTokenStore,
hashedID: hashedID, hashedID: hashedID,
@@ -159,14 +161,15 @@ func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.Token
return &buf return &buf
}, },
}, },
conns: make(map[string]*connContainer), conns: make(map[messages.PeerID]*connContainer),
} }
c.log.Infof("create new relay connection: local peerID: %s, local peer hashedID: %s", peerID, hashedStringId)
c.log.Infof("create new relay connection: local peerID: %s, local peer hashedID: %s", peerID, hashedID)
return c return c
} }
// Connect establishes a connection to the relay server. It blocks until the connection is established or an error occurs. // Connect establishes a connection to the relay server. It blocks until the connection is established or an error occurs.
func (c *Client) Connect() error { func (c *Client) Connect(ctx context.Context) error {
c.log.Infof("connecting to relay server") c.log.Infof("connecting to relay server")
c.readLoopMutex.Lock() c.readLoopMutex.Lock()
defer c.readLoopMutex.Unlock() defer c.readLoopMutex.Unlock()
@@ -178,17 +181,27 @@ func (c *Client) Connect() error {
return nil return nil
} }
if err := c.connect(); err != nil { instanceURL, err := c.connect(ctx)
if err != nil {
return err return err
} }
c.muInstanceURL.Lock()
c.instanceURL = instanceURL
c.muInstanceURL.Unlock()
c.log = c.log.WithField("relay", c.instanceURL.String()) c.stateSubscription = NewPeersStateSubscription(c.log, c.relayConn, c.closeConnsByPeerID)
c.log = c.log.WithField("relay", instanceURL.String())
c.log.Infof("relay connection established") c.log.Infof("relay connection established")
c.serviceIsRunning = true c.serviceIsRunning = true
internallyStoppedFlag := newInternalStopFlag()
hc := healthcheck.NewReceiver(c.log)
go c.listenForStopEvents(ctx, hc, c.relayConn, internallyStoppedFlag)
c.wgReadLoop.Add(1) c.wgReadLoop.Add(1)
go c.readLoop(c.relayConn) go c.readLoop(hc, c.relayConn, internallyStoppedFlag)
return nil return nil
} }
@@ -196,26 +209,50 @@ func (c *Client) Connect() error {
// OpenConn create a new net.Conn for the destination peer ID. In case if the connection is in progress // OpenConn create a new net.Conn for the destination peer ID. In case if the connection is in progress
// to the relay server, the function will block until the connection is established or timed out. Otherwise, // to the relay server, the function will block until the connection is established or timed out. Otherwise,
// it will return immediately. // it will return immediately.
// It block until the server confirm the peer is online.
// todo: what should happen if call with the same peerID with multiple times? // todo: what should happen if call with the same peerID with multiple times?
func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) { func (c *Client) OpenConn(ctx context.Context, dstPeerID string) (net.Conn, error) {
c.mu.Lock() peerID := messages.HashID(dstPeerID)
defer c.mu.Unlock()
c.mu.Lock()
if !c.serviceIsRunning { if !c.serviceIsRunning {
c.mu.Unlock()
return nil, fmt.Errorf("relay connection is not established")
}
_, ok := c.conns[peerID]
if ok {
c.mu.Unlock()
return nil, ErrConnAlreadyExists
}
c.mu.Unlock()
if err := c.stateSubscription.WaitToBeOnlineAndSubscribe(ctx, peerID); err != nil {
c.log.Errorf("peer not available: %s, %s", peerID, err)
return nil, err
}
c.log.Infof("remote peer is available, prepare the relayed connection: %s", peerID)
msgChannel := make(chan Msg, 100)
c.mu.Lock()
if !c.serviceIsRunning {
c.mu.Unlock()
return nil, fmt.Errorf("relay connection is not established") return nil, fmt.Errorf("relay connection is not established")
} }
hashedID, hashedStringID := messages.HashID(dstPeerID) c.muInstanceURL.Lock()
_, ok := c.conns[hashedStringID] instanceURL := c.instanceURL
c.muInstanceURL.Unlock()
conn := NewConn(c, peerID, msgChannel, instanceURL)
_, ok = c.conns[peerID]
if ok { if ok {
c.mu.Unlock()
_ = conn.Close()
return nil, ErrConnAlreadyExists return nil, ErrConnAlreadyExists
} }
c.conns[peerID] = newConnContainer(c.log, conn, msgChannel)
c.log.Infof("open connection to peer: %s", hashedStringID) c.mu.Unlock()
msgChannel := make(chan Msg, 100)
conn := NewConn(c, hashedID, hashedStringID, msgChannel, c.instanceURL)
c.conns[hashedStringID] = newConnContainer(c.log, conn, msgChannel)
return conn, nil return conn, nil
} }
@@ -254,76 +291,70 @@ func (c *Client) Close() error {
return c.close(true) return c.close(true)
} }
func (c *Client) connect() error { func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
rd := dialer.NewRaceDial(c.log, c.connectionURL, quic.Dialer{}, ws.Dialer{}) rd := dialer.NewRaceDial(c.log, c.connectionURL, quic.Dialer{}, ws.Dialer{})
conn, err := rd.Dial() conn, err := rd.Dial()
if err != nil { if err != nil {
return err return nil, err
} }
c.relayConn = conn c.relayConn = conn
if err = c.handShake(); err != nil { instanceURL, err := c.handShake(ctx)
if err != nil {
cErr := conn.Close() cErr := conn.Close()
if cErr != nil { if cErr != nil {
c.log.Errorf("failed to close connection: %s", cErr) c.log.Errorf("failed to close connection: %s", cErr)
} }
return err return nil, err
} }
return nil return instanceURL, nil
} }
func (c *Client) handShake() error { func (c *Client) handShake(ctx context.Context) (*RelayAddr, error) {
msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary()) msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary())
if err != nil { if err != nil {
c.log.Errorf("failed to marshal auth message: %s", err) c.log.Errorf("failed to marshal auth message: %s", err)
return err return nil, err
} }
_, err = c.relayConn.Write(msg) _, err = c.relayConn.Write(msg)
if err != nil { if err != nil {
c.log.Errorf("failed to send auth message: %s", err) c.log.Errorf("failed to send auth message: %s", err)
return err return nil, err
} }
buf := make([]byte, messages.MaxHandshakeRespSize) buf := make([]byte, messages.MaxHandshakeRespSize)
n, err := c.readWithTimeout(buf) n, err := c.readWithTimeout(ctx, buf)
if err != nil { if err != nil {
c.log.Errorf("failed to read auth response: %s", err) c.log.Errorf("failed to read auth response: %s", err)
return err return nil, err
} }
_, err = messages.ValidateVersion(buf[:n]) _, err = messages.ValidateVersion(buf[:n])
if err != nil { if err != nil {
return fmt.Errorf("validate version: %w", err) return nil, fmt.Errorf("validate version: %w", err)
} }
msgType, err := messages.DetermineServerMessageType(buf[:n]) msgType, err := messages.DetermineServerMessageType(buf[:n])
if err != nil { if err != nil {
c.log.Errorf("failed to determine message type: %s", err) c.log.Errorf("failed to determine message type: %s", err)
return err return nil, err
} }
if msgType != messages.MsgTypeAuthResponse { if msgType != messages.MsgTypeAuthResponse {
c.log.Errorf("unexpected message type: %s", msgType) c.log.Errorf("unexpected message type: %s", msgType)
return fmt.Errorf("unexpected message type") return nil, fmt.Errorf("unexpected message type")
} }
addr, err := messages.UnmarshalAuthResponse(buf[:n]) addr, err := messages.UnmarshalAuthResponse(buf[:n])
if err != nil { if err != nil {
return err return nil, err
} }
c.muInstanceURL.Lock() return &RelayAddr{addr: addr}, nil
c.instanceURL = &RelayAddr{addr: addr}
c.muInstanceURL.Unlock()
return nil
} }
func (c *Client) readLoop(relayConn net.Conn) { func (c *Client) readLoop(hc *healthcheck.Receiver, relayConn net.Conn, internallyStoppedFlag *internalStopFlag) {
internallyStoppedFlag := newInternalStopFlag()
hc := healthcheck.NewReceiver(c.log)
go c.listenForStopEvents(hc, relayConn, internallyStoppedFlag)
var ( var (
errExit error errExit error
n int n int
@@ -366,10 +397,7 @@ func (c *Client) readLoop(relayConn net.Conn) {
hc.Stop() hc.Stop()
c.muInstanceURL.Lock() c.stateSubscription.Cleanup()
c.instanceURL = nil
c.muInstanceURL.Unlock()
c.wgReadLoop.Done() c.wgReadLoop.Done()
_ = c.close(false) _ = c.close(false)
c.notifyDisconnected() c.notifyDisconnected()
@@ -382,6 +410,14 @@ func (c *Client) handleMsg(msgType messages.MsgType, buf []byte, bufPtr *[]byte,
c.bufPool.Put(bufPtr) c.bufPool.Put(bufPtr)
case messages.MsgTypeTransport: case messages.MsgTypeTransport:
return c.handleTransportMsg(buf, bufPtr, internallyStoppedFlag) return c.handleTransportMsg(buf, bufPtr, internallyStoppedFlag)
case messages.MsgTypePeersOnline:
c.handlePeersOnlineMsg(buf)
c.bufPool.Put(bufPtr)
return true
case messages.MsgTypePeersWentOffline:
c.handlePeersWentOfflineMsg(buf)
c.bufPool.Put(bufPtr)
return true
case messages.MsgTypeClose: case messages.MsgTypeClose:
c.log.Debugf("relay connection close by server") c.log.Debugf("relay connection close by server")
c.bufPool.Put(bufPtr) c.bufPool.Put(bufPtr)
@@ -413,18 +449,16 @@ func (c *Client) handleTransportMsg(buf []byte, bufPtr *[]byte, internallyStoppe
return true return true
} }
stringID := messages.HashIDToString(peerID)
c.mu.Lock() c.mu.Lock()
if !c.serviceIsRunning { if !c.serviceIsRunning {
c.mu.Unlock() c.mu.Unlock()
c.bufPool.Put(bufPtr) c.bufPool.Put(bufPtr)
return false return false
} }
container, ok := c.conns[stringID] container, ok := c.conns[*peerID]
c.mu.Unlock() c.mu.Unlock()
if !ok { if !ok {
c.log.Errorf("peer not found: %s", stringID) c.log.Errorf("peer not found: %s", peerID.String())
c.bufPool.Put(bufPtr) c.bufPool.Put(bufPtr)
return true return true
} }
@@ -437,9 +471,9 @@ func (c *Client) handleTransportMsg(buf []byte, bufPtr *[]byte, internallyStoppe
return true return true
} }
func (c *Client) writeTo(connReference *Conn, id string, dstID []byte, payload []byte) (int, error) { func (c *Client) writeTo(connReference *Conn, dstID messages.PeerID, payload []byte) (int, error) {
c.mu.Lock() c.mu.Lock()
conn, ok := c.conns[id] conn, ok := c.conns[dstID]
c.mu.Unlock() c.mu.Unlock()
if !ok { if !ok {
return 0, net.ErrClosed return 0, net.ErrClosed
@@ -464,7 +498,7 @@ func (c *Client) writeTo(connReference *Conn, id string, dstID []byte, payload [
return len(payload), err return len(payload), err
} }
func (c *Client) listenForStopEvents(hc *healthcheck.Receiver, conn net.Conn, internalStopFlag *internalStopFlag) { func (c *Client) listenForStopEvents(ctx context.Context, hc *healthcheck.Receiver, conn net.Conn, internalStopFlag *internalStopFlag) {
for { for {
select { select {
case _, ok := <-hc.OnTimeout: case _, ok := <-hc.OnTimeout:
@@ -478,7 +512,7 @@ func (c *Client) listenForStopEvents(hc *healthcheck.Receiver, conn net.Conn, in
c.log.Warnf("failed to close connection: %s", err) c.log.Warnf("failed to close connection: %s", err)
} }
return return
case <-c.parentCtx.Done(): case <-ctx.Done():
err := c.close(true) err := c.close(true)
if err != nil { if err != nil {
c.log.Errorf("failed to teardown connection: %s", err) c.log.Errorf("failed to teardown connection: %s", err)
@@ -492,10 +526,31 @@ func (c *Client) closeAllConns() {
for _, container := range c.conns { for _, container := range c.conns {
container.close() container.close()
} }
c.conns = make(map[string]*connContainer) c.conns = make(map[messages.PeerID]*connContainer)
} }
func (c *Client) closeConn(connReference *Conn, id string) error { func (c *Client) closeConnsByPeerID(peerIDs []messages.PeerID) {
c.mu.Lock()
defer c.mu.Unlock()
for _, peerID := range peerIDs {
container, ok := c.conns[peerID]
if !ok {
c.log.Warnf("can not close connection, peer not found: %s", peerID)
continue
}
container.log.Infof("remote peer has been disconnected, free up connection: %s", peerID)
container.close()
delete(c.conns, peerID)
}
if err := c.stateSubscription.UnsubscribeStateChange(peerIDs); err != nil {
c.log.Errorf("failed to unsubscribe from peer state change: %s, %s", peerIDs, err)
}
}
func (c *Client) closeConn(connReference *Conn, id messages.PeerID) error {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
@@ -507,6 +562,11 @@ func (c *Client) closeConn(connReference *Conn, id string) error {
if container.conn != connReference { if container.conn != connReference {
return fmt.Errorf("conn reference mismatch") return fmt.Errorf("conn reference mismatch")
} }
if err := c.stateSubscription.UnsubscribeStateChange([]messages.PeerID{id}); err != nil {
container.log.Errorf("failed to unsubscribe from peer state change: %s", err)
}
c.log.Infof("free up connection to peer: %s", id) c.log.Infof("free up connection to peer: %s", id)
delete(c.conns, id) delete(c.conns, id)
container.close() container.close()
@@ -525,8 +585,12 @@ func (c *Client) close(gracefullyExit bool) error {
c.log.Warn("relay connection was already marked as not running") c.log.Warn("relay connection was already marked as not running")
return nil return nil
} }
c.serviceIsRunning = false c.serviceIsRunning = false
c.muInstanceURL.Lock()
c.instanceURL = nil
c.muInstanceURL.Unlock()
c.log.Infof("closing all peer connections") c.log.Infof("closing all peer connections")
c.closeAllConns() c.closeAllConns()
if gracefullyExit { if gracefullyExit {
@@ -559,8 +623,8 @@ func (c *Client) writeCloseMsg() {
} }
} }
func (c *Client) readWithTimeout(buf []byte) (int, error) { func (c *Client) readWithTimeout(ctx context.Context, buf []byte) (int, error) {
ctx, cancel := context.WithTimeout(c.parentCtx, serverResponseTimeout) ctx, cancel := context.WithTimeout(ctx, serverResponseTimeout)
defer cancel() defer cancel()
readDone := make(chan struct{}) readDone := make(chan struct{})
@@ -581,3 +645,21 @@ func (c *Client) readWithTimeout(buf []byte) (int, error) {
return n, err return n, err
} }
} }
func (c *Client) handlePeersOnlineMsg(buf []byte) {
peersID, err := messages.UnmarshalPeersOnlineMsg(buf)
if err != nil {
c.log.Errorf("failed to unmarshal peers online msg: %s", err)
return
}
c.stateSubscription.OnPeersOnline(peersID)
}
func (c *Client) handlePeersWentOfflineMsg(buf []byte) {
peersID, err := messages.UnMarshalPeersWentOffline(buf)
if err != nil {
c.log.Errorf("failed to unmarshal peers went offline msg: %s", err)
return
}
c.stateSubscription.OnPeersWentOffline(peersID)
}

View File

@@ -18,14 +18,19 @@ import (
) )
var ( var (
av = &allow.Auth{}
hmacTokenStore = &hmac.TokenStore{} hmacTokenStore = &hmac.TokenStore{}
serverListenAddr = "127.0.0.1:1234" serverListenAddr = "127.0.0.1:1234"
serverURL = "rel://127.0.0.1:1234" serverURL = "rel://127.0.0.1:1234"
serverCfg = server.Config{
Meter: otel.Meter(""),
ExposedAddress: serverURL,
TLSSupport: false,
AuthValidator: &allow.Auth{},
}
) )
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
_ = util.InitLog("error", "console") _ = util.InitLog("debug", "console")
code := m.Run() code := m.Run()
os.Exit(code) os.Exit(code)
} }
@@ -33,7 +38,7 @@ func TestMain(m *testing.M) {
func TestClient(t *testing.T) { func TestClient(t *testing.T) {
ctx := context.Background() ctx := context.Background()
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) srv, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@@ -58,37 +63,37 @@ func TestClient(t *testing.T) {
t.Fatalf("failed to start server: %s", err) t.Fatalf("failed to start server: %s", err)
} }
t.Log("alice connecting to server") t.Log("alice connecting to server")
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice") clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect() err = clientAlice.Connect(ctx)
if err != nil { if err != nil {
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
} }
defer clientAlice.Close() defer clientAlice.Close()
t.Log("placeholder connecting to server") t.Log("placeholder connecting to server")
clientPlaceHolder := NewClient(ctx, serverURL, hmacTokenStore, "clientPlaceHolder") clientPlaceHolder := NewClient(serverURL, hmacTokenStore, "clientPlaceHolder")
err = clientPlaceHolder.Connect() err = clientPlaceHolder.Connect(ctx)
if err != nil { if err != nil {
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
} }
defer clientPlaceHolder.Close() defer clientPlaceHolder.Close()
t.Log("Bob connecting to server") t.Log("Bob connecting to server")
clientBob := NewClient(ctx, serverURL, hmacTokenStore, "bob") clientBob := NewClient(serverURL, hmacTokenStore, "bob")
err = clientBob.Connect() err = clientBob.Connect(ctx)
if err != nil { if err != nil {
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
} }
defer clientBob.Close() defer clientBob.Close()
t.Log("Alice open connection to Bob") t.Log("Alice open connection to Bob")
connAliceToBob, err := clientAlice.OpenConn("bob") connAliceToBob, err := clientAlice.OpenConn(ctx, "bob")
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }
t.Log("Bob open connection to Alice") t.Log("Bob open connection to Alice")
connBobToAlice, err := clientBob.OpenConn("alice") connBobToAlice, err := clientBob.OpenConn(ctx, "alice")
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }
@@ -115,7 +120,7 @@ func TestClient(t *testing.T) {
func TestRegistration(t *testing.T) { func TestRegistration(t *testing.T) {
ctx := context.Background() ctx := context.Background()
srvCfg := server.ListenerConfig{Address: serverListenAddr} srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) srv, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@@ -132,8 +137,8 @@ func TestRegistration(t *testing.T) {
t.Fatalf("failed to start server: %s", err) t.Fatalf("failed to start server: %s", err)
} }
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice") clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect() err = clientAlice.Connect(ctx)
if err != nil { if err != nil {
_ = srv.Shutdown(ctx) _ = srv.Shutdown(ctx)
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
@@ -172,8 +177,8 @@ func TestRegistrationTimeout(t *testing.T) {
_ = fakeTCPListener.Close() _ = fakeTCPListener.Close()
}(fakeTCPListener) }(fakeTCPListener)
clientAlice := NewClient(ctx, "127.0.0.1:1234", hmacTokenStore, "alice") clientAlice := NewClient("127.0.0.1:1234", hmacTokenStore, "alice")
err = clientAlice.Connect() err = clientAlice.Connect(ctx)
if err == nil { if err == nil {
t.Errorf("failed to connect to server: %s", err) t.Errorf("failed to connect to server: %s", err)
} }
@@ -189,7 +194,7 @@ func TestEcho(t *testing.T) {
idAlice := "alice" idAlice := "alice"
idBob := "bob" idBob := "bob"
srvCfg := server.ListenerConfig{Address: serverListenAddr} srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) srv, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@@ -213,8 +218,8 @@ func TestEcho(t *testing.T) {
t.Fatalf("failed to start server: %s", err) t.Fatalf("failed to start server: %s", err)
} }
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, idAlice) clientAlice := NewClient(serverURL, hmacTokenStore, idAlice)
err = clientAlice.Connect() err = clientAlice.Connect(ctx)
if err != nil { if err != nil {
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
} }
@@ -225,8 +230,8 @@ func TestEcho(t *testing.T) {
} }
}() }()
clientBob := NewClient(ctx, serverURL, hmacTokenStore, idBob) clientBob := NewClient(serverURL, hmacTokenStore, idBob)
err = clientBob.Connect() err = clientBob.Connect(ctx)
if err != nil { if err != nil {
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
} }
@@ -237,12 +242,12 @@ func TestEcho(t *testing.T) {
} }
}() }()
connAliceToBob, err := clientAlice.OpenConn(idBob) connAliceToBob, err := clientAlice.OpenConn(ctx, idBob)
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }
connBobToAlice, err := clientBob.OpenConn(idAlice) connBobToAlice, err := clientBob.OpenConn(ctx, idAlice)
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }
@@ -278,7 +283,7 @@ func TestBindToUnavailabePeer(t *testing.T) {
ctx := context.Background() ctx := context.Background()
srvCfg := server.ListenerConfig{Address: serverListenAddr} srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) srv, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@@ -303,14 +308,14 @@ func TestBindToUnavailabePeer(t *testing.T) {
t.Fatalf("failed to start server: %s", err) t.Fatalf("failed to start server: %s", err)
} }
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice") clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect() err = clientAlice.Connect(ctx)
if err != nil { if err != nil {
t.Errorf("failed to connect to server: %s", err) t.Errorf("failed to connect to server: %s", err)
} }
_, err = clientAlice.OpenConn("bob") _, err = clientAlice.OpenConn(ctx, "bob")
if err != nil { if err == nil {
t.Errorf("failed to bind channel: %s", err) t.Errorf("expected error when binding to unavailable peer, got nil")
} }
log.Infof("closing client") log.Infof("closing client")
@@ -324,7 +329,7 @@ func TestBindReconnect(t *testing.T) {
ctx := context.Background() ctx := context.Background()
srvCfg := server.ListenerConfig{Address: serverListenAddr} srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) srv, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@@ -349,24 +354,24 @@ func TestBindReconnect(t *testing.T) {
t.Fatalf("failed to start server: %s", err) t.Fatalf("failed to start server: %s", err)
} }
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice") clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect() err = clientAlice.Connect(ctx)
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
clientBob := NewClient(serverURL, hmacTokenStore, "bob")
err = clientBob.Connect(ctx)
if err != nil { if err != nil {
t.Errorf("failed to connect to server: %s", err) t.Errorf("failed to connect to server: %s", err)
} }
_, err = clientAlice.OpenConn("bob") _, err = clientAlice.OpenConn(ctx, "bob")
if err != nil { if err != nil {
t.Errorf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }
clientBob := NewClient(ctx, serverURL, hmacTokenStore, "bob") chBob, err := clientBob.OpenConn(ctx, "alice")
err = clientBob.Connect()
if err != nil {
t.Errorf("failed to connect to server: %s", err)
}
chBob, err := clientBob.OpenConn("alice")
if err != nil { if err != nil {
t.Errorf("failed to bind channel: %s", err) t.Errorf("failed to bind channel: %s", err)
} }
@@ -377,18 +382,28 @@ func TestBindReconnect(t *testing.T) {
t.Errorf("failed to close client: %s", err) t.Errorf("failed to close client: %s", err)
} }
clientAlice = NewClient(ctx, serverURL, hmacTokenStore, "alice") clientAlice = NewClient(serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect() err = clientAlice.Connect(ctx)
if err != nil { if err != nil {
t.Errorf("failed to connect to server: %s", err) t.Errorf("failed to connect to server: %s", err)
} }
chAlice, err := clientAlice.OpenConn("bob") chAlice, err := clientAlice.OpenConn(ctx, "bob")
if err != nil { if err != nil {
t.Errorf("failed to bind channel: %s", err) t.Errorf("failed to bind channel: %s", err)
} }
testString := "hello alice, I am bob" testString := "hello alice, I am bob"
_, err = chBob.Write([]byte(testString))
if err == nil {
t.Errorf("expected error when writing to channel, got nil")
}
chBob, err = clientBob.OpenConn(ctx, "alice")
if err != nil {
t.Errorf("failed to bind channel: %s", err)
}
_, err = chBob.Write([]byte(testString)) _, err = chBob.Write([]byte(testString))
if err != nil { if err != nil {
t.Errorf("failed to write to channel: %s", err) t.Errorf("failed to write to channel: %s", err)
@@ -415,7 +430,7 @@ func TestCloseConn(t *testing.T) {
ctx := context.Background() ctx := context.Background()
srvCfg := server.ListenerConfig{Address: serverListenAddr} srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) srv, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@@ -440,13 +455,19 @@ func TestCloseConn(t *testing.T) {
t.Fatalf("failed to start server: %s", err) t.Fatalf("failed to start server: %s", err)
} }
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice") bob := NewClient(serverURL, hmacTokenStore, "bob")
err = clientAlice.Connect() err = bob.Connect(ctx)
if err != nil { if err != nil {
t.Errorf("failed to connect to server: %s", err) t.Errorf("failed to connect to server: %s", err)
} }
conn, err := clientAlice.OpenConn("bob") clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect(ctx)
if err != nil {
t.Errorf("failed to connect to server: %s", err)
}
conn, err := clientAlice.OpenConn(ctx, "bob")
if err != nil { if err != nil {
t.Errorf("failed to bind channel: %s", err) t.Errorf("failed to bind channel: %s", err)
} }
@@ -472,7 +493,7 @@ func TestCloseRelayConn(t *testing.T) {
ctx := context.Background() ctx := context.Background()
srvCfg := server.ListenerConfig{Address: serverListenAddr} srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) srv, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@@ -496,13 +517,19 @@ func TestCloseRelayConn(t *testing.T) {
t.Fatalf("failed to start server: %s", err) t.Fatalf("failed to start server: %s", err)
} }
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice") bob := NewClient(serverURL, hmacTokenStore, "bob")
err = clientAlice.Connect() err = bob.Connect(ctx)
if err != nil { if err != nil {
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
} }
conn, err := clientAlice.OpenConn("bob") clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect(ctx)
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
conn, err := clientAlice.OpenConn(ctx, "bob")
if err != nil { if err != nil {
t.Errorf("failed to bind channel: %s", err) t.Errorf("failed to bind channel: %s", err)
} }
@@ -514,7 +541,7 @@ func TestCloseRelayConn(t *testing.T) {
t.Errorf("unexpected reading from closed connection") t.Errorf("unexpected reading from closed connection")
} }
_, err = clientAlice.OpenConn("bob") _, err = clientAlice.OpenConn(ctx, "bob")
if err == nil { if err == nil {
t.Errorf("unexpected opening connection to closed server") t.Errorf("unexpected opening connection to closed server")
} }
@@ -524,7 +551,7 @@ func TestCloseByServer(t *testing.T) {
ctx := context.Background() ctx := context.Background()
srvCfg := server.ListenerConfig{Address: serverListenAddr} srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv1, err := server.NewServer(otel.Meter(""), serverURL, false, av) srv1, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@@ -544,8 +571,8 @@ func TestCloseByServer(t *testing.T) {
idAlice := "alice" idAlice := "alice"
log.Debugf("connect by alice") log.Debugf("connect by alice")
relayClient := NewClient(ctx, serverURL, hmacTokenStore, idAlice) relayClient := NewClient(serverURL, hmacTokenStore, idAlice)
err = relayClient.Connect() err = relayClient.Connect(ctx)
if err != nil { if err != nil {
log.Fatalf("failed to connect to server: %s", err) log.Fatalf("failed to connect to server: %s", err)
} }
@@ -567,7 +594,7 @@ func TestCloseByServer(t *testing.T) {
log.Fatalf("timeout waiting for client to disconnect") log.Fatalf("timeout waiting for client to disconnect")
} }
_, err = relayClient.OpenConn("bob") _, err = relayClient.OpenConn(ctx, "bob")
if err == nil { if err == nil {
t.Errorf("unexpected opening connection to closed server") t.Errorf("unexpected opening connection to closed server")
} }
@@ -577,7 +604,7 @@ func TestCloseByClient(t *testing.T) {
ctx := context.Background() ctx := context.Background()
srvCfg := server.ListenerConfig{Address: serverListenAddr} srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) srv, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@@ -596,8 +623,8 @@ func TestCloseByClient(t *testing.T) {
idAlice := "alice" idAlice := "alice"
log.Debugf("connect by alice") log.Debugf("connect by alice")
relayClient := NewClient(ctx, serverURL, hmacTokenStore, idAlice) relayClient := NewClient(serverURL, hmacTokenStore, idAlice)
err = relayClient.Connect() err = relayClient.Connect(ctx)
if err != nil { if err != nil {
log.Fatalf("failed to connect to server: %s", err) log.Fatalf("failed to connect to server: %s", err)
} }
@@ -607,7 +634,7 @@ func TestCloseByClient(t *testing.T) {
t.Errorf("failed to close client: %s", err) t.Errorf("failed to close client: %s", err)
} }
_, err = relayClient.OpenConn("bob") _, err = relayClient.OpenConn(ctx, "bob")
if err == nil { if err == nil {
t.Errorf("unexpected opening connection to closed server") t.Errorf("unexpected opening connection to closed server")
} }
@@ -623,7 +650,7 @@ func TestCloseNotDrainedChannel(t *testing.T) {
idAlice := "alice" idAlice := "alice"
idBob := "bob" idBob := "bob"
srvCfg := server.ListenerConfig{Address: serverListenAddr} srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) srv, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@@ -647,8 +674,8 @@ func TestCloseNotDrainedChannel(t *testing.T) {
t.Fatalf("failed to start server: %s", err) t.Fatalf("failed to start server: %s", err)
} }
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, idAlice) clientAlice := NewClient(serverURL, hmacTokenStore, idAlice)
err = clientAlice.Connect() err = clientAlice.Connect(ctx)
if err != nil { if err != nil {
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
} }
@@ -659,8 +686,8 @@ func TestCloseNotDrainedChannel(t *testing.T) {
} }
}() }()
clientBob := NewClient(ctx, serverURL, hmacTokenStore, idBob) clientBob := NewClient(serverURL, hmacTokenStore, idBob)
err = clientBob.Connect() err = clientBob.Connect(ctx)
if err != nil { if err != nil {
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
} }
@@ -671,12 +698,12 @@ func TestCloseNotDrainedChannel(t *testing.T) {
} }
}() }()
connAliceToBob, err := clientAlice.OpenConn(idBob) connAliceToBob, err := clientAlice.OpenConn(ctx, idBob)
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }
connBobToAlice, err := clientBob.OpenConn(idAlice) connBobToAlice, err := clientBob.OpenConn(ctx, idAlice)
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }

View File

@@ -3,13 +3,14 @@ package client
import ( import (
"net" "net"
"time" "time"
"github.com/netbirdio/netbird/relay/messages"
) )
// Conn represent a connection to a relayed remote peer. // Conn represent a connection to a relayed remote peer.
type Conn struct { type Conn struct {
client *Client client *Client
dstID []byte dstID messages.PeerID
dstStringID string
messageChan chan Msg messageChan chan Msg
instanceURL *RelayAddr instanceURL *RelayAddr
} }
@@ -17,14 +18,12 @@ type Conn struct {
// NewConn creates a new connection to a relayed remote peer. // NewConn creates a new connection to a relayed remote peer.
// client: the client instance, it used to send messages to the destination peer // client: the client instance, it used to send messages to the destination peer
// dstID: the destination peer ID // dstID: the destination peer ID
// dstStringID: the destination peer ID in string format
// messageChan: the channel where the messages will be received // messageChan: the channel where the messages will be received
// instanceURL: the relay instance URL, it used to get the proper server instance address for the remote peer // instanceURL: the relay instance URL, it used to get the proper server instance address for the remote peer
func NewConn(client *Client, dstID []byte, dstStringID string, messageChan chan Msg, instanceURL *RelayAddr) *Conn { func NewConn(client *Client, dstID messages.PeerID, messageChan chan Msg, instanceURL *RelayAddr) *Conn {
c := &Conn{ c := &Conn{
client: client, client: client,
dstID: dstID, dstID: dstID,
dstStringID: dstStringID,
messageChan: messageChan, messageChan: messageChan,
instanceURL: instanceURL, instanceURL: instanceURL,
} }
@@ -33,7 +32,7 @@ func NewConn(client *Client, dstID []byte, dstStringID string, messageChan chan
} }
func (c *Conn) Write(p []byte) (n int, err error) { func (c *Conn) Write(p []byte) (n int, err error) {
return c.client.writeTo(c, c.dstStringID, c.dstID, p) return c.client.writeTo(c, c.dstID, p)
} }
func (c *Conn) Read(b []byte) (n int, err error) { func (c *Conn) Read(b []byte) (n int, err error) {
@@ -48,7 +47,7 @@ func (c *Conn) Read(b []byte) (n int, err error) {
} }
func (c *Conn) Close() error { func (c *Conn) Close() error {
return c.client.closeConn(c, c.dstStringID) return c.client.closeConn(c, c.dstID)
} }
func (c *Conn) LocalAddr() net.Addr { func (c *Conn) LocalAddr() net.Addr {

View File

@@ -80,7 +80,7 @@ func (g *Guard) tryToQuickReconnect(parentCtx context.Context, rc *Client) bool
log.Infof("try to reconnect to Relay server: %s", rc.connectionURL) log.Infof("try to reconnect to Relay server: %s", rc.connectionURL)
if err := rc.Connect(); err != nil { if err := rc.Connect(parentCtx); err != nil {
log.Errorf("failed to reconnect to relay server: %s", err) log.Errorf("failed to reconnect to relay server: %s", err)
return false return false
} }

View File

@@ -42,7 +42,7 @@ type OnServerCloseListener func()
// ManagerService is the interface for the relay manager. // ManagerService is the interface for the relay manager.
type ManagerService interface { type ManagerService interface {
Serve() error Serve() error
OpenConn(serverAddress, peerKey string) (net.Conn, error) OpenConn(ctx context.Context, serverAddress, peerKey string) (net.Conn, error)
AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error
RelayInstanceAddress() (string, error) RelayInstanceAddress() (string, error)
ServerURLs() []string ServerURLs() []string
@@ -65,7 +65,7 @@ type Manager struct {
relayClient *Client relayClient *Client
// the guard logic can overwrite the relayClient variable, this mutex protect the usage of the variable // the guard logic can overwrite the relayClient variable, this mutex protect the usage of the variable
relayClientMu sync.Mutex relayClientMu sync.RWMutex
reconnectGuard *Guard reconnectGuard *Guard
relayClients map[string]*RelayTrack relayClients map[string]*RelayTrack
@@ -123,9 +123,9 @@ func (m *Manager) Serve() error {
// OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be // OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be
// established via the relay server. If the peer is on a different relay server, the manager will establish a new // established via the relay server. If the peer is on a different relay server, the manager will establish a new
// connection to the relay server. It returns back with a net.Conn what represent the remote peer connection. // connection to the relay server. It returns back with a net.Conn what represent the remote peer connection.
func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) { func (m *Manager) OpenConn(ctx context.Context, serverAddress, peerKey string) (net.Conn, error) {
m.relayClientMu.Lock() m.relayClientMu.RLock()
defer m.relayClientMu.Unlock() defer m.relayClientMu.RUnlock()
if m.relayClient == nil { if m.relayClient == nil {
return nil, ErrRelayClientNotConnected return nil, ErrRelayClientNotConnected
@@ -141,10 +141,10 @@ func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) {
) )
if !foreign { if !foreign {
log.Debugf("open peer connection via permanent server: %s", peerKey) log.Debugf("open peer connection via permanent server: %s", peerKey)
netConn, err = m.relayClient.OpenConn(peerKey) netConn, err = m.relayClient.OpenConn(ctx, peerKey)
} else { } else {
log.Debugf("open peer connection via foreign server: %s", serverAddress) log.Debugf("open peer connection via foreign server: %s", serverAddress)
netConn, err = m.openConnVia(serverAddress, peerKey) netConn, err = m.openConnVia(ctx, serverAddress, peerKey)
} }
if err != nil { if err != nil {
return nil, err return nil, err
@@ -155,8 +155,8 @@ func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) {
// Ready returns true if the home Relay client is connected to the relay server. // Ready returns true if the home Relay client is connected to the relay server.
func (m *Manager) Ready() bool { func (m *Manager) Ready() bool {
m.relayClientMu.Lock() m.relayClientMu.RLock()
defer m.relayClientMu.Unlock() defer m.relayClientMu.RUnlock()
if m.relayClient == nil { if m.relayClient == nil {
return false return false
@@ -174,8 +174,8 @@ func (m *Manager) SetOnReconnectedListener(f func()) {
// AddCloseListener adds a listener to the given server instance address. The listener will be called if the connection // AddCloseListener adds a listener to the given server instance address. The listener will be called if the connection
// closed. // closed.
func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error { func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error {
m.relayClientMu.Lock() m.relayClientMu.RLock()
defer m.relayClientMu.Unlock() defer m.relayClientMu.RUnlock()
if m.relayClient == nil { if m.relayClient == nil {
return ErrRelayClientNotConnected return ErrRelayClientNotConnected
@@ -199,8 +199,8 @@ func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServ
// RelayInstanceAddress returns the address of the permanent relay server. It could change if the network connection is // RelayInstanceAddress returns the address of the permanent relay server. It could change if the network connection is
// lost. This address will be sent to the target peer to choose the common relay server for the communication. // lost. This address will be sent to the target peer to choose the common relay server for the communication.
func (m *Manager) RelayInstanceAddress() (string, error) { func (m *Manager) RelayInstanceAddress() (string, error) {
m.relayClientMu.Lock() m.relayClientMu.RLock()
defer m.relayClientMu.Unlock() defer m.relayClientMu.RUnlock()
if m.relayClient == nil { if m.relayClient == nil {
return "", ErrRelayClientNotConnected return "", ErrRelayClientNotConnected
@@ -229,7 +229,7 @@ func (m *Manager) UpdateToken(token *relayAuth.Token) error {
return m.tokenStore.UpdateToken(token) return m.tokenStore.UpdateToken(token)
} }
func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) { func (m *Manager) openConnVia(ctx context.Context, serverAddress, peerKey string) (net.Conn, error) {
// check if already has a connection to the desired relay server // check if already has a connection to the desired relay server
m.relayClientsMutex.RLock() m.relayClientsMutex.RLock()
rt, ok := m.relayClients[serverAddress] rt, ok := m.relayClients[serverAddress]
@@ -240,7 +240,7 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) {
if rt.err != nil { if rt.err != nil {
return nil, rt.err return nil, rt.err
} }
return rt.relayClient.OpenConn(peerKey) return rt.relayClient.OpenConn(ctx, peerKey)
} }
m.relayClientsMutex.RUnlock() m.relayClientsMutex.RUnlock()
@@ -255,7 +255,7 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) {
if rt.err != nil { if rt.err != nil {
return nil, rt.err return nil, rt.err
} }
return rt.relayClient.OpenConn(peerKey) return rt.relayClient.OpenConn(ctx, peerKey)
} }
// create a new relay client and store it in the relayClients map // create a new relay client and store it in the relayClients map
@@ -264,8 +264,8 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) {
m.relayClients[serverAddress] = rt m.relayClients[serverAddress] = rt
m.relayClientsMutex.Unlock() m.relayClientsMutex.Unlock()
relayClient := NewClient(m.ctx, serverAddress, m.tokenStore, m.peerID) relayClient := NewClient(serverAddress, m.tokenStore, m.peerID)
err := relayClient.Connect() err := relayClient.Connect(m.ctx)
if err != nil { if err != nil {
rt.err = err rt.err = err
rt.Unlock() rt.Unlock()
@@ -279,7 +279,7 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) {
rt.relayClient = relayClient rt.relayClient = relayClient
rt.Unlock() rt.Unlock()
conn, err := relayClient.OpenConn(peerKey) conn, err := relayClient.OpenConn(ctx, peerKey)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -300,7 +300,9 @@ func (m *Manager) onServerConnected() {
func (m *Manager) onServerDisconnected(serverAddress string) { func (m *Manager) onServerDisconnected(serverAddress string) {
m.relayClientMu.Lock() m.relayClientMu.Lock()
if serverAddress == m.relayClient.connectionURL { if serverAddress == m.relayClient.connectionURL {
go m.reconnectGuard.StartReconnectTrys(m.ctx, m.relayClient) go func(client *Client) {
m.reconnectGuard.StartReconnectTrys(m.ctx, client)
}(m.relayClient)
} }
m.relayClientMu.Unlock() m.relayClientMu.Unlock()

View File

@@ -8,6 +8,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel" "go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/relay/auth/allow"
"github.com/netbirdio/netbird/relay/server" "github.com/netbirdio/netbird/relay/server"
) )
@@ -22,16 +23,22 @@ func TestEmptyURL(t *testing.T) {
func TestForeignConn(t *testing.T) { func TestForeignConn(t *testing.T) {
ctx := context.Background() ctx := context.Background()
srvCfg1 := server.ListenerConfig{ lstCfg1 := server.ListenerConfig{
Address: "localhost:1234", Address: "localhost:1234",
} }
srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av)
srv1, err := server.NewServer(server.Config{
Meter: otel.Meter(""),
ExposedAddress: lstCfg1.Address,
TLSSupport: false,
AuthValidator: &allow.Auth{},
})
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
errChan := make(chan error, 1) errChan := make(chan error, 1)
go func() { go func() {
err := srv1.Listen(srvCfg1) err := srv1.Listen(lstCfg1)
if err != nil { if err != nil {
errChan <- err errChan <- err
} }
@@ -51,7 +58,12 @@ func TestForeignConn(t *testing.T) {
srvCfg2 := server.ListenerConfig{ srvCfg2 := server.ListenerConfig{
Address: "localhost:2234", Address: "localhost:2234",
} }
srv2, err := server.NewServer(otel.Meter(""), srvCfg2.Address, false, av) srv2, err := server.NewServer(server.Config{
Meter: otel.Meter(""),
ExposedAddress: srvCfg2.Address,
TLSSupport: false,
AuthValidator: &allow.Auth{},
})
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@@ -74,32 +86,26 @@ func TestForeignConn(t *testing.T) {
t.Fatalf("failed to start server: %s", err) t.Fatalf("failed to start server: %s", err)
} }
idAlice := "alice"
log.Debugf("connect by alice")
mCtx, cancel := context.WithCancel(ctx) mCtx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
clientAlice := NewManager(mCtx, toURL(srvCfg1), idAlice) clientAlice := NewManager(mCtx, toURL(lstCfg1), "alice")
err = clientAlice.Serve() if err := clientAlice.Serve(); err != nil {
if err != nil {
t.Fatalf("failed to serve manager: %s", err) t.Fatalf("failed to serve manager: %s", err)
} }
idBob := "bob" clientBob := NewManager(mCtx, toURL(srvCfg2), "bob")
log.Debugf("connect by bob") if err := clientBob.Serve(); err != nil {
clientBob := NewManager(mCtx, toURL(srvCfg2), idBob)
err = clientBob.Serve()
if err != nil {
t.Fatalf("failed to serve manager: %s", err) t.Fatalf("failed to serve manager: %s", err)
} }
bobsSrvAddr, err := clientBob.RelayInstanceAddress() bobsSrvAddr, err := clientBob.RelayInstanceAddress()
if err != nil { if err != nil {
t.Fatalf("failed to get relay address: %s", err) t.Fatalf("failed to get relay address: %s", err)
} }
connAliceToBob, err := clientAlice.OpenConn(bobsSrvAddr, idBob) connAliceToBob, err := clientAlice.OpenConn(ctx, bobsSrvAddr, "bob")
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }
connBobToAlice, err := clientBob.OpenConn(bobsSrvAddr, idAlice) connBobToAlice, err := clientBob.OpenConn(ctx, bobsSrvAddr, "alice")
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }
@@ -137,7 +143,7 @@ func TestForeginConnClose(t *testing.T) {
srvCfg1 := server.ListenerConfig{ srvCfg1 := server.ListenerConfig{
Address: "localhost:1234", Address: "localhost:1234",
} }
srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av) srv1, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@@ -163,7 +169,7 @@ func TestForeginConnClose(t *testing.T) {
srvCfg2 := server.ListenerConfig{ srvCfg2 := server.ListenerConfig{
Address: "localhost:2234", Address: "localhost:2234",
} }
srv2, err := server.NewServer(otel.Meter(""), srvCfg2.Address, false, av) srv2, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@@ -186,16 +192,20 @@ func TestForeginConnClose(t *testing.T) {
t.Fatalf("failed to start server: %s", err) t.Fatalf("failed to start server: %s", err)
} }
idAlice := "alice"
log.Debugf("connect by alice")
mCtx, cancel := context.WithCancel(ctx) mCtx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
mgr := NewManager(mCtx, toURL(srvCfg1), idAlice)
mgrBob := NewManager(mCtx, toURL(srvCfg2), "bob")
if err := mgrBob.Serve(); err != nil {
t.Fatalf("failed to serve manager: %s", err)
}
mgr := NewManager(mCtx, toURL(srvCfg1), "alice")
err = mgr.Serve() err = mgr.Serve()
if err != nil { if err != nil {
t.Fatalf("failed to serve manager: %s", err) t.Fatalf("failed to serve manager: %s", err)
} }
conn, err := mgr.OpenConn(toURL(srvCfg2)[0], "anotherpeer") conn, err := mgr.OpenConn(ctx, toURL(srvCfg2)[0], "bob")
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }
@@ -212,23 +222,21 @@ func TestForeginAutoClose(t *testing.T) {
srvCfg1 := server.ListenerConfig{ srvCfg1 := server.ListenerConfig{
Address: "localhost:1234", Address: "localhost:1234",
} }
srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av) srv1, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
errChan := make(chan error, 1) errChan := make(chan error, 1)
go func() { go func() {
t.Log("binding server 1.") t.Log("binding server 1.")
err := srv1.Listen(srvCfg1) if err := srv1.Listen(srvCfg1); err != nil {
if err != nil {
errChan <- err errChan <- err
} }
}() }()
defer func() { defer func() {
t.Logf("closing server 1.") t.Logf("closing server 1.")
err := srv1.Shutdown(ctx) if err := srv1.Shutdown(ctx); err != nil {
if err != nil {
t.Errorf("failed to close server: %s", err) t.Errorf("failed to close server: %s", err)
} }
t.Logf("server 1. closed") t.Logf("server 1. closed")
@@ -241,7 +249,7 @@ func TestForeginAutoClose(t *testing.T) {
srvCfg2 := server.ListenerConfig{ srvCfg2 := server.ListenerConfig{
Address: "localhost:2234", Address: "localhost:2234",
} }
srv2, err := server.NewServer(otel.Meter(""), srvCfg2.Address, false, av) srv2, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@@ -277,15 +285,8 @@ func TestForeginAutoClose(t *testing.T) {
} }
t.Log("open connection to another peer") t.Log("open connection to another peer")
conn, err := mgr.OpenConn(toURL(srvCfg2)[0], "anotherpeer") if _, err = mgr.OpenConn(ctx, toURL(srvCfg2)[0], "anotherpeer"); err == nil {
if err != nil { t.Fatalf("should have failed to open connection to another peer")
t.Fatalf("failed to bind channel: %s", err)
}
t.Log("close conn")
err = conn.Close()
if err != nil {
t.Fatalf("failed to close connection: %s", err)
} }
timeout := relayCleanupInterval + keepUnusedServerTime + 1*time.Second timeout := relayCleanupInterval + keepUnusedServerTime + 1*time.Second
@@ -305,7 +306,7 @@ func TestAutoReconnect(t *testing.T) {
srvCfg := server.ListenerConfig{ srvCfg := server.ListenerConfig{
Address: "localhost:1234", Address: "localhost:1234",
} }
srv, err := server.NewServer(otel.Meter(""), srvCfg.Address, false, av) srv, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@@ -330,6 +331,13 @@ func TestAutoReconnect(t *testing.T) {
mCtx, cancel := context.WithCancel(ctx) mCtx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
clientBob := NewManager(mCtx, toURL(srvCfg), "bob")
err = clientBob.Serve()
if err != nil {
t.Fatalf("failed to serve manager: %s", err)
}
clientAlice := NewManager(mCtx, toURL(srvCfg), "alice") clientAlice := NewManager(mCtx, toURL(srvCfg), "alice")
err = clientAlice.Serve() err = clientAlice.Serve()
if err != nil { if err != nil {
@@ -339,7 +347,7 @@ func TestAutoReconnect(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("failed to get relay address: %s", err) t.Errorf("failed to get relay address: %s", err)
} }
conn, err := clientAlice.OpenConn(ra, "bob") conn, err := clientAlice.OpenConn(ctx, ra, "bob")
if err != nil { if err != nil {
t.Errorf("failed to bind channel: %s", err) t.Errorf("failed to bind channel: %s", err)
} }
@@ -357,7 +365,7 @@ func TestAutoReconnect(t *testing.T) {
time.Sleep(reconnectingTimeout + 1*time.Second) time.Sleep(reconnectingTimeout + 1*time.Second)
log.Infof("reopent the connection") log.Infof("reopent the connection")
_, err = clientAlice.OpenConn(ra, "bob") _, err = clientAlice.OpenConn(ctx, ra, "bob")
if err != nil { if err != nil {
t.Errorf("failed to open channel: %s", err) t.Errorf("failed to open channel: %s", err)
} }
@@ -366,24 +374,27 @@ func TestAutoReconnect(t *testing.T) {
func TestNotifierDoubleAdd(t *testing.T) { func TestNotifierDoubleAdd(t *testing.T) {
ctx := context.Background() ctx := context.Background()
srvCfg1 := server.ListenerConfig{ listenerCfg1 := server.ListenerConfig{
Address: "localhost:1234", Address: "localhost:1234",
} }
srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av) srv, err := server.NewServer(server.Config{
Meter: otel.Meter(""),
ExposedAddress: listenerCfg1.Address,
TLSSupport: false,
AuthValidator: &allow.Auth{},
})
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
errChan := make(chan error, 1) errChan := make(chan error, 1)
go func() { go func() {
err := srv1.Listen(srvCfg1) if err := srv.Listen(listenerCfg1); err != nil {
if err != nil {
errChan <- err errChan <- err
} }
}() }()
defer func() { defer func() {
err := srv1.Shutdown(ctx) if err := srv.Shutdown(ctx); err != nil {
if err != nil {
t.Errorf("failed to close server: %s", err) t.Errorf("failed to close server: %s", err)
} }
}() }()
@@ -392,17 +403,21 @@ func TestNotifierDoubleAdd(t *testing.T) {
t.Fatalf("failed to start server: %s", err) t.Fatalf("failed to start server: %s", err)
} }
idAlice := "alice"
log.Debugf("connect by alice") log.Debugf("connect by alice")
mCtx, cancel := context.WithCancel(ctx) mCtx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
clientAlice := NewManager(mCtx, toURL(srvCfg1), idAlice)
err = clientAlice.Serve() clientBob := NewManager(mCtx, toURL(listenerCfg1), "bob")
if err != nil { if err = clientBob.Serve(); err != nil {
t.Fatalf("failed to serve manager: %s", err) t.Fatalf("failed to serve manager: %s", err)
} }
conn1, err := clientAlice.OpenConn(clientAlice.ServerURLs()[0], "idBob") clientAlice := NewManager(mCtx, toURL(listenerCfg1), "alice")
if err = clientAlice.Serve(); err != nil {
t.Fatalf("failed to serve manager: %s", err)
}
conn1, err := clientAlice.OpenConn(ctx, clientAlice.ServerURLs()[0], "bob")
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }

View File

@@ -0,0 +1,191 @@
package client
import (
"context"
"errors"
"fmt"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/messages"
)
const (
OpenConnectionTimeout = 30 * time.Second
)
type relayedConnWriter interface {
Write(p []byte) (n int, err error)
}
// PeersStateSubscription manages subscriptions to peer state changes (online/offline)
// over a relay connection. It allows tracking peers' availability and handling offline
// events via a callback. We get online notification from the server only once.
type PeersStateSubscription struct {
log *log.Entry
relayConn relayedConnWriter
offlineCallback func(peerIDs []messages.PeerID)
listenForOfflinePeers map[messages.PeerID]struct{}
waitingPeers map[messages.PeerID]chan struct{}
mu sync.Mutex // Mutex to protect access to waitingPeers and listenForOfflinePeers
}
func NewPeersStateSubscription(log *log.Entry, relayConn relayedConnWriter, offlineCallback func(peerIDs []messages.PeerID)) *PeersStateSubscription {
return &PeersStateSubscription{
log: log,
relayConn: relayConn,
offlineCallback: offlineCallback,
listenForOfflinePeers: make(map[messages.PeerID]struct{}),
waitingPeers: make(map[messages.PeerID]chan struct{}),
}
}
// OnPeersOnline should be called when a notification is received that certain peers have come online.
// It checks if any of the peers are being waited on and signals their availability.
func (s *PeersStateSubscription) OnPeersOnline(peersID []messages.PeerID) {
s.mu.Lock()
defer s.mu.Unlock()
for _, peerID := range peersID {
waitCh, ok := s.waitingPeers[peerID]
if !ok {
// If meanwhile the peer was unsubscribed, we don't need to signal it
continue
}
waitCh <- struct{}{}
delete(s.waitingPeers, peerID)
close(waitCh)
}
}
func (s *PeersStateSubscription) OnPeersWentOffline(peersID []messages.PeerID) {
s.mu.Lock()
relevantPeers := make([]messages.PeerID, 0, len(peersID))
for _, peerID := range peersID {
if _, ok := s.listenForOfflinePeers[peerID]; ok {
relevantPeers = append(relevantPeers, peerID)
}
}
s.mu.Unlock()
if len(relevantPeers) > 0 {
s.offlineCallback(relevantPeers)
}
}
// WaitToBeOnlineAndSubscribe waits for a specific peer to come online and subscribes to its state changes.
func (s *PeersStateSubscription) WaitToBeOnlineAndSubscribe(ctx context.Context, peerID messages.PeerID) error {
// Check if already waiting for this peer
s.mu.Lock()
if _, exists := s.waitingPeers[peerID]; exists {
s.mu.Unlock()
return errors.New("already waiting for peer to come online")
}
// Create a channel to wait for the peer to come online
waitCh := make(chan struct{}, 1)
s.waitingPeers[peerID] = waitCh
s.listenForOfflinePeers[peerID] = struct{}{}
s.mu.Unlock()
if err := s.subscribeStateChange(peerID); err != nil {
s.log.Errorf("failed to subscribe to peer state: %s", err)
s.mu.Lock()
if ch, exists := s.waitingPeers[peerID]; exists && ch == waitCh {
close(waitCh)
delete(s.waitingPeers, peerID)
delete(s.listenForOfflinePeers, peerID)
}
s.mu.Unlock()
return err
}
// Wait for peer to come online or context to be cancelled
timeoutCtx, cancel := context.WithTimeout(ctx, OpenConnectionTimeout)
defer cancel()
select {
case _, ok := <-waitCh:
if !ok {
return fmt.Errorf("wait for peer to come online has been cancelled")
}
s.log.Debugf("peer %s is now online", peerID)
return nil
case <-timeoutCtx.Done():
s.log.Debugf("context timed out while waiting for peer %s to come online", peerID)
if err := s.unsubscribeStateChange([]messages.PeerID{peerID}); err != nil {
s.log.Errorf("failed to unsubscribe from peer state: %s", err)
}
s.mu.Lock()
if ch, exists := s.waitingPeers[peerID]; exists && ch == waitCh {
close(waitCh)
delete(s.waitingPeers, peerID)
delete(s.listenForOfflinePeers, peerID)
}
s.mu.Unlock()
return timeoutCtx.Err()
}
}
func (s *PeersStateSubscription) UnsubscribeStateChange(peerIDs []messages.PeerID) error {
msgErr := s.unsubscribeStateChange(peerIDs)
s.mu.Lock()
for _, peerID := range peerIDs {
if wch, ok := s.waitingPeers[peerID]; ok {
close(wch)
delete(s.waitingPeers, peerID)
}
delete(s.listenForOfflinePeers, peerID)
}
s.mu.Unlock()
return msgErr
}
func (s *PeersStateSubscription) Cleanup() {
s.mu.Lock()
defer s.mu.Unlock()
for _, waitCh := range s.waitingPeers {
close(waitCh)
}
s.waitingPeers = make(map[messages.PeerID]chan struct{})
s.listenForOfflinePeers = make(map[messages.PeerID]struct{})
}
func (s *PeersStateSubscription) subscribeStateChange(peerID messages.PeerID) error {
msgs, err := messages.MarshalSubPeerStateMsg([]messages.PeerID{peerID})
if err != nil {
return err
}
for _, msg := range msgs {
if _, err := s.relayConn.Write(msg); err != nil {
return err
}
}
return nil
}
func (s *PeersStateSubscription) unsubscribeStateChange(peerIDs []messages.PeerID) error {
msgs, err := messages.MarshalUnsubPeerStateMsg(peerIDs)
if err != nil {
return err
}
var connWriteErr error
for _, msg := range msgs {
if _, err := s.relayConn.Write(msg); err != nil {
connWriteErr = err
}
}
return connWriteErr
}

View File

@@ -0,0 +1,99 @@
package client
import (
"bytes"
"context"
"testing"
"time"
"github.com/netbirdio/netbird/relay/messages"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type mockRelayedConn struct {
}
func (m *mockRelayedConn) Write(p []byte) (n int, err error) {
return len(p), nil
}
func TestWaitToBeOnlineAndSubscribe_Success(t *testing.T) {
peerID := messages.HashID("peer1")
mockConn := &mockRelayedConn{}
logger := logrus.New()
logger.SetOutput(&bytes.Buffer{}) // discard log output
sub := NewPeersStateSubscription(logrus.NewEntry(logger), mockConn, nil)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Launch wait in background
go func() {
time.Sleep(100 * time.Millisecond)
sub.OnPeersOnline([]messages.PeerID{peerID})
}()
err := sub.WaitToBeOnlineAndSubscribe(ctx, peerID)
assert.NoError(t, err)
}
func TestWaitToBeOnlineAndSubscribe_Timeout(t *testing.T) {
peerID := messages.HashID("peer2")
mockConn := &mockRelayedConn{}
logger := logrus.New()
logger.SetOutput(&bytes.Buffer{})
sub := NewPeersStateSubscription(logrus.NewEntry(logger), mockConn, nil)
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
err := sub.WaitToBeOnlineAndSubscribe(ctx, peerID)
assert.Error(t, err)
assert.Equal(t, context.DeadlineExceeded, err)
}
func TestWaitToBeOnlineAndSubscribe_Duplicate(t *testing.T) {
peerID := messages.HashID("peer3")
mockConn := &mockRelayedConn{}
logger := logrus.New()
logger.SetOutput(&bytes.Buffer{})
sub := NewPeersStateSubscription(logrus.NewEntry(logger), mockConn, nil)
ctx := context.Background()
go func() {
_ = sub.WaitToBeOnlineAndSubscribe(ctx, peerID)
}()
time.Sleep(100 * time.Millisecond)
err := sub.WaitToBeOnlineAndSubscribe(ctx, peerID)
require.Error(t, err)
assert.Contains(t, err.Error(), "already waiting")
}
func TestUnsubscribeStateChange(t *testing.T) {
peerID := messages.HashID("peer4")
mockConn := &mockRelayedConn{}
logger := logrus.New()
logger.SetOutput(&bytes.Buffer{})
sub := NewPeersStateSubscription(logrus.NewEntry(logger), mockConn, nil)
doneChan := make(chan struct{})
go func() {
_ = sub.WaitToBeOnlineAndSubscribe(context.Background(), peerID)
close(doneChan)
}()
time.Sleep(100 * time.Millisecond)
err := sub.UnsubscribeStateChange([]messages.PeerID{peerID})
assert.NoError(t, err)
select {
case <-doneChan:
case <-time.After(200 * time.Millisecond):
// Expected timeout, meaning the subscription was successfully unsubscribed
t.Errorf("timeout")
}
}

View File

@@ -70,8 +70,8 @@ func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) {
func (sp *ServerPicker) startConnection(ctx context.Context, resultChan chan connResult, url string) { func (sp *ServerPicker) startConnection(ctx context.Context, resultChan chan connResult, url string) {
log.Infof("try to connecting to relay server: %s", url) log.Infof("try to connecting to relay server: %s", url)
relayClient := NewClient(ctx, url, sp.TokenStore, sp.PeerID) relayClient := NewClient(url, sp.TokenStore, sp.PeerID)
err := relayClient.Connect() err := relayClient.Connect(ctx)
resultChan <- connResult{ resultChan <- connResult{
RelayClient: relayClient, RelayClient: relayClient,
Url: url, Url: url,

View File

@@ -141,7 +141,14 @@ func execute(cmd *cobra.Command, args []string) error {
hashedSecret := sha256.Sum256([]byte(cobraConfig.AuthSecret)) hashedSecret := sha256.Sum256([]byte(cobraConfig.AuthSecret))
authenticator := auth.NewTimedHMACValidator(hashedSecret[:], 24*time.Hour) authenticator := auth.NewTimedHMACValidator(hashedSecret[:], 24*time.Hour)
srv, err := server.NewServer(metricsServer.Meter, cobraConfig.ExposedAddress, tlsSupport, authenticator) cfg := server.Config{
Meter: metricsServer.Meter,
ExposedAddress: cobraConfig.ExposedAddress,
AuthValidator: authenticator,
TLSSupport: tlsSupport,
}
srv, err := server.NewServer(cfg)
if err != nil { if err != nil {
log.Debugf("failed to create relay server: %v", err) log.Debugf("failed to create relay server: %v", err)
return fmt.Errorf("failed to create relay server: %v", err) return fmt.Errorf("failed to create relay server: %v", err)

View File

@@ -8,24 +8,24 @@ import (
const ( const (
prefixLength = 4 prefixLength = 4
IDSize = prefixLength + sha256.Size peerIDSize = prefixLength + sha256.Size
) )
var ( var (
prefix = []byte("sha-") // 4 bytes prefix = []byte("sha-") // 4 bytes
) )
// HashID generates a sha256 hash from the peerID and returns the hash and the human-readable string type PeerID [peerIDSize]byte
func HashID(peerID string) ([]byte, string) {
idHash := sha256.Sum256([]byte(peerID)) func (p PeerID) String() string {
idHashString := string(prefix) + base64.StdEncoding.EncodeToString(idHash[:]) return fmt.Sprintf("%s%s", p[:prefixLength], base64.StdEncoding.EncodeToString(p[prefixLength:]))
var prefixedHash []byte
prefixedHash = append(prefixedHash, prefix...)
prefixedHash = append(prefixedHash, idHash[:]...)
return prefixedHash, idHashString
} }
// HashIDToString converts a hash to a human-readable string // HashID generates a sha256 hash from the peerID and returns the hash and the human-readable string
func HashIDToString(idHash []byte) string { func HashID(peerID string) PeerID {
return fmt.Sprintf("%s%s", idHash[:prefixLength], base64.StdEncoding.EncodeToString(idHash[prefixLength:])) idHash := sha256.Sum256([]byte(peerID))
var prefixedHash [peerIDSize]byte
copy(prefixedHash[:prefixLength], prefix)
copy(prefixedHash[prefixLength:], idHash[:])
return prefixedHash
} }

View File

@@ -1,13 +0,0 @@
package messages
import (
"testing"
)
func TestHashID(t *testing.T) {
hashedID, hashedStringId := HashID("alice")
enc := HashIDToString(hashedID)
if enc != hashedStringId {
t.Errorf("expected %s, got %s", hashedStringId, enc)
}
}

View File

@@ -9,19 +9,26 @@ import (
const ( const (
MaxHandshakeSize = 212 MaxHandshakeSize = 212
MaxHandshakeRespSize = 8192 MaxHandshakeRespSize = 8192
MaxMessageSize = 8820
CurrentProtocolVersion = 1 CurrentProtocolVersion = 1
MsgTypeUnknown MsgType = 0 MsgTypeUnknown MsgType = 0
// Deprecated: Use MsgTypeAuth instead. // Deprecated: Use MsgTypeAuth instead.
MsgTypeHello MsgType = 1 MsgTypeHello = 1
// Deprecated: Use MsgTypeAuthResponse instead. // Deprecated: Use MsgTypeAuthResponse instead.
MsgTypeHelloResponse MsgType = 2 MsgTypeHelloResponse = 2
MsgTypeTransport MsgType = 3 MsgTypeTransport = 3
MsgTypeClose MsgType = 4 MsgTypeClose = 4
MsgTypeHealthCheck MsgType = 5 MsgTypeHealthCheck = 5
MsgTypeAuth = 6 MsgTypeAuth = 6
MsgTypeAuthResponse = 7 MsgTypeAuthResponse = 7
// Peers state messages
MsgTypeSubscribePeerState = 8
MsgTypeUnsubscribePeerState = 9
MsgTypePeersOnline = 10
MsgTypePeersWentOffline = 11
// base size of the message // base size of the message
sizeOfVersionByte = 1 sizeOfVersionByte = 1
@@ -30,17 +37,17 @@ const (
// auth message // auth message
sizeOfMagicByte = 4 sizeOfMagicByte = 4
headerSizeAuth = sizeOfMagicByte + IDSize headerSizeAuth = sizeOfMagicByte + peerIDSize
offsetMagicByte = sizeOfProtoHeader offsetMagicByte = sizeOfProtoHeader
offsetAuthPeerID = sizeOfProtoHeader + sizeOfMagicByte offsetAuthPeerID = sizeOfProtoHeader + sizeOfMagicByte
headerTotalSizeAuth = sizeOfProtoHeader + headerSizeAuth headerTotalSizeAuth = sizeOfProtoHeader + headerSizeAuth
// hello message // hello message
headerSizeHello = sizeOfMagicByte + IDSize headerSizeHello = sizeOfMagicByte + peerIDSize
headerSizeHelloResp = 0 headerSizeHelloResp = 0
// transport // transport
headerSizeTransport = IDSize headerSizeTransport = peerIDSize
offsetTransportID = sizeOfProtoHeader offsetTransportID = sizeOfProtoHeader
headerTotalSizeTransport = sizeOfProtoHeader + headerSizeTransport headerTotalSizeTransport = sizeOfProtoHeader + headerSizeTransport
) )
@@ -72,6 +79,14 @@ func (m MsgType) String() string {
return "close" return "close"
case MsgTypeHealthCheck: case MsgTypeHealthCheck:
return "health check" return "health check"
case MsgTypeSubscribePeerState:
return "subscribe peer state"
case MsgTypeUnsubscribePeerState:
return "unsubscribe peer state"
case MsgTypePeersOnline:
return "peers online"
case MsgTypePeersWentOffline:
return "peers went offline"
default: default:
return "unknown" return "unknown"
} }
@@ -102,7 +117,9 @@ func DetermineClientMessageType(msg []byte) (MsgType, error) {
MsgTypeAuth, MsgTypeAuth,
MsgTypeTransport, MsgTypeTransport,
MsgTypeClose, MsgTypeClose,
MsgTypeHealthCheck: MsgTypeHealthCheck,
MsgTypeSubscribePeerState,
MsgTypeUnsubscribePeerState:
return msgType, nil return msgType, nil
default: default:
return MsgTypeUnknown, fmt.Errorf("invalid msg type %d", msgType) return MsgTypeUnknown, fmt.Errorf("invalid msg type %d", msgType)
@@ -122,7 +139,9 @@ func DetermineServerMessageType(msg []byte) (MsgType, error) {
MsgTypeAuthResponse, MsgTypeAuthResponse,
MsgTypeTransport, MsgTypeTransport,
MsgTypeClose, MsgTypeClose,
MsgTypeHealthCheck: MsgTypeHealthCheck,
MsgTypePeersOnline,
MsgTypePeersWentOffline:
return msgType, nil return msgType, nil
default: default:
return MsgTypeUnknown, fmt.Errorf("invalid msg type %d", msgType) return MsgTypeUnknown, fmt.Errorf("invalid msg type %d", msgType)
@@ -135,11 +154,7 @@ func DetermineServerMessageType(msg []byte) (MsgType, error) {
// message is used to authenticate the client with the server. The authentication is done using an HMAC method. // message is used to authenticate the client with the server. The authentication is done using an HMAC method.
// The protocol does not limit to use HMAC, it can be any other method. If the authentication failed the server will // The protocol does not limit to use HMAC, it can be any other method. If the authentication failed the server will
// close the network connection without any response. // close the network connection without any response.
func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) { func MarshalHelloMsg(peerID PeerID, additions []byte) ([]byte, error) {
if len(peerID) != IDSize {
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
}
msg := make([]byte, sizeOfProtoHeader+sizeOfMagicByte, sizeOfProtoHeader+headerSizeHello+len(additions)) msg := make([]byte, sizeOfProtoHeader+sizeOfMagicByte, sizeOfProtoHeader+headerSizeHello+len(additions))
msg[0] = byte(CurrentProtocolVersion) msg[0] = byte(CurrentProtocolVersion)
@@ -147,7 +162,7 @@ func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) {
copy(msg[sizeOfProtoHeader:sizeOfProtoHeader+sizeOfMagicByte], magicHeader) copy(msg[sizeOfProtoHeader:sizeOfProtoHeader+sizeOfMagicByte], magicHeader)
msg = append(msg, peerID...) msg = append(msg, peerID[:]...)
msg = append(msg, additions...) msg = append(msg, additions...)
return msg, nil return msg, nil
@@ -156,7 +171,7 @@ func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) {
// Deprecated: Use UnmarshalAuthMsg instead. // Deprecated: Use UnmarshalAuthMsg instead.
// UnmarshalHelloMsg extracts peerID and the additional data from the hello message. The Additional data is used to // UnmarshalHelloMsg extracts peerID and the additional data from the hello message. The Additional data is used to
// authenticate the client with the server. // authenticate the client with the server.
func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) { func UnmarshalHelloMsg(msg []byte) (*PeerID, []byte, error) {
if len(msg) < sizeOfProtoHeader+headerSizeHello { if len(msg) < sizeOfProtoHeader+headerSizeHello {
return nil, nil, ErrInvalidMessageLength return nil, nil, ErrInvalidMessageLength
} }
@@ -164,7 +179,9 @@ func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) {
return nil, nil, errors.New("invalid magic header") return nil, nil, errors.New("invalid magic header")
} }
return msg[sizeOfProtoHeader+sizeOfMagicByte : sizeOfProtoHeader+headerSizeHello], msg[headerSizeHello:], nil peerID := PeerID(msg[sizeOfProtoHeader+sizeOfMagicByte : sizeOfProtoHeader+headerSizeHello])
return &peerID, msg[headerSizeHello:], nil
} }
// Deprecated: Use MarshalAuthResponse instead. // Deprecated: Use MarshalAuthResponse instead.
@@ -197,34 +214,33 @@ func UnmarshalHelloResponse(msg []byte) ([]byte, error) {
// message is used to authenticate the client with the server. The authentication is done using an HMAC method. // message is used to authenticate the client with the server. The authentication is done using an HMAC method.
// The protocol does not limit to use HMAC, it can be any other method. If the authentication failed the server will // The protocol does not limit to use HMAC, it can be any other method. If the authentication failed the server will
// close the network connection without any response. // close the network connection without any response.
func MarshalAuthMsg(peerID []byte, authPayload []byte) ([]byte, error) { func MarshalAuthMsg(peerID PeerID, authPayload []byte) ([]byte, error) {
if len(peerID) != IDSize { if headerTotalSizeAuth+len(authPayload) > MaxHandshakeSize {
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID)) return nil, fmt.Errorf("too large auth payload")
} }
msg := make([]byte, sizeOfProtoHeader+sizeOfMagicByte, headerTotalSizeAuth+len(authPayload)) msg := make([]byte, headerTotalSizeAuth+len(authPayload))
msg[0] = byte(CurrentProtocolVersion) msg[0] = byte(CurrentProtocolVersion)
msg[1] = byte(MsgTypeAuth) msg[1] = byte(MsgTypeAuth)
copy(msg[sizeOfProtoHeader:], magicHeader) copy(msg[sizeOfProtoHeader:], magicHeader)
copy(msg[offsetAuthPeerID:], peerID[:])
msg = append(msg, peerID...) copy(msg[headerTotalSizeAuth:], authPayload)
msg = append(msg, authPayload...)
return msg, nil return msg, nil
} }
// UnmarshalAuthMsg extracts peerID and the auth payload from the message // UnmarshalAuthMsg extracts peerID and the auth payload from the message
func UnmarshalAuthMsg(msg []byte) ([]byte, []byte, error) { func UnmarshalAuthMsg(msg []byte) (*PeerID, []byte, error) {
if len(msg) < headerTotalSizeAuth { if len(msg) < headerTotalSizeAuth {
return nil, nil, ErrInvalidMessageLength return nil, nil, ErrInvalidMessageLength
} }
// Validate the magic header
if !bytes.Equal(msg[offsetMagicByte:offsetMagicByte+sizeOfMagicByte], magicHeader) { if !bytes.Equal(msg[offsetMagicByte:offsetMagicByte+sizeOfMagicByte], magicHeader) {
return nil, nil, errors.New("invalid magic header") return nil, nil, errors.New("invalid magic header")
} }
return msg[offsetAuthPeerID:headerTotalSizeAuth], msg[headerTotalSizeAuth:], nil peerID := PeerID(msg[offsetAuthPeerID:headerTotalSizeAuth])
return &peerID, msg[headerTotalSizeAuth:], nil
} }
// MarshalAuthResponse creates a response message to the auth. // MarshalAuthResponse creates a response message to the auth.
@@ -268,45 +284,48 @@ func MarshalCloseMsg() []byte {
// MarshalTransportMsg creates a transport message. // MarshalTransportMsg creates a transport message.
// The transport message is used to exchange data between peers. The message contains the data to be exchanged and the // The transport message is used to exchange data between peers. The message contains the data to be exchanged and the
// destination peer hashed ID. // destination peer hashed ID.
func MarshalTransportMsg(peerID, payload []byte) ([]byte, error) { func MarshalTransportMsg(peerID PeerID, payload []byte) ([]byte, error) {
if len(peerID) != IDSize { // todo validate size
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID)) msg := make([]byte, headerTotalSizeTransport+len(payload))
}
msg := make([]byte, headerTotalSizeTransport, headerTotalSizeTransport+len(payload))
msg[0] = byte(CurrentProtocolVersion) msg[0] = byte(CurrentProtocolVersion)
msg[1] = byte(MsgTypeTransport) msg[1] = byte(MsgTypeTransport)
copy(msg[sizeOfProtoHeader:], peerID) copy(msg[sizeOfProtoHeader:], peerID[:])
msg = append(msg, payload...) copy(msg[sizeOfProtoHeader+peerIDSize:], payload)
return msg, nil return msg, nil
} }
// UnmarshalTransportMsg extracts the peerID and the payload from the transport message. // UnmarshalTransportMsg extracts the peerID and the payload from the transport message.
func UnmarshalTransportMsg(buf []byte) ([]byte, []byte, error) { func UnmarshalTransportMsg(buf []byte) (*PeerID, []byte, error) {
if len(buf) < headerTotalSizeTransport { if len(buf) < headerTotalSizeTransport {
return nil, nil, ErrInvalidMessageLength return nil, nil, ErrInvalidMessageLength
} }
return buf[offsetTransportID:headerTotalSizeTransport], buf[headerTotalSizeTransport:], nil const offsetEnd = offsetTransportID + peerIDSize
var peerID PeerID
copy(peerID[:], buf[offsetTransportID:offsetEnd])
return &peerID, buf[headerTotalSizeTransport:], nil
} }
// UnmarshalTransportID extracts the peerID from the transport message. // UnmarshalTransportID extracts the peerID from the transport message.
func UnmarshalTransportID(buf []byte) ([]byte, error) { func UnmarshalTransportID(buf []byte) (*PeerID, error) {
if len(buf) < headerTotalSizeTransport { if len(buf) < headerTotalSizeTransport {
return nil, ErrInvalidMessageLength return nil, ErrInvalidMessageLength
} }
return buf[offsetTransportID:headerTotalSizeTransport], nil
const offsetEnd = offsetTransportID + peerIDSize
var id PeerID
copy(id[:], buf[offsetTransportID:offsetEnd])
return &id, nil
} }
// UpdateTransportMsg updates the peerID in the transport message. // UpdateTransportMsg updates the peerID in the transport message.
// With this function the server can reuse the given byte slice to update the peerID in the transport message. So do // With this function the server can reuse the given byte slice to update the peerID in the transport message. So do
// need to allocate a new byte slice. // need to allocate a new byte slice.
func UpdateTransportMsg(msg []byte, peerID []byte) error { func UpdateTransportMsg(msg []byte, peerID PeerID) error {
if len(msg) < offsetTransportID+len(peerID) { if len(msg) < offsetTransportID+peerIDSize {
return ErrInvalidMessageLength return ErrInvalidMessageLength
} }
copy(msg[offsetTransportID:], peerID) copy(msg[offsetTransportID:], peerID[:])
return nil return nil
} }

View File

@@ -5,7 +5,7 @@ import (
) )
func TestMarshalHelloMsg(t *testing.T) { func TestMarshalHelloMsg(t *testing.T) {
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+") peerID := HashID("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
msg, err := MarshalHelloMsg(peerID, nil) msg, err := MarshalHelloMsg(peerID, nil)
if err != nil { if err != nil {
t.Fatalf("error: %v", err) t.Fatalf("error: %v", err)
@@ -24,13 +24,13 @@ func TestMarshalHelloMsg(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("error: %v", err) t.Fatalf("error: %v", err)
} }
if string(receivedPeerID) != string(peerID) { if receivedPeerID.String() != peerID.String() {
t.Errorf("expected %s, got %s", peerID, receivedPeerID) t.Errorf("expected %s, got %s", peerID, receivedPeerID)
} }
} }
func TestMarshalAuthMsg(t *testing.T) { func TestMarshalAuthMsg(t *testing.T) {
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+") peerID := HashID("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
msg, err := MarshalAuthMsg(peerID, []byte{}) msg, err := MarshalAuthMsg(peerID, []byte{})
if err != nil { if err != nil {
t.Fatalf("error: %v", err) t.Fatalf("error: %v", err)
@@ -49,7 +49,7 @@ func TestMarshalAuthMsg(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("error: %v", err) t.Fatalf("error: %v", err)
} }
if string(receivedPeerID) != string(peerID) { if receivedPeerID.String() != peerID.String() {
t.Errorf("expected %s, got %s", peerID, receivedPeerID) t.Errorf("expected %s, got %s", peerID, receivedPeerID)
} }
} }
@@ -80,7 +80,7 @@ func TestMarshalAuthResponse(t *testing.T) {
} }
func TestMarshalTransportMsg(t *testing.T) { func TestMarshalTransportMsg(t *testing.T) {
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+") peerID := HashID("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
payload := []byte("payload") payload := []byte("payload")
msg, err := MarshalTransportMsg(peerID, payload) msg, err := MarshalTransportMsg(peerID, payload)
if err != nil { if err != nil {
@@ -101,7 +101,7 @@ func TestMarshalTransportMsg(t *testing.T) {
t.Fatalf("failed to unmarshal transport id: %v", err) t.Fatalf("failed to unmarshal transport id: %v", err)
} }
if string(uPeerID) != string(peerID) { if uPeerID.String() != peerID.String() {
t.Errorf("expected %s, got %s", peerID, uPeerID) t.Errorf("expected %s, got %s", peerID, uPeerID)
} }
@@ -110,8 +110,8 @@ func TestMarshalTransportMsg(t *testing.T) {
t.Fatalf("error: %v", err) t.Fatalf("error: %v", err)
} }
if string(id) != string(peerID) { if id.String() != peerID.String() {
t.Errorf("expected %s, got %s", peerID, id) t.Errorf("expected: '%s', got: '%s'", peerID, id)
} }
if string(respPayload) != string(payload) { if string(respPayload) != string(payload) {

View File

@@ -0,0 +1,92 @@
package messages
import (
"fmt"
)
func MarshalSubPeerStateMsg(ids []PeerID) ([][]byte, error) {
return marshalPeerIDs(ids, byte(MsgTypeSubscribePeerState))
}
func UnmarshalSubPeerStateMsg(buf []byte) ([]PeerID, error) {
return unmarshalPeerIDs(buf)
}
func MarshalUnsubPeerStateMsg(ids []PeerID) ([][]byte, error) {
return marshalPeerIDs(ids, byte(MsgTypeUnsubscribePeerState))
}
func UnmarshalUnsubPeerStateMsg(buf []byte) ([]PeerID, error) {
return unmarshalPeerIDs(buf)
}
func MarshalPeersOnline(ids []PeerID) ([][]byte, error) {
return marshalPeerIDs(ids, byte(MsgTypePeersOnline))
}
func UnmarshalPeersOnlineMsg(buf []byte) ([]PeerID, error) {
return unmarshalPeerIDs(buf)
}
func MarshalPeersWentOffline(ids []PeerID) ([][]byte, error) {
return marshalPeerIDs(ids, byte(MsgTypePeersWentOffline))
}
func UnMarshalPeersWentOffline(buf []byte) ([]PeerID, error) {
return unmarshalPeerIDs(buf)
}
// marshalPeerIDs is a generic function to marshal peer IDs with a specific message type
func marshalPeerIDs(ids []PeerID, msgType byte) ([][]byte, error) {
if len(ids) == 0 {
return nil, fmt.Errorf("no list of peer ids provided")
}
const maxPeersPerMessage = (MaxMessageSize - sizeOfProtoHeader) / peerIDSize
var messages [][]byte
for i := 0; i < len(ids); i += maxPeersPerMessage {
end := i + maxPeersPerMessage
if end > len(ids) {
end = len(ids)
}
chunk := ids[i:end]
totalSize := sizeOfProtoHeader + len(chunk)*peerIDSize
buf := make([]byte, totalSize)
buf[0] = byte(CurrentProtocolVersion)
buf[1] = msgType
offset := sizeOfProtoHeader
for _, id := range chunk {
copy(buf[offset:], id[:])
offset += peerIDSize
}
messages = append(messages, buf)
}
return messages, nil
}
// unmarshalPeerIDs is a generic function to unmarshal peer IDs from a buffer
func unmarshalPeerIDs(buf []byte) ([]PeerID, error) {
if len(buf) < sizeOfProtoHeader {
return nil, fmt.Errorf("invalid message format")
}
if (len(buf)-sizeOfProtoHeader)%peerIDSize != 0 {
return nil, fmt.Errorf("invalid peer list size: %d", len(buf)-sizeOfProtoHeader)
}
numIDs := (len(buf) - sizeOfProtoHeader) / peerIDSize
ids := make([]PeerID, numIDs)
offset := sizeOfProtoHeader
for i := 0; i < numIDs; i++ {
copy(ids[i][:], buf[offset:offset+peerIDSize])
offset += peerIDSize
}
return ids, nil
}

View File

@@ -0,0 +1,144 @@
package messages
import (
"bytes"
"testing"
)
const (
testPeerCount = 10
)
// Helper function to generate test PeerIDs
func generateTestPeerIDs(n int) []PeerID {
ids := make([]PeerID, n)
for i := 0; i < n; i++ {
for j := 0; j < peerIDSize; j++ {
ids[i][j] = byte(i + j)
}
}
return ids
}
// Helper function to compare slices of PeerID
func peerIDEqual(a, b []PeerID) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if !bytes.Equal(a[i][:], b[i][:]) {
return false
}
}
return true
}
func TestMarshalUnmarshalSubPeerState(t *testing.T) {
ids := generateTestPeerIDs(testPeerCount)
msgs, err := MarshalSubPeerStateMsg(ids)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
var allIDs []PeerID
for _, msg := range msgs {
decoded, err := UnmarshalSubPeerStateMsg(msg)
if err != nil {
t.Fatalf("unmarshal failed: %v", err)
}
allIDs = append(allIDs, decoded...)
}
if !peerIDEqual(ids, allIDs) {
t.Errorf("expected %v, got %v", ids, allIDs)
}
}
func TestMarshalSubPeerState_EmptyInput(t *testing.T) {
_, err := MarshalSubPeerStateMsg([]PeerID{})
if err == nil {
t.Errorf("expected error for empty input")
}
}
func TestUnmarshalSubPeerState_Invalid(t *testing.T) {
// Too short
_, err := UnmarshalSubPeerStateMsg([]byte{1})
if err == nil {
t.Errorf("expected error for short input")
}
// Misaligned length
buf := make([]byte, sizeOfProtoHeader+1)
_, err = UnmarshalSubPeerStateMsg(buf)
if err == nil {
t.Errorf("expected error for misaligned input")
}
}
func TestMarshalUnmarshalPeersOnline(t *testing.T) {
ids := generateTestPeerIDs(testPeerCount)
msgs, err := MarshalPeersOnline(ids)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
var allIDs []PeerID
for _, msg := range msgs {
decoded, err := UnmarshalPeersOnlineMsg(msg)
if err != nil {
t.Fatalf("unmarshal failed: %v", err)
}
allIDs = append(allIDs, decoded...)
}
if !peerIDEqual(ids, allIDs) {
t.Errorf("expected %v, got %v", ids, allIDs)
}
}
func TestMarshalPeersOnline_EmptyInput(t *testing.T) {
_, err := MarshalPeersOnline([]PeerID{})
if err == nil {
t.Errorf("expected error for empty input")
}
}
func TestUnmarshalPeersOnline_Invalid(t *testing.T) {
_, err := UnmarshalPeersOnlineMsg([]byte{1})
if err == nil {
t.Errorf("expected error for short input")
}
}
func TestMarshalUnmarshalPeersWentOffline(t *testing.T) {
ids := generateTestPeerIDs(testPeerCount)
msgs, err := MarshalPeersWentOffline(ids)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
var allIDs []PeerID
for _, msg := range msgs {
// MarshalPeersWentOffline shares no unmarshal function, so reuse PeersOnline
decoded, err := UnmarshalPeersOnlineMsg(msg)
if err != nil {
t.Fatalf("unmarshal failed: %v", err)
}
allIDs = append(allIDs, decoded...)
}
if !peerIDEqual(ids, allIDs) {
t.Errorf("expected %v, got %v", ids, allIDs)
}
}
func TestMarshalPeersWentOffline_EmptyInput(t *testing.T) {
_, err := MarshalPeersWentOffline([]PeerID{})
if err == nil {
t.Errorf("expected error for empty input")
}
}

View File

@@ -6,7 +6,6 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/auth"
"github.com/netbirdio/netbird/relay/messages" "github.com/netbirdio/netbird/relay/messages"
//nolint:staticcheck //nolint:staticcheck
"github.com/netbirdio/netbird/relay/messages/address" "github.com/netbirdio/netbird/relay/messages/address"
@@ -14,6 +13,12 @@ import (
authmsg "github.com/netbirdio/netbird/relay/messages/auth" authmsg "github.com/netbirdio/netbird/relay/messages/auth"
) )
type Validator interface {
Validate(any) error
// Deprecated: Use Validate instead.
ValidateHelloMsgType(any) error
}
// preparedMsg contains the marshalled success response messages // preparedMsg contains the marshalled success response messages
type preparedMsg struct { type preparedMsg struct {
responseHelloMsg []byte responseHelloMsg []byte
@@ -54,14 +59,14 @@ func marshalResponseHelloMsg(instanceURL string) ([]byte, error) {
type handshake struct { type handshake struct {
conn net.Conn conn net.Conn
validator auth.Validator validator Validator
preparedMsg *preparedMsg preparedMsg *preparedMsg
handshakeMethodAuth bool handshakeMethodAuth bool
peerID string peerID *messages.PeerID
} }
func (h *handshake) handshakeReceive() ([]byte, error) { func (h *handshake) handshakeReceive() (*messages.PeerID, error) {
buf := make([]byte, messages.MaxHandshakeSize) buf := make([]byte, messages.MaxHandshakeSize)
n, err := h.conn.Read(buf) n, err := h.conn.Read(buf)
if err != nil { if err != nil {
@@ -80,17 +85,14 @@ func (h *handshake) handshakeReceive() ([]byte, error) {
return nil, fmt.Errorf("determine message type from %s: %w", h.conn.RemoteAddr(), err) return nil, fmt.Errorf("determine message type from %s: %w", h.conn.RemoteAddr(), err)
} }
var ( var peerID *messages.PeerID
bytePeerID []byte
peerID string
)
switch msgType { switch msgType {
//nolint:staticcheck //nolint:staticcheck
case messages.MsgTypeHello: case messages.MsgTypeHello:
bytePeerID, peerID, err = h.handleHelloMsg(buf) peerID, err = h.handleHelloMsg(buf)
case messages.MsgTypeAuth: case messages.MsgTypeAuth:
h.handshakeMethodAuth = true h.handshakeMethodAuth = true
bytePeerID, peerID, err = h.handleAuthMsg(buf) peerID, err = h.handleAuthMsg(buf)
default: default:
return nil, fmt.Errorf("invalid message type %d from %s", msgType, h.conn.RemoteAddr()) return nil, fmt.Errorf("invalid message type %d from %s", msgType, h.conn.RemoteAddr())
} }
@@ -98,7 +100,7 @@ func (h *handshake) handshakeReceive() ([]byte, error) {
return nil, err return nil, err
} }
h.peerID = peerID h.peerID = peerID
return bytePeerID, nil return peerID, nil
} }
func (h *handshake) handshakeResponse() error { func (h *handshake) handshakeResponse() error {
@@ -116,40 +118,37 @@ func (h *handshake) handshakeResponse() error {
return nil return nil
} }
func (h *handshake) handleHelloMsg(buf []byte) ([]byte, string, error) { func (h *handshake) handleHelloMsg(buf []byte) (*messages.PeerID, error) {
//nolint:staticcheck //nolint:staticcheck
rawPeerID, authData, err := messages.UnmarshalHelloMsg(buf) peerID, authData, err := messages.UnmarshalHelloMsg(buf)
if err != nil { if err != nil {
return nil, "", fmt.Errorf("unmarshal hello message: %w", err) return nil, fmt.Errorf("unmarshal hello message: %w", err)
} }
peerID := messages.HashIDToString(rawPeerID)
log.Warnf("peer %s (%s) is using deprecated initial message type", peerID, h.conn.RemoteAddr()) log.Warnf("peer %s (%s) is using deprecated initial message type", peerID, h.conn.RemoteAddr())
authMsg, err := authmsg.UnmarshalMsg(authData) authMsg, err := authmsg.UnmarshalMsg(authData)
if err != nil { if err != nil {
return nil, "", fmt.Errorf("unmarshal auth message: %w", err) return nil, fmt.Errorf("unmarshal auth message: %w", err)
} }
//nolint:staticcheck //nolint:staticcheck
if err := h.validator.ValidateHelloMsgType(authMsg.AdditionalData); err != nil { if err := h.validator.ValidateHelloMsgType(authMsg.AdditionalData); err != nil {
return nil, "", fmt.Errorf("validate %s (%s): %w", peerID, h.conn.RemoteAddr(), err) return nil, fmt.Errorf("validate %s (%s): %w", peerID, h.conn.RemoteAddr(), err)
} }
return rawPeerID, peerID, nil return peerID, nil
} }
func (h *handshake) handleAuthMsg(buf []byte) ([]byte, string, error) { func (h *handshake) handleAuthMsg(buf []byte) (*messages.PeerID, error) {
rawPeerID, authPayload, err := messages.UnmarshalAuthMsg(buf) rawPeerID, authPayload, err := messages.UnmarshalAuthMsg(buf)
if err != nil { if err != nil {
return nil, "", fmt.Errorf("unmarshal hello message: %w", err) return nil, fmt.Errorf("unmarshal hello message: %w", err)
} }
peerID := messages.HashIDToString(rawPeerID)
if err := h.validator.Validate(authPayload); err != nil { if err := h.validator.Validate(authPayload); err != nil {
return nil, "", fmt.Errorf("validate %s (%s): %w", peerID, h.conn.RemoteAddr(), err) return nil, fmt.Errorf("validate %s (%s): %w", rawPeerID.String(), h.conn.RemoteAddr(), err)
} }
return rawPeerID, peerID, nil return rawPeerID, nil
} }

View File

@@ -12,43 +12,50 @@ import (
"github.com/netbirdio/netbird/relay/healthcheck" "github.com/netbirdio/netbird/relay/healthcheck"
"github.com/netbirdio/netbird/relay/messages" "github.com/netbirdio/netbird/relay/messages"
"github.com/netbirdio/netbird/relay/metrics" "github.com/netbirdio/netbird/relay/metrics"
"github.com/netbirdio/netbird/relay/server/store"
) )
const ( const (
bufferSize = 8820 bufferSize = messages.MaxMessageSize
errCloseConn = "failed to close connection to peer: %s" errCloseConn = "failed to close connection to peer: %s"
) )
// Peer represents a peer connection // Peer represents a peer connection
type Peer struct { type Peer struct {
metrics *metrics.Metrics metrics *metrics.Metrics
log *log.Entry log *log.Entry
idS string id messages.PeerID
idB []byte conn net.Conn
conn net.Conn connMu sync.RWMutex
connMu sync.RWMutex store *store.Store
store *Store notifier *store.PeerNotifier
peersListener *store.Listener
} }
// NewPeer creates a new Peer instance and prepare custom logging // NewPeer creates a new Peer instance and prepare custom logging
func NewPeer(metrics *metrics.Metrics, id []byte, conn net.Conn, store *Store) *Peer { func NewPeer(metrics *metrics.Metrics, id messages.PeerID, conn net.Conn, store *store.Store, notifier *store.PeerNotifier) *Peer {
stringID := messages.HashIDToString(id) p := &Peer{
return &Peer{ metrics: metrics,
metrics: metrics, log: log.WithField("peer_id", id.String()),
log: log.WithField("peer_id", stringID), id: id,
idS: stringID, conn: conn,
idB: id, store: store,
conn: conn, notifier: notifier,
store: store,
} }
return p
} }
// Work reads data from the connection // Work reads data from the connection
// It manages the protocol (healthcheck, transport, close). Read the message and determine the message type and handle // It manages the protocol (healthcheck, transport, close). Read the message and determine the message type and handle
// the message accordingly. // the message accordingly.
func (p *Peer) Work() { func (p *Peer) Work() {
p.peersListener = p.notifier.NewListener(p.sendPeersOnline, p.sendPeersWentOffline)
defer func() { defer func() {
p.notifier.RemoveListener(p.peersListener)
if err := p.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { if err := p.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
p.log.Errorf(errCloseConn, err) p.log.Errorf(errCloseConn, err)
} }
@@ -94,6 +101,10 @@ func (p *Peer) Work() {
} }
} }
func (p *Peer) ID() messages.PeerID {
return p.id
}
func (p *Peer) handleMsgType(ctx context.Context, msgType messages.MsgType, hc *healthcheck.Sender, n int, msg []byte) { func (p *Peer) handleMsgType(ctx context.Context, msgType messages.MsgType, hc *healthcheck.Sender, n int, msg []byte) {
switch msgType { switch msgType {
case messages.MsgTypeHealthCheck: case messages.MsgTypeHealthCheck:
@@ -107,6 +118,10 @@ func (p *Peer) handleMsgType(ctx context.Context, msgType messages.MsgType, hc *
if err := p.conn.Close(); err != nil { if err := p.conn.Close(); err != nil {
log.Errorf(errCloseConn, err) log.Errorf(errCloseConn, err)
} }
case messages.MsgTypeSubscribePeerState:
p.handleSubscribePeerState(msg)
case messages.MsgTypeUnsubscribePeerState:
p.handleUnsubscribePeerState(msg)
default: default:
p.log.Warnf("received unexpected message type: %s", msgType) p.log.Warnf("received unexpected message type: %s", msgType)
} }
@@ -145,7 +160,7 @@ func (p *Peer) Close() {
// String returns the peer ID // String returns the peer ID
func (p *Peer) String() string { func (p *Peer) String() string {
return p.idS return p.id.String()
} }
func (p *Peer) writeWithTimeout(ctx context.Context, buf []byte) error { func (p *Peer) writeWithTimeout(ctx context.Context, buf []byte) error {
@@ -197,14 +212,14 @@ func (p *Peer) handleTransportMsg(msg []byte) {
return return
} }
stringPeerID := messages.HashIDToString(peerID) item, ok := p.store.Peer(*peerID)
dp, ok := p.store.Peer(stringPeerID)
if !ok { if !ok {
p.log.Debugf("peer not found: %s", stringPeerID) p.log.Debugf("peer not found: %s", peerID)
return return
} }
dp := item.(*Peer)
err = messages.UpdateTransportMsg(msg, p.idB) err = messages.UpdateTransportMsg(msg, p.id)
if err != nil { if err != nil {
p.log.Errorf("failed to update transport message: %s", err) p.log.Errorf("failed to update transport message: %s", err)
return return
@@ -217,3 +232,57 @@ func (p *Peer) handleTransportMsg(msg []byte) {
} }
p.metrics.TransferBytesSent.Add(context.Background(), int64(n)) p.metrics.TransferBytesSent.Add(context.Background(), int64(n))
} }
func (p *Peer) handleSubscribePeerState(msg []byte) {
peerIDs, err := messages.UnmarshalSubPeerStateMsg(msg)
if err != nil {
p.log.Errorf("failed to unmarshal open connection message: %s", err)
return
}
p.log.Debugf("received subscription message for %d peers", len(peerIDs))
onlinePeers := p.peersListener.AddInterestedPeers(peerIDs)
if len(onlinePeers) == 0 {
return
}
p.log.Debugf("response with %d online peers", len(onlinePeers))
p.sendPeersOnline(onlinePeers)
}
func (p *Peer) handleUnsubscribePeerState(msg []byte) {
peerIDs, err := messages.UnmarshalUnsubPeerStateMsg(msg)
if err != nil {
p.log.Errorf("failed to unmarshal open connection message: %s", err)
return
}
p.peersListener.RemoveInterestedPeer(peerIDs)
}
func (p *Peer) sendPeersOnline(peers []messages.PeerID) {
msgs, err := messages.MarshalPeersOnline(peers)
if err != nil {
p.log.Errorf("failed to marshal peer location message: %s", err)
return
}
for n, msg := range msgs {
if _, err := p.Write(msg); err != nil {
p.log.Errorf("failed to write %d. peers offline message: %s", n, err)
}
}
}
func (p *Peer) sendPeersWentOffline(peers []messages.PeerID) {
msgs, err := messages.MarshalPeersWentOffline(peers)
if err != nil {
p.log.Errorf("failed to marshal peer location message: %s", err)
return
}
for n, msg := range msgs {
if _, err := p.Write(msg); err != nil {
p.log.Errorf("failed to write %d. peers offline message: %s", n, err)
}
}
}

View File

@@ -4,26 +4,55 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"net/url"
"strings"
"sync" "sync"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/metric" "go.opentelemetry.io/otel/metric"
"github.com/netbirdio/netbird/relay/auth"
//nolint:staticcheck //nolint:staticcheck
"github.com/netbirdio/netbird/relay/metrics" "github.com/netbirdio/netbird/relay/metrics"
"github.com/netbirdio/netbird/relay/server/store"
) )
type Config struct {
Meter metric.Meter
ExposedAddress string
TLSSupport bool
AuthValidator Validator
instanceURL string
}
func (c *Config) validate() error {
if c.Meter == nil {
c.Meter = otel.Meter("")
}
if c.ExposedAddress == "" {
return fmt.Errorf("exposed address is required")
}
instanceURL, err := getInstanceURL(c.ExposedAddress, c.TLSSupport)
if err != nil {
return fmt.Errorf("invalid url: %v", err)
}
c.instanceURL = instanceURL
if c.AuthValidator == nil {
return fmt.Errorf("auth validator is required")
}
return nil
}
// Relay represents the relay server // Relay represents the relay server
type Relay struct { type Relay struct {
metrics *metrics.Metrics metrics *metrics.Metrics
metricsCancel context.CancelFunc metricsCancel context.CancelFunc
validator auth.Validator validator Validator
store *Store store *store.Store
notifier *store.PeerNotifier
instanceURL string instanceURL string
preparedMsg *preparedMsg preparedMsg *preparedMsg
@@ -31,40 +60,40 @@ type Relay struct {
closeMu sync.RWMutex closeMu sync.RWMutex
} }
// NewRelay creates a new Relay instance // NewRelay creates and returns a new Relay instance.
// //
// Parameters: // Parameters:
// meter: An instance of metric.Meter from the go.opentelemetry.io/otel/metric package. It is used to create and manage //
// metrics for the relay server. // config: A Config struct that holds the configuration needed to initialize the relay server.
// exposedAddress: A string representing the address that the relay server is exposed on. The client will use this // - Meter: A metric.Meter used for emitting metrics. If not set, a default no-op meter will be used.
// address as the relay server's instance URL. // - ExposedAddress: The external address clients use to reach this relay. Required.
// tlsSupport: A boolean indicating whether the relay server supports TLS (Transport Layer Security) or not. The // - TLSSupport: A boolean indicating if the relay uses TLS. Affects the generated instance URL.
// instance URL depends on this value. // - AuthValidator: A Validator implementation used to authenticate peers. Required.
// validator: An instance of auth.Validator from the auth package. It is used to validate the authentication of the
// peers.
// //
// Returns: // Returns:
// A pointer to a Relay instance and an error. If the Relay instance is successfully created, the error is nil. //
// Otherwise, the error contains the details of what went wrong. // A pointer to a Relay instance and an error. If initialization is successful, the error will be nil;
func NewRelay(meter metric.Meter, exposedAddress string, tlsSupport bool, validator auth.Validator) (*Relay, error) { // otherwise, it will contain the reason the relay could not be created (e.g., invalid configuration).
func NewRelay(config Config) (*Relay, error) {
if err := config.validate(); err != nil {
return nil, fmt.Errorf("invalid config: %v", err)
}
ctx, metricsCancel := context.WithCancel(context.Background()) ctx, metricsCancel := context.WithCancel(context.Background())
m, err := metrics.NewMetrics(ctx, meter) m, err := metrics.NewMetrics(ctx, config.Meter)
if err != nil { if err != nil {
metricsCancel() metricsCancel()
return nil, fmt.Errorf("creating app metrics: %v", err) return nil, fmt.Errorf("creating app metrics: %v", err)
} }
peerStore := store.NewStore()
r := &Relay{ r := &Relay{
metrics: m, metrics: m,
metricsCancel: metricsCancel, metricsCancel: metricsCancel,
validator: validator, validator: config.AuthValidator,
store: NewStore(), instanceURL: config.instanceURL,
} store: peerStore,
notifier: store.NewPeerNotifier(peerStore),
r.instanceURL, err = getInstanceURL(exposedAddress, tlsSupport)
if err != nil {
metricsCancel()
return nil, fmt.Errorf("get instance URL: %v", err)
} }
r.preparedMsg, err = newPreparedMsg(r.instanceURL) r.preparedMsg, err = newPreparedMsg(r.instanceURL)
@@ -76,32 +105,6 @@ func NewRelay(meter metric.Meter, exposedAddress string, tlsSupport bool, valida
return r, nil return r, nil
} }
// getInstanceURL checks if user supplied a URL scheme otherwise adds to the
// provided address according to TLS definition and parses the address before returning it
func getInstanceURL(exposedAddress string, tlsSupported bool) (string, error) {
addr := exposedAddress
split := strings.Split(exposedAddress, "://")
switch {
case len(split) == 1 && tlsSupported:
addr = "rels://" + exposedAddress
case len(split) == 1 && !tlsSupported:
addr = "rel://" + exposedAddress
case len(split) > 2:
return "", fmt.Errorf("invalid exposed address: %s", exposedAddress)
}
parsedURL, err := url.ParseRequestURI(addr)
if err != nil {
return "", fmt.Errorf("invalid exposed address: %v", err)
}
if parsedURL.Scheme != "rel" && parsedURL.Scheme != "rels" {
return "", fmt.Errorf("invalid scheme: %s", parsedURL.Scheme)
}
return parsedURL.String(), nil
}
// Accept start to handle a new peer connection // Accept start to handle a new peer connection
func (r *Relay) Accept(conn net.Conn) { func (r *Relay) Accept(conn net.Conn) {
acceptTime := time.Now() acceptTime := time.Now()
@@ -125,14 +128,17 @@ func (r *Relay) Accept(conn net.Conn) {
return return
} }
peer := NewPeer(r.metrics, peerID, conn, r.store) peer := NewPeer(r.metrics, *peerID, conn, r.store, r.notifier)
peer.log.Infof("peer connected from: %s", conn.RemoteAddr()) peer.log.Infof("peer connected from: %s", conn.RemoteAddr())
storeTime := time.Now() storeTime := time.Now()
r.store.AddPeer(peer) r.store.AddPeer(peer)
r.notifier.PeerCameOnline(peer.ID())
r.metrics.RecordPeerStoreTime(time.Since(storeTime)) r.metrics.RecordPeerStoreTime(time.Since(storeTime))
r.metrics.PeerConnected(peer.String()) r.metrics.PeerConnected(peer.String())
go func() { go func() {
peer.Work() peer.Work()
r.notifier.PeerWentOffline(peer.ID())
r.store.DeletePeer(peer) r.store.DeletePeer(peer)
peer.log.Debugf("relay connection closed") peer.log.Debugf("relay connection closed")
r.metrics.PeerDisconnected(peer.String()) r.metrics.PeerDisconnected(peer.String())
@@ -154,12 +160,12 @@ func (r *Relay) Shutdown(ctx context.Context) {
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
peers := r.store.Peers() peers := r.store.Peers()
for _, peer := range peers { for _, v := range peers {
wg.Add(1) wg.Add(1)
go func(p *Peer) { go func(p *Peer) {
p.CloseGracefully(ctx) p.CloseGracefully(ctx)
wg.Done() wg.Done()
}(peer) }(v.(*Peer))
} }
wg.Wait() wg.Wait()
r.metricsCancel() r.metricsCancel()

View File

@@ -6,15 +6,12 @@ import (
"sync" "sync"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/metric"
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/relay/auth"
"github.com/netbirdio/netbird/relay/server/listener" "github.com/netbirdio/netbird/relay/server/listener"
"github.com/netbirdio/netbird/relay/server/listener/quic" "github.com/netbirdio/netbird/relay/server/listener/quic"
"github.com/netbirdio/netbird/relay/server/listener/ws" "github.com/netbirdio/netbird/relay/server/listener/ws"
quictls "github.com/netbirdio/netbird/relay/tls" quictls "github.com/netbirdio/netbird/relay/tls"
log "github.com/sirupsen/logrus"
) )
// ListenerConfig is the configuration for the listener. // ListenerConfig is the configuration for the listener.
@@ -33,13 +30,22 @@ type Server struct {
listeners []listener.Listener listeners []listener.Listener
} }
// NewServer creates a new relay server instance. // NewServer creates and returns a new relay server instance.
// meter: the OpenTelemetry meter //
// exposedAddress: this address will be used as the instance URL. It should be a domain:port format. // Parameters:
// tlsSupport: if true, the server will support TLS //
// authValidator: the auth validator to use for the server // config: A Config struct containing the necessary configuration:
func NewServer(meter metric.Meter, exposedAddress string, tlsSupport bool, authValidator auth.Validator) (*Server, error) { // - Meter: An OpenTelemetry metric.Meter used for recording metrics. If nil, a default no-op meter is used.
relay, err := NewRelay(meter, exposedAddress, tlsSupport, authValidator) // - ExposedAddress: The public address (in domain:port format) used as the server's instance URL. Required.
// - TLSSupport: A boolean indicating whether TLS is enabled for the server.
// - AuthValidator: A Validator used to authenticate peers. Required.
//
// Returns:
//
// A pointer to a Server instance and an error. If the configuration is valid and initialization succeeds,
// the returned error will be nil. Otherwise, the error will describe the problem.
func NewServer(config Config) (*Server, error) {
relay, err := NewRelay(config)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -0,0 +1,120 @@
package store
import (
"context"
"sync"
"github.com/netbirdio/netbird/relay/messages"
)
type Listener struct {
ctx context.Context
store *Store
onlineChan chan messages.PeerID
offlineChan chan messages.PeerID
interestedPeersForOffline map[messages.PeerID]struct{}
interestedPeersForOnline map[messages.PeerID]struct{}
mu sync.RWMutex
}
func newListener(ctx context.Context, store *Store) *Listener {
l := &Listener{
ctx: ctx,
store: store,
onlineChan: make(chan messages.PeerID, 244), //244 is the message size limit in the relay protocol
offlineChan: make(chan messages.PeerID, 244), //244 is the message size limit in the relay protocol
interestedPeersForOffline: make(map[messages.PeerID]struct{}),
interestedPeersForOnline: make(map[messages.PeerID]struct{}),
}
return l
}
func (l *Listener) AddInterestedPeers(peerIDs []messages.PeerID) []messages.PeerID {
availablePeers := make([]messages.PeerID, 0)
l.mu.Lock()
defer l.mu.Unlock()
for _, id := range peerIDs {
l.interestedPeersForOnline[id] = struct{}{}
l.interestedPeersForOffline[id] = struct{}{}
}
// collect online peers to response back to the caller
for _, id := range peerIDs {
_, ok := l.store.Peer(id)
if !ok {
continue
}
availablePeers = append(availablePeers, id)
}
return availablePeers
}
func (l *Listener) RemoveInterestedPeer(peerIDs []messages.PeerID) {
l.mu.Lock()
defer l.mu.Unlock()
for _, id := range peerIDs {
delete(l.interestedPeersForOffline, id)
delete(l.interestedPeersForOnline, id)
}
}
func (l *Listener) listenForEvents(onPeersComeOnline, onPeersWentOffline func([]messages.PeerID)) {
for {
select {
case <-l.ctx.Done():
return
case pID := <-l.onlineChan:
peers := make([]messages.PeerID, 0)
peers = append(peers, pID)
for len(l.onlineChan) > 0 {
pID = <-l.onlineChan
peers = append(peers, pID)
}
onPeersComeOnline(peers)
case pID := <-l.offlineChan:
peers := make([]messages.PeerID, 0)
peers = append(peers, pID)
for len(l.offlineChan) > 0 {
pID = <-l.offlineChan
peers = append(peers, pID)
}
onPeersWentOffline(peers)
}
}
}
func (l *Listener) peerWentOffline(peerID messages.PeerID) {
l.mu.RLock()
defer l.mu.RUnlock()
if _, ok := l.interestedPeersForOffline[peerID]; ok {
select {
case l.offlineChan <- peerID:
case <-l.ctx.Done():
}
}
}
func (l *Listener) peerComeOnline(peerID messages.PeerID) {
l.mu.Lock()
defer l.mu.Unlock()
if _, ok := l.interestedPeersForOnline[peerID]; ok {
select {
case l.onlineChan <- peerID:
case <-l.ctx.Done():
}
delete(l.interestedPeersForOnline, peerID)
}
}

View File

@@ -0,0 +1,64 @@
package store
import (
"context"
"sync"
"github.com/netbirdio/netbird/relay/messages"
)
type PeerNotifier struct {
store *Store
listeners map[*Listener]context.CancelFunc
listenersMutex sync.RWMutex
}
func NewPeerNotifier(store *Store) *PeerNotifier {
pn := &PeerNotifier{
store: store,
listeners: make(map[*Listener]context.CancelFunc),
}
return pn
}
func (pn *PeerNotifier) NewListener(onPeersComeOnline, onPeersWentOffline func([]messages.PeerID)) *Listener {
ctx, cancel := context.WithCancel(context.Background())
listener := newListener(ctx, pn.store)
go listener.listenForEvents(onPeersComeOnline, onPeersWentOffline)
pn.listenersMutex.Lock()
pn.listeners[listener] = cancel
pn.listenersMutex.Unlock()
return listener
}
func (pn *PeerNotifier) RemoveListener(listener *Listener) {
pn.listenersMutex.Lock()
defer pn.listenersMutex.Unlock()
cancel, ok := pn.listeners[listener]
if !ok {
return
}
cancel()
delete(pn.listeners, listener)
}
func (pn *PeerNotifier) PeerWentOffline(peerID messages.PeerID) {
pn.listenersMutex.RLock()
defer pn.listenersMutex.RUnlock()
for listener := range pn.listeners {
listener.peerWentOffline(peerID)
}
}
func (pn *PeerNotifier) PeerCameOnline(peerID messages.PeerID) {
pn.listenersMutex.RLock()
defer pn.listenersMutex.RUnlock()
for listener := range pn.listeners {
listener.peerComeOnline(peerID)
}
}

View File

@@ -1,41 +1,48 @@
package server package store
import ( import (
"sync" "sync"
"github.com/netbirdio/netbird/relay/messages"
) )
type IPeer interface {
Close()
ID() messages.PeerID
}
// Store is a thread-safe store of peers // Store is a thread-safe store of peers
// It is used to store the peers that are connected to the relay server // It is used to store the peers that are connected to the relay server
type Store struct { type Store struct {
peers map[string]*Peer // consider to use [32]byte as key. The Peer(id string) would be faster peers map[messages.PeerID]IPeer
peersLock sync.RWMutex peersLock sync.RWMutex
} }
// NewStore creates a new Store instance // NewStore creates a new Store instance
func NewStore() *Store { func NewStore() *Store {
return &Store{ return &Store{
peers: make(map[string]*Peer), peers: make(map[messages.PeerID]IPeer),
} }
} }
// AddPeer adds a peer to the store // AddPeer adds a peer to the store
func (s *Store) AddPeer(peer *Peer) { func (s *Store) AddPeer(peer IPeer) {
s.peersLock.Lock() s.peersLock.Lock()
defer s.peersLock.Unlock() defer s.peersLock.Unlock()
odlPeer, ok := s.peers[peer.String()] odlPeer, ok := s.peers[peer.ID()]
if ok { if ok {
odlPeer.Close() odlPeer.Close()
} }
s.peers[peer.String()] = peer s.peers[peer.ID()] = peer
} }
// DeletePeer deletes a peer from the store // DeletePeer deletes a peer from the store
func (s *Store) DeletePeer(peer *Peer) { func (s *Store) DeletePeer(peer IPeer) {
s.peersLock.Lock() s.peersLock.Lock()
defer s.peersLock.Unlock() defer s.peersLock.Unlock()
dp, ok := s.peers[peer.String()] dp, ok := s.peers[peer.ID()]
if !ok { if !ok {
return return
} }
@@ -43,11 +50,11 @@ func (s *Store) DeletePeer(peer *Peer) {
return return
} }
delete(s.peers, peer.String()) delete(s.peers, peer.ID())
} }
// Peer returns a peer by its ID // Peer returns a peer by its ID
func (s *Store) Peer(id string) (*Peer, bool) { func (s *Store) Peer(id messages.PeerID) (IPeer, bool) {
s.peersLock.RLock() s.peersLock.RLock()
defer s.peersLock.RUnlock() defer s.peersLock.RUnlock()
@@ -56,11 +63,11 @@ func (s *Store) Peer(id string) (*Peer, bool) {
} }
// Peers returns all the peers in the store // Peers returns all the peers in the store
func (s *Store) Peers() []*Peer { func (s *Store) Peers() []IPeer {
s.peersLock.RLock() s.peersLock.RLock()
defer s.peersLock.RUnlock() defer s.peersLock.RUnlock()
peers := make([]*Peer, 0, len(s.peers)) peers := make([]IPeer, 0, len(s.peers))
for _, p := range s.peers { for _, p := range s.peers {
peers = append(peers, p) peers = append(peers, p)
} }

View File

@@ -0,0 +1,49 @@
package store
import (
"testing"
"github.com/netbirdio/netbird/relay/messages"
)
type MocPeer struct {
id messages.PeerID
}
func (m *MocPeer) Close() {
}
func (m *MocPeer) ID() messages.PeerID {
return m.id
}
func TestStore_DeletePeer(t *testing.T) {
s := NewStore()
pID := messages.HashID("peer_one")
p := &MocPeer{id: pID}
s.AddPeer(p)
s.DeletePeer(p)
if _, ok := s.Peer(pID); ok {
t.Errorf("peer was not deleted")
}
}
func TestStore_DeleteDeprecatedPeer(t *testing.T) {
s := NewStore()
pID1 := messages.HashID("peer_one")
pID2 := messages.HashID("peer_one")
p1 := &MocPeer{id: pID1}
p2 := &MocPeer{id: pID2}
s.AddPeer(p1)
s.AddPeer(p2)
s.DeletePeer(p1)
if _, ok := s.Peer(pID2); !ok {
t.Errorf("second peer was deleted")
}
}

View File

@@ -1,85 +0,0 @@
package server
import (
"context"
"net"
"testing"
"time"
"go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/relay/metrics"
)
type mockConn struct {
}
func (m mockConn) Read(b []byte) (n int, err error) {
//TODO implement me
panic("implement me")
}
func (m mockConn) Write(b []byte) (n int, err error) {
//TODO implement me
panic("implement me")
}
func (m mockConn) Close() error {
return nil
}
func (m mockConn) LocalAddr() net.Addr {
//TODO implement me
panic("implement me")
}
func (m mockConn) RemoteAddr() net.Addr {
//TODO implement me
panic("implement me")
}
func (m mockConn) SetDeadline(t time.Time) error {
//TODO implement me
panic("implement me")
}
func (m mockConn) SetReadDeadline(t time.Time) error {
//TODO implement me
panic("implement me")
}
func (m mockConn) SetWriteDeadline(t time.Time) error {
//TODO implement me
panic("implement me")
}
func TestStore_DeletePeer(t *testing.T) {
s := NewStore()
m, _ := metrics.NewMetrics(context.Background(), otel.Meter(""))
p := NewPeer(m, []byte("peer_one"), nil, nil)
s.AddPeer(p)
s.DeletePeer(p)
if _, ok := s.Peer(p.String()); ok {
t.Errorf("peer was not deleted")
}
}
func TestStore_DeleteDeprecatedPeer(t *testing.T) {
s := NewStore()
m, _ := metrics.NewMetrics(context.Background(), otel.Meter(""))
conn := &mockConn{}
p1 := NewPeer(m, []byte("peer_id"), conn, nil)
p2 := NewPeer(m, []byte("peer_id"), conn, nil)
s.AddPeer(p1)
s.AddPeer(p2)
s.DeletePeer(p1)
if _, ok := s.Peer(p2.String()); !ok {
t.Errorf("second peer was deleted")
}
}

33
relay/server/url.go Normal file
View File

@@ -0,0 +1,33 @@
package server
import (
"fmt"
"net/url"
"strings"
)
// getInstanceURL checks if user supplied a URL scheme otherwise adds to the
// provided address according to TLS definition and parses the address before returning it
func getInstanceURL(exposedAddress string, tlsSupported bool) (string, error) {
addr := exposedAddress
split := strings.Split(exposedAddress, "://")
switch {
case len(split) == 1 && tlsSupported:
addr = "rels://" + exposedAddress
case len(split) == 1 && !tlsSupported:
addr = "rel://" + exposedAddress
case len(split) > 2:
return "", fmt.Errorf("invalid exposed address: %s", exposedAddress)
}
parsedURL, err := url.ParseRequestURI(addr)
if err != nil {
return "", fmt.Errorf("invalid exposed address: %v", err)
}
if parsedURL.Scheme != "rel" && parsedURL.Scheme != "rels" {
return "", fmt.Errorf("invalid scheme: %s", parsedURL.Scheme)
}
return parsedURL.String(), nil
}

View File

@@ -12,7 +12,6 @@ import (
"github.com/pion/logging" "github.com/pion/logging"
"github.com/pion/turn/v3" "github.com/pion/turn/v3"
"go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/relay/auth/allow" "github.com/netbirdio/netbird/relay/auth/allow"
"github.com/netbirdio/netbird/relay/auth/hmac" "github.com/netbirdio/netbird/relay/auth/hmac"
@@ -22,7 +21,6 @@ import (
) )
var ( var (
av = &allow.Auth{}
hmacTokenStore = &hmac.TokenStore{} hmacTokenStore = &hmac.TokenStore{}
pairs = []int{1, 5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100} pairs = []int{1, 5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100}
dataSize = 1024 * 1024 * 10 dataSize = 1024 * 1024 * 10
@@ -70,8 +68,12 @@ func transfer(t *testing.T, testData []byte, peerPairs int) {
port := 35000 + peerPairs port := 35000 + peerPairs
serverAddress := fmt.Sprintf("127.0.0.1:%d", port) serverAddress := fmt.Sprintf("127.0.0.1:%d", port)
serverConnURL := fmt.Sprintf("rel://%s", serverAddress) serverConnURL := fmt.Sprintf("rel://%s", serverAddress)
serverCfg := server.Config{
srv, err := server.NewServer(otel.Meter(""), serverConnURL, false, av) ExposedAddress: serverConnURL,
TLSSupport: false,
AuthValidator: &allow.Auth{},
}
srv, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@@ -98,8 +100,8 @@ func transfer(t *testing.T, testData []byte, peerPairs int) {
clientsSender := make([]*client.Client, peerPairs) clientsSender := make([]*client.Client, peerPairs)
for i := 0; i < cap(clientsSender); i++ { for i := 0; i < cap(clientsSender); i++ {
c := client.NewClient(ctx, serverConnURL, hmacTokenStore, "sender-"+fmt.Sprint(i)) c := client.NewClient(serverConnURL, hmacTokenStore, "sender-"+fmt.Sprint(i))
err := c.Connect() err := c.Connect(ctx)
if err != nil { if err != nil {
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
} }
@@ -108,8 +110,8 @@ func transfer(t *testing.T, testData []byte, peerPairs int) {
clientsReceiver := make([]*client.Client, peerPairs) clientsReceiver := make([]*client.Client, peerPairs)
for i := 0; i < cap(clientsReceiver); i++ { for i := 0; i < cap(clientsReceiver); i++ {
c := client.NewClient(ctx, serverConnURL, hmacTokenStore, "receiver-"+fmt.Sprint(i)) c := client.NewClient(serverConnURL, hmacTokenStore, "receiver-"+fmt.Sprint(i))
err := c.Connect() err := c.Connect(ctx)
if err != nil { if err != nil {
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
} }
@@ -119,13 +121,13 @@ func transfer(t *testing.T, testData []byte, peerPairs int) {
connsSender := make([]net.Conn, 0, peerPairs) connsSender := make([]net.Conn, 0, peerPairs)
connsReceiver := make([]net.Conn, 0, peerPairs) connsReceiver := make([]net.Conn, 0, peerPairs)
for i := 0; i < len(clientsSender); i++ { for i := 0; i < len(clientsSender); i++ {
conn, err := clientsSender[i].OpenConn("receiver-" + fmt.Sprint(i)) conn, err := clientsSender[i].OpenConn(ctx, "receiver-"+fmt.Sprint(i))
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }
connsSender = append(connsSender, conn) connsSender = append(connsSender, conn)
conn, err = clientsReceiver[i].OpenConn("sender-" + fmt.Sprint(i)) conn, err = clientsReceiver[i].OpenConn(ctx, "sender-"+fmt.Sprint(i))
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }

View File

@@ -70,8 +70,8 @@ func prepareConnsSender(serverConnURL string, peerPairs int) []net.Conn {
ctx := context.Background() ctx := context.Background()
clientsSender := make([]*client.Client, peerPairs) clientsSender := make([]*client.Client, peerPairs)
for i := 0; i < cap(clientsSender); i++ { for i := 0; i < cap(clientsSender); i++ {
c := client.NewClient(ctx, serverConnURL, hmacTokenStore, "sender-"+fmt.Sprint(i)) c := client.NewClient(serverConnURL, hmacTokenStore, "sender-"+fmt.Sprint(i))
if err := c.Connect(); err != nil { if err := c.Connect(ctx); err != nil {
log.Fatalf("failed to connect to server: %s", err) log.Fatalf("failed to connect to server: %s", err)
} }
clientsSender[i] = c clientsSender[i] = c
@@ -79,7 +79,7 @@ func prepareConnsSender(serverConnURL string, peerPairs int) []net.Conn {
connsSender := make([]net.Conn, 0, peerPairs) connsSender := make([]net.Conn, 0, peerPairs)
for i := 0; i < len(clientsSender); i++ { for i := 0; i < len(clientsSender); i++ {
conn, err := clientsSender[i].OpenConn("receiver-" + fmt.Sprint(i)) conn, err := clientsSender[i].OpenConn(ctx, "receiver-"+fmt.Sprint(i))
if err != nil { if err != nil {
log.Fatalf("failed to bind channel: %s", err) log.Fatalf("failed to bind channel: %s", err)
} }
@@ -156,8 +156,8 @@ func runReader(conn net.Conn) time.Duration {
func prepareConnsReceiver(serverConnURL string, peerPairs int) []net.Conn { func prepareConnsReceiver(serverConnURL string, peerPairs int) []net.Conn {
clientsReceiver := make([]*client.Client, peerPairs) clientsReceiver := make([]*client.Client, peerPairs)
for i := 0; i < cap(clientsReceiver); i++ { for i := 0; i < cap(clientsReceiver); i++ {
c := client.NewClient(context.Background(), serverConnURL, hmacTokenStore, "receiver-"+fmt.Sprint(i)) c := client.NewClient(serverConnURL, hmacTokenStore, "receiver-"+fmt.Sprint(i))
err := c.Connect() err := c.Connect(context.Background())
if err != nil { if err != nil {
log.Fatalf("failed to connect to server: %s", err) log.Fatalf("failed to connect to server: %s", err)
} }
@@ -166,7 +166,7 @@ func prepareConnsReceiver(serverConnURL string, peerPairs int) []net.Conn {
connsReceiver := make([]net.Conn, 0, peerPairs) connsReceiver := make([]net.Conn, 0, peerPairs)
for i := 0; i < len(clientsReceiver); i++ { for i := 0; i < len(clientsReceiver); i++ {
conn, err := clientsReceiver[i].OpenConn("sender-" + fmt.Sprint(i)) conn, err := clientsReceiver[i].OpenConn(context.Background(), "sender-"+fmt.Sprint(i))
if err != nil { if err != nil {
log.Fatalf("failed to bind channel: %s", err) log.Fatalf("failed to bind channel: %s", err)
} }

View File

@@ -6,6 +6,7 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"net/netip"
"sync" "sync"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -17,11 +18,16 @@ type ListenerWriteHookFunc func(connID ConnectionID, ip *net.IPAddr, data []byte
// ListenerCloseHookFunc defines the function signature for close hooks for PacketConn. // ListenerCloseHookFunc defines the function signature for close hooks for PacketConn.
type ListenerCloseHookFunc func(connID ConnectionID, conn net.PacketConn) error type ListenerCloseHookFunc func(connID ConnectionID, conn net.PacketConn) error
// ListenerAddressRemoveHookFunc defines the function signature for hooks called when addresses are removed.
type ListenerAddressRemoveHookFunc func(connID ConnectionID, prefix netip.Prefix) error
var ( var (
listenerWriteHooksMutex sync.RWMutex listenerWriteHooksMutex sync.RWMutex
listenerWriteHooks []ListenerWriteHookFunc listenerWriteHooks []ListenerWriteHookFunc
listenerCloseHooksMutex sync.RWMutex listenerCloseHooksMutex sync.RWMutex
listenerCloseHooks []ListenerCloseHookFunc listenerCloseHooks []ListenerCloseHookFunc
listenerAddressRemoveHooksMutex sync.RWMutex
listenerAddressRemoveHooks []ListenerAddressRemoveHookFunc
) )
// AddListenerWriteHook allows adding a new write hook to be executed before a UDP packet is sent. // AddListenerWriteHook allows adding a new write hook to be executed before a UDP packet is sent.
@@ -38,7 +44,14 @@ func AddListenerCloseHook(hook ListenerCloseHookFunc) {
listenerCloseHooks = append(listenerCloseHooks, hook) listenerCloseHooks = append(listenerCloseHooks, hook)
} }
// RemoveListenerHooks removes all dialer hooks. // AddListenerAddressRemoveHook allows adding a new hook to be executed when an address is removed.
func AddListenerAddressRemoveHook(hook ListenerAddressRemoveHookFunc) {
listenerAddressRemoveHooksMutex.Lock()
defer listenerAddressRemoveHooksMutex.Unlock()
listenerAddressRemoveHooks = append(listenerAddressRemoveHooks, hook)
}
// RemoveListenerHooks removes all listener hooks.
func RemoveListenerHooks() { func RemoveListenerHooks() {
listenerWriteHooksMutex.Lock() listenerWriteHooksMutex.Lock()
defer listenerWriteHooksMutex.Unlock() defer listenerWriteHooksMutex.Unlock()
@@ -47,6 +60,10 @@ func RemoveListenerHooks() {
listenerCloseHooksMutex.Lock() listenerCloseHooksMutex.Lock()
defer listenerCloseHooksMutex.Unlock() defer listenerCloseHooksMutex.Unlock()
listenerCloseHooks = nil listenerCloseHooks = nil
listenerAddressRemoveHooksMutex.Lock()
defer listenerAddressRemoveHooksMutex.Unlock()
listenerAddressRemoveHooks = nil
} }
// ListenPacket listens on the network address and returns a PacketConn // ListenPacket listens on the network address and returns a PacketConn
@@ -61,6 +78,7 @@ func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address stri
return nil, fmt.Errorf("listen packet: %w", err) return nil, fmt.Errorf("listen packet: %w", err)
} }
connID := GenerateConnID() connID := GenerateConnID()
return &PacketConn{PacketConn: pc, ID: connID, seenAddrs: &sync.Map{}}, nil return &PacketConn{PacketConn: pc, ID: connID, seenAddrs: &sync.Map{}}, nil
} }
@@ -102,6 +120,45 @@ func (c *UDPConn) Close() error {
return closeConn(c.ID, c.UDPConn) return closeConn(c.ID, c.UDPConn)
} }
// WrapUDPConn wraps an existing *net.UDPConn with nbnet functionality
func WrapUDPConn(conn *net.UDPConn) *UDPConn {
return &UDPConn{
UDPConn: conn,
ID: GenerateConnID(),
seenAddrs: &sync.Map{},
}
}
// RemoveAddress removes an address from the seen cache and triggers removal hooks.
func (c *UDPConn) RemoveAddress(addr string) {
if _, exists := c.seenAddrs.LoadAndDelete(addr); !exists {
return
}
ipStr, _, err := net.SplitHostPort(addr)
if err != nil {
log.Errorf("Error splitting IP address and port: %v", err)
return
}
ipAddr, err := netip.ParseAddr(ipStr)
if err != nil {
log.Errorf("Error parsing IP address %s: %v", ipStr, err)
return
}
prefix := netip.PrefixFrom(ipAddr, ipAddr.BitLen())
listenerAddressRemoveHooksMutex.RLock()
defer listenerAddressRemoveHooksMutex.RUnlock()
for _, hook := range listenerAddressRemoveHooks {
if err := hook(c.ID, prefix); err != nil {
log.Errorf("Error executing listener address remove hook: %v", err)
}
}
}
func callWriteHooks(id ConnectionID, seenAddrs *sync.Map, b []byte, addr net.Addr) { func callWriteHooks(id ConnectionID, seenAddrs *sync.Map, b []byte, addr net.Addr) {
// Lookup the address in the seenAddrs map to avoid calling the hooks for every write // Lookup the address in the seenAddrs map to avoid calling the hooks for every write
if _, loaded := seenAddrs.LoadOrStore(addr.String(), true); !loaded { if _, loaded := seenAddrs.LoadOrStore(addr.String(), true); !loaded {

View File

@@ -0,0 +1,10 @@
package net
import (
"net"
)
// WrapUDPConn on iOS just returns the original connection since iOS handles its own networking
func WrapUDPConn(conn *net.UDPConn) *net.UDPConn {
return conn
}