Compare commits

..

14 Commits

Author SHA1 Message Date
Zoltan Papp
940367e1c6 Remove unused wait group 2023-04-18 18:59:11 +02:00
Zoltan Papp
c80cb22cc0 Add parallel processing 2023-04-17 16:44:45 +02:00
Zoltan Papp
d0e51dfc11 Fix Linux implementation 2023-04-17 14:46:14 +02:00
Zoltan Papp
7e11842a5e Fix unix implementation 2023-04-17 14:37:51 +02:00
pzoli
1bf3e6feec Better pkg destroyer 2023-04-17 14:37:51 +02:00
pzoli
c392fe44e4 Handle STUN msg 2023-04-17 14:37:51 +02:00
pzoli
fef24b4a4d Update ICEBind to the latest Bind version 2023-04-17 14:37:48 +02:00
Zoltan Papp
bb1efe68aa Update wireguard version to 9e2f38602202 2023-04-17 14:37:28 +02:00
Zoltan Papp
bb147c2a7c Remove unnecessary uapi open (#807)
Remove unnecessary uapi open from Android implementation
2023-04-17 11:50:12 +02:00
Zoltan Papp
4616bc5258 Add route management for Android interface (#801)
Support client route management feature on Android
2023-04-17 11:15:37 +02:00
Zoltan Papp
1803cf3678 Fix error handling in case of the port is in used (#810) 2023-04-14 16:18:00 +02:00
Zoltan Papp
9f35a7fb8d Ignore ipv6 labeled address (#809)
Ignore ipv6 labeled address
2023-04-14 15:40:27 +02:00
Misha Bragin
2eeed55c18 Bind implementation (#779)
This PR adds supports for the WireGuard userspace implementation
using Bind interface from wireguard-go. 
The newly introduced ICEBind struct implements Bind with UDPMux-based
structs from pion/ice to handle hole punching using ICE.
The core implementation was taken from StdBind of wireguard-go.

The result is a single WireGuard port that is used for host and server reflexive candidates. 
Relay candidates are still handled separately and will be integrated in the following PRs.

ICEBind checks the incoming packets for being STUN or WireGuard ones
and routes them to UDPMux (to handle hole punching) or to WireGuard  respectively.
2023-04-13 17:00:01 +02:00
Givi Khojanashvili
0343c5f239 Rollback simple ACL rules processing. (#803) 2023-04-12 09:39:17 +02:00
63 changed files with 1759 additions and 648 deletions

View File

@@ -144,13 +144,19 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
peerConfig := loginResp.GetPeerConfig() peerConfig := loginResp.GetPeerConfig()
engineConfig, err := createEngineConfig(myPrivateKey, config, peerConfig, tunAdapter, iFaceDiscover) engineConfig, err := createEngineConfig(myPrivateKey, config, peerConfig)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
return wrapErr(err) return wrapErr(err)
} }
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, engineConfig, statusRecorder) md, err := newMobileDependency(tunAdapter, iFaceDiscover, mgmClient)
if err != nil {
log.Error(err)
return wrapErr(err)
}
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, engineConfig, md, statusRecorder)
err = engine.Start() err = engine.Start()
if err != nil { if err != nil {
log.Errorf("error while starting Netbird Connection Engine: %s", err) log.Errorf("error while starting Netbird Connection Engine: %s", err)
@@ -194,13 +200,10 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
} }
// createEngineConfig converts configuration received from Management Service to EngineConfig // createEngineConfig converts configuration received from Management Service to EngineConfig
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover) (*EngineConfig, error) { func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) {
engineConf := &EngineConfig{ engineConf := &EngineConfig{
WgIfaceName: config.WgIface, WgIfaceName: config.WgIface,
WgAddr: peerConfig.Address, WgAddr: peerConfig.Address,
TunAdapter: tunAdapter,
IFaceDiscover: iFaceDiscover,
IFaceBlackList: config.IFaceBlackList, IFaceBlackList: config.IFaceBlackList,
DisableIPv6Discovery: config.DisableIPv6Discovery, DisableIPv6Discovery: config.DisableIPv6Discovery,
WgPrivateKey: key, WgPrivateKey: key,

View File

@@ -206,7 +206,7 @@ func TestUpdateDNSServer(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), iface.DefaultMTU, nil, newNet) wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), iface.DefaultMTU, nil, nil, newNet)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -18,8 +18,8 @@ import (
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/proxy"
"github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/stdnet"
nbssh "github.com/netbirdio/netbird/client/ssh" nbssh "github.com/netbirdio/netbird/client/ssh"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
@@ -46,10 +46,6 @@ var ErrResetConnection = fmt.Errorf("reset connection")
type EngineConfig struct { type EngineConfig struct {
WgPort int WgPort int
WgIfaceName string WgIfaceName string
// TunAdapter is option. It is necessary for mobile version.
TunAdapter iface.TunAdapter
IFaceDiscover stdnet.ExternalIFaceDiscover
// WgAddr is a Wireguard local address (Netbird Network IP) // WgAddr is a Wireguard local address (Netbird Network IP)
WgAddr string WgAddr string
@@ -89,7 +85,9 @@ type Engine struct {
// syncMsgMux is used to guarantee sequential Management Service message processing // syncMsgMux is used to guarantee sequential Management Service message processing
syncMsgMux *sync.Mutex syncMsgMux *sync.Mutex
config *EngineConfig config *EngineConfig
mobileDep MobileDependency
// STUNs is a list of STUN servers used by ICE // STUNs is a list of STUN servers used by ICE
STUNs []*ice.URL STUNs []*ice.URL
// TURNs is a list of STUN servers used by ICE // TURNs is a list of STUN servers used by ICE
@@ -129,7 +127,7 @@ type Peer struct {
func NewEngine( func NewEngine(
ctx context.Context, cancel context.CancelFunc, ctx context.Context, cancel context.CancelFunc,
signalClient signal.Client, mgmClient mgm.Client, signalClient signal.Client, mgmClient mgm.Client,
config *EngineConfig, statusRecorder *peer.Status, config *EngineConfig, mobileDep MobileDependency, statusRecorder *peer.Status,
) *Engine { ) *Engine {
return &Engine{ return &Engine{
ctx: ctx, ctx: ctx,
@@ -139,6 +137,7 @@ func NewEngine(
peerConns: make(map[string]*peer.Conn), peerConns: make(map[string]*peer.Conn),
syncMsgMux: &sync.Mutex{}, syncMsgMux: &sync.Mutex{},
config: config, config: config,
mobileDep: mobileDep,
STUNs: []*ice.URL{}, STUNs: []*ice.URL{},
TURNs: []*ice.URL{}, TURNs: []*ice.URL{},
networkSerial: 0, networkSerial: 0,
@@ -178,9 +177,9 @@ func (e *Engine) Start() error {
var err error var err error
transportNet, err := e.newStdNet() transportNet, err := e.newStdNet()
if err != nil { if err != nil {
log.Warnf("failed to create pion's stdnet: %s", err) log.Errorf("failed to create pion's stdnet: %s", err)
} }
e.wgInterface, err = iface.NewWGIFace(wgIFaceName, wgAddr, iface.DefaultMTU, e.config.TunAdapter, transportNet) e.wgInterface, err = iface.NewWGIFace(wgIFaceName, wgAddr, iface.DefaultMTU, e.mobileDep.Routes, e.mobileDep.TunAdapter, transportNet)
if err != nil { if err != nil {
log.Errorf("failed creating wireguard interface instance %s: [%s]", wgIFaceName, err.Error()) log.Errorf("failed creating wireguard interface instance %s: [%s]", wgIFaceName, err.Error())
return err return err
@@ -808,6 +807,14 @@ func (e Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, er
stunTurn = append(stunTurn, e.STUNs...) stunTurn = append(stunTurn, e.STUNs...)
stunTurn = append(stunTurn, e.TURNs...) stunTurn = append(stunTurn, e.TURNs...)
proxyConfig := proxy.Config{
RemoteKey: pubKey,
WgListenAddr: fmt.Sprintf("127.0.0.1:%d", e.config.WgPort),
WgInterface: e.wgInterface,
AllowedIps: allowedIPs,
PreSharedKey: e.config.PreSharedKey,
}
// randomize connection timeout // randomize connection timeout
timeout := time.Duration(rand.Intn(PeerConnectionTimeoutMax-PeerConnectionTimeoutMin)+PeerConnectionTimeoutMin) * time.Millisecond timeout := time.Duration(rand.Intn(PeerConnectionTimeoutMax-PeerConnectionTimeoutMin)+PeerConnectionTimeoutMin) * time.Millisecond
config := peer.ConnConfig{ config := peer.ConnConfig{
@@ -819,12 +826,13 @@ func (e Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, er
Timeout: timeout, Timeout: timeout,
UDPMux: e.udpMux, UDPMux: e.udpMux,
UDPMuxSrflx: e.udpMuxSrflx, UDPMuxSrflx: e.udpMuxSrflx,
ProxyConfig: proxyConfig,
LocalWgPort: e.config.WgPort, LocalWgPort: e.config.WgPort,
NATExternalIPs: e.parseNATExternalIPMappings(), NATExternalIPs: e.parseNATExternalIPMappings(),
AllowedIPs: allowedIPs, UserspaceBind: e.wgInterface.IsUserspaceBind(),
} }
peerConn, err := peer.NewConn(config, e.wgInterface, e.statusRecorder, e.config.TunAdapter, e.config.IFaceDiscover) peerConn, err := peer.NewConn(config, e.statusRecorder, e.mobileDep.TunAdapter, e.mobileDep.IFaceDiscover)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -3,5 +3,5 @@ package internal
import "github.com/netbirdio/netbird/client/internal/stdnet" import "github.com/netbirdio/netbird/client/internal/stdnet"
func (e *Engine) newStdNet() (*stdnet.Net, error) { func (e *Engine) newStdNet() (*stdnet.Net, error) {
return stdnet.NewNetWithDiscover(e.config.IFaceDiscover, e.config.IFaceBlackList) return stdnet.NewNetWithDiscover(e.mobileDep.IFaceDiscover, e.config.IFaceBlackList)
} }

View File

@@ -74,7 +74,7 @@ func TestEngine_SSH(t *testing.T) {
WgAddr: "100.64.0.1/24", WgAddr: "100.64.0.1/24",
WgPrivateKey: key, WgPrivateKey: key,
WgPort: 33100, WgPort: 33100,
}, peer.NewRecorder("https://mgm")) }, MobileDependency{}, peer.NewRecorder("https://mgm"))
engine.dnsServer = &dns.MockServer{ engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil }, UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
@@ -208,12 +208,12 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
WgAddr: "100.64.0.1/24", WgAddr: "100.64.0.1/24",
WgPrivateKey: key, WgPrivateKey: key,
WgPort: 33100, WgPort: 33100,
}, peer.NewRecorder("https://mgm")) }, MobileDependency{}, peer.NewRecorder("https://mgm"))
newNet, err := stdnet.NewNet() newNet, err := stdnet.NewNet()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", iface.DefaultMTU, nil, newNet) engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", iface.DefaultMTU, nil, nil, newNet)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -404,7 +404,7 @@ func TestEngine_Sync(t *testing.T) {
WgAddr: "100.64.0.1/24", WgAddr: "100.64.0.1/24",
WgPrivateKey: key, WgPrivateKey: key,
WgPort: 33100, WgPort: 33100,
}, peer.NewRecorder("https://mgm")) }, MobileDependency{}, peer.NewRecorder("https://mgm"))
engine.dnsServer = &dns.MockServer{ engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil }, UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
@@ -562,12 +562,12 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
WgAddr: wgAddr, WgAddr: wgAddr,
WgPrivateKey: key, WgPrivateKey: key,
WgPort: 33100, WgPort: 33100,
}, peer.NewRecorder("https://mgm")) }, MobileDependency{}, peer.NewRecorder("https://mgm"))
newNet, err := stdnet.NewNet() newNet, err := stdnet.NewNet()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, nil, newNet) engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, nil, nil, newNet)
assert.NoError(t, err, "shouldn't return error") assert.NoError(t, err, "shouldn't return error")
input := struct { input := struct {
inputSerial uint64 inputSerial uint64
@@ -731,12 +731,12 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
WgAddr: wgAddr, WgAddr: wgAddr,
WgPrivateKey: key, WgPrivateKey: key,
WgPort: 33100, WgPort: 33100,
}, peer.NewRecorder("https://mgm")) }, MobileDependency{}, peer.NewRecorder("https://mgm"))
newNet, err := stdnet.NewNet() newNet, err := stdnet.NewNet()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, nil, newNet) engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, nil, nil, newNet)
assert.NoError(t, err, "shouldn't return error") assert.NoError(t, err, "shouldn't return error")
mockRouteManager := &routemanager.MockManager{ mockRouteManager := &routemanager.MockManager{
@@ -1000,7 +1000,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
WgPort: wgPort, WgPort: wgPort,
} }
return NewEngine(ctx, cancel, signalClient, mgmtClient, conf, peer.NewRecorder("https://mgm")), nil return NewEngine(ctx, cancel, signalClient, mgmtClient, conf, MobileDependency{}, peer.NewRecorder("https://mgm")), nil
} }
func startSignal() (*grpc.Server, string, error) { func startSignal() (*grpc.Server, string, error) {

View File

@@ -0,0 +1,13 @@
package internal
import (
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/iface"
)
// MobileDependency collect all dependencies for mobile platform
type MobileDependency struct {
TunAdapter iface.TunAdapter
IFaceDiscover stdnet.ExternalIFaceDiscover
Routes []string
}

View File

@@ -0,0 +1,29 @@
package internal
import (
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/iface"
mgm "github.com/netbirdio/netbird/management/client"
)
func newMobileDependency(tunAdapter iface.TunAdapter, ifaceDiscover stdnet.ExternalIFaceDiscover, mgmClient *mgm.GrpcClient) (MobileDependency, error) {
md := MobileDependency{
TunAdapter: tunAdapter,
IFaceDiscover: ifaceDiscover,
}
err := md.readMap(mgmClient)
return md, err
}
func (d *MobileDependency) readMap(mgmClient *mgm.GrpcClient) error {
routes, err := mgmClient.GetRoutes()
if err != nil {
return err
}
d.Routes = make([]string, len(routes))
for i, r := range routes {
d.Routes[i] = r.GetNetwork()
}
return nil
}

View File

@@ -0,0 +1,13 @@
//go:build !android
package internal
import (
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/iface"
mgm "github.com/netbirdio/netbird/management/client"
)
func newMobileDependency(tunAdapter iface.TunAdapter, ifaceDiscover stdnet.ExternalIFaceDiscover, mgmClient *mgm.GrpcClient) (MobileDependency, error) {
return MobileDependency{}, nil
}

View File

@@ -37,6 +37,8 @@ type ConnConfig struct {
Timeout time.Duration Timeout time.Duration
ProxyConfig proxy.Config
UDPMux ice.UDPMux UDPMux ice.UDPMux
UDPMuxSrflx ice.UniversalUDPMux UDPMuxSrflx ice.UniversalUDPMux
@@ -44,7 +46,8 @@ type ConnConfig struct {
NATExternalIPs []string NATExternalIPs []string
AllowedIPs string // UsesBind indicates whether the WireGuard interface is userspace and uses bind.ICEBind
UserspaceBind bool
} }
// OfferAnswer represents a session establishment offer or answer // OfferAnswer represents a session establishment offer or answer
@@ -93,7 +96,6 @@ type Conn struct {
remoteModeCh chan ModeMessage remoteModeCh chan ModeMessage
meta meta meta meta
wgIface *iface.WGIface
adapter iface.TunAdapter adapter iface.TunAdapter
iFaceDiscover stdnet.ExternalIFaceDiscover iFaceDiscover stdnet.ExternalIFaceDiscover
} }
@@ -121,7 +123,7 @@ func (conn *Conn) UpdateConf(conf ConnConfig) {
// NewConn creates a new not opened Conn to the remote peer. // NewConn creates a new not opened Conn to the remote peer.
// To establish a connection run Conn.Open // To establish a connection run Conn.Open
func NewConn(config ConnConfig, wgIface *iface.WGIface, statusRecorder *Status, adapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover) (*Conn, error) { func NewConn(config ConnConfig, statusRecorder *Status, adapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover) (*Conn, error) {
return &Conn{ return &Conn{
config: config, config: config,
mu: sync.Mutex{}, mu: sync.Mutex{},
@@ -133,7 +135,6 @@ func NewConn(config ConnConfig, wgIface *iface.WGIface, statusRecorder *Status,
remoteModeCh: make(chan ModeMessage, 1), remoteModeCh: make(chan ModeMessage, 1),
adapter: adapter, adapter: adapter,
iFaceDiscover: iFaceDiscover, iFaceDiscover: iFaceDiscover,
wgIface: wgIface,
}, nil }, nil
} }
@@ -146,7 +147,7 @@ func (conn *Conn) reCreateAgent() error {
var err error var err error
transportNet, err := conn.newStdNet() transportNet, err := conn.newStdNet()
if err != nil { if err != nil {
log.Warnf("failed to create pion's stdnet: %s", err) log.Errorf("failed to create pion's stdnet: %s", err)
} }
agentConfig := &ice.AgentConfig{ agentConfig := &ice.AgentConfig{
MulticastDNSMode: ice.MulticastDNSModeDisabled, MulticastDNSMode: ice.MulticastDNSModeDisabled,
@@ -197,7 +198,7 @@ func (conn *Conn) Open() error {
peerState := State{PubKey: conn.config.Key} peerState := State{PubKey: conn.config.Key}
peerState.IP = strings.Split(conn.config.AllowedIPs, "/")[0] peerState.IP = strings.Split(conn.config.ProxyConfig.AllowedIps, "/")[0]
peerState.ConnStatusUpdate = time.Now() peerState.ConnStatusUpdate = time.Now()
peerState.ConnStatus = conn.status peerState.ConnStatus = conn.status
@@ -328,18 +329,28 @@ func (conn *Conn) Open() error {
func shouldUseProxy(pair *ice.CandidatePair, userspaceBind bool) bool { func shouldUseProxy(pair *ice.CandidatePair, userspaceBind bool) bool {
if !isRelayCandidate(pair.Local) && userspaceBind { if !isRelayCandidate(pair.Local) && userspaceBind {
log.Debugf("shouldn't use proxy because using Bind and the connection is not relayed")
return false return false
} }
if !isHardNATCandidate(pair.Local) && isHostCandidateWithPublicIP(pair.Remote) { if !isHardNATCandidate(pair.Local) && isHostCandidateWithPublicIP(pair.Remote) {
log.Debugf("shouldn't use proxy because the local peer is not behind a hard NAT and the remote one has a public IP")
return false return false
} }
if !isHardNATCandidate(pair.Remote) && isHostCandidateWithPublicIP(pair.Local) { if !isHardNATCandidate(pair.Remote) && isHostCandidateWithPublicIP(pair.Local) {
log.Debugf("shouldn't use proxy because the remote peer is not behind a hard NAT and the local one has a public IP")
return false return false
} }
if isHostCandidateWithPrivateIP(pair.Local) && isHostCandidateWithPrivateIP(pair.Remote) && isSameNetworkPrefix(pair) { if isHostCandidateWithPrivateIP(pair.Local) && isHostCandidateWithPrivateIP(pair.Remote) && isSameNetworkPrefix(pair) {
log.Debugf("shouldn't use proxy because peers are in the same private /16 network")
return false
}
if (isPeerReflexiveCandidateWithPrivateIP(pair.Local) && isHostCandidateWithPrivateIP(pair.Remote) ||
isHostCandidateWithPrivateIP(pair.Local) && isPeerReflexiveCandidateWithPrivateIP(pair.Remote)) && isSameNetworkPrefix(pair) {
log.Debugf("shouldn't use proxy because peers are in the same private /16 network and one peer is peer reflexive")
return false return false
} }
@@ -348,16 +359,8 @@ func shouldUseProxy(pair *ice.CandidatePair, userspaceBind bool) bool {
func isSameNetworkPrefix(pair *ice.CandidatePair) bool { func isSameNetworkPrefix(pair *ice.CandidatePair) bool {
localIPStr, _, err := net.SplitHostPort(pair.Local.Address()) localIP := net.ParseIP(pair.Local.Address())
if err != nil { remoteIP := net.ParseIP(pair.Remote.Address())
return false
}
remoteIPStr, _, err := net.SplitHostPort(pair.Remote.Address())
if err != nil {
return false
}
localIP := net.ParseIP(localIPStr)
remoteIP := net.ParseIP(remoteIPStr)
if localIP == nil || remoteIP == nil { if localIP == nil || remoteIP == nil {
return false return false
} }
@@ -382,9 +385,13 @@ func isHostCandidateWithPrivateIP(candidate ice.Candidate) bool {
return candidate.Type() == ice.CandidateTypeHost && !isPublicIP(candidate.Address()) return candidate.Type() == ice.CandidateTypeHost && !isPublicIP(candidate.Address())
} }
func isPeerReflexiveCandidateWithPrivateIP(candidate ice.Candidate) bool {
return candidate.Type() == ice.CandidateTypePeerReflexive && !isPublicIP(candidate.Address())
}
func isPublicIP(address string) bool { func isPublicIP(address string) bool {
ip := net.ParseIP(address) ip := net.ParseIP(address)
if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsPrivate() { if ip == nil || ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsPrivate() {
return false return false
} }
return true return true
@@ -418,7 +425,7 @@ func (conn *Conn) startProxy(remoteConn net.Conn, remoteWgPort int) error {
if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay { if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
peerState.Relayed = true peerState.Relayed = true
} }
peerState.Direct = p.Type() == proxy.TypeDirectNoProxy peerState.Direct = p.Type() == proxy.TypeDirectNoProxy || p.Type() == proxy.TypeNoProxy
err = conn.statusRecorder.UpdatePeerState(peerState) err = conn.statusRecorder.UpdatePeerState(peerState)
if err != nil { if err != nil {
@@ -429,7 +436,7 @@ func (conn *Conn) startProxy(remoteConn net.Conn, remoteWgPort int) error {
} }
func (conn *Conn) getProxyWithMessageExchange(pair *ice.CandidatePair, remoteWgPort int) proxy.Proxy { func (conn *Conn) getProxyWithMessageExchange(pair *ice.CandidatePair, remoteWgPort int) proxy.Proxy {
useProxy := shouldUseProxy(pair, conn.wgIface.IsUserspaceBind()) useProxy := shouldUseProxy(pair, conn.config.UserspaceBind)
localDirectMode := !useProxy localDirectMode := !useProxy
remoteDirectMode := localDirectMode remoteDirectMode := localDirectMode
@@ -439,13 +446,12 @@ func (conn *Conn) getProxyWithMessageExchange(pair *ice.CandidatePair, remoteWgP
remoteDirectMode = conn.receiveRemoteDirectMode() remoteDirectMode = conn.receiveRemoteDirectMode()
} }
if conn.wgIface.IsUserspaceBind() && localDirectMode { if conn.config.UserspaceBind && localDirectMode {
return proxy.NewNoProxy(conn.config.ProxyConfig) return proxy.NewNoProxy(conn.config.ProxyConfig)
} }
if localDirectMode && remoteDirectMode { if localDirectMode && remoteDirectMode {
//wgInterface *iface.WGIface, remoteKey string, allowedIps string, preSharedKey *wgtypes.Key, remoteWgPort int) return proxy.NewDirectNoProxy(conn.config.ProxyConfig, remoteWgPort)
return proxy.NewDirectNoProxy(conn.wgIface, conn.config.Key, conn.config.AllowedIPs, remoteWgPort)
} }
log.Debugf("falling back to local proxy mode with peer %s", conn.config.Key) log.Debugf("falling back to local proxy mode with peer %s", conn.config.Key)

View File

@@ -1,11 +1,12 @@
package peer package peer
import ( import (
"github.com/netbirdio/netbird/client/internal/stdnet"
"sync" "sync"
"testing" "testing"
"time" "time"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/magiconair/properties/assert" "github.com/magiconair/properties/assert"
"github.com/pion/ice/v2" "github.com/pion/ice/v2"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
@@ -203,7 +204,7 @@ func TestConn_ShouldUseProxy(t *testing.T) {
} }
privateHostCandidate := &mockICECandidate{ privateHostCandidate := &mockICECandidate{
AddressFunc: func() string { AddressFunc: func() string {
return "10.0.0.1:44576" return "10.0.0.1"
}, },
TypeFunc: func() ice.CandidateType { TypeFunc: func() ice.CandidateType {
return ice.CandidateTypeHost return ice.CandidateTypeHost
@@ -322,6 +323,42 @@ func TestConn_ShouldUseProxy(t *testing.T) {
}, },
expected: false, expected: false,
}, },
{
name: "Don't Use Proxy When Both Candidates are in private network and one is peer reflexive",
candatePair: &ice.CandidatePair{
Local: &mockICECandidate{AddressFunc: func() string {
return "10.16.102.168"
},
TypeFunc: func() ice.CandidateType {
return ice.CandidateTypeHost
}},
Remote: &mockICECandidate{AddressFunc: func() string {
return "10.16.101.96"
},
TypeFunc: func() ice.CandidateType {
return ice.CandidateTypePeerReflexive
}},
},
expected: false,
},
{
name: "Should Use Proxy When Both Candidates are in private network and both are peer reflexive",
candatePair: &ice.CandidatePair{
Local: &mockICECandidate{AddressFunc: func() string {
return "10.16.102.168"
},
TypeFunc: func() ice.CandidateType {
return ice.CandidateTypePeerReflexive
}},
Remote: &mockICECandidate{AddressFunc: func() string {
return "10.16.101.96"
},
TypeFunc: func() ice.CandidateType {
return ice.CandidateTypePeerReflexive
}},
},
expected: true,
},
} }
for _, testCase := range testCases { for _, testCase := range testCases {

View File

@@ -1,11 +1,8 @@
package proxy package proxy
import ( import (
"net"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"net"
"github.com/netbirdio/netbird/iface"
) )
// DirectNoProxy is used when there is no need for a proxy between ICE and WireGuard. // DirectNoProxy is used when there is no need for a proxy between ICE and WireGuard.
@@ -15,28 +12,20 @@ import (
// DirectNoProxy will just update remote peer with a remote host and fixed WireGuard port (r.g. 51820). // DirectNoProxy will just update remote peer with a remote host and fixed WireGuard port (r.g. 51820).
// In order DirectNoProxy to work, WireGuard port has to be fixed for the time being. // In order DirectNoProxy to work, WireGuard port has to be fixed for the time being.
type DirectNoProxy struct { type DirectNoProxy struct {
wgInterface *iface.WGIface config Config
remoteKey string
allowedIps string
// RemoteWgListenPort is a WireGuard port of a remote peer. // RemoteWgListenPort is a WireGuard port of a remote peer.
// It is used instead of the hardcoded 51820 port. // It is used instead of the hardcoded 51820 port.
remoteWgListenPort int RemoteWgListenPort int
} }
// NewDirectNoProxy creates a new DirectNoProxy with a provided config and remote peer's WireGuard listen port // NewDirectNoProxy creates a new DirectNoProxy with a provided config and remote peer's WireGuard listen port
func NewDirectNoProxy(wgInterface *iface.WGIface, remoteKey string, allowedIps string, remoteWgPort int) *DirectNoProxy { func NewDirectNoProxy(config Config, remoteWgPort int) *DirectNoProxy {
return &DirectNoProxy{ return &DirectNoProxy{config: config, RemoteWgListenPort: remoteWgPort}
wgInterface: wgInterface,
remoteKey: remoteKey,
allowedIps: allowedIps,
remoteWgListenPort: remoteWgPort}
} }
// Close removes peer from the WireGuard interface // Close removes peer from the WireGuard interface
func (p *DirectNoProxy) Close() error { func (p *DirectNoProxy) Close() error {
err := p.wgInterface.RemovePeer(p.remoteKey) err := p.config.WgInterface.RemovePeer(p.config.RemoteKey)
if err != nil { if err != nil {
return err return err
} }
@@ -46,13 +35,14 @@ func (p *DirectNoProxy) Close() error {
// Start just updates WireGuard peer with the remote IP and default WireGuard port // Start just updates WireGuard peer with the remote IP and default WireGuard port
func (p *DirectNoProxy) Start(remoteConn net.Conn) error { func (p *DirectNoProxy) Start(remoteConn net.Conn) error {
log.Debugf("using DirectNoProxy while connecting to peer %s", p.remoteKey) log.Debugf("using DirectNoProxy while connecting to peer %s", p.config.RemoteKey)
addr, err := net.ResolveUDPAddr("udp", remoteConn.RemoteAddr().String()) addr, err := net.ResolveUDPAddr("udp", remoteConn.RemoteAddr().String())
if err != nil { if err != nil {
return err return err
} }
addr.Port = p.remoteWgListenPort addr.Port = p.RemoteWgListenPort
err = p.wgInterface.UpdatePeer(p.remoteKey, p.allowedIps, addr) err = p.config.WgInterface.UpdatePeer(p.config.RemoteKey, p.config.AllowedIps, DefaultWgKeepAlive,
addr, p.config.PreSharedKey)
if err != nil { if err != nil {
return err return err

View File

@@ -1,10 +1,15 @@
package proxy package proxy
import ( import (
"github.com/netbirdio/netbird/iface"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"io" "io"
"net" "net"
"time"
) )
const DefaultWgKeepAlive = 25 * time.Second
type Type string type Type string
const ( const (
@@ -14,6 +19,14 @@ const (
TypeNoProxy Type = "NoProxy" TypeNoProxy Type = "NoProxy"
) )
type Config struct {
WgListenAddr string
RemoteKey string
WgInterface *iface.WGIface
AllowedIps string
PreSharedKey *wgtypes.Key
}
type Proxy interface { type Proxy interface {
io.Closer io.Closer
// Start creates a local remoteConn and starts proxying data from/to remoteConn // Start creates a local remoteConn and starts proxying data from/to remoteConn

View File

@@ -1,12 +1,15 @@
//go:build !android
package routemanager package routemanager
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/coreos/go-iptables/iptables" "github.com/coreos/go-iptables/iptables"
"github.com/google/nftables"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
import "github.com/google/nftables"
const ( const (
ipv6Forwarding = "netbird-rt-ipv6-forwarding" ipv6Forwarding = "netbird-rt-ipv6-forwarding"

View File

@@ -1,14 +1,17 @@
//go:build !android
package routemanager package routemanager
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/coreos/go-iptables/iptables"
log "github.com/sirupsen/logrus"
"net/netip" "net/netip"
"os/exec" "os/exec"
"strings" "strings"
"sync" "sync"
"github.com/coreos/go-iptables/iptables"
log "github.com/sirupsen/logrus"
) )
func isIptablesSupported() bool { func isIptablesSupported() bool {

View File

@@ -1,10 +1,13 @@
//go:build !android
package routemanager package routemanager
import ( import (
"context" "context"
"testing"
"github.com/coreos/go-iptables/iptables" "github.com/coreos/go-iptables/iptables"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"testing"
) )
func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) { func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {

View File

@@ -1,9 +1,130 @@
package routemanager package routemanager
import "github.com/netbirdio/netbird/route" import (
"context"
"runtime"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/version"
)
// Manager is a route manager interface // Manager is a route manager interface
type Manager interface { type Manager interface {
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error
Stop() Stop()
} }
// DefaultManager is the default instance of a route manager
type DefaultManager struct {
ctx context.Context
stop context.CancelFunc
mux sync.Mutex
clientNetworks map[string]*clientNetwork
serverRouter *serverRouter
statusRecorder *peer.Status
wgInterface *iface.WGIface
pubKey string
}
// NewManager returns a new route manager
func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status) *DefaultManager {
mCTX, cancel := context.WithCancel(ctx)
return &DefaultManager{
ctx: mCTX,
stop: cancel,
clientNetworks: make(map[string]*clientNetwork),
serverRouter: newServerRouter(ctx, wgInterface),
statusRecorder: statusRecorder,
wgInterface: wgInterface,
pubKey: pubKey,
}
}
// Stop stops the manager watchers and clean firewall rules
func (m *DefaultManager) Stop() {
m.stop()
m.serverRouter.cleanUp()
}
// UpdateRoutes compares received routes with existing routes and remove, update or add them to the client and server maps
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
select {
case <-m.ctx.Done():
log.Infof("not updating routes as context is closed")
return m.ctx.Err()
default:
m.mux.Lock()
defer m.mux.Unlock()
newClientRoutesIDMap := make(map[string][]*route.Route)
newServerRoutesMap := make(map[string]*route.Route)
ownNetworkIDs := make(map[string]bool)
for _, newRoute := range newRoutes {
networkID := route.GetHAUniqueID(newRoute)
if newRoute.Peer == m.pubKey {
ownNetworkIDs[networkID] = true
// only linux is supported for now
if runtime.GOOS != "linux" {
log.Warnf("received a route to manage, but agent doesn't support router mode on %s OS", runtime.GOOS)
continue
}
newServerRoutesMap[newRoute.ID] = newRoute
}
}
for _, newRoute := range newRoutes {
networkID := route.GetHAUniqueID(newRoute)
if !ownNetworkIDs[networkID] {
// if prefix is too small, lets assume is a possible default route which is not yet supported
// we skip this route management
if newRoute.Network.Bits() < 7 {
log.Errorf("this agent version: %s, doesn't support default routes, received %s, skiping this route",
version.NetbirdVersion(), newRoute.Network)
continue
}
newClientRoutesIDMap[networkID] = append(newClientRoutesIDMap[networkID], newRoute)
}
}
m.updateClientNetworks(updateSerial, newClientRoutesIDMap)
err := m.serverRouter.updateRoutes(newServerRoutesMap)
if err != nil {
return err
}
return nil
}
}
func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) {
// removing routes that do not exist as per the update from the Management service.
for id, client := range m.clientNetworks {
_, found := networks[id]
if !found {
log.Debugf("stopping client network watcher, %s", id)
client.stop()
delete(m.clientNetworks, id)
}
}
for id, routes := range networks {
clientNetworkWatcher, found := m.clientNetworks[id]
if !found {
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.wgInterface, m.statusRecorder, routes[0].Network)
m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.peersStateAndUpdateWatcher()
}
update := routesUpdate{
updateSerial: updateSerial,
routes: routes,
}
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(update)
}
}

View File

@@ -1,31 +0,0 @@
package routemanager
import (
"context"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
)
// DefaultManager dummy router manager for Android
type DefaultManager struct {
ctx context.Context
serverRouter *serverRouter
wgInterface *iface.WGIface
}
// NewManager returns a new dummy route manager what doing nothing
func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status) *DefaultManager {
return &DefaultManager{}
}
// UpdateRoutes ...
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
return nil
}
// Stop ...
func (m *DefaultManager) Stop() {
}

View File

@@ -1,186 +0,0 @@
//go:build !android
package routemanager
import (
"context"
"fmt"
"runtime"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/version"
)
// DefaultManager is the default instance of a route manager
type DefaultManager struct {
ctx context.Context
stop context.CancelFunc
mux sync.Mutex
clientNetworks map[string]*clientNetwork
serverRoutes map[string]*route.Route
serverRouter *serverRouter
statusRecorder *peer.Status
wgInterface *iface.WGIface
pubKey string
}
// NewManager returns a new route manager
func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status) *DefaultManager {
mCTX, cancel := context.WithCancel(ctx)
return &DefaultManager{
ctx: mCTX,
stop: cancel,
clientNetworks: make(map[string]*clientNetwork),
serverRoutes: make(map[string]*route.Route),
serverRouter: &serverRouter{
routes: make(map[string]*route.Route),
netForwardHistoryEnabled: isNetForwardHistoryEnabled(),
firewall: NewFirewall(ctx),
},
statusRecorder: statusRecorder,
wgInterface: wgInterface,
pubKey: pubKey,
}
}
// Stop stops the manager watchers and clean firewall rules
func (m *DefaultManager) Stop() {
m.stop()
m.serverRouter.firewall.CleanRoutingRules()
}
func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) {
// removing routes that do not exist as per the update from the Management service.
for id, client := range m.clientNetworks {
_, found := networks[id]
if !found {
log.Debugf("stopping client network watcher, %s", id)
client.stop()
delete(m.clientNetworks, id)
}
}
for id, routes := range networks {
clientNetworkWatcher, found := m.clientNetworks[id]
if !found {
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.wgInterface, m.statusRecorder, routes[0].Network)
m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.peersStateAndUpdateWatcher()
}
update := routesUpdate{
updateSerial: updateSerial,
routes: routes,
}
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(update)
}
}
func (m *DefaultManager) updateServerRoutes(routesMap map[string]*route.Route) error {
serverRoutesToRemove := make([]string, 0)
if len(routesMap) > 0 {
err := m.serverRouter.firewall.RestoreOrCreateContainers()
if err != nil {
return fmt.Errorf("couldn't initialize firewall containers, got err: %v", err)
}
}
for routeID := range m.serverRoutes {
update, found := routesMap[routeID]
if !found || !update.IsEqual(m.serverRoutes[routeID]) {
serverRoutesToRemove = append(serverRoutesToRemove, routeID)
continue
}
}
for _, routeID := range serverRoutesToRemove {
oldRoute := m.serverRoutes[routeID]
err := m.removeFromServerNetwork(oldRoute)
if err != nil {
log.Errorf("unable to remove route id: %s, network %s, from server, got: %v",
oldRoute.ID, oldRoute.Network, err)
}
delete(m.serverRoutes, routeID)
}
for id, newRoute := range routesMap {
_, found := m.serverRoutes[id]
if found {
continue
}
err := m.addToServerNetwork(newRoute)
if err != nil {
log.Errorf("unable to add route %s from server, got: %v", newRoute.ID, err)
continue
}
m.serverRoutes[id] = newRoute
}
if len(m.serverRoutes) > 0 {
err := enableIPForwarding()
if err != nil {
return err
}
}
return nil
}
// UpdateRoutes compares received routes with existing routes and remove, update or add them to the client and server maps
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
select {
case <-m.ctx.Done():
log.Infof("not updating routes as context is closed")
return m.ctx.Err()
default:
m.mux.Lock()
defer m.mux.Unlock()
newClientRoutesIDMap := make(map[string][]*route.Route)
newServerRoutesMap := make(map[string]*route.Route)
ownNetworkIDs := make(map[string]bool)
for _, newRoute := range newRoutes {
networkID := route.GetHAUniqueID(newRoute)
if newRoute.Peer == m.pubKey {
ownNetworkIDs[networkID] = true
// only linux is supported for now
if runtime.GOOS != "linux" {
log.Warnf("received a route to manage, but agent doesn't support router mode on %s OS", runtime.GOOS)
continue
}
newServerRoutesMap[newRoute.ID] = newRoute
}
}
for _, newRoute := range newRoutes {
networkID := route.GetHAUniqueID(newRoute)
if !ownNetworkIDs[networkID] {
// if prefix is too small, lets assume is a possible default route which is not yet supported
// we skip this route management
if newRoute.Network.Bits() < 7 {
log.Errorf("this agent version: %s, doesn't support default routes, received %s, skiping this route",
version.NetbirdVersion(), newRoute.Network)
continue
}
newClientRoutesIDMap[networkID] = append(newClientRoutesIDMap[networkID], newRoute)
}
}
m.updateClientNetworks(updateSerial, newClientRoutesIDMap)
err := m.updateServerRoutes(newServerRoutesMap)
if err != nil {
return err
}
return nil
}
}

View File

@@ -392,11 +392,12 @@ func TestManagerUpdateRoutes(t *testing.T) {
for n, testCase := range testCases { for n, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
newNet, err := stdnet.NewNet() newNet, err := stdnet.NewNet()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", iface.DefaultMTU, nil, newNet) wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", iface.DefaultMTU, nil, nil, newNet)
require.NoError(t, err, "should create testing WGIface interface") require.NoError(t, err, "should create testing WGIface interface")
defer wgInterface.Close() defer wgInterface.Close()
@@ -419,7 +420,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
require.Len(t, routeManager.clientNetworks, testCase.clientNetworkWatchersExpected, "client networks size should match") require.Len(t, routeManager.clientNetworks, testCase.clientNetworkWatchersExpected, "client networks size should match")
if testCase.shouldCheckServerRoutes { if testCase.shouldCheckServerRoutes {
require.Len(t, routeManager.serverRoutes, testCase.serverRoutesExpected, "server networks size should match") require.Len(t, routeManager.serverRouter.routes, testCase.serverRoutesExpected, "server networks size should match")
} }
}) })
} }

View File

@@ -1,16 +1,19 @@
//go:build !android
package routemanager package routemanager
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
log "github.com/sirupsen/logrus"
"net" "net"
"net/netip" "net/netip"
"sync" "sync"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
log "github.com/sirupsen/logrus"
) )
import "github.com/google/nftables"
const ( const (
nftablesTable = "netbird-rt" nftablesTable = "netbird-rt"

View File

@@ -1,12 +1,15 @@
//go:build !android
package routemanager package routemanager
import ( import (
"context" "context"
"testing"
"github.com/google/nftables" "github.com/google/nftables"
"github.com/google/nftables/expr" "github.com/google/nftables/expr"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"testing"
) )
func TestNftablesManager_RestoreOrCreateContainers(t *testing.T) { func TestNftablesManager_RestoreOrCreateContainers(t *testing.T) {

View File

@@ -0,0 +1,24 @@
package routemanager
import (
"net/netip"
"github.com/netbirdio/netbird/route"
)
type routerPair struct {
ID string
source string
destination string
masquerade bool
}
func routeToRouterPair(source string, route *route.Route) routerPair {
parsed := netip.MustParsePrefix(source).Masked()
return routerPair{
ID: route.ID,
source: parsed.String(),
destination: route.Network.Masked().String(),
masquerade: route.Masquerade,
}
}

View File

@@ -1,67 +0,0 @@
package routemanager
import (
"github.com/netbirdio/netbird/route"
log "github.com/sirupsen/logrus"
"net/netip"
"sync"
)
type serverRouter struct {
routes map[string]*route.Route
// best effort to keep net forward configuration as it was
netForwardHistoryEnabled bool
mux sync.Mutex
firewall firewallManager
}
type routerPair struct {
ID string
source string
destination string
masquerade bool
}
func routeToRouterPair(source string, route *route.Route) routerPair {
parsed := netip.MustParsePrefix(source).Masked()
return routerPair{
ID: route.ID,
source: parsed.String(),
destination: route.Network.Masked().String(),
masquerade: route.Masquerade,
}
}
func (m *DefaultManager) removeFromServerNetwork(route *route.Route) error {
select {
case <-m.ctx.Done():
log.Infof("not removing from server network because context is done")
return m.ctx.Err()
default:
m.serverRouter.mux.Lock()
defer m.serverRouter.mux.Unlock()
err := m.serverRouter.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route))
if err != nil {
return err
}
delete(m.serverRouter.routes, route.ID)
return nil
}
}
func (m *DefaultManager) addToServerNetwork(route *route.Route) error {
select {
case <-m.ctx.Done():
log.Infof("not adding to server network because context is done")
return m.ctx.Err()
default:
m.serverRouter.mux.Lock()
defer m.serverRouter.mux.Unlock()
err := m.serverRouter.firewall.InsertRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route))
if err != nil {
return err
}
m.serverRouter.routes[route.ID] = route
return nil
}
}

View File

@@ -0,0 +1,21 @@
package routemanager
import (
"context"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
)
type serverRouter struct {
}
func newServerRouter(ctx context.Context, wgInterface *iface.WGIface) *serverRouter {
return &serverRouter{}
}
func (r *serverRouter) updateRoutes(routesMap map[string]*route.Route) error {
return nil
}
func (r *serverRouter) cleanUp() {}

View File

@@ -0,0 +1,120 @@
//go:build !android
package routemanager
import (
"context"
"fmt"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
)
type serverRouter struct {
mux sync.Mutex
ctx context.Context
routes map[string]*route.Route
firewall firewallManager
wgInterface *iface.WGIface
}
func newServerRouter(ctx context.Context, wgInterface *iface.WGIface) *serverRouter {
return &serverRouter{
ctx: ctx,
routes: make(map[string]*route.Route),
firewall: NewFirewall(ctx),
wgInterface: wgInterface,
}
}
func (m *serverRouter) updateRoutes(routesMap map[string]*route.Route) error {
serverRoutesToRemove := make([]string, 0)
if len(routesMap) > 0 {
err := m.firewall.RestoreOrCreateContainers()
if err != nil {
return fmt.Errorf("couldn't initialize firewall containers, got err: %v", err)
}
}
for routeID := range m.routes {
update, found := routesMap[routeID]
if !found || !update.IsEqual(m.routes[routeID]) {
serverRoutesToRemove = append(serverRoutesToRemove, routeID)
}
}
for _, routeID := range serverRoutesToRemove {
oldRoute := m.routes[routeID]
err := m.removeFromServerNetwork(oldRoute)
if err != nil {
log.Errorf("unable to remove route id: %s, network %s, from server, got: %v",
oldRoute.ID, oldRoute.Network, err)
}
delete(m.routes, routeID)
}
for id, newRoute := range routesMap {
_, found := m.routes[id]
if found {
continue
}
err := m.addToServerNetwork(newRoute)
if err != nil {
log.Errorf("unable to add route %s from server, got: %v", newRoute.ID, err)
continue
}
m.routes[id] = newRoute
}
if len(m.routes) > 0 {
err := enableIPForwarding()
if err != nil {
return err
}
}
return nil
}
func (m *serverRouter) removeFromServerNetwork(route *route.Route) error {
select {
case <-m.ctx.Done():
log.Infof("not removing from server network because context is done")
return m.ctx.Err()
default:
m.mux.Lock()
defer m.mux.Unlock()
err := m.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route))
if err != nil {
return err
}
delete(m.routes, route.ID)
return nil
}
}
func (m *serverRouter) addToServerNetwork(route *route.Route) error {
select {
case <-m.ctx.Done():
log.Infof("not adding to server network because context is done")
return m.ctx.Err()
default:
m.mux.Lock()
defer m.mux.Unlock()
err := m.firewall.InsertRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route))
if err != nil {
return err
}
m.routes[route.ID] = route
return nil
}
}
func (m *serverRouter) cleanUp() {
m.firewall.CleanRoutingRules()
}

View File

@@ -0,0 +1,13 @@
package routemanager
import (
"net/netip"
)
func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error {
return nil
}
func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string) error {
return nil
}

View File

@@ -1,10 +1,13 @@
//go:build !android
package routemanager package routemanager
import ( import (
"github.com/vishvananda/netlink"
"net" "net"
"net/netip" "net/netip"
"os" "os"
"github.com/vishvananda/netlink"
) )
const ipv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward" const ipv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward"
@@ -62,12 +65,3 @@ func enableIPForwarding() error {
err := os.WriteFile(ipv4ForwardingPath, []byte("1"), 0644) err := os.WriteFile(ipv4ForwardingPath, []byte("1"), 0644)
return err return err
} }
func isNetForwardHistoryEnabled() bool {
out, err := os.ReadFile(ipv4ForwardingPath)
if err != nil {
// todo
panic(err)
}
return string(out) == "1"
}

View File

@@ -1,11 +1,14 @@
//go:build !android
package routemanager package routemanager
import ( import (
"fmt" "fmt"
"github.com/libp2p/go-netroute"
log "github.com/sirupsen/logrus"
"net" "net"
"net/netip" "net/netip"
"github.com/libp2p/go-netroute"
log "github.com/sirupsen/logrus"
) )
var errRouteNotFound = fmt.Errorf("route not found") var errRouteNotFound = fmt.Errorf("route not found")

View File

@@ -37,7 +37,7 @@ func TestAddRemoveRoutes(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", iface.DefaultMTU, nil, newNet) wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", iface.DefaultMTU, nil, nil, newNet)
require.NoError(t, err, "should create testing WGIface interface") require.NoError(t, err, "should create testing WGIface interface")
defer wgInterface.Close() defer wgInterface.Close()

View File

@@ -4,10 +4,11 @@
package routemanager package routemanager
import ( import (
log "github.com/sirupsen/logrus"
"net/netip" "net/netip"
"os/exec" "os/exec"
"runtime" "runtime"
log "github.com/sirupsen/logrus"
) )
func addToRouteTable(prefix netip.Prefix, addr string) error { func addToRouteTable(prefix netip.Prefix, addr string) error {
@@ -34,8 +35,3 @@ func enableIPForwarding() error {
log.Infof("enable IP forwarding is not implemented on %s", runtime.GOOS) log.Infof("enable IP forwarding is not implemented on %s", runtime.GOOS)
return nil return nil
} }
func isNetForwardHistoryEnabled() bool {
log.Infof("check netforward history is not implemented on %s", runtime.GOOS)
return false
}

View File

@@ -78,6 +78,9 @@ func (m *mobileIFaceDiscover) parseInterfacesString(interfaces string) []*transp
addrs := strings.Trim(fields[1], " \n") addrs := strings.Trim(fields[1], " \n")
foundAddress := false foundAddress := false
for _, addr := range strings.Split(addrs, " ") { for _, addr := range strings.Split(addrs, " ") {
if strings.Contains(addr, "%") {
continue
}
ip, ipNet, err := net.ParseCIDR(addr) ip, ipNet, err := net.ParseCIDR(addr)
if err != nil { if err != nil {
log.Warnf("%s", err) log.Warnf("%s", err)

View File

@@ -3,6 +3,8 @@ package stdnet
import ( import (
"fmt" "fmt"
"testing" "testing"
log "github.com/sirupsen/logrus"
) )
func Test_parseInterfacesString(t *testing.T) { func Test_parseInterfacesString(t *testing.T) {
@@ -20,6 +22,7 @@ func Test_parseInterfacesString(t *testing.T) {
{"wlan0", 30, 1500, true, true, false, false, true, "10.1.10.131/24"}, {"wlan0", 30, 1500, true, true, false, false, true, "10.1.10.131/24"},
{"rmnet0", 30, 1500, true, true, false, false, true, "192.168.0.56/24"}, {"rmnet0", 30, 1500, true, true, false, false, true, "192.168.0.56/24"},
{"rmnet_data1", 30, 1500, true, true, false, false, true, "fec0::118c:faf7:8d97:3cb2/64"}, {"rmnet_data1", 30, 1500, true, true, false, false, true, "fec0::118c:faf7:8d97:3cb2/64"},
{"rmnet_data2", 30, 1500, true, true, false, false, true, "fec0::118c:faf7:8d97:3cb2%rmnet2/64"},
} }
var exampleString string var exampleString string
@@ -35,13 +38,13 @@ func Test_parseInterfacesString(t *testing.T) {
d.multicast, d.multicast,
d.addr) d.addr)
} }
d := mobileIFaceDiscover{} d := mobileIFaceDiscover{}
nets := d.parseInterfacesString(exampleString) nets := d.parseInterfacesString(exampleString)
if len(nets) == 0 { if len(nets) == 0 {
t.Fatalf("failed to parse interfaces") t.Fatalf("failed to parse interfaces")
} }
log.Printf("%d", len(nets))
for i, net := range nets { for i, net := range nets {
if net.MTU != testData[i].mtu { if net.MTU != testData[i].mtu {
t.Errorf("invalid mtu: %d, expected: %d", net.MTU, testData[0].mtu) t.Errorf("invalid mtu: %d, expected: %d", net.MTU, testData[0].mtu)
@@ -60,7 +63,7 @@ func Test_parseInterfacesString(t *testing.T) {
if len(addr) == 0 { if len(addr) == 0 {
t.Errorf("invalid address parsing") t.Errorf("invalid address parsing")
} }
log.Printf("%v", addr)
if addr[0].String() != testData[i].addr { if addr[0].String() != testData[i].addr {
t.Errorf("invalid address: %s, expected: %s", addr[0].String(), testData[i].addr) t.Errorf("invalid address: %s, expected: %s", addr[0].String(), testData[i].addr)
} }

2
go.mod
View File

@@ -19,7 +19,7 @@ require (
github.com/vishvananda/netlink v1.1.0 github.com/vishvananda/netlink v1.1.0
golang.org/x/crypto v0.7.0 golang.org/x/crypto v0.7.0
golang.org/x/sys v0.6.0 golang.org/x/sys v0.6.0
golang.zx2c4.com/wireguard v0.0.0-20230223181233-21636207a675 golang.zx2c4.com/wireguard v0.0.0-20230310135217-9e2f38602202
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211215182854-7a385b3431de golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211215182854-7a385b3431de
golang.zx2c4.com/wireguard/windows v0.5.1 golang.zx2c4.com/wireguard/windows v0.5.1
google.golang.org/grpc v1.52.3 google.golang.org/grpc v1.52.3

2
go.sum
View File

@@ -887,6 +887,8 @@ golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224/go.mod h1:deeaetjYA+D
golang.zx2c4.com/wireguard v0.0.0-20211129173154-2dd424e2d808/go.mod h1:TjUWrnD5ATh7bFvmm/ALEJZQ4ivKbETb6pmyj1vUoNI= golang.zx2c4.com/wireguard v0.0.0-20211129173154-2dd424e2d808/go.mod h1:TjUWrnD5ATh7bFvmm/ALEJZQ4ivKbETb6pmyj1vUoNI=
golang.zx2c4.com/wireguard v0.0.0-20230223181233-21636207a675 h1:/J/RVnr7ng4fWPRH3xa4WtBJ1Jp+Auu4YNLmGiPv5QU= golang.zx2c4.com/wireguard v0.0.0-20230223181233-21636207a675 h1:/J/RVnr7ng4fWPRH3xa4WtBJ1Jp+Auu4YNLmGiPv5QU=
golang.zx2c4.com/wireguard v0.0.0-20230223181233-21636207a675/go.mod h1:whfbyDBt09xhCYQWtO2+3UVjlaq6/9hDZrjg2ZE6SyA= golang.zx2c4.com/wireguard v0.0.0-20230223181233-21636207a675/go.mod h1:whfbyDBt09xhCYQWtO2+3UVjlaq6/9hDZrjg2ZE6SyA=
golang.zx2c4.com/wireguard v0.0.0-20230310135217-9e2f38602202 h1:KcD4X7IcoRdQpr9NSQpQpn5S4rUMIQaCdF90FOEj6KY=
golang.zx2c4.com/wireguard v0.0.0-20230310135217-9e2f38602202/go.mod h1:qc3aHNhM1Rc4hW2az896MjLVcxHvLbJ6LZc9MI7RTMY=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211215182854-7a385b3431de h1:qDZ+lyO5jC9RNJ7ANJA0GWXk3pSn0Fu5SlcAIlgw+6w= golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211215182854-7a385b3431de h1:qDZ+lyO5jC9RNJ7ANJA0GWXk3pSn0Fu5SlcAIlgw+6w=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211215182854-7a385b3431de/go.mod h1:Q2XNgour4QSkFj0BWCkVlW0HWJwQgNMsMahpSlI0Eno= golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211215182854-7a385b3431de/go.mod h1:Q2XNgour4QSkFj0BWCkVlW0HWJwQgNMsMahpSlI0Eno=
golang.zx2c4.com/wireguard/windows v0.5.1 h1:OnYw96PF+CsIMrqWo5QP3Q59q5hY1rFErk/yN3cS+JQ= golang.zx2c4.com/wireguard/windows v0.5.1 h1:OnYw96PF+CsIMrqWo5QP3Q59q5hY1rFErk/yN3cS+JQ=

View File

@@ -1,81 +1,138 @@
package bind package bind
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"net"
"net/netip"
"strconv"
"sync"
"syscall"
"github.com/pion/stun" "github.com/pion/stun"
"github.com/pion/transport/v2" "github.com/pion/transport/v2"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/conn" "golang.org/x/net/ipv4"
"net" "golang.org/x/net/ipv6"
"net/netip" wgConn "golang.zx2c4.com/wireguard/conn"
"sync"
"syscall"
) )
// ICEBind is the userspace implementation of WireGuard's conn.Bind interface using ice.UDPMux of the pion/ice library var (
_ wgConn.Bind = (*ICEBind)(nil)
)
// ICEBind implements Bind for all platforms except Windows.
type ICEBind struct { type ICEBind struct {
// below fields, initialized on open mu sync.Mutex // protects following fields
sharedConn net.PacketConn ipv4 *net.UDPConn
udpMux *UniversalUDPMuxDefault ipv6 *net.UDPConn
blackhole4 bool
blackhole6 bool
ipv4PC *ipv4.PacketConn
ipv6PC *ipv6.PacketConn
batchSize int
udpAddrPool sync.Pool
ipv4MsgsPool sync.Pool
ipv6MsgsPool sync.Pool
// below are fields initialized on creation // NetBird related variables
transportNet transport.Net transportNet transport.Net
mu sync.Mutex udpMux *UniversalUDPMuxDefault
worker *worker
} }
// NewICEBind create a new instance of ICEBind with a given transportNet and an interfaceFilter function.
// The interfaceFilter function is used to exclude interfaces from hole punching (the IPs of that interfaces won't be used as connection candidates)
// The transportNet can be nil.
func NewICEBind(transportNet transport.Net) *ICEBind { func NewICEBind(transportNet transport.Net) *ICEBind {
return &ICEBind{ b := &ICEBind{
transportNet: transportNet, batchSize: wgConn.DefaultBatchSize,
mu: sync.Mutex{},
}
}
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind udpAddrPool: sync.Pool{
func (b *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) { New: func() any {
b.mu.Lock() return &net.UDPAddr{
defer b.mu.Unlock() IP: make([]byte, 16),
if b.udpMux == nil { }
return nil, fmt.Errorf("ICEBind has not been initialized yet") },
}
return b.udpMux, nil
}
// Open creates a WireGuard socket and an instance of UDPMux that is used to glue up ICE and WireGuard for hole punching
func (b *ICEBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
b.mu.Lock()
defer b.mu.Unlock()
if b.sharedConn != nil {
return nil, 0, conn.ErrBindAlreadyOpen
}
ipv4Conn, _, err := listenNet("udp4", int(uport))
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
return nil, 0, err
}
b.sharedConn = ipv4Conn
b.udpMux = NewUniversalUDPMuxDefault(UniversalUDPMuxParams{UDPConn: b.sharedConn, Net: b.transportNet})
portAddr1, err := netip.ParseAddrPort(ipv4Conn.LocalAddr().String())
if err != nil {
return nil, 0, err
}
log.Infof("opened ICEBind on %s", ipv4Conn.LocalAddr().String())
return []conn.ReceiveFunc{
b.makeReceiveIPv4(b.sharedConn),
}, },
portAddr1.Port(), nil
ipv4MsgsPool: sync.Pool{
New: func() any {
msgs := make([]ipv4.Message, wgConn.DefaultBatchSize)
for i := range msgs {
msgs[i].Buffers = make(net.Buffers, 1)
msgs[i].OOB = make([]byte, srcControlSize)
}
return &msgs
},
},
ipv6MsgsPool: sync.Pool{
New: func() any {
msgs := make([]ipv6.Message, wgConn.DefaultBatchSize)
for i := range msgs {
msgs[i].Buffers = make(net.Buffers, 1)
msgs[i].OOB = make([]byte, srcControlSize)
}
return &msgs
},
},
transportNet: transportNet,
}
b.worker = newWorker(b.handlePkgs)
return b
}
type StdNetEndpoint struct {
// AddrPort is the endpoint destination.
netip.AddrPort
// src is the current sticky source address and interface index, if supported.
src struct {
netip.Addr
ifidx int32
}
}
var (
_ wgConn.Bind = (*ICEBind)(nil)
_ wgConn.Endpoint = &StdNetEndpoint{}
)
func (*ICEBind) ParseEndpoint(s string) (wgConn.Endpoint, error) {
e, err := netip.ParseAddrPort(s)
return asEndpoint(e), err
}
func (e *StdNetEndpoint) ClearSrc() {
e.src.ifidx = 0
e.src.Addr = netip.Addr{}
}
func (e *StdNetEndpoint) DstIP() netip.Addr {
return e.AddrPort.Addr()
}
func (e *StdNetEndpoint) SrcIP() netip.Addr {
return e.src.Addr
}
func (e *StdNetEndpoint) SrcIfidx() int32 {
return e.src.ifidx
}
func (e *StdNetEndpoint) DstToBytes() []byte {
b, _ := e.AddrPort.MarshalBinary()
return b
}
func (e *StdNetEndpoint) DstToString() string {
return e.AddrPort.String()
}
func (e *StdNetEndpoint) SrcToString() string {
return e.src.Addr.String()
} }
func listenNet(network string, port int) (*net.UDPConn, int, error) { func listenNet(network string, port int) (*net.UDPConn, int, error) {
conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port}) conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port))
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
@@ -89,101 +146,250 @@ func listenNet(network string, port int) (*net.UDPConn, int, error) {
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
return conn, uaddr.Port, nil return conn.(*net.UDPConn), uaddr.Port, nil
} }
func parseSTUNMessage(raw []byte) (*stun.Message, error) { func (s *ICEBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) {
msg := &stun.Message{ s.mu.Lock()
Raw: append([]byte{}, raw...), defer s.mu.Unlock()
}
if err := msg.Decode(); err != nil { var err error
return nil, err var tries int
if s.ipv4 != nil || s.ipv6 != nil {
return nil, 0, wgConn.ErrBindAlreadyOpen
} }
return msg, nil // Attempt to open ipv4 and ipv6 listeners on the same port.
// If uport is 0, we can retry on failure.
again:
port := int(uport)
var v4conn, v6conn *net.UDPConn
v4conn, port, err = listenNet("udp4", port)
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
return nil, 0, err
}
// Listen on the same port as we're using for ipv4.
v6conn, port, err = listenNet("udp6", port)
if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
v4conn.Close()
tries++
goto again
}
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
v4conn.Close()
return nil, 0, err
}
var fns []wgConn.ReceiveFunc
if v4conn != nil {
fns = append(fns, s.receiveIPv4)
s.ipv4 = v4conn
}
if v6conn != nil {
fns = append(fns, s.receiveIPv6)
s.ipv6 = v6conn
}
if len(fns) == 0 {
return nil, 0, syscall.EAFNOSUPPORT
}
s.ipv4PC = ipv4.NewPacketConn(s.ipv4)
s.ipv6PC = ipv6.NewPacketConn(s.ipv6)
s.udpMux = NewUniversalUDPMuxDefault(UniversalUDPMuxParams{UDPConn: s.ipv4, Net: s.transportNet})
return fns, uint16(port), nil
} }
func (b *ICEBind) makeReceiveIPv4(c net.PacketConn) conn.ReceiveFunc { func (s *ICEBind) receiveIPv4(buffs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
return func(buff []byte) (int, conn.Endpoint, error) { msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
n, endpoint, err := c.ReadFrom(buff) defer s.ipv4MsgsPool.Put(msgs)
if err != nil { for i := range buffs {
return 0, nil, err (*msgs)[i].Buffers[0] = buffs[i]
}
e, err := netip.ParseAddrPort(endpoint.String())
if err != nil {
return 0, nil, err
}
if !stun.IsMessage(buff[:20]) {
// WireGuard traffic
return n, (conn.StdNetEndpoint)(netip.AddrPortFrom(e.Addr(), e.Port())), nil
}
msg, err := parseSTUNMessage(buff[:n])
if err != nil {
return 0, nil, err
}
err = b.udpMux.HandleSTUNMessage(msg, endpoint)
if err != nil {
log.Warnf("failed to handle packet")
}
// discard packets because they are STUN related
return 0, nil, nil //todo proper return
} }
numMsgs, err := s.ipv4PC.ReadBatch(*msgs, 0)
if err != nil {
return 0, err
}
s.worker.doWork((*msgs)[:numMsgs], sizes, eps)
return numMsgs, nil
} }
// Close closes the WireGuard socket and UDPMux func (s *ICEBind) receiveIPv6(buffs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
func (b *ICEBind) Close() error { msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
b.mu.Lock() defer s.ipv6MsgsPool.Put(msgs)
defer b.mu.Unlock() for i := range buffs {
(*msgs)[i].Buffers[0] = buffs[i]
}
numMsgs, err := s.ipv6PC.ReadBatch(*msgs, 0)
if err != nil {
return 0, err
}
for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i]
sizes[i] = msg.N
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := asEndpoint(addrPort)
getSrcFromControl(msg.OOB, ep)
eps[i] = ep
}
return numMsgs, nil
}
func (s *ICEBind) BatchSize() int {
return s.batchSize
}
func (s *ICEBind) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
var err1, err2 error var err1, err2 error
if b.sharedConn != nil { if s.ipv4 != nil {
c := b.sharedConn err1 = s.ipv4.Close()
b.sharedConn = nil s.ipv4 = nil
err1 = c.Close()
} }
if s.ipv6 != nil {
if b.udpMux != nil { err2 = s.ipv6.Close()
m := b.udpMux s.ipv6 = nil
b.udpMux = nil
err2 = m.Close()
} }
s.blackhole4 = false
s.blackhole6 = false
if err1 != nil { if err1 != nil {
return err1 return err1
} }
return err2 return err2
} }
// SetMark sets the mark for each packet sent through this Bind. func (s *ICEBind) Send(buffs [][]byte, endpoint wgConn.Endpoint) error {
// This mark is passed to the kernel as the socket option SO_MARK. s.mu.Lock()
func (b *ICEBind) SetMark(mark uint32) error { blackhole := s.blackhole4
return nil conn := s.ipv4
is6 := false
if endpoint.DstIP().Is6() {
blackhole = s.blackhole6
conn = s.ipv6
is6 = true
}
s.mu.Unlock()
if blackhole {
return nil
}
if conn == nil {
return syscall.EAFNOSUPPORT
}
if is6 {
return s.send6(s.ipv6PC, endpoint, buffs)
} else {
return s.send4(s.ipv4PC, endpoint, buffs)
}
} }
// Send bytes to the remote endpoint (peer) // GetICEMux returns the ICE UDPMux that was created and used by ICEBind
func (b *ICEBind) Send(buff []byte, endpoint conn.Endpoint) error { func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
s.mu.Lock()
nend, ok := endpoint.(conn.StdNetEndpoint) defer s.mu.Unlock()
if !ok { if s.udpMux == nil {
return conn.ErrWrongEndpointType return nil, fmt.Errorf("ICEBind has not been initialized yet")
} }
addrPort := netip.AddrPort(nend)
_, err := b.sharedConn.WriteTo(buff, &net.UDPAddr{ return s.udpMux, nil
IP: addrPort.Addr().AsSlice(), }
Port: int(addrPort.Port()),
Zone: addrPort.Addr().Zone(), func (s *ICEBind) send4(conn *ipv4.PacketConn, ep wgConn.Endpoint, buffs [][]byte) error {
}) ua := s.udpAddrPool.Get().(*net.UDPAddr)
as4 := ep.DstIP().As4()
copy(ua.IP, as4[:])
ua.IP = ua.IP[:4]
ua.Port = int(ep.(*StdNetEndpoint).Port())
msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
for i, buff := range buffs {
(*msgs)[i].Buffers[0] = buff
(*msgs)[i].Addr = ua
setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint))
}
var (
n int
err error
start int
)
for {
n, err = conn.WriteBatch((*msgs)[start:len(buffs)], 0)
if err != nil || n == len((*msgs)[start:len(buffs)]) {
break
}
start += n
}
s.udpAddrPool.Put(ua)
s.ipv4MsgsPool.Put(msgs)
return err return err
} }
// ParseEndpoint creates a new endpoint from a string. func (s *ICEBind) send6(conn *ipv6.PacketConn, ep wgConn.Endpoint, buffs [][]byte) error {
func (b *ICEBind) ParseEndpoint(s string) (ep conn.Endpoint, err error) { ua := s.udpAddrPool.Get().(*net.UDPAddr)
e, err := netip.ParseAddrPort(s) as16 := ep.DstIP().As16()
return asEndpoint(e), err copy(ua.IP, as16[:])
ua.IP = ua.IP[:16]
ua.Port = int(ep.(*StdNetEndpoint).Port())
msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
for i, buff := range buffs {
(*msgs)[i].Buffers[0] = buff
(*msgs)[i].Addr = ua
setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint))
}
var (
n int
err error
start int
)
for {
n, err = conn.WriteBatch((*msgs)[start:len(buffs)], 0)
if err != nil || n == len((*msgs)[start:len(buffs)]) {
break
}
start += n
}
s.udpAddrPool.Put(ua)
s.ipv6MsgsPool.Put(msgs)
return err
}
func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr) (bool, error) {
for _, buffer := range buffers {
if !stun.IsMessage(buffer) {
continue
}
msg, err := parseSTUNMessage(buffer[:n])
if err != nil {
buffer = []byte{}
return true, err
}
muxErr := s.udpMux.HandleSTUNMessage(msg, addr)
if muxErr != nil {
log.Warnf("failed to handle packet")
}
buffer = []byte{}
return true, nil
}
return false, nil
}
func (s *ICEBind) handlePkgs(msg *ipv4.Message) (int, *StdNetEndpoint) {
// todo: handle err
size := 0
ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr)
if !ok {
size = msg.N
}
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := asEndpoint(addrPort)
getSrcFromControl(msg.OOB, ep)
return size, ep
} }
// endpointPool contains a re-usable set of mapping from netip.AddrPort to Endpoint. // endpointPool contains a re-usable set of mapping from netip.AddrPort to Endpoint.
@@ -191,18 +397,29 @@ func (b *ICEBind) ParseEndpoint(s string) (ep conn.Endpoint, err error) {
// but Endpoints are immutable, so we can re-use them. // but Endpoints are immutable, so we can re-use them.
var endpointPool = sync.Pool{ var endpointPool = sync.Pool{
New: func() any { New: func() any {
return make(map[netip.AddrPort]conn.Endpoint) return make(map[netip.AddrPort]*StdNetEndpoint)
}, },
} }
// asEndpoint returns an Endpoint containing ap. // asEndpoint returns an Endpoint containing ap.
func asEndpoint(ap netip.AddrPort) conn.Endpoint { func asEndpoint(ap netip.AddrPort) *StdNetEndpoint {
m := endpointPool.Get().(map[netip.AddrPort]conn.Endpoint) m := endpointPool.Get().(map[netip.AddrPort]*StdNetEndpoint)
defer endpointPool.Put(m) defer endpointPool.Put(m)
e, ok := m[ap] e, ok := m[ap]
if !ok { if !ok {
e = conn.Endpoint(conn.StdNetEndpoint(ap)) e = &StdNetEndpoint{AddrPort: ap}
m[ap] = e m[ap] = e
} }
return e return e
} }
func parseSTUNMessage(raw []byte) (*stun.Message, error) {
msg := &stun.Message{
Raw: raw,
}
if err := msg.Decode(); err != nil {
return nil, err
}
return msg, nil
}

36
iface/bind/controlfns.go Normal file
View File

@@ -0,0 +1,36 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package bind
import (
"net"
"syscall"
)
// controlFn is the callback function signature from net.ListenConfig.Control.
// It is used to apply platform specific configuration to the socket prior to
// bind.
type controlFn func(network, address string, c syscall.RawConn) error
// controlFns is a list of functions that are called from the listen config
// that can apply socket options.
var controlFns = []controlFn{}
// listenConfig returns a net.ListenConfig that applies the controlFns to the
// socket prior to bind. This is used to apply socket buffer sizing and packet
// information OOB configuration for sticky sockets.
func listenConfig() *net.ListenConfig {
return &net.ListenConfig{
Control: func(network, address string, c syscall.RawConn) error {
for _, fn := range controlFns {
if err := fn(network, address, c); err != nil {
return err
}
}
return nil
},
}
}

View File

@@ -0,0 +1,41 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package bind
import (
"fmt"
"syscall"
"golang.org/x/sys/unix"
)
func init() {
controlFns = append(controlFns,
// Enable receiving of the packet information (IP_PKTINFO for IPv4,
// IPV6_PKTINFO for IPv6) that is used to implement sticky socket support.
func(network, address string, c syscall.RawConn) error {
var err error
switch network {
case "udp4":
c.Control(func(fd uintptr) {
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO, 1)
})
case "udp6":
c.Control(func(fd uintptr) {
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO, 1)
if err != nil {
return
}
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1)
})
default:
err = fmt.Errorf("unhandled network: %s: %w", network, unix.EINVAL)
}
return err
},
)
}

View File

@@ -0,0 +1,28 @@
//go:build !windows && !linux && !js
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package bind
import (
"syscall"
"golang.org/x/sys/unix"
)
func init() {
controlFns = append(controlFns,
func(network, address string, c syscall.RawConn) error {
var err error
if network == "udp6" {
c.Control(func(fd uintptr) {
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1)
})
}
return err
},
)
}

View File

@@ -0,0 +1,12 @@
//go:build !linux && !openbsd && !freebsd
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package bind
func (s *ICEBind) SetMark(mark uint32) error {
return nil
}

65
iface/bind/mark_unix.go Normal file
View File

@@ -0,0 +1,65 @@
//go:build linux || openbsd || freebsd
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package bind
import (
"runtime"
"golang.org/x/sys/unix"
)
var fwmarkIoctl int
func init() {
switch runtime.GOOS {
case "linux", "android":
fwmarkIoctl = 36 /* unix.SO_MARK */
case "freebsd":
fwmarkIoctl = 0x1015 /* unix.SO_USER_COOKIE */
case "openbsd":
fwmarkIoctl = 0x1021 /* unix.SO_RTABLE */
}
}
func (s *ICEBind) SetMark(mark uint32) error {
var operr error
if fwmarkIoctl == 0 {
return nil
}
if s.ipv4 != nil {
fd, err := s.ipv4.SyscallConn()
if err != nil {
return err
}
err = fd.Control(func(fd uintptr) {
operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark))
})
if err == nil {
err = operr
}
if err != nil {
return err
}
}
if s.ipv6 != nil {
fd, err := s.ipv6.SyscallConn()
if err != nil {
return err
}
err = fd.Control(func(fd uintptr) {
operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark))
})
if err == nil {
err = operr
}
if err != nil {
return err
}
}
return nil
}

View File

@@ -0,0 +1,28 @@
//go:build !linux
// +build !linux
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package bind
import wgConn "golang.zx2c4.com/wireguard/conn"
// TODO: macOS, FreeBSD and other BSDs likely do support this feature set, but
// use alternatively named flags and need ports and require testing.
// getSrcFromControl parses the control for PKTINFO and if found updates ep with
// the source information found.
func getSrcFromControl(control []byte, ep *wgConn.StdNetEndpoint) {
}
// setSrcControl parses the control for PKTINFO and if found updates ep with
// the source information found.
func setSrcControl(control *[]byte, ep *wgConn.StdNetEndpoint) {
}
// srcControlSize returns the recommended buffer size for pooling sticky control
// data.
const srcControlSize = 0

111
iface/bind/sticky_linux.go Normal file
View File

@@ -0,0 +1,111 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package bind
import (
"net/netip"
"unsafe"
"golang.org/x/sys/unix"
)
// getSrcFromControl parses the control for PKTINFO and if found updates ep with
// the source information found.
func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
ep.ClearSrc()
var (
hdr unix.Cmsghdr
data []byte
rem []byte = control
err error
)
for len(rem) > unix.SizeofCmsghdr {
hdr, data, rem, err = unix.ParseOneSocketControlMessage(control)
if err != nil {
return
}
if hdr.Level == unix.IPPROTO_IP &&
hdr.Type == unix.IP_PKTINFO {
info := pktInfoFromBuf[unix.Inet4Pktinfo](data)
ep.src.Addr = netip.AddrFrom4(info.Spec_dst)
ep.src.ifidx = info.Ifindex
return
}
if hdr.Level == unix.IPPROTO_IPV6 &&
hdr.Type == unix.IPV6_PKTINFO {
info := pktInfoFromBuf[unix.Inet6Pktinfo](data)
ep.src.Addr = netip.AddrFrom16(info.Addr)
ep.src.ifidx = int32(info.Ifindex)
return
}
}
}
// pktInfoFromBuf returns type T populated from the provided buf via copy(). It
// panics if buf is of insufficient size.
func pktInfoFromBuf[T unix.Inet4Pktinfo | unix.Inet6Pktinfo](buf []byte) (t T) {
size := int(unsafe.Sizeof(t))
if len(buf) < size {
panic("pktInfoFromBuf: buffer too small")
}
copy(unsafe.Slice((*byte)(unsafe.Pointer(&t)), size), buf)
return t
}
// setSrcControl parses the control for PKTINFO and if found updates ep with
// the source information found.
func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
*control = (*control)[:cap(*control)]
if len(*control) < int(unsafe.Sizeof(unix.Cmsghdr{})) {
*control = (*control)[:0]
return
}
if ep.src.ifidx == 0 && !ep.SrcIP().IsValid() {
*control = (*control)[:0]
return
}
if len(*control) < srcControlSize {
*control = (*control)[:0]
return
}
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(*control)[0]))
if ep.SrcIP().Is4() {
hdr.Level = unix.IPPROTO_IP
hdr.Type = unix.IP_PKTINFO
hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo))
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&(*control)[unix.SizeofCmsghdr]))
info.Ifindex = ep.src.ifidx
if ep.SrcIP().IsValid() {
info.Spec_dst = ep.SrcIP().As4()
}
} else {
hdr.Level = unix.IPPROTO_IPV6
hdr.Type = unix.IPV6_PKTINFO
hdr.Len = unix.SizeofCmsghdr + unix.SizeofInet6Pktinfo
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&(*control)[unix.SizeofCmsghdr]))
info.Ifindex = uint32(ep.src.ifidx)
if ep.SrcIP().IsValid() {
info.Addr = ep.SrcIP().As16()
}
}
*control = (*control)[:hdr.Len]
}
var srcControlSize = unix.CmsgLen(unix.SizeofInet6Pktinfo)

View File

@@ -0,0 +1,207 @@
//go:build linux
// +build linux
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package bind
import (
"context"
"net"
"net/netip"
"runtime"
"testing"
"unsafe"
"golang.org/x/sys/unix"
)
func Test_setSrcControl(t *testing.T) {
t.Run("IPv4", func(t *testing.T) {
ep := &StdNetEndpoint{
AddrPort: netip.MustParseAddrPort("127.0.0.1:1234"),
}
ep.src.Addr = netip.MustParseAddr("127.0.0.1")
ep.src.ifidx = 5
control := make([]byte, srcControlSize)
setSrcControl(&control, ep)
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
if hdr.Level != unix.IPPROTO_IP {
t.Errorf("unexpected level: %d", hdr.Level)
}
if hdr.Type != unix.IP_PKTINFO {
t.Errorf("unexpected type: %d", hdr.Type)
}
if hdr.Len != uint64(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{})))) {
t.Errorf("unexpected length: %d", hdr.Len)
}
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
if info.Spec_dst[0] != 127 || info.Spec_dst[1] != 0 || info.Spec_dst[2] != 0 || info.Spec_dst[3] != 1 {
t.Errorf("unexpected address: %v", info.Spec_dst)
}
if info.Ifindex != 5 {
t.Errorf("unexpected ifindex: %d", info.Ifindex)
}
})
t.Run("IPv6", func(t *testing.T) {
ep := &StdNetEndpoint{
AddrPort: netip.MustParseAddrPort("[::1]:1234"),
}
ep.src.Addr = netip.MustParseAddr("::1")
ep.src.ifidx = 5
control := make([]byte, srcControlSize)
setSrcControl(&control, ep)
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
if hdr.Level != unix.IPPROTO_IPV6 {
t.Errorf("unexpected level: %d", hdr.Level)
}
if hdr.Type != unix.IPV6_PKTINFO {
t.Errorf("unexpected type: %d", hdr.Type)
}
if hdr.Len != uint64(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{})))) {
t.Errorf("unexpected length: %d", hdr.Len)
}
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
if info.Addr != ep.SrcIP().As16() {
t.Errorf("unexpected address: %v", info.Addr)
}
if info.Ifindex != 5 {
t.Errorf("unexpected ifindex: %d", info.Ifindex)
}
})
t.Run("ClearOnNoSrc", func(t *testing.T) {
control := make([]byte, srcControlSize)
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
hdr.Level = 1
hdr.Type = 2
hdr.Len = 3
setSrcControl(&control, &StdNetEndpoint{})
if len(control) != 0 {
t.Errorf("unexpected control: %v", control)
}
})
}
func Test_getSrcFromControl(t *testing.T) {
t.Run("IPv4", func(t *testing.T) {
control := make([]byte, srcControlSize)
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
hdr.Level = unix.IPPROTO_IP
hdr.Type = unix.IP_PKTINFO
hdr.Len = uint64(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{}))))
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
info.Spec_dst = [4]byte{127, 0, 0, 1}
info.Ifindex = 5
ep := &StdNetEndpoint{}
getSrcFromControl(control, ep)
if ep.src.Addr != netip.MustParseAddr("127.0.0.1") {
t.Errorf("unexpected address: %v", ep.src.Addr)
}
if ep.src.ifidx != 5 {
t.Errorf("unexpected ifindex: %d", ep.src.ifidx)
}
})
t.Run("IPv6", func(t *testing.T) {
control := make([]byte, srcControlSize)
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
hdr.Level = unix.IPPROTO_IPV6
hdr.Type = unix.IPV6_PKTINFO
hdr.Len = uint64(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{}))))
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
info.Addr = [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
info.Ifindex = 5
ep := &StdNetEndpoint{}
getSrcFromControl(control, ep)
if ep.SrcIP() != netip.MustParseAddr("::1") {
t.Errorf("unexpected address: %v", ep.SrcIP())
}
if ep.src.ifidx != 5 {
t.Errorf("unexpected ifindex: %d", ep.src.ifidx)
}
})
t.Run("ClearOnEmpty", func(t *testing.T) {
control := make([]byte, srcControlSize)
ep := &StdNetEndpoint{}
ep.src.Addr = netip.MustParseAddr("::1")
ep.src.ifidx = 5
getSrcFromControl(control, ep)
if ep.SrcIP().IsValid() {
t.Errorf("unexpected address: %v", ep.src.Addr)
}
if ep.src.ifidx != 0 {
t.Errorf("unexpected ifindex: %d", ep.src.ifidx)
}
})
}
func Test_listenConfig(t *testing.T) {
t.Run("IPv4", func(t *testing.T) {
conn, err := listenConfig().ListenPacket(context.Background(), "udp4", ":0")
if err != nil {
t.Fatal(err)
}
defer conn.Close()
sc, err := conn.(*net.UDPConn).SyscallConn()
if err != nil {
t.Fatal(err)
}
if runtime.GOOS == "linux" {
var i int
sc.Control(func(fd uintptr) {
i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO)
})
if err != nil {
t.Fatal(err)
}
if i != 1 {
t.Error("IP_PKTINFO not set!")
}
} else {
t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS)
}
})
t.Run("IPv6", func(t *testing.T) {
conn, err := listenConfig().ListenPacket(context.Background(), "udp6", ":0")
if err != nil {
t.Fatal(err)
}
sc, err := conn.(*net.UDPConn).SyscallConn()
if err != nil {
t.Fatal(err)
}
if runtime.GOOS == "linux" {
var i int
sc.Control(func(fd uintptr) {
i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO)
})
if err != nil {
t.Fatal(err)
}
if i != 1 {
t.Error("IPV6_PKTINFO not set!")
}
} else {
t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS)
}
})
}

View File

@@ -2,15 +2,16 @@ package bind
import ( import (
"fmt" "fmt"
"github.com/pion/ice/v2"
"github.com/pion/stun"
"github.com/pion/transport/v2/stdnet"
log "github.com/sirupsen/logrus"
"io" "io"
"net" "net"
"strings" "strings"
"sync" "sync"
"github.com/pion/ice/v2"
"github.com/pion/stun"
"github.com/pion/transport/v2/stdnet"
log "github.com/sirupsen/logrus"
"github.com/pion/logging" "github.com/pion/logging"
"github.com/pion/transport/v2" "github.com/pion/transport/v2"
) )

View File

@@ -6,10 +6,11 @@ package bind
import ( import (
"fmt" "fmt"
log "github.com/sirupsen/logrus"
"net" "net"
"time" "time"
log "github.com/sirupsen/logrus"
"github.com/pion/logging" "github.com/pion/logging"
"github.com/pion/stun" "github.com/pion/stun"
"github.com/pion/transport/v2" "github.com/pion/transport/v2"

56
iface/bind/worker.go Normal file
View File

@@ -0,0 +1,56 @@
package bind
import (
"runtime"
"golang.org/x/net/ipv4"
wgConn "golang.zx2c4.com/wireguard/conn"
)
// todo: add close function
type worker struct {
jobOffer chan int
numOfWorker int
jobFn func(msg *ipv4.Message) (int, *StdNetEndpoint)
messages []ipv4.Message
sizes []int
eps []wgConn.Endpoint
}
func newWorker(jobFn func(msg *ipv4.Message) (int, *StdNetEndpoint)) *worker {
w := &worker{
jobOffer: make(chan int),
numOfWorker: runtime.NumCPU(),
jobFn: jobFn,
}
w.populateWorkers()
return w
}
func (w *worker) doWork(messages []ipv4.Message, sizes []int, eps []wgConn.Endpoint) {
w.messages = messages
w.sizes = sizes
w.eps = eps
for i := 0; i < len(messages); i++ {
w.jobOffer <- i
}
}
func (w *worker) populateWorkers() {
for i := 0; i < w.numOfWorker; i++ {
go w.loop()
}
}
func (w *worker) loop() {
for {
select {
case msgPos := <-w.jobOffer:
w.sizes[msgPos], w.eps[msgPos] = w.jobFn(&w.messages[msgPos])
}
}
}

View File

@@ -8,13 +8,12 @@ import (
"github.com/netbirdio/netbird/iface/bind" "github.com/netbirdio/netbird/iface/bind"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
const ( const (
DefaultMTU = 1280 DefaultMTU = 1280
DefaultWgPort = 51820 DefaultWgPort = 51820
defaultWgKeepAlive = 25 * time.Second
) )
// WGIface represents a interface instance // WGIface represents a interface instance
@@ -78,12 +77,12 @@ func (w *WGIface) UpdateAddr(newAddr string) error {
// UpdatePeer updates existing Wireguard Peer or creates a new one if doesn't exist // UpdatePeer updates existing Wireguard Peer or creates a new one if doesn't exist
// Endpoint is optional // Endpoint is optional
func (w *WGIface) UpdatePeer(peerKey string, allowedIps string, endpoint *net.UDPAddr) error { func (w *WGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
log.Debugf("updating interface %s peer %s: endpoint %s ", w.tun.DeviceName(), peerKey, endpoint) log.Debugf("updating interface %s peer %s: endpoint %s ", w.tun.DeviceName(), peerKey, endpoint)
return w.configurer.updatePeer(peerKey, allowedIps, endpoint) return w.configurer.updatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
} }
// RemovePeer removes a Wireguard Peer from the interface iface // RemovePeer removes a Wireguard Peer from the interface iface

View File

@@ -7,7 +7,7 @@ import (
) )
// NewWGIFace Creates a new WireGuard interface instance // NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(iFaceName string, address string, mtu int, preSharedKey *wgtypes.Key, tunAdapter TunAdapter, transportNet transport.Net) (*WGIface, error) { func NewWGIFace(ifaceName string, address string, mtu int, routes []string, tunAdapter TunAdapter, transportNet transport.Net) (*WGIface, error) {
wgIFace := &WGIface{ wgIFace := &WGIface{
mu: sync.Mutex{}, mu: sync.Mutex{},
} }
@@ -17,10 +17,10 @@ func NewWGIFace(iFaceName string, address string, mtu int, preSharedKey *wgtypes
return wgIFace, err return wgIFace, err
} }
tun := newTunDevice(wgAddress, mtu, tunAdapter, transportNet) tun := newTunDevice(wgAddress, mtu, routes, tunAdapter, transportNet)
wgIFace.tun = tun wgIFace.tun = tun
wgIFace.configurer = newWGConfigurer(tun, preSharedKey) wgIFace.configurer = newWGConfigurer(tun)
wgIFace.userspaceBind = !WireGuardModuleIsLoaded() wgIFace.userspaceBind = !WireGuardModuleIsLoaded()

View File

@@ -3,14 +3,13 @@
package iface package iface
import ( import (
"github.com/pion/transport/v2"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"sync" "sync"
"github.com/pion/transport/v2"
) )
// NewWGIFace Creates a new WireGuard interface instance // NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(iFaceName string, address string, mtu int, preSharedKey *wgtypes.Key, tunAdapter TunAdapter, transportNet transport.Net) (*WGIface, error) { func NewWGIFace(iFaceName string, address string, mtu int, routes []string, tunAdapter TunAdapter, transportNet transport.Net) (*WGIface, error) {
wgIFace := &WGIface{ wgIFace := &WGIface{
mu: sync.Mutex{}, mu: sync.Mutex{},
} }
@@ -22,7 +21,7 @@ func NewWGIFace(iFaceName string, address string, mtu int, preSharedKey *wgtypes
wgIFace.tun = newTunDevice(iFaceName, wgAddress, mtu, transportNet) wgIFace.tun = newTunDevice(iFaceName, wgAddress, mtu, transportNet)
wgIFace.configurer = newWGConfigurer(iFaceName, preSharedKey) wgIFace.configurer = newWGConfigurer(iFaceName)
wgIFace.userspaceBind = !WireGuardModuleIsLoaded() wgIFace.userspaceBind = !WireGuardModuleIsLoaded()
return wgIFace, nil return wgIFace, nil
} }

View File

@@ -2,14 +2,15 @@ package iface
import ( import (
"fmt" "fmt"
"net"
"testing"
"time"
"github.com/pion/transport/v2/stdnet" "github.com/pion/transport/v2/stdnet"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"net"
"testing"
"time"
) )
// keep darwin compability // keep darwin compability
@@ -38,7 +39,7 @@ func TestWGIface_UpdateAddr(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
iface, err := NewWGIFace(ifaceName, addr, DefaultMTU, nil, newNet) iface, err := NewWGIFace(ifaceName, addr, DefaultMTU, nil, nil, newNet)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -102,7 +103,7 @@ func Test_CreateInterface(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, newNet) iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, nil, newNet)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -135,7 +136,7 @@ func Test_Close(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, newNet) iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, nil, newNet)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -167,7 +168,7 @@ func Test_ConfigureInterface(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, newNet) iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, nil, newNet)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -218,7 +219,7 @@ func Test_UpdatePeer(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, newNet) iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, nil, newNet)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -281,7 +282,7 @@ func Test_RemovePeer(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, newNet) iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, nil, newNet)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -334,7 +335,7 @@ func Test_ConnectPeers(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
iface1, err := NewWGIFace(peer1ifaceName, peer1wgIP, DefaultMTU, nil, newNet) iface1, err := NewWGIFace(peer1ifaceName, peer1wgIP, DefaultMTU, nil, nil, newNet)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -355,7 +356,7 @@ func Test_ConnectPeers(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
iface2, err := NewWGIFace(peer2ifaceName, peer2wgIP, DefaultMTU, nil, newNet) iface2, err := NewWGIFace(peer2ifaceName, peer2wgIP, DefaultMTU, nil, nil, newNet)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -7,9 +7,6 @@ import (
"bufio" "bufio"
"errors" "errors"
"fmt" "fmt"
log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
"io" "io"
"io/fs" "io/fs"
"math" "math"
@@ -17,6 +14,10 @@ import (
"path/filepath" "path/filepath"
"strings" "strings"
"syscall" "syscall"
log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
) )
// Holds logic to check existence of kernel modules used by wireguard interfaces // Holds logic to check existence of kernel modules used by wireguard interfaces

View File

@@ -3,13 +3,14 @@ package iface
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"github.com/stretchr/testify/require"
"golang.org/x/sys/unix"
"io" "io"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
"testing" "testing"
"github.com/stretchr/testify/require"
"golang.org/x/sys/unix"
) )
func TestGetModuleDependencies(t *testing.T) { func TestGetModuleDependencies(t *testing.T) {

View File

@@ -2,6 +2,6 @@ package iface
// TunAdapter is an interface for create tun device from externel service // TunAdapter is an interface for create tun device from externel service
type TunAdapter interface { type TunAdapter interface {
ConfigureInterface(address string, mtu int) (int, error) ConfigureInterface(address string, mtu int, routes string) (int, error)
UpdateAddr(address string) error UpdateAddr(address string) error
} }

View File

@@ -1,33 +1,34 @@
package iface package iface
import ( import (
"github.com/netbirdio/netbird/iface/bind" "strings"
"github.com/pion/transport/v2"
"net"
"github.com/pion/transport/v2"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/ipc"
"golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun"
"github.com/netbirdio/netbird/iface/bind"
) )
type tunDevice struct { type tunDevice struct {
address WGAddress address WGAddress
mtu int mtu int
routes []string
tunAdapter TunAdapter tunAdapter TunAdapter
fd int fd int
name string name string
device *device.Device device *device.Device
uapi net.Listener
iceBind *bind.ICEBind iceBind *bind.ICEBind
} }
func newTunDevice(address WGAddress, mtu int, tunAdapter TunAdapter, transportNet transport.Net) *tunDevice { func newTunDevice(address WGAddress, mtu int, routes []string, tunAdapter TunAdapter, transportNet transport.Net) *tunDevice {
return &tunDevice{ return &tunDevice{
address: address, address: address,
mtu: mtu, mtu: mtu,
routes: routes,
tunAdapter: tunAdapter, tunAdapter: tunAdapter,
iceBind: bind.NewICEBind(transportNet), iceBind: bind.NewICEBind(transportNet),
} }
@@ -35,7 +36,8 @@ func newTunDevice(address WGAddress, mtu int, tunAdapter TunAdapter, transportNe
func (t *tunDevice) Create() error { func (t *tunDevice) Create() error {
var err error var err error
t.fd, err = t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu) routesString := t.routesToString()
t.fd, err = t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, routesString)
if err != nil { if err != nil {
log.Errorf("failed to create Android interface: %s", err) log.Errorf("failed to create Android interface: %s", err)
return err return err
@@ -52,32 +54,8 @@ func (t *tunDevice) Create() error {
t.device = device.NewDevice(tunDevice, t.iceBind, device.NewLogger(device.LogLevelSilent, "[wiretrustee] ")) t.device = device.NewDevice(tunDevice, t.iceBind, device.NewLogger(device.LogLevelSilent, "[wiretrustee] "))
t.device.DisableSomeRoamingForBrokenMobileSemantics() t.device.DisableSomeRoamingForBrokenMobileSemantics()
log.Debugf("create uapi")
tunSock, err := ipc.UAPIOpen(name)
if err != nil {
return err
}
t.uapi, err = ipc.UAPIListen(name, tunSock)
if err != nil {
tunSock.Close()
unix.Close(t.fd)
return err
}
go func() {
for {
uapiConn, err := t.uapi.Accept()
if err != nil {
return
}
go t.device.IpcHandle(uapiConn)
}
}()
err = t.device.Up() err = t.device.Up()
if err != nil { if err != nil {
tunSock.Close()
t.device.Close() t.device.Close()
return err return err
} }
@@ -103,13 +81,13 @@ func (t *tunDevice) UpdateAddr(addr WGAddress) error {
} }
func (t *tunDevice) Close() (err error) { func (t *tunDevice) Close() (err error) {
if t.uapi != nil {
err = t.uapi.Close()
}
if t.device != nil { if t.device != nil {
t.device.Close() t.device.Close()
} }
return return
} }
func (t *tunDevice) routesToString() string {
return strings.Join(t.routes, ";")
}

View File

@@ -3,14 +3,16 @@
package iface package iface
import ( import (
"github.com/netbirdio/netbird/iface/bind"
"github.com/pion/transport/v2"
"net" "net"
"os" "os"
"github.com/pion/transport/v2"
"golang.zx2c4.com/wireguard/ipc"
"github.com/netbirdio/netbird/iface/bind"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/ipc"
"golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun"
) )
@@ -20,6 +22,8 @@ type tunDevice struct {
mtu int mtu int
netInterface NetInterface netInterface NetInterface
iceBind *bind.ICEBind iceBind *bind.ICEBind
uapi net.Listener
close chan struct{}
} }
func newTunDevice(name string, address WGAddress, mtu int, transportNet transport.Net) *tunDevice { func newTunDevice(name string, address WGAddress, mtu int, transportNet transport.Net) *tunDevice {
@@ -28,6 +32,7 @@ func newTunDevice(name string, address WGAddress, mtu int, transportNet transpor
address: address, address: address,
mtu: mtu, mtu: mtu,
iceBind: bind.NewICEBind(transportNet), iceBind: bind.NewICEBind(transportNet),
close: make(chan struct{}),
} }
} }
@@ -45,23 +50,38 @@ func (c *tunDevice) DeviceName() string {
} }
func (c *tunDevice) Close() error { func (c *tunDevice) Close() error {
if c.netInterface == nil {
return nil select {
case c.close <- struct{}{}:
default:
} }
err := c.netInterface.Close()
if err != nil { var err1, err2, err3 error
return err if c.netInterface != nil {
err1 = c.netInterface.Close()
}
if c.uapi != nil {
err2 = c.uapi.Close()
} }
sockPath := "/var/run/wireguard/" + c.name + ".sock" sockPath := "/var/run/wireguard/" + c.name + ".sock"
if _, statErr := os.Stat(sockPath); statErr == nil { if _, statErr := os.Stat(sockPath); statErr == nil {
statErr = os.Remove(sockPath) statErr = os.Remove(sockPath)
if statErr != nil { if statErr != nil {
return statErr err3 = statErr
} }
} }
return nil if err1 != nil {
return err1
}
if err2 != nil {
return err2
}
return err3
} }
// createWithUserspace Creates a new Wireguard interface, using wireguard-go userspace implementation // createWithUserspace Creates a new Wireguard interface, using wireguard-go userspace implementation
@@ -72,26 +92,36 @@ func (c *tunDevice) createWithUserspace() (NetInterface, error) {
} }
// We need to create a wireguard-go device and listen to configuration requests // We need to create a wireguard-go device and listen to configuration requests
tunDevice := device.NewDevice(tunIface, c.iceBind, device.NewLogger(device.LogLevelSilent, "[wiretrustee] ")) tunDev := device.NewDevice(tunIface, c.iceBind, device.NewLogger(device.LogLevelSilent, "[netbird] "))
err = tunDevice.Up() err = tunDev.Up()
if err != nil { if err != nil {
return tunIface, err _ = tunIface.Close()
return nil, err
} }
// todo: after this line in case of error close the tunSock c.uapi, err = c.getUAPI(c.name)
uapi, err := c.getUAPI(c.name)
if err != nil { if err != nil {
return tunIface, err _ = tunIface.Close()
return nil, err
} }
go func() { go func() {
for { for {
uapiConn, uapiErr := uapi.Accept() select {
case <-c.close:
log.Debugf("exit uapi.Accept()")
return
default:
}
uapiConn, uapiErr := c.uapi.Accept()
if uapiErr != nil { if uapiErr != nil {
log.Traceln("uapi Accept failed with error: ", uapiErr) log.Traceln("uapi Accept failed with error: ", uapiErr)
continue continue
} }
go tunDevice.IpcHandle(uapiConn) go func() {
tunDev.IpcHandle(uapiConn)
log.Debugf("exit tunDevice.IpcHandle")
}()
} }
}() }()

View File

@@ -2,27 +2,36 @@ package iface
import ( import (
"fmt" "fmt"
"github.com/netbirdio/netbird/iface/bind" "net"
"github.com/pion/transport/v2" "github.com/pion/transport/v2"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/ipc" "golang.zx2c4.com/wireguard/ipc"
"golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
"net"
"github.com/netbirdio/netbird/iface/bind"
) )
type tunDevice struct { type tunDevice struct {
name string name string
address WGAddress address WGAddress
netInterface NetInterface netInterface NetInterface
uapi net.Listener
iceBind *bind.ICEBind iceBind *bind.ICEBind
mtu int mtu int
uapi net.Listener
close chan struct{}
} }
func newTunDevice(name string, address WGAddress, mtu int, transportNet transport.Net) *tunDevice { func newTunDevice(name string, address WGAddress, mtu int, transportNet transport.Net) *tunDevice {
return &tunDevice{name: name, address: address, iceBind: bind.NewICEBind(transportNet), mtu: mtu} return &tunDevice{
name: name,
address: address,
mtu: mtu,
iceBind: bind.NewICEBind(transportNet),
close: make(chan struct{}),
}
} }
func (c *tunDevice) Create() error { func (c *tunDevice) Create() error {
@@ -35,34 +44,44 @@ func (c *tunDevice) Create() error {
return c.assignAddr() return c.assignAddr()
} }
// createWithUserspace Creates a new WireGuard interface, using wireguard-go userspace implementation // createWithUserspace Creates a new Wireguard interface, using wireguard-go userspace implementation
func (c *tunDevice) createWithUserspace() (NetInterface, error) { func (c *tunDevice) createWithUserspace() (NetInterface, error) {
tunIface, err := tun.CreateTUN(c.name, c.mtu) tunIface, err := tun.CreateTUN(c.name, c.mtu)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// We need to create a wireguard-go device and listen to configuration requests
tunDevice := device.NewDevice(tunIface, c.iceBind, device.NewLogger(device.LogLevelSilent, "[netbird] "))
err = tunDevice.Up()
if err != nil {
return tunIface, err
}
uapi, err := c.getUAPI(c.name) // We need to create a wireguard-go device and listen to configuration requests
tunDev := device.NewDevice(tunIface, c.iceBind, device.NewLogger(device.LogLevelSilent, "[netbird] "))
err = tunDev.Up()
if err != nil { if err != nil {
_ = tunIface.Close()
return nil, err
}
c.uapi, err = c.getUAPI(c.name)
if err != nil {
_ = tunIface.Close()
return nil, err return nil, err
} }
c.uapi = uapi
go func() { go func() {
for { for {
uapiConn, uapiErr := uapi.Accept() select {
case <-c.close:
log.Debugf("exit uapi.Accept()")
return
default:
}
uapiConn, uapiErr := c.uapi.Accept()
if uapiErr != nil { if uapiErr != nil {
log.Traceln("uapi accept failed with error: ", uapiErr) log.Traceln("uapi Accept failed with error: ", uapiErr)
continue continue
} }
go tunDevice.IpcHandle(uapiConn) go func() {
tunDev.IpcHandle(uapiConn)
log.Debugf("exit tunDevice.IpcHandle")
}()
} }
}() }()
@@ -84,6 +103,11 @@ func (c *tunDevice) DeviceName() string {
} }
func (c *tunDevice) Close() error { func (c *tunDevice) Close() error {
select {
case c.close <- struct{}{}:
default:
}
var err1, err2 error var err1, err2 error
if c.netInterface != nil { if c.netInterface != nil {
err1 = c.netInterface.Close() err1 = c.netInterface.Close()

View File

@@ -3,6 +3,7 @@ package iface
import ( import (
"errors" "errors"
"net" "net"
"time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -14,14 +15,12 @@ var (
) )
type wGConfigurer struct { type wGConfigurer struct {
tunDevice *tunDevice tunDevice *tunDevice
preSharedKey *wgtypes.Key
} }
func newWGConfigurer(tunDevice *tunDevice, preSharedKey *wgtypes.Key) wGConfigurer { func newWGConfigurer(tunDevice *tunDevice) wGConfigurer {
return wGConfigurer{ return wGConfigurer{
tunDevice: tunDevice, tunDevice: tunDevice,
preSharedKey: preSharedKey,
} }
} }
@@ -42,7 +41,7 @@ func (c *wGConfigurer) configureInterface(privateKey string, port int) error {
return c.tunDevice.Device().IpcSet(toWgUserspaceString(config)) return c.tunDevice.Device().IpcSet(toWgUserspaceString(config))
} }
func (c *wGConfigurer) updatePeer(peerKey string, allowedIps string, endpoint *net.UDPAddr) error { func (c *wGConfigurer) updatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
//parse allowed ips //parse allowed ips
_, ipNet, err := net.ParseCIDR(allowedIps) _, ipNet, err := net.ParseCIDR(allowedIps)
if err != nil { if err != nil {
@@ -53,13 +52,12 @@ func (c *wGConfigurer) updatePeer(peerKey string, allowedIps string, endpoint *n
if err != nil { if err != nil {
return err return err
} }
keepalive := defaultWgKeepAlive
peer := wgtypes.PeerConfig{ peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed, PublicKey: peerKeyParsed,
ReplaceAllowedIPs: true, ReplaceAllowedIPs: true,
AllowedIPs: []net.IPNet{*ipNet}, AllowedIPs: []net.IPNet{*ipNet},
PersistentKeepaliveInterval: &keepAlive, PersistentKeepaliveInterval: &keepAlive,
PresharedKey: c.preSharedKey, PresharedKey: preSharedKey,
Endpoint: endpoint, Endpoint: endpoint,
} }

View File

@@ -5,6 +5,7 @@ package iface
import ( import (
"fmt" "fmt"
"net" "net"
"time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
@@ -12,14 +13,12 @@ import (
) )
type wGConfigurer struct { type wGConfigurer struct {
deviceName string deviceName string
preSharedKey *wgtypes.Key
} }
func newWGConfigurer(deviceName string, preSharedKey *wgtypes.Key) wGConfigurer { func newWGConfigurer(deviceName string) wGConfigurer {
return wGConfigurer{ return wGConfigurer{
deviceName: deviceName, deviceName: deviceName,
preSharedKey: preSharedKey,
} }
} }
@@ -44,7 +43,7 @@ func (c *wGConfigurer) configureInterface(privateKey string, port int) error {
return nil return nil
} }
func (c *wGConfigurer) updatePeer(peerKey string, allowedIps string, endpoint *net.UDPAddr) error { func (c *wGConfigurer) updatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
//parse allowed ips //parse allowed ips
_, ipNet, err := net.ParseCIDR(allowedIps) _, ipNet, err := net.ParseCIDR(allowedIps)
if err != nil { if err != nil {
@@ -55,13 +54,12 @@ func (c *wGConfigurer) updatePeer(peerKey string, allowedIps string, endpoint *n
if err != nil { if err != nil {
return err return err
} }
keepalive := defaultWgKeepAlive
peer := wgtypes.PeerConfig{ peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed, PublicKey: peerKeyParsed,
ReplaceAllowedIPs: true, ReplaceAllowedIPs: true,
AllowedIPs: []net.IPNet{*ipNet}, AllowedIPs: []net.IPNet{*ipNet},
PersistentKeepaliveInterval: &keepalive, PersistentKeepaliveInterval: &keepAlive,
PresharedKey: c.preSharedKey, PresharedKey: preSharedKey,
Endpoint: endpoint, Endpoint: endpoint,
} }

View File

@@ -172,6 +172,49 @@ func (c *GrpcClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error
return nil return nil
} }
// GetRoutes return with the routes
func (c *GrpcClient) GetRoutes() ([]*proto.Route, error) {
serverPubKey, err := c.GetServerPublicKey()
if err != nil {
log.Debugf("failed getting Management Service public key: %s", err)
return nil, err
}
ctx, cancelStream := context.WithCancel(c.ctx)
defer cancelStream()
stream, err := c.connectToStream(ctx, *serverPubKey)
if err != nil {
log.Debugf("failed to open Management Service stream: %s", err)
return nil, err
}
defer func() {
_ = stream.CloseSend()
}()
update, err := stream.Recv()
if err == io.EOF {
log.Debugf("Management stream has been closed by server: %s", err)
return nil, err
}
if err != nil {
log.Debugf("disconnected from Management Service sync stream: %v", err)
return nil, err
}
decryptedResp := &proto.SyncResponse{}
err = encryption.DecryptMessage(*serverPubKey, c.key, update.Body, decryptedResp)
if err != nil {
log.Errorf("failed decrypting update message from Management Service: %s", err)
return nil, err
}
if decryptedResp.GetNetworkMap() == nil {
return nil, fmt.Errorf("invalid msg, required network map")
}
return decryptedResp.GetNetworkMap().GetRoutes(), nil
}
func (c *GrpcClient) connectToStream(ctx context.Context, serverPubKey wgtypes.Key) (proto.ManagementService_SyncClient, error) { func (c *GrpcClient) connectToStream(ctx context.Context, serverPubKey wgtypes.Key) (proto.ManagementService_SyncClient, error) {
req := &proto.SyncRequest{} req := &proto.SyncRequest{}

View File

@@ -283,7 +283,7 @@ func (a *Account) GetGroup(groupID string) *Group {
// GetPeerNetworkMap returns a group by ID if exists, nil otherwise // GetPeerNetworkMap returns a group by ID if exists, nil otherwise
func (a *Account) GetPeerNetworkMap(peerID, dnsDomain string) *NetworkMap { func (a *Account) GetPeerNetworkMap(peerID, dnsDomain string) *NetworkMap {
aclPeers, _ := a.getPeersByPolicy(peerID) aclPeers := a.getPeersByACL(peerID)
// exclude expired peers // exclude expired peers
var peersToConnect []*Peer var peersToConnect []*Peer
var expiredPeers []*Peer var expiredPeers []*Peer

View File

@@ -286,9 +286,7 @@ func (s *FileStore) SaveAccount(account *Account) error {
s.PrivateDomain2AccountID[accountCopy.Domain] = accountCopy.Id s.PrivateDomain2AccountID[accountCopy.Domain] = accountCopy.Id
} }
if accountCopy.Rules == nil { accountCopy.Rules = make(map[string]*Rule)
accountCopy.Rules = make(map[string]*Rule)
}
for _, policy := range accountCopy.Policies { for _, policy := range accountCopy.Policies {
for _, rule := range policy.Rules { for _, rule := range policy.Rules {
accountCopy.Rules[rule.ID] = rule.ToRule() accountCopy.Rules[rule.ID] = rule.ToRule()

View File

@@ -209,7 +209,7 @@ func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*Peer, er
// fetch all the peers that have access to the user's peers // fetch all the peers that have access to the user's peers
for _, peer := range peers { for _, peer := range peers {
// TODO: use firewall rules // TODO: use firewall rules
aclPeers, _ := account.getPeersByPolicy(peer.ID) aclPeers := account.getPeersByACL(peer.ID)
for _, p := range aclPeers { for _, p := range aclPeers {
peersMap[p.ID] = p peersMap[p.ID] = p
} }
@@ -816,7 +816,7 @@ func (am *DefaultAccountManager) GetPeer(accountID, peerID, userID string) (*Pee
} }
for _, p := range userPeers { for _, p := range userPeers {
aclPeers, _ := account.getPeersByPolicy(p.ID) aclPeers := account.getPeersByACL(p.ID)
for _, aclPeer := range aclPeers { for _, aclPeer := range aclPeers {
if aclPeer.ID == peerID { if aclPeer.ID == peerID {
return peer, nil return peer, nil
@@ -833,6 +833,98 @@ func updatePeerMeta(peer *Peer, meta PeerSystemMeta, account *Account) *Peer {
return peer return peer
} }
// GetPeerRules returns a list of source or destination rules of a given peer.
func (a *Account) GetPeerRules(peerID string) (srcRules []*Rule, dstRules []*Rule) {
// Rules are group based so there is no direct access to peers.
// First, find all groups that the given peer belongs to
peerGroups := make(map[string]struct{})
for s, group := range a.Groups {
for _, peer := range group.Peers {
if peerID == peer {
peerGroups[s] = struct{}{}
break
}
}
}
// Second, find all rules that have discovered source and destination groups
srcRulesMap := make(map[string]*Rule)
dstRulesMap := make(map[string]*Rule)
for _, rule := range a.Rules {
for _, g := range rule.Source {
if _, ok := peerGroups[g]; ok && srcRulesMap[rule.ID] == nil {
srcRules = append(srcRules, rule)
srcRulesMap[rule.ID] = rule
}
}
for _, g := range rule.Destination {
if _, ok := peerGroups[g]; ok && dstRulesMap[rule.ID] == nil {
dstRules = append(dstRules, rule)
dstRulesMap[rule.ID] = rule
}
}
}
return srcRules, dstRules
}
// getPeersByACL returns all peers that given peer has access to.
func (a *Account) getPeersByACL(peerID string) []*Peer {
var peers []*Peer
srcRules, dstRules := a.GetPeerRules(peerID)
groups := map[string]*Group{}
for _, r := range srcRules {
if r.Disabled {
continue
}
if r.Flow == TrafficFlowBidirect {
for _, gid := range r.Destination {
if group, ok := a.Groups[gid]; ok {
groups[gid] = group
}
}
}
}
for _, r := range dstRules {
if r.Disabled {
continue
}
if r.Flow == TrafficFlowBidirect {
for _, gid := range r.Source {
if group, ok := a.Groups[gid]; ok {
groups[gid] = group
}
}
}
}
peersSet := make(map[string]struct{})
for _, g := range groups {
for _, pid := range g.Peers {
peer, ok := a.Peers[pid]
if !ok {
log.Warnf(
"peer %s found in group %s but doesn't belong to account %s",
pid,
g.ID,
a.Id,
)
continue
}
// exclude original peer
if _, ok := peersSet[peer.ID]; peer.ID != peerID && !ok {
peersSet[peer.ID] = struct{}{}
peers = append(peers, peer.Copy())
}
}
}
return peers
}
// updateAccountPeers updates all peers that belong to an account. // updateAccountPeers updates all peers that belong to an account.
// Should be called when changes have to be synced to peers. // Should be called when changes have to be synced to peers.
func (am *DefaultAccountManager) updateAccountPeers(account *Account) error { func (am *DefaultAccountManager) updateAccountPeers(account *Account) error {

View File

@@ -136,6 +136,8 @@ func TestAccountManager_GetNetworkMap(t *testing.T) {
} }
func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
// TODO: disable until we start use policy again
t.Skip()
manager, err := createManager(t) manager, err := createManager(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)