[relay] Feature/relay integration (#2244)

This update adds new relay integration for NetBird clients. The new relay is based on web sockets and listens on a single port.

- Adds new relay implementation with websocket with single port relaying mechanism
- refactor peer connection logic, allowing upgrade and downgrade from/to P2P connection
- peer connections are faster since it connects first to relay and then upgrades to P2P
- maintains compatibility with old clients by not using the new relay
- updates infrastructure scripts with new relay service
This commit is contained in:
Zoltan Papp
2024-09-08 12:06:14 +02:00
committed by GitHub
parent fcac02a92f
commit 0c039274a4
120 changed files with 9879 additions and 1940 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -2,6 +2,7 @@ package peer
import (
"context"
"os"
"sync"
"testing"
"time"
@@ -11,6 +12,7 @@ import (
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/internal/wgproxy"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/util"
)
var connConf = ConnConfig{
@@ -23,6 +25,12 @@ var connConf = ConnConfig{
},
}
func TestMain(m *testing.M) {
_ = util.InitLog("trace", "console")
code := m.Run()
os.Exit(code)
}
func TestNewConn_interfaceFilter(t *testing.T) {
ignore := []string{iface.WgInterfaceDefault, "tun0", "zt", "ZeroTier", "utun", "wg", "ts",
"Tailscale", "tailscale"}
@@ -40,7 +48,7 @@ func TestConn_GetKey(t *testing.T) {
defer func() {
_ = wgProxyFactory.Free()
}()
conn, err := NewConn(connConf, nil, wgProxyFactory, nil, nil)
conn, err := NewConn(context.Background(), connConf, nil, wgProxyFactory, nil, nil, nil)
if err != nil {
return
}
@@ -55,7 +63,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
defer func() {
_ = wgProxyFactory.Free()
}()
conn, err := NewConn(connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil)
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil, nil)
if err != nil {
return
}
@@ -63,7 +71,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
wg := sync.WaitGroup{}
wg.Add(2)
go func() {
<-conn.remoteOffersCh
<-conn.handshaker.remoteOffersCh
wg.Done()
}()
@@ -92,7 +100,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
defer func() {
_ = wgProxyFactory.Free()
}()
conn, err := NewConn(connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil)
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil, nil)
if err != nil {
return
}
@@ -100,7 +108,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
wg := sync.WaitGroup{}
wg.Add(2)
go func() {
<-conn.remoteAnswerCh
<-conn.handshaker.remoteAnswerCh
wg.Done()
}()
@@ -128,58 +136,33 @@ func TestConn_Status(t *testing.T) {
defer func() {
_ = wgProxyFactory.Free()
}()
conn, err := NewConn(connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil)
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil, nil)
if err != nil {
return
}
tables := []struct {
name string
status ConnStatus
want ConnStatus
name string
statusIce ConnStatus
statusRelay ConnStatus
want ConnStatus
}{
{"StatusConnected", StatusConnected, StatusConnected},
{"StatusDisconnected", StatusDisconnected, StatusDisconnected},
{"StatusConnecting", StatusConnecting, StatusConnecting},
{"StatusConnected", StatusConnected, StatusConnected, StatusConnected},
{"StatusDisconnected", StatusDisconnected, StatusDisconnected, StatusDisconnected},
{"StatusConnecting", StatusConnecting, StatusConnecting, StatusConnecting},
{"StatusConnectingIce", StatusConnecting, StatusDisconnected, StatusConnecting},
{"StatusConnectingIceAlternative", StatusConnecting, StatusConnected, StatusConnected},
{"StatusConnectingRelay", StatusDisconnected, StatusConnecting, StatusConnecting},
{"StatusConnectingRelayAlternative", StatusConnected, StatusConnecting, StatusConnected},
}
for _, table := range tables {
t.Run(table.name, func(t *testing.T) {
conn.status = table.status
conn.statusICE = table.statusIce
conn.statusRelay = table.statusRelay
got := conn.Status()
assert.Equal(t, got, table.want, "they should be equal")
})
}
}
func TestConn_Close(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()
conn, err := NewConn(connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil)
if err != nil {
return
}
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
<-conn.closeCh
wg.Done()
}()
go func() {
for {
err := conn.Close()
if err != nil {
continue
} else {
return
}
}
}()
wg.Wait()
}

View File

@@ -0,0 +1,192 @@
package peer
import (
"context"
"errors"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/version"
)
var (
ErrSignalIsNotReady = errors.New("signal is not ready")
)
// IceCredentials ICE protocol credentials struct
type IceCredentials struct {
UFrag string
Pwd string
}
// OfferAnswer represents a session establishment offer or answer
type OfferAnswer struct {
IceCredentials IceCredentials
// WgListenPort is a remote WireGuard listen port.
// This field is used when establishing a direct WireGuard connection without any proxy.
// We can set the remote peer's endpoint with this port.
WgListenPort int
// Version of NetBird Agent
Version string
// RosenpassPubKey is the Rosenpass public key of the remote peer when receiving this message
// This value is the local Rosenpass server public key when sending the message
RosenpassPubKey []byte
// RosenpassAddr is the Rosenpass server address (IP:port) of the remote peer when receiving this message
// This value is the local Rosenpass server address when sending the message
RosenpassAddr string
// relay server address
RelaySrvAddress string
}
type Handshaker struct {
mu sync.Mutex
ctx context.Context
log *log.Entry
config ConnConfig
signaler *Signaler
ice *WorkerICE
relay *WorkerRelay
onNewOfferListeners []func(*OfferAnswer)
// remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection
remoteOffersCh chan OfferAnswer
// remoteAnswerCh is a channel used to wait for remote credentials answer (confirmation of our offer) to proceed with the connection
remoteAnswerCh chan OfferAnswer
}
func NewHandshaker(ctx context.Context, log *log.Entry, config ConnConfig, signaler *Signaler, ice *WorkerICE, relay *WorkerRelay) *Handshaker {
return &Handshaker{
ctx: ctx,
log: log,
config: config,
signaler: signaler,
ice: ice,
relay: relay,
remoteOffersCh: make(chan OfferAnswer),
remoteAnswerCh: make(chan OfferAnswer),
}
}
func (h *Handshaker) AddOnNewOfferListener(offer func(remoteOfferAnswer *OfferAnswer)) {
h.onNewOfferListeners = append(h.onNewOfferListeners, offer)
}
func (h *Handshaker) Listen() {
for {
h.log.Debugf("wait for remote offer confirmation")
remoteOfferAnswer, err := h.waitForRemoteOfferConfirmation()
if err != nil {
var connectionClosedError *ConnectionClosedError
if errors.As(err, &connectionClosedError) {
h.log.Tracef("stop handshaker")
return
}
h.log.Errorf("failed to received remote offer confirmation: %s", err)
continue
}
h.log.Debugf("received connection confirmation, running version %s and with remote WireGuard listen port %d", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort)
for _, listener := range h.onNewOfferListeners {
go listener(remoteOfferAnswer)
}
}
}
func (h *Handshaker) SendOffer() error {
h.mu.Lock()
defer h.mu.Unlock()
return h.sendOffer()
}
// OnRemoteOffer handles an offer from the remote peer and returns true if the message was accepted, false otherwise
// doesn't block, discards the message if connection wasn't ready
func (h *Handshaker) OnRemoteOffer(offer OfferAnswer) bool {
select {
case h.remoteOffersCh <- offer:
return true
default:
h.log.Debugf("OnRemoteOffer skipping message because is not ready")
// connection might not be ready yet to receive so we ignore the message
return false
}
}
// OnRemoteAnswer handles an offer from the remote peer and returns true if the message was accepted, false otherwise
// doesn't block, discards the message if connection wasn't ready
func (h *Handshaker) OnRemoteAnswer(answer OfferAnswer) bool {
select {
case h.remoteAnswerCh <- answer:
return true
default:
// connection might not be ready yet to receive so we ignore the message
h.log.Debugf("OnRemoteAnswer skipping message because is not ready")
return false
}
}
func (h *Handshaker) waitForRemoteOfferConfirmation() (*OfferAnswer, error) {
select {
case remoteOfferAnswer := <-h.remoteOffersCh:
// received confirmation from the remote peer -> ready to proceed
err := h.sendAnswer()
if err != nil {
return nil, err
}
return &remoteOfferAnswer, nil
case remoteOfferAnswer := <-h.remoteAnswerCh:
return &remoteOfferAnswer, nil
case <-h.ctx.Done():
// closed externally
return nil, NewConnectionClosedError(h.config.Key)
}
}
// sendOffer prepares local user credentials and signals them to the remote peer
func (h *Handshaker) sendOffer() error {
if !h.signaler.Ready() {
return ErrSignalIsNotReady
}
iceUFrag, icePwd := h.ice.GetLocalUserCredentials()
offer := OfferAnswer{
IceCredentials: IceCredentials{iceUFrag, icePwd},
WgListenPort: h.config.LocalWgPort,
Version: version.NetbirdVersion(),
RosenpassPubKey: h.config.RosenpassPubKey,
RosenpassAddr: h.config.RosenpassAddr,
}
addr, err := h.relay.RelayInstanceAddress()
if err == nil {
offer.RelaySrvAddress = addr
}
return h.signaler.SignalOffer(offer, h.config.Key)
}
func (h *Handshaker) sendAnswer() error {
h.log.Debugf("sending answer")
uFrag, pwd := h.ice.GetLocalUserCredentials()
answer := OfferAnswer{
IceCredentials: IceCredentials{uFrag, pwd},
WgListenPort: h.config.LocalWgPort,
Version: version.NetbirdVersion(),
RosenpassPubKey: h.config.RosenpassPubKey,
RosenpassAddr: h.config.RosenpassAddr,
}
addr, err := h.relay.RelayInstanceAddress()
if err == nil {
answer.RelaySrvAddress = addr
}
err = h.signaler.SignalAnswer(answer, h.config.Key)
if err != nil {
return err
}
return nil
}

View File

@@ -0,0 +1,70 @@
package peer
import (
"github.com/pion/ice/v3"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
signal "github.com/netbirdio/netbird/signal/client"
sProto "github.com/netbirdio/netbird/signal/proto"
)
type Signaler struct {
signal signal.Client
wgPrivateKey wgtypes.Key
}
func NewSignaler(signal signal.Client, wgPrivateKey wgtypes.Key) *Signaler {
return &Signaler{
signal: signal,
wgPrivateKey: wgPrivateKey,
}
}
func (s *Signaler) SignalOffer(offer OfferAnswer, remoteKey string) error {
return s.signalOfferAnswer(offer, remoteKey, sProto.Body_OFFER)
}
func (s *Signaler) SignalAnswer(offer OfferAnswer, remoteKey string) error {
return s.signalOfferAnswer(offer, remoteKey, sProto.Body_ANSWER)
}
func (s *Signaler) SignalICECandidate(candidate ice.Candidate, remoteKey string) error {
return s.signal.Send(&sProto.Message{
Key: s.wgPrivateKey.PublicKey().String(),
RemoteKey: remoteKey,
Body: &sProto.Body{
Type: sProto.Body_CANDIDATE,
Payload: candidate.Marshal(),
},
})
}
func (s *Signaler) Ready() bool {
return s.signal.Ready()
}
// SignalOfferAnswer signals either an offer or an answer to remote peer
func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string, bodyType sProto.Body_Type) error {
msg, err := signal.MarshalCredential(
s.wgPrivateKey,
offerAnswer.WgListenPort,
remoteKey,
&signal.Credential{
UFrag: offerAnswer.IceCredentials.UFrag,
Pwd: offerAnswer.IceCredentials.Pwd,
},
bodyType,
offerAnswer.RosenpassPubKey,
offerAnswer.RosenpassAddr,
offerAnswer.RelaySrvAddress)
if err != nil {
return err
}
err = s.signal.Send(msg)
if err != nil {
return err
}
return nil
}

View File

@@ -3,6 +3,7 @@ package peer
import (
"errors"
"net/netip"
"slices"
"sync"
"time"
@@ -13,6 +14,7 @@ import (
"github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/management/domain"
relayClient "github.com/netbirdio/netbird/relay/client"
)
// State contains the latest state of a peer
@@ -24,11 +26,11 @@ type State struct {
ConnStatus ConnStatus
ConnStatusUpdate time.Time
Relayed bool
Direct bool
LocalIceCandidateType string
RemoteIceCandidateType string
LocalIceCandidateEndpoint string
RemoteIceCandidateEndpoint string
RelayServerAddress string
LastWireguardHandshake time.Time
BytesTx int64
BytesRx int64
@@ -142,6 +144,8 @@ type Status struct {
// Some Peer actions mostly used by in a batch when the network map has been synchronized. In these type of events
// set to true this variable and at the end of the processing we will reset it by the FinishPeerListModifications()
peerListChangedForNotification bool
relayMgr *relayClient.Manager
}
// NewRecorder returns a new Status instance
@@ -156,6 +160,12 @@ func NewRecorder(mgmAddress string) *Status {
}
}
func (d *Status) SetRelayMgr(manager *relayClient.Manager) {
d.mux.Lock()
defer d.mux.Unlock()
d.relayMgr = manager
}
// ReplaceOfflinePeers replaces
func (d *Status) ReplaceOfflinePeers(replacement []State) {
d.mux.Lock()
@@ -231,17 +241,17 @@ func (d *Status) UpdatePeerState(receivedState State) error {
peerState.SetRoutes(receivedState.GetRoutes())
}
skipNotification := shouldSkipNotify(receivedState, peerState)
skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState)
if receivedState.ConnStatus != peerState.ConnStatus {
peerState.ConnStatus = receivedState.ConnStatus
peerState.ConnStatusUpdate = receivedState.ConnStatusUpdate
peerState.Direct = receivedState.Direct
peerState.Relayed = receivedState.Relayed
peerState.LocalIceCandidateType = receivedState.LocalIceCandidateType
peerState.RemoteIceCandidateType = receivedState.RemoteIceCandidateType
peerState.LocalIceCandidateEndpoint = receivedState.LocalIceCandidateEndpoint
peerState.RemoteIceCandidateEndpoint = receivedState.RemoteIceCandidateEndpoint
peerState.RelayServerAddress = receivedState.RelayServerAddress
peerState.RosenpassEnabled = receivedState.RosenpassEnabled
}
@@ -261,6 +271,146 @@ func (d *Status) UpdatePeerState(receivedState State) error {
return nil
}
func (d *Status) UpdatePeerICEState(receivedState State) error {
d.mux.Lock()
defer d.mux.Unlock()
peerState, ok := d.peers[receivedState.PubKey]
if !ok {
return errors.New("peer doesn't exist")
}
if receivedState.IP != "" {
peerState.IP = receivedState.IP
}
skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState)
peerState.ConnStatus = receivedState.ConnStatus
peerState.ConnStatusUpdate = receivedState.ConnStatusUpdate
peerState.Relayed = receivedState.Relayed
peerState.LocalIceCandidateType = receivedState.LocalIceCandidateType
peerState.RemoteIceCandidateType = receivedState.RemoteIceCandidateType
peerState.LocalIceCandidateEndpoint = receivedState.LocalIceCandidateEndpoint
peerState.RemoteIceCandidateEndpoint = receivedState.RemoteIceCandidateEndpoint
peerState.RosenpassEnabled = receivedState.RosenpassEnabled
d.peers[receivedState.PubKey] = peerState
if skipNotification {
return nil
}
ch, found := d.changeNotify[receivedState.PubKey]
if found && ch != nil {
close(ch)
d.changeNotify[receivedState.PubKey] = nil
}
d.notifyPeerListChanged()
return nil
}
func (d *Status) UpdatePeerRelayedState(receivedState State) error {
d.mux.Lock()
defer d.mux.Unlock()
peerState, ok := d.peers[receivedState.PubKey]
if !ok {
return errors.New("peer doesn't exist")
}
skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState)
peerState.ConnStatus = receivedState.ConnStatus
peerState.ConnStatusUpdate = receivedState.ConnStatusUpdate
peerState.Relayed = receivedState.Relayed
peerState.RelayServerAddress = receivedState.RelayServerAddress
peerState.RosenpassEnabled = receivedState.RosenpassEnabled
d.peers[receivedState.PubKey] = peerState
if skipNotification {
return nil
}
ch, found := d.changeNotify[receivedState.PubKey]
if found && ch != nil {
close(ch)
d.changeNotify[receivedState.PubKey] = nil
}
d.notifyPeerListChanged()
return nil
}
func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error {
d.mux.Lock()
defer d.mux.Unlock()
peerState, ok := d.peers[receivedState.PubKey]
if !ok {
return errors.New("peer doesn't exist")
}
skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState)
peerState.ConnStatus = receivedState.ConnStatus
peerState.Relayed = receivedState.Relayed
peerState.ConnStatusUpdate = receivedState.ConnStatusUpdate
peerState.RelayServerAddress = ""
d.peers[receivedState.PubKey] = peerState
if skipNotification {
return nil
}
ch, found := d.changeNotify[receivedState.PubKey]
if found && ch != nil {
close(ch)
d.changeNotify[receivedState.PubKey] = nil
}
d.notifyPeerListChanged()
return nil
}
func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error {
d.mux.Lock()
defer d.mux.Unlock()
peerState, ok := d.peers[receivedState.PubKey]
if !ok {
return errors.New("peer doesn't exist")
}
skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState)
peerState.ConnStatus = receivedState.ConnStatus
peerState.Relayed = receivedState.Relayed
peerState.ConnStatusUpdate = receivedState.ConnStatusUpdate
peerState.LocalIceCandidateType = receivedState.LocalIceCandidateType
peerState.RemoteIceCandidateType = receivedState.RemoteIceCandidateType
peerState.LocalIceCandidateEndpoint = receivedState.LocalIceCandidateEndpoint
peerState.RemoteIceCandidateEndpoint = receivedState.RemoteIceCandidateEndpoint
d.peers[receivedState.PubKey] = peerState
if skipNotification {
return nil
}
ch, found := d.changeNotify[receivedState.PubKey]
if found && ch != nil {
close(ch)
d.changeNotify[receivedState.PubKey] = nil
}
d.notifyPeerListChanged()
return nil
}
// UpdateWireGuardPeerState updates the WireGuard bits of the peer state
func (d *Status) UpdateWireGuardPeerState(pubKey string, wgStats iface.WGStats) error {
d.mux.Lock()
@@ -280,13 +430,13 @@ func (d *Status) UpdateWireGuardPeerState(pubKey string, wgStats iface.WGStats)
return nil
}
func shouldSkipNotify(received, curr State) bool {
func shouldSkipNotify(receivedConnStatus ConnStatus, curr State) bool {
switch {
case received.ConnStatus == StatusConnecting:
case receivedConnStatus == StatusConnecting:
return true
case received.ConnStatus == StatusDisconnected && curr.ConnStatus == StatusConnecting:
case receivedConnStatus == StatusDisconnected && curr.ConnStatus == StatusConnecting:
return true
case received.ConnStatus == StatusDisconnected && curr.ConnStatus == StatusDisconnected:
case receivedConnStatus == StatusDisconnected && curr.ConnStatus == StatusDisconnected:
return curr.IP != ""
default:
return false
@@ -502,8 +652,35 @@ func (d *Status) GetSignalState() SignalState {
}
}
// GetRelayStates returns the stun/turn/permanent relay states
func (d *Status) GetRelayStates() []relay.ProbeResult {
return d.relayStates
if d.relayMgr == nil {
return d.relayStates
}
// extend the list of stun, turn servers with relay address
relayStates := slices.Clone(d.relayStates)
var relayState relay.ProbeResult
// if the server connection is not established then we will use the general address
// in case of connection we will use the instance specific address
instanceAddr, err := d.relayMgr.RelayInstanceAddress()
if err != nil {
// TODO add their status
if errors.Is(err, relayClient.ErrRelayClientNotConnected) {
for _, r := range d.relayMgr.ServerURLs() {
relayStates = append(relayStates, relay.ProbeResult{
URI: r,
})
}
return relayStates
}
relayState.Err = err
}
relayState.URI = instanceAddr
return append(relayStates, relayState)
}
func (d *Status) GetDNSStates() []NSGroupState {
@@ -535,7 +712,6 @@ func (d *Status) GetFullStatus() FullStatus {
}
fullStatus.Peers = append(fullStatus.Peers, d.offlinePeers...)
return fullStatus
}

View File

@@ -2,8 +2,8 @@ package peer
import (
"errors"
"testing"
"sync"
"testing"
"github.com/stretchr/testify/assert"
)
@@ -43,7 +43,7 @@ func TestUpdatePeerState(t *testing.T) {
status := NewRecorder("https://mgm")
peerState := State{
PubKey: key,
Mux: new(sync.RWMutex),
Mux: new(sync.RWMutex),
}
status.peers[key] = peerState
@@ -64,7 +64,7 @@ func TestStatus_UpdatePeerFQDN(t *testing.T) {
status := NewRecorder("https://mgm")
peerState := State{
PubKey: key,
Mux: new(sync.RWMutex),
Mux: new(sync.RWMutex),
}
status.peers[key] = peerState
@@ -83,7 +83,7 @@ func TestGetPeerStateChangeNotifierLogic(t *testing.T) {
status := NewRecorder("https://mgm")
peerState := State{
PubKey: key,
Mux: new(sync.RWMutex),
Mux: new(sync.RWMutex),
}
status.peers[key] = peerState
@@ -108,7 +108,7 @@ func TestRemovePeer(t *testing.T) {
status := NewRecorder("https://mgm")
peerState := State{
PubKey: key,
Mux: new(sync.RWMutex),
Mux: new(sync.RWMutex),
}
status.peers[key] = peerState

View File

@@ -6,6 +6,6 @@ import (
"github.com/netbirdio/netbird/client/internal/stdnet"
)
func (conn *Conn) newStdNet() (*stdnet.Net, error) {
return stdnet.NewNet(conn.config.ICEConfig.InterfaceBlackList)
func (w *WorkerICE) newStdNet() (*stdnet.Net, error) {
return stdnet.NewNet(w.config.ICEConfig.InterfaceBlackList)
}

View File

@@ -2,6 +2,6 @@ package peer
import "github.com/netbirdio/netbird/client/internal/stdnet"
func (conn *Conn) newStdNet() (*stdnet.Net, error) {
return stdnet.NewNetWithDiscover(conn.iFaceDiscover, conn.config.ICEConfig.InterfaceBlackList)
func (w *WorkerICE) newStdNet() (*stdnet.Net, error) {
return stdnet.NewNetWithDiscover(w.iFaceDiscover, w.config.ICEConfig.InterfaceBlackList)
}

View File

@@ -0,0 +1,470 @@
package peer
import (
"context"
"fmt"
"net"
"net/netip"
"runtime"
"sync"
"sync/atomic"
"time"
"github.com/pion/ice/v3"
"github.com/pion/randutil"
"github.com/pion/stun/v2"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/iface/bind"
"github.com/netbirdio/netbird/route"
)
const (
iceKeepAliveDefault = 4 * time.Second
iceDisconnectedTimeoutDefault = 6 * time.Second
// iceRelayAcceptanceMinWaitDefault is the same as in the Pion ICE package
iceRelayAcceptanceMinWaitDefault = 2 * time.Second
lenUFrag = 16
lenPwd = 32
runesAlpha = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
)
var (
failedTimeout = 6 * time.Second
)
type ICEConfig struct {
// StunTurn is a list of STUN and TURN URLs
StunTurn *atomic.Value // []*stun.URI
// InterfaceBlackList is a list of machine interfaces that should be filtered out by ICE Candidate gathering
// (e.g. if eth0 is in the list, host candidate of this interface won't be used)
InterfaceBlackList []string
DisableIPv6Discovery bool
UDPMux ice.UDPMux
UDPMuxSrflx ice.UniversalUDPMux
NATExternalIPs []string
}
type ICEConnInfo struct {
RemoteConn net.Conn
RosenpassPubKey []byte
RosenpassAddr string
LocalIceCandidateType string
RemoteIceCandidateType string
RemoteIceCandidateEndpoint string
LocalIceCandidateEndpoint string
Relayed bool
RelayedOnLocal bool
}
type WorkerICECallbacks struct {
OnConnReady func(ConnPriority, ICEConnInfo)
OnStatusChanged func(ConnStatus)
}
type WorkerICE struct {
ctx context.Context
log *log.Entry
config ConnConfig
signaler *Signaler
iFaceDiscover stdnet.ExternalIFaceDiscover
statusRecorder *Status
hasRelayOnLocally bool
conn WorkerICECallbacks
selectedPriority ConnPriority
agent *ice.Agent
muxAgent sync.Mutex
StunTurn []*stun.URI
sentExtraSrflx bool
localUfrag string
localPwd string
}
func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool, callBacks WorkerICECallbacks) (*WorkerICE, error) {
w := &WorkerICE{
ctx: ctx,
log: log,
config: config,
signaler: signaler,
iFaceDiscover: ifaceDiscover,
statusRecorder: statusRecorder,
hasRelayOnLocally: hasRelayOnLocally,
conn: callBacks,
}
localUfrag, localPwd, err := generateICECredentials()
if err != nil {
return nil, err
}
w.localUfrag = localUfrag
w.localPwd = localPwd
return w, nil
}
func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
w.log.Debugf("OnNewOffer for ICE")
w.muxAgent.Lock()
if w.agent != nil {
w.log.Debugf("agent already exists, skipping the offer")
w.muxAgent.Unlock()
return
}
var preferredCandidateTypes []ice.CandidateType
if w.hasRelayOnLocally && remoteOfferAnswer.RelaySrvAddress != "" {
w.selectedPriority = connPriorityICEP2P
preferredCandidateTypes = candidateTypesP2P()
} else {
w.selectedPriority = connPriorityICETurn
preferredCandidateTypes = candidateTypes()
}
w.log.Debugf("recreate ICE agent")
agentCtx, agentCancel := context.WithCancel(w.ctx)
agent, err := w.reCreateAgent(agentCancel, preferredCandidateTypes)
if err != nil {
w.log.Errorf("failed to recreate ICE Agent: %s", err)
w.muxAgent.Unlock()
return
}
w.agent = agent
w.muxAgent.Unlock()
w.log.Debugf("gather candidates")
err = w.agent.GatherCandidates()
if err != nil {
w.log.Debugf("failed to gather candidates: %s", err)
return
}
// will block until connection succeeded
// but it won't release if ICE Agent went into Disconnected or Failed state,
// so we have to cancel it with the provided context once agent detected a broken connection
w.log.Debugf("turn agent dial")
remoteConn, err := w.turnAgentDial(agentCtx, remoteOfferAnswer)
if err != nil {
w.log.Debugf("failed to dial the remote peer: %s", err)
return
}
w.log.Debugf("agent dial succeeded")
pair, err := w.agent.GetSelectedCandidatePair()
if err != nil {
return
}
if !isRelayCandidate(pair.Local) {
// dynamically set remote WireGuard port if other side specified a different one from the default one
remoteWgPort := iface.DefaultWgPort
if remoteOfferAnswer.WgListenPort != 0 {
remoteWgPort = remoteOfferAnswer.WgListenPort
}
// To support old version's with direct mode we attempt to punch an additional role with the remote WireGuard port
go w.punchRemoteWGPort(pair, remoteWgPort)
}
ci := ICEConnInfo{
RemoteConn: remoteConn,
RosenpassPubKey: remoteOfferAnswer.RosenpassPubKey,
RosenpassAddr: remoteOfferAnswer.RosenpassAddr,
LocalIceCandidateType: pair.Local.Type().String(),
RemoteIceCandidateType: pair.Remote.Type().String(),
LocalIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Local.Address(), pair.Local.Port()),
RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Remote.Port()),
Relayed: isRelayed(pair),
RelayedOnLocal: isRelayCandidate(pair.Local),
}
w.log.Debugf("on ICE conn read to use ready")
go w.conn.OnConnReady(w.selectedPriority, ci)
}
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
func (w *WorkerICE) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) {
w.muxAgent.Lock()
defer w.muxAgent.Unlock()
w.log.Debugf("OnRemoteCandidate from peer %s -> %s", w.config.Key, candidate.String())
if w.agent == nil {
w.log.Warnf("ICE Agent is not initialized yet")
return
}
if candidateViaRoutes(candidate, haRoutes) {
return
}
err := w.agent.AddRemoteCandidate(candidate)
if err != nil {
w.log.Errorf("error while handling remote candidate")
return
}
}
func (w *WorkerICE) GetLocalUserCredentials() (frag string, pwd string) {
w.muxAgent.Lock()
defer w.muxAgent.Unlock()
return w.localUfrag, w.localPwd
}
func (w *WorkerICE) Close() {
w.muxAgent.Lock()
defer w.muxAgent.Unlock()
if w.agent == nil {
return
}
err := w.agent.Close()
if err != nil {
w.log.Warnf("failed to close ICE agent: %s", err)
}
}
func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, relaySupport []ice.CandidateType) (*ice.Agent, error) {
transportNet, err := w.newStdNet()
if err != nil {
w.log.Errorf("failed to create pion's stdnet: %s", err)
}
iceKeepAlive := iceKeepAlive()
iceDisconnectedTimeout := iceDisconnectedTimeout()
iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait()
agentConfig := &ice.AgentConfig{
MulticastDNSMode: ice.MulticastDNSModeDisabled,
NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6},
Urls: w.config.ICEConfig.StunTurn.Load().([]*stun.URI),
CandidateTypes: relaySupport,
InterfaceFilter: stdnet.InterfaceFilter(w.config.ICEConfig.InterfaceBlackList),
UDPMux: w.config.ICEConfig.UDPMux,
UDPMuxSrflx: w.config.ICEConfig.UDPMuxSrflx,
NAT1To1IPs: w.config.ICEConfig.NATExternalIPs,
Net: transportNet,
FailedTimeout: &failedTimeout,
DisconnectedTimeout: &iceDisconnectedTimeout,
KeepaliveInterval: &iceKeepAlive,
RelayAcceptanceMinWait: &iceRelayAcceptanceMinWait,
LocalUfrag: w.localUfrag,
LocalPwd: w.localPwd,
}
if w.config.ICEConfig.DisableIPv6Discovery {
agentConfig.NetworkTypes = []ice.NetworkType{ice.NetworkTypeUDP4}
}
w.sentExtraSrflx = false
agent, err := ice.NewAgent(agentConfig)
if err != nil {
return nil, err
}
err = agent.OnCandidate(w.onICECandidate)
if err != nil {
return nil, err
}
err = agent.OnConnectionStateChange(func(state ice.ConnectionState) {
w.log.Debugf("ICE ConnectionState has changed to %s", state.String())
if state == ice.ConnectionStateFailed || state == ice.ConnectionStateDisconnected {
w.conn.OnStatusChanged(StatusDisconnected)
w.muxAgent.Lock()
agentCancel()
_ = agent.Close()
w.agent = nil
w.muxAgent.Unlock()
}
})
if err != nil {
return nil, err
}
err = agent.OnSelectedCandidatePairChange(w.onICESelectedCandidatePair)
if err != nil {
return nil, err
}
err = agent.OnSuccessfulSelectedPairBindingResponse(func(p *ice.CandidatePair) {
err := w.statusRecorder.UpdateLatency(w.config.Key, p.Latency())
if err != nil {
w.log.Debugf("failed to update latency for peer: %s", err)
return
}
})
if err != nil {
return nil, fmt.Errorf("failed setting binding response callback: %w", err)
}
return agent, nil
}
func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
// wait local endpoint configuration
time.Sleep(time.Second)
addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", pair.Remote.Address(), remoteWgPort))
if err != nil {
w.log.Warnf("got an error while resolving the udp address, err: %s", err)
return
}
mux, ok := w.config.ICEConfig.UDPMuxSrflx.(*bind.UniversalUDPMuxDefault)
if !ok {
w.log.Warn("invalid udp mux conversion")
return
}
_, err = mux.GetSharedConn().WriteTo([]byte{0x6e, 0x62}, addr)
if err != nil {
w.log.Warnf("got an error while sending the punch packet, err: %s", err)
}
}
// onICECandidate is a callback attached to an ICE Agent to receive new local connection candidates
// and then signals them to the remote peer
func (w *WorkerICE) onICECandidate(candidate ice.Candidate) {
// nil means candidate gathering has been ended
if candidate == nil {
return
}
// TODO: reported port is incorrect for CandidateTypeHost, makes understanding ICE use via logs confusing as port is ignored
w.log.Debugf("discovered local candidate %s", candidate.String())
go func() {
err := w.signaler.SignalICECandidate(candidate, w.config.Key)
if err != nil {
w.log.Errorf("failed signaling candidate to the remote peer %s %s", w.config.Key, err)
}
}()
if !w.shouldSendExtraSrflxCandidate(candidate) {
return
}
// sends an extra server reflexive candidate to the remote peer with our related port (usually the wireguard port)
// this is useful when network has an existing port forwarding rule for the wireguard port and this peer
extraSrflx, err := extraSrflxCandidate(candidate)
if err != nil {
w.log.Errorf("failed creating extra server reflexive candidate %s", err)
return
}
w.sentExtraSrflx = true
go func() {
err = w.signaler.SignalICECandidate(extraSrflx, w.config.Key)
if err != nil {
w.log.Errorf("failed signaling the extra server reflexive candidate: %s", err)
}
}()
}
func (w *WorkerICE) onICESelectedCandidatePair(c1 ice.Candidate, c2 ice.Candidate) {
w.log.Debugf("selected candidate pair [local <-> remote] -> [%s <-> %s], peer %s", c1.String(), c2.String(),
w.config.Key)
}
func (w *WorkerICE) shouldSendExtraSrflxCandidate(candidate ice.Candidate) bool {
if !w.sentExtraSrflx && candidate.Type() == ice.CandidateTypeServerReflexive && candidate.Port() != candidate.RelatedAddress().Port {
return true
}
return false
}
func (w *WorkerICE) turnAgentDial(ctx context.Context, remoteOfferAnswer *OfferAnswer) (*ice.Conn, error) {
isControlling := w.config.LocalKey > w.config.Key
if isControlling {
return w.agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
} else {
return w.agent.Accept(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
}
}
func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) {
relatedAdd := candidate.RelatedAddress()
return ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{
Network: candidate.NetworkType().String(),
Address: candidate.Address(),
Port: relatedAdd.Port,
Component: candidate.Component(),
RelAddr: relatedAdd.Address,
RelPort: relatedAdd.Port,
})
}
func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool {
var routePrefixes []netip.Prefix
for _, routes := range clientRoutes {
if len(routes) > 0 && routes[0] != nil {
routePrefixes = append(routePrefixes, routes[0].Network)
}
}
addr, err := netip.ParseAddr(candidate.Address())
if err != nil {
log.Errorf("Failed to parse IP address %s: %v", candidate.Address(), err)
return false
}
for _, prefix := range routePrefixes {
// default route is
if prefix.Bits() == 0 {
continue
}
if prefix.Contains(addr) {
log.Debugf("Ignoring candidate [%s], its address is part of routed network %s", candidate.String(), prefix)
return true
}
}
return false
}
func candidateTypes() []ice.CandidateType {
if hasICEForceRelayConn() {
return []ice.CandidateType{ice.CandidateTypeRelay}
}
// TODO: remove this once we have refactored userspace proxy into the bind package
if runtime.GOOS == "ios" {
return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive}
}
return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay}
}
func candidateTypesP2P() []ice.CandidateType {
return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive}
}
func isRelayCandidate(candidate ice.Candidate) bool {
return candidate.Type() == ice.CandidateTypeRelay
}
func isRelayed(pair *ice.CandidatePair) bool {
if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
return true
}
return false
}
func generateICECredentials() (string, string, error) {
ufrag, err := randutil.GenerateCryptoRandomString(lenUFrag, runesAlpha)
if err != nil {
return "", "", err
}
pwd, err := randutil.GenerateCryptoRandomString(lenPwd, runesAlpha)
if err != nil {
return "", "", err
}
return ufrag, pwd, nil
}

View File

@@ -0,0 +1,223 @@
package peer
import (
"context"
"errors"
"net"
"sync"
"sync/atomic"
"time"
log "github.com/sirupsen/logrus"
relayClient "github.com/netbirdio/netbird/relay/client"
)
var (
wgHandshakePeriod = 2 * time.Minute
wgHandshakeOvertime = 30 * time.Second
)
type RelayConnInfo struct {
relayedConn net.Conn
rosenpassPubKey []byte
rosenpassAddr string
}
type WorkerRelayCallbacks struct {
OnConnReady func(RelayConnInfo)
OnDisconnected func()
}
type WorkerRelay struct {
log *log.Entry
config ConnConfig
relayManager relayClient.ManagerService
callBacks WorkerRelayCallbacks
relayedConn net.Conn
relayLock sync.Mutex
ctxWgWatch context.Context
ctxCancelWgWatch context.CancelFunc
ctxLock sync.Mutex
relaySupportedOnRemotePeer atomic.Bool
}
func NewWorkerRelay(log *log.Entry, config ConnConfig, relayManager relayClient.ManagerService, callbacks WorkerRelayCallbacks) *WorkerRelay {
r := &WorkerRelay{
log: log,
config: config,
relayManager: relayManager,
callBacks: callbacks,
}
return r
}
func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
if !w.isRelaySupported(remoteOfferAnswer) {
w.log.Infof("Relay is not supported by remote peer")
w.relaySupportedOnRemotePeer.Store(false)
return
}
w.relaySupportedOnRemotePeer.Store(true)
// the relayManager will return with error in case if the connection has lost with relay server
currentRelayAddress, err := w.relayManager.RelayInstanceAddress()
if err != nil {
w.log.Errorf("failed to handle new offer: %s", err)
return
}
srv := w.preferredRelayServer(currentRelayAddress, remoteOfferAnswer.RelaySrvAddress)
relayedConn, err := w.relayManager.OpenConn(srv, w.config.Key)
if err != nil {
if errors.Is(err, relayClient.ErrConnAlreadyExists) {
w.log.Infof("do not need to reopen relay connection")
return
}
w.log.Errorf("failed to open connection via Relay: %s", err)
return
}
w.relayLock.Lock()
w.relayedConn = relayedConn
w.relayLock.Unlock()
err = w.relayManager.AddCloseListener(srv, w.onRelayMGDisconnected)
if err != nil {
log.Errorf("failed to add close listener: %s", err)
_ = relayedConn.Close()
return
}
w.log.Debugf("peer conn opened via Relay: %s", srv)
go w.callBacks.OnConnReady(RelayConnInfo{
relayedConn: relayedConn,
rosenpassPubKey: remoteOfferAnswer.RosenpassPubKey,
rosenpassAddr: remoteOfferAnswer.RosenpassAddr,
})
}
func (w *WorkerRelay) EnableWgWatcher(ctx context.Context) {
w.log.Debugf("enable WireGuard watcher")
w.ctxLock.Lock()
defer w.ctxLock.Unlock()
if w.ctxWgWatch != nil && w.ctxWgWatch.Err() == nil {
return
}
ctx, ctxCancel := context.WithCancel(ctx)
go w.wgStateCheck(ctx)
w.ctxWgWatch = ctx
w.ctxCancelWgWatch = ctxCancel
}
func (w *WorkerRelay) DisableWgWatcher() {
w.ctxLock.Lock()
defer w.ctxLock.Unlock()
if w.ctxCancelWgWatch == nil {
return
}
w.log.Debugf("disable WireGuard watcher")
w.ctxCancelWgWatch()
}
func (w *WorkerRelay) RelayInstanceAddress() (string, error) {
return w.relayManager.RelayInstanceAddress()
}
func (w *WorkerRelay) IsRelayConnectionSupportedWithPeer() bool {
return w.relaySupportedOnRemotePeer.Load() && w.RelayIsSupportedLocally()
}
func (w *WorkerRelay) IsController() bool {
return w.config.LocalKey > w.config.Key
}
func (w *WorkerRelay) RelayIsSupportedLocally() bool {
return w.relayManager.HasRelayAddress()
}
func (w *WorkerRelay) CloseConn() {
w.relayLock.Lock()
defer w.relayLock.Unlock()
if w.relayedConn == nil {
return
}
err := w.relayedConn.Close()
if err != nil {
w.log.Warnf("failed to close relay connection: %v", err)
}
}
// wgStateCheck help to check the state of the wireguard handshake and relay connection
func (w *WorkerRelay) wgStateCheck(ctx context.Context) {
timer := time.NewTimer(wgHandshakeOvertime)
defer timer.Stop()
expected := wgHandshakeOvertime
for {
select {
case <-timer.C:
lastHandshake, err := w.wgState()
if err != nil {
w.log.Errorf("failed to read wg stats: %v", err)
continue
}
w.log.Tracef("last handshake: %v", lastHandshake)
if time.Since(lastHandshake) > expected {
w.log.Infof("Wireguard handshake timed out, closing relay connection")
w.relayLock.Lock()
_ = w.relayedConn.Close()
w.relayLock.Unlock()
w.callBacks.OnDisconnected()
return
}
resetTime := time.Until(lastHandshake.Add(wgHandshakePeriod + wgHandshakeOvertime))
timer.Reset(resetTime)
expected = wgHandshakePeriod
case <-ctx.Done():
w.log.Debugf("WireGuard watcher stopped")
return
}
}
}
func (w *WorkerRelay) isRelaySupported(answer *OfferAnswer) bool {
if !w.relayManager.HasRelayAddress() {
return false
}
return answer.RelaySrvAddress != ""
}
func (w *WorkerRelay) preferredRelayServer(myRelayAddress, remoteRelayAddress string) string {
if w.IsController() {
return myRelayAddress
}
return remoteRelayAddress
}
func (w *WorkerRelay) wgState() (time.Time, error) {
wgState, err := w.config.WgConfig.WgInterface.GetStats(w.config.Key)
if err != nil {
return time.Time{}, err
}
return wgState.LastHandshake, nil
}
func (w *WorkerRelay) onRelayMGDisconnected() {
w.ctxLock.Lock()
defer w.ctxLock.Unlock()
if w.ctxCancelWgWatch != nil {
w.ctxCancelWgWatch()
}
go w.callBacks.OnDisconnected()
}