mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-20 01:06:45 +00:00
Compare commits
10 Commits
v0.16.0
...
wg_bind_pa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
940367e1c6 | ||
|
|
c80cb22cc0 | ||
|
|
d0e51dfc11 | ||
|
|
7e11842a5e | ||
|
|
1bf3e6feec | ||
|
|
c392fe44e4 | ||
|
|
fef24b4a4d | ||
|
|
bb1efe68aa | ||
|
|
bb147c2a7c | ||
|
|
4616bc5258 |
@@ -144,13 +144,19 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
|
|||||||
|
|
||||||
peerConfig := loginResp.GetPeerConfig()
|
peerConfig := loginResp.GetPeerConfig()
|
||||||
|
|
||||||
engineConfig, err := createEngineConfig(myPrivateKey, config, peerConfig, tunAdapter, iFaceDiscover)
|
engineConfig, err := createEngineConfig(myPrivateKey, config, peerConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
return wrapErr(err)
|
return wrapErr(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, engineConfig, statusRecorder)
|
md, err := newMobileDependency(tunAdapter, iFaceDiscover, mgmClient)
|
||||||
|
if err != nil {
|
||||||
|
log.Error(err)
|
||||||
|
return wrapErr(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, engineConfig, md, statusRecorder)
|
||||||
err = engine.Start()
|
err = engine.Start()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("error while starting Netbird Connection Engine: %s", err)
|
log.Errorf("error while starting Netbird Connection Engine: %s", err)
|
||||||
@@ -194,13 +200,10 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// createEngineConfig converts configuration received from Management Service to EngineConfig
|
// createEngineConfig converts configuration received from Management Service to EngineConfig
|
||||||
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover) (*EngineConfig, error) {
|
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) {
|
||||||
|
|
||||||
engineConf := &EngineConfig{
|
engineConf := &EngineConfig{
|
||||||
WgIfaceName: config.WgIface,
|
WgIfaceName: config.WgIface,
|
||||||
WgAddr: peerConfig.Address,
|
WgAddr: peerConfig.Address,
|
||||||
TunAdapter: tunAdapter,
|
|
||||||
IFaceDiscover: iFaceDiscover,
|
|
||||||
IFaceBlackList: config.IFaceBlackList,
|
IFaceBlackList: config.IFaceBlackList,
|
||||||
DisableIPv6Discovery: config.DisableIPv6Discovery,
|
DisableIPv6Discovery: config.DisableIPv6Discovery,
|
||||||
WgPrivateKey: key,
|
WgPrivateKey: key,
|
||||||
|
|||||||
@@ -206,7 +206,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), iface.DefaultMTU, nil, newNet)
|
wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), iface.DefaultMTU, nil, nil, newNet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/proxy"
|
"github.com/netbirdio/netbird/client/internal/proxy"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
|
||||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
@@ -47,10 +46,6 @@ var ErrResetConnection = fmt.Errorf("reset connection")
|
|||||||
type EngineConfig struct {
|
type EngineConfig struct {
|
||||||
WgPort int
|
WgPort int
|
||||||
WgIfaceName string
|
WgIfaceName string
|
||||||
// TunAdapter is option. It is necessary for mobile version.
|
|
||||||
TunAdapter iface.TunAdapter
|
|
||||||
|
|
||||||
IFaceDiscover stdnet.ExternalIFaceDiscover
|
|
||||||
|
|
||||||
// WgAddr is a Wireguard local address (Netbird Network IP)
|
// WgAddr is a Wireguard local address (Netbird Network IP)
|
||||||
WgAddr string
|
WgAddr string
|
||||||
@@ -91,6 +86,8 @@ type Engine struct {
|
|||||||
syncMsgMux *sync.Mutex
|
syncMsgMux *sync.Mutex
|
||||||
|
|
||||||
config *EngineConfig
|
config *EngineConfig
|
||||||
|
mobileDep MobileDependency
|
||||||
|
|
||||||
// STUNs is a list of STUN servers used by ICE
|
// STUNs is a list of STUN servers used by ICE
|
||||||
STUNs []*ice.URL
|
STUNs []*ice.URL
|
||||||
// TURNs is a list of STUN servers used by ICE
|
// TURNs is a list of STUN servers used by ICE
|
||||||
@@ -130,7 +127,7 @@ type Peer struct {
|
|||||||
func NewEngine(
|
func NewEngine(
|
||||||
ctx context.Context, cancel context.CancelFunc,
|
ctx context.Context, cancel context.CancelFunc,
|
||||||
signalClient signal.Client, mgmClient mgm.Client,
|
signalClient signal.Client, mgmClient mgm.Client,
|
||||||
config *EngineConfig, statusRecorder *peer.Status,
|
config *EngineConfig, mobileDep MobileDependency, statusRecorder *peer.Status,
|
||||||
) *Engine {
|
) *Engine {
|
||||||
return &Engine{
|
return &Engine{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
@@ -140,6 +137,7 @@ func NewEngine(
|
|||||||
peerConns: make(map[string]*peer.Conn),
|
peerConns: make(map[string]*peer.Conn),
|
||||||
syncMsgMux: &sync.Mutex{},
|
syncMsgMux: &sync.Mutex{},
|
||||||
config: config,
|
config: config,
|
||||||
|
mobileDep: mobileDep,
|
||||||
STUNs: []*ice.URL{},
|
STUNs: []*ice.URL{},
|
||||||
TURNs: []*ice.URL{},
|
TURNs: []*ice.URL{},
|
||||||
networkSerial: 0,
|
networkSerial: 0,
|
||||||
@@ -181,7 +179,7 @@ func (e *Engine) Start() error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to create pion's stdnet: %s", err)
|
log.Errorf("failed to create pion's stdnet: %s", err)
|
||||||
}
|
}
|
||||||
e.wgInterface, err = iface.NewWGIFace(wgIFaceName, wgAddr, iface.DefaultMTU, e.config.TunAdapter, transportNet)
|
e.wgInterface, err = iface.NewWGIFace(wgIFaceName, wgAddr, iface.DefaultMTU, e.mobileDep.Routes, e.mobileDep.TunAdapter, transportNet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed creating wireguard interface instance %s: [%s]", wgIFaceName, err.Error())
|
log.Errorf("failed creating wireguard interface instance %s: [%s]", wgIFaceName, err.Error())
|
||||||
return err
|
return err
|
||||||
@@ -834,7 +832,7 @@ func (e Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, er
|
|||||||
UserspaceBind: e.wgInterface.IsUserspaceBind(),
|
UserspaceBind: e.wgInterface.IsUserspaceBind(),
|
||||||
}
|
}
|
||||||
|
|
||||||
peerConn, err := peer.NewConn(config, e.statusRecorder, e.config.TunAdapter, e.config.IFaceDiscover)
|
peerConn, err := peer.NewConn(config, e.statusRecorder, e.mobileDep.TunAdapter, e.mobileDep.IFaceDiscover)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,5 +3,5 @@ package internal
|
|||||||
import "github.com/netbirdio/netbird/client/internal/stdnet"
|
import "github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
|
|
||||||
func (e *Engine) newStdNet() (*stdnet.Net, error) {
|
func (e *Engine) newStdNet() (*stdnet.Net, error) {
|
||||||
return stdnet.NewNetWithDiscover(e.config.IFaceDiscover, e.config.IFaceBlackList)
|
return stdnet.NewNetWithDiscover(e.mobileDep.IFaceDiscover, e.config.IFaceBlackList)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -74,7 +74,7 @@ func TestEngine_SSH(t *testing.T) {
|
|||||||
WgAddr: "100.64.0.1/24",
|
WgAddr: "100.64.0.1/24",
|
||||||
WgPrivateKey: key,
|
WgPrivateKey: key,
|
||||||
WgPort: 33100,
|
WgPort: 33100,
|
||||||
}, peer.NewRecorder("https://mgm"))
|
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
|
||||||
|
|
||||||
engine.dnsServer = &dns.MockServer{
|
engine.dnsServer = &dns.MockServer{
|
||||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||||
@@ -208,12 +208,12 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
|||||||
WgAddr: "100.64.0.1/24",
|
WgAddr: "100.64.0.1/24",
|
||||||
WgPrivateKey: key,
|
WgPrivateKey: key,
|
||||||
WgPort: 33100,
|
WgPort: 33100,
|
||||||
}, peer.NewRecorder("https://mgm"))
|
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
|
||||||
newNet, err := stdnet.NewNet()
|
newNet, err := stdnet.NewNet()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", iface.DefaultMTU, nil, newNet)
|
engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", iface.DefaultMTU, nil, nil, newNet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -404,7 +404,7 @@ func TestEngine_Sync(t *testing.T) {
|
|||||||
WgAddr: "100.64.0.1/24",
|
WgAddr: "100.64.0.1/24",
|
||||||
WgPrivateKey: key,
|
WgPrivateKey: key,
|
||||||
WgPort: 33100,
|
WgPort: 33100,
|
||||||
}, peer.NewRecorder("https://mgm"))
|
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
|
||||||
|
|
||||||
engine.dnsServer = &dns.MockServer{
|
engine.dnsServer = &dns.MockServer{
|
||||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||||
@@ -562,12 +562,12 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
|||||||
WgAddr: wgAddr,
|
WgAddr: wgAddr,
|
||||||
WgPrivateKey: key,
|
WgPrivateKey: key,
|
||||||
WgPort: 33100,
|
WgPort: 33100,
|
||||||
}, peer.NewRecorder("https://mgm"))
|
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
|
||||||
newNet, err := stdnet.NewNet()
|
newNet, err := stdnet.NewNet()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, nil, newNet)
|
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, nil, nil, newNet)
|
||||||
assert.NoError(t, err, "shouldn't return error")
|
assert.NoError(t, err, "shouldn't return error")
|
||||||
input := struct {
|
input := struct {
|
||||||
inputSerial uint64
|
inputSerial uint64
|
||||||
@@ -731,12 +731,12 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
|||||||
WgAddr: wgAddr,
|
WgAddr: wgAddr,
|
||||||
WgPrivateKey: key,
|
WgPrivateKey: key,
|
||||||
WgPort: 33100,
|
WgPort: 33100,
|
||||||
}, peer.NewRecorder("https://mgm"))
|
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
|
||||||
newNet, err := stdnet.NewNet()
|
newNet, err := stdnet.NewNet()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, nil, newNet)
|
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, nil, nil, newNet)
|
||||||
assert.NoError(t, err, "shouldn't return error")
|
assert.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
mockRouteManager := &routemanager.MockManager{
|
mockRouteManager := &routemanager.MockManager{
|
||||||
@@ -1000,7 +1000,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
|
|||||||
WgPort: wgPort,
|
WgPort: wgPort,
|
||||||
}
|
}
|
||||||
|
|
||||||
return NewEngine(ctx, cancel, signalClient, mgmtClient, conf, peer.NewRecorder("https://mgm")), nil
|
return NewEngine(ctx, cancel, signalClient, mgmtClient, conf, MobileDependency{}, peer.NewRecorder("https://mgm")), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func startSignal() (*grpc.Server, string, error) {
|
func startSignal() (*grpc.Server, string, error) {
|
||||||
|
|||||||
13
client/internal/mobile_dependency.go
Normal file
13
client/internal/mobile_dependency.go
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MobileDependency collect all dependencies for mobile platform
|
||||||
|
type MobileDependency struct {
|
||||||
|
TunAdapter iface.TunAdapter
|
||||||
|
IFaceDiscover stdnet.ExternalIFaceDiscover
|
||||||
|
Routes []string
|
||||||
|
}
|
||||||
29
client/internal/mobile_dependency_android.go
Normal file
29
client/internal/mobile_dependency_android.go
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
|
mgm "github.com/netbirdio/netbird/management/client"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newMobileDependency(tunAdapter iface.TunAdapter, ifaceDiscover stdnet.ExternalIFaceDiscover, mgmClient *mgm.GrpcClient) (MobileDependency, error) {
|
||||||
|
md := MobileDependency{
|
||||||
|
TunAdapter: tunAdapter,
|
||||||
|
IFaceDiscover: ifaceDiscover,
|
||||||
|
}
|
||||||
|
err := md.readMap(mgmClient)
|
||||||
|
return md, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *MobileDependency) readMap(mgmClient *mgm.GrpcClient) error {
|
||||||
|
routes, err := mgmClient.GetRoutes()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
d.Routes = make([]string, len(routes))
|
||||||
|
for i, r := range routes {
|
||||||
|
d.Routes[i] = r.GetNetwork()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
13
client/internal/mobile_dependency_nonandroid.go
Normal file
13
client/internal/mobile_dependency_nonandroid.go
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
|
mgm "github.com/netbirdio/netbird/management/client"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newMobileDependency(tunAdapter iface.TunAdapter, ifaceDiscover stdnet.ExternalIFaceDiscover, mgmClient *mgm.GrpcClient) (MobileDependency, error) {
|
||||||
|
return MobileDependency{}, nil
|
||||||
|
}
|
||||||
@@ -1,12 +1,15 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
package routemanager
|
package routemanager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/coreos/go-iptables/iptables"
|
"github.com/coreos/go-iptables/iptables"
|
||||||
|
"github.com/google/nftables"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
import "github.com/google/nftables"
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
ipv6Forwarding = "netbird-rt-ipv6-forwarding"
|
ipv6Forwarding = "netbird-rt-ipv6-forwarding"
|
||||||
|
|||||||
@@ -1,14 +1,17 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
package routemanager
|
package routemanager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/coreos/go-iptables/iptables"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"github.com/coreos/go-iptables/iptables"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
func isIptablesSupported() bool {
|
func isIptablesSupported() bool {
|
||||||
|
|||||||
@@ -1,10 +1,13 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
package routemanager
|
package routemanager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
"github.com/coreos/go-iptables/iptables"
|
"github.com/coreos/go-iptables/iptables"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"testing"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||||
|
|||||||
@@ -1,9 +1,130 @@
|
|||||||
package routemanager
|
package routemanager
|
||||||
|
|
||||||
import "github.com/netbirdio/netbird/route"
|
import (
|
||||||
|
"context"
|
||||||
|
"runtime"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
|
"github.com/netbirdio/netbird/version"
|
||||||
|
)
|
||||||
|
|
||||||
// Manager is a route manager interface
|
// Manager is a route manager interface
|
||||||
type Manager interface {
|
type Manager interface {
|
||||||
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error
|
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error
|
||||||
Stop()
|
Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DefaultManager is the default instance of a route manager
|
||||||
|
type DefaultManager struct {
|
||||||
|
ctx context.Context
|
||||||
|
stop context.CancelFunc
|
||||||
|
mux sync.Mutex
|
||||||
|
clientNetworks map[string]*clientNetwork
|
||||||
|
serverRouter *serverRouter
|
||||||
|
statusRecorder *peer.Status
|
||||||
|
wgInterface *iface.WGIface
|
||||||
|
pubKey string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewManager returns a new route manager
|
||||||
|
func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status) *DefaultManager {
|
||||||
|
mCTX, cancel := context.WithCancel(ctx)
|
||||||
|
return &DefaultManager{
|
||||||
|
ctx: mCTX,
|
||||||
|
stop: cancel,
|
||||||
|
clientNetworks: make(map[string]*clientNetwork),
|
||||||
|
serverRouter: newServerRouter(ctx, wgInterface),
|
||||||
|
statusRecorder: statusRecorder,
|
||||||
|
wgInterface: wgInterface,
|
||||||
|
pubKey: pubKey,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop stops the manager watchers and clean firewall rules
|
||||||
|
func (m *DefaultManager) Stop() {
|
||||||
|
m.stop()
|
||||||
|
m.serverRouter.cleanUp()
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateRoutes compares received routes with existing routes and remove, update or add them to the client and server maps
|
||||||
|
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
|
||||||
|
select {
|
||||||
|
case <-m.ctx.Done():
|
||||||
|
log.Infof("not updating routes as context is closed")
|
||||||
|
return m.ctx.Err()
|
||||||
|
default:
|
||||||
|
m.mux.Lock()
|
||||||
|
defer m.mux.Unlock()
|
||||||
|
|
||||||
|
newClientRoutesIDMap := make(map[string][]*route.Route)
|
||||||
|
newServerRoutesMap := make(map[string]*route.Route)
|
||||||
|
ownNetworkIDs := make(map[string]bool)
|
||||||
|
|
||||||
|
for _, newRoute := range newRoutes {
|
||||||
|
networkID := route.GetHAUniqueID(newRoute)
|
||||||
|
if newRoute.Peer == m.pubKey {
|
||||||
|
ownNetworkIDs[networkID] = true
|
||||||
|
// only linux is supported for now
|
||||||
|
if runtime.GOOS != "linux" {
|
||||||
|
log.Warnf("received a route to manage, but agent doesn't support router mode on %s OS", runtime.GOOS)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
newServerRoutesMap[newRoute.ID] = newRoute
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, newRoute := range newRoutes {
|
||||||
|
networkID := route.GetHAUniqueID(newRoute)
|
||||||
|
if !ownNetworkIDs[networkID] {
|
||||||
|
// if prefix is too small, lets assume is a possible default route which is not yet supported
|
||||||
|
// we skip this route management
|
||||||
|
if newRoute.Network.Bits() < 7 {
|
||||||
|
log.Errorf("this agent version: %s, doesn't support default routes, received %s, skiping this route",
|
||||||
|
version.NetbirdVersion(), newRoute.Network)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
newClientRoutesIDMap[networkID] = append(newClientRoutesIDMap[networkID], newRoute)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.updateClientNetworks(updateSerial, newClientRoutesIDMap)
|
||||||
|
|
||||||
|
err := m.serverRouter.updateRoutes(newServerRoutesMap)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) {
|
||||||
|
// removing routes that do not exist as per the update from the Management service.
|
||||||
|
for id, client := range m.clientNetworks {
|
||||||
|
_, found := networks[id]
|
||||||
|
if !found {
|
||||||
|
log.Debugf("stopping client network watcher, %s", id)
|
||||||
|
client.stop()
|
||||||
|
delete(m.clientNetworks, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for id, routes := range networks {
|
||||||
|
clientNetworkWatcher, found := m.clientNetworks[id]
|
||||||
|
if !found {
|
||||||
|
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.wgInterface, m.statusRecorder, routes[0].Network)
|
||||||
|
m.clientNetworks[id] = clientNetworkWatcher
|
||||||
|
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
||||||
|
}
|
||||||
|
update := routesUpdate{
|
||||||
|
updateSerial: updateSerial,
|
||||||
|
routes: routes,
|
||||||
|
}
|
||||||
|
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(update)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,31 +0,0 @@
|
|||||||
package routemanager
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
"github.com/netbirdio/netbird/route"
|
|
||||||
)
|
|
||||||
|
|
||||||
// DefaultManager dummy router manager for Android
|
|
||||||
type DefaultManager struct {
|
|
||||||
ctx context.Context
|
|
||||||
serverRouter *serverRouter
|
|
||||||
wgInterface *iface.WGIface
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewManager returns a new dummy route manager what doing nothing
|
|
||||||
func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status) *DefaultManager {
|
|
||||||
return &DefaultManager{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateRoutes ...
|
|
||||||
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stop ...
|
|
||||||
func (m *DefaultManager) Stop() {
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -1,186 +0,0 @@
|
|||||||
//go:build !android
|
|
||||||
|
|
||||||
package routemanager
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"runtime"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
"github.com/netbirdio/netbird/route"
|
|
||||||
"github.com/netbirdio/netbird/version"
|
|
||||||
)
|
|
||||||
|
|
||||||
// DefaultManager is the default instance of a route manager
|
|
||||||
type DefaultManager struct {
|
|
||||||
ctx context.Context
|
|
||||||
stop context.CancelFunc
|
|
||||||
mux sync.Mutex
|
|
||||||
clientNetworks map[string]*clientNetwork
|
|
||||||
serverRoutes map[string]*route.Route
|
|
||||||
serverRouter *serverRouter
|
|
||||||
statusRecorder *peer.Status
|
|
||||||
wgInterface *iface.WGIface
|
|
||||||
pubKey string
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewManager returns a new route manager
|
|
||||||
func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status) *DefaultManager {
|
|
||||||
mCTX, cancel := context.WithCancel(ctx)
|
|
||||||
return &DefaultManager{
|
|
||||||
ctx: mCTX,
|
|
||||||
stop: cancel,
|
|
||||||
clientNetworks: make(map[string]*clientNetwork),
|
|
||||||
serverRoutes: make(map[string]*route.Route),
|
|
||||||
serverRouter: &serverRouter{
|
|
||||||
routes: make(map[string]*route.Route),
|
|
||||||
netForwardHistoryEnabled: isNetForwardHistoryEnabled(),
|
|
||||||
firewall: NewFirewall(ctx),
|
|
||||||
},
|
|
||||||
statusRecorder: statusRecorder,
|
|
||||||
wgInterface: wgInterface,
|
|
||||||
pubKey: pubKey,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stop stops the manager watchers and clean firewall rules
|
|
||||||
func (m *DefaultManager) Stop() {
|
|
||||||
m.stop()
|
|
||||||
m.serverRouter.firewall.CleanRoutingRules()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) {
|
|
||||||
// removing routes that do not exist as per the update from the Management service.
|
|
||||||
for id, client := range m.clientNetworks {
|
|
||||||
_, found := networks[id]
|
|
||||||
if !found {
|
|
||||||
log.Debugf("stopping client network watcher, %s", id)
|
|
||||||
client.stop()
|
|
||||||
delete(m.clientNetworks, id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for id, routes := range networks {
|
|
||||||
clientNetworkWatcher, found := m.clientNetworks[id]
|
|
||||||
if !found {
|
|
||||||
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.wgInterface, m.statusRecorder, routes[0].Network)
|
|
||||||
m.clientNetworks[id] = clientNetworkWatcher
|
|
||||||
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
|
||||||
}
|
|
||||||
update := routesUpdate{
|
|
||||||
updateSerial: updateSerial,
|
|
||||||
routes: routes,
|
|
||||||
}
|
|
||||||
|
|
||||||
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(update)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *DefaultManager) updateServerRoutes(routesMap map[string]*route.Route) error {
|
|
||||||
serverRoutesToRemove := make([]string, 0)
|
|
||||||
|
|
||||||
if len(routesMap) > 0 {
|
|
||||||
err := m.serverRouter.firewall.RestoreOrCreateContainers()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("couldn't initialize firewall containers, got err: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for routeID := range m.serverRoutes {
|
|
||||||
update, found := routesMap[routeID]
|
|
||||||
if !found || !update.IsEqual(m.serverRoutes[routeID]) {
|
|
||||||
serverRoutesToRemove = append(serverRoutesToRemove, routeID)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, routeID := range serverRoutesToRemove {
|
|
||||||
oldRoute := m.serverRoutes[routeID]
|
|
||||||
err := m.removeFromServerNetwork(oldRoute)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("unable to remove route id: %s, network %s, from server, got: %v",
|
|
||||||
oldRoute.ID, oldRoute.Network, err)
|
|
||||||
}
|
|
||||||
delete(m.serverRoutes, routeID)
|
|
||||||
}
|
|
||||||
|
|
||||||
for id, newRoute := range routesMap {
|
|
||||||
_, found := m.serverRoutes[id]
|
|
||||||
if found {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
err := m.addToServerNetwork(newRoute)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("unable to add route %s from server, got: %v", newRoute.ID, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
m.serverRoutes[id] = newRoute
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(m.serverRoutes) > 0 {
|
|
||||||
err := enableIPForwarding()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateRoutes compares received routes with existing routes and remove, update or add them to the client and server maps
|
|
||||||
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
|
|
||||||
select {
|
|
||||||
case <-m.ctx.Done():
|
|
||||||
log.Infof("not updating routes as context is closed")
|
|
||||||
return m.ctx.Err()
|
|
||||||
default:
|
|
||||||
m.mux.Lock()
|
|
||||||
defer m.mux.Unlock()
|
|
||||||
|
|
||||||
newClientRoutesIDMap := make(map[string][]*route.Route)
|
|
||||||
newServerRoutesMap := make(map[string]*route.Route)
|
|
||||||
ownNetworkIDs := make(map[string]bool)
|
|
||||||
|
|
||||||
for _, newRoute := range newRoutes {
|
|
||||||
networkID := route.GetHAUniqueID(newRoute)
|
|
||||||
if newRoute.Peer == m.pubKey {
|
|
||||||
ownNetworkIDs[networkID] = true
|
|
||||||
// only linux is supported for now
|
|
||||||
if runtime.GOOS != "linux" {
|
|
||||||
log.Warnf("received a route to manage, but agent doesn't support router mode on %s OS", runtime.GOOS)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
newServerRoutesMap[newRoute.ID] = newRoute
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, newRoute := range newRoutes {
|
|
||||||
networkID := route.GetHAUniqueID(newRoute)
|
|
||||||
if !ownNetworkIDs[networkID] {
|
|
||||||
// if prefix is too small, lets assume is a possible default route which is not yet supported
|
|
||||||
// we skip this route management
|
|
||||||
if newRoute.Network.Bits() < 7 {
|
|
||||||
log.Errorf("this agent version: %s, doesn't support default routes, received %s, skiping this route",
|
|
||||||
version.NetbirdVersion(), newRoute.Network)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
newClientRoutesIDMap[networkID] = append(newClientRoutesIDMap[networkID], newRoute)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
m.updateClientNetworks(updateSerial, newClientRoutesIDMap)
|
|
||||||
|
|
||||||
err := m.updateServerRoutes(newServerRoutesMap)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -392,11 +392,12 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
|||||||
|
|
||||||
for n, testCase := range testCases {
|
for n, testCase := range testCases {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
|
||||||
newNet, err := stdnet.NewNet()
|
newNet, err := stdnet.NewNet()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", iface.DefaultMTU, nil, newNet)
|
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", iface.DefaultMTU, nil, nil, newNet)
|
||||||
require.NoError(t, err, "should create testing WGIface interface")
|
require.NoError(t, err, "should create testing WGIface interface")
|
||||||
defer wgInterface.Close()
|
defer wgInterface.Close()
|
||||||
|
|
||||||
@@ -419,7 +420,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
|||||||
require.Len(t, routeManager.clientNetworks, testCase.clientNetworkWatchersExpected, "client networks size should match")
|
require.Len(t, routeManager.clientNetworks, testCase.clientNetworkWatchersExpected, "client networks size should match")
|
||||||
|
|
||||||
if testCase.shouldCheckServerRoutes {
|
if testCase.shouldCheckServerRoutes {
|
||||||
require.Len(t, routeManager.serverRoutes, testCase.serverRoutesExpected, "server networks size should match")
|
require.Len(t, routeManager.serverRouter.routes, testCase.serverRoutesExpected, "server networks size should match")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,16 +1,19 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
package routemanager
|
package routemanager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/google/nftables/binaryutil"
|
|
||||||
"github.com/google/nftables/expr"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"github.com/google/nftables"
|
||||||
|
"github.com/google/nftables/binaryutil"
|
||||||
|
"github.com/google/nftables/expr"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
import "github.com/google/nftables"
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
nftablesTable = "netbird-rt"
|
nftablesTable = "netbird-rt"
|
||||||
|
|||||||
@@ -1,12 +1,15 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
package routemanager
|
package routemanager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
"github.com/google/nftables"
|
"github.com/google/nftables"
|
||||||
"github.com/google/nftables/expr"
|
"github.com/google/nftables/expr"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"testing"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNftablesManager_RestoreOrCreateContainers(t *testing.T) {
|
func TestNftablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||||
|
|||||||
24
client/internal/routemanager/router_pair.go
Normal file
24
client/internal/routemanager/router_pair.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
package routemanager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
|
)
|
||||||
|
|
||||||
|
type routerPair struct {
|
||||||
|
ID string
|
||||||
|
source string
|
||||||
|
destination string
|
||||||
|
masquerade bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func routeToRouterPair(source string, route *route.Route) routerPair {
|
||||||
|
parsed := netip.MustParsePrefix(source).Masked()
|
||||||
|
return routerPair{
|
||||||
|
ID: route.ID,
|
||||||
|
source: parsed.String(),
|
||||||
|
destination: route.Network.Masked().String(),
|
||||||
|
masquerade: route.Masquerade,
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,67 +0,0 @@
|
|||||||
package routemanager
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/netbirdio/netbird/route"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"net/netip"
|
|
||||||
"sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
type serverRouter struct {
|
|
||||||
routes map[string]*route.Route
|
|
||||||
// best effort to keep net forward configuration as it was
|
|
||||||
netForwardHistoryEnabled bool
|
|
||||||
mux sync.Mutex
|
|
||||||
firewall firewallManager
|
|
||||||
}
|
|
||||||
|
|
||||||
type routerPair struct {
|
|
||||||
ID string
|
|
||||||
source string
|
|
||||||
destination string
|
|
||||||
masquerade bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func routeToRouterPair(source string, route *route.Route) routerPair {
|
|
||||||
parsed := netip.MustParsePrefix(source).Masked()
|
|
||||||
return routerPair{
|
|
||||||
ID: route.ID,
|
|
||||||
source: parsed.String(),
|
|
||||||
destination: route.Network.Masked().String(),
|
|
||||||
masquerade: route.Masquerade,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *DefaultManager) removeFromServerNetwork(route *route.Route) error {
|
|
||||||
select {
|
|
||||||
case <-m.ctx.Done():
|
|
||||||
log.Infof("not removing from server network because context is done")
|
|
||||||
return m.ctx.Err()
|
|
||||||
default:
|
|
||||||
m.serverRouter.mux.Lock()
|
|
||||||
defer m.serverRouter.mux.Unlock()
|
|
||||||
err := m.serverRouter.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
delete(m.serverRouter.routes, route.ID)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *DefaultManager) addToServerNetwork(route *route.Route) error {
|
|
||||||
select {
|
|
||||||
case <-m.ctx.Done():
|
|
||||||
log.Infof("not adding to server network because context is done")
|
|
||||||
return m.ctx.Err()
|
|
||||||
default:
|
|
||||||
m.serverRouter.mux.Lock()
|
|
||||||
defer m.serverRouter.mux.Unlock()
|
|
||||||
err := m.serverRouter.firewall.InsertRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
m.serverRouter.routes[route.ID] = route
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
21
client/internal/routemanager/server_android.go
Normal file
21
client/internal/routemanager/server_android.go
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
package routemanager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
|
)
|
||||||
|
|
||||||
|
type serverRouter struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func newServerRouter(ctx context.Context, wgInterface *iface.WGIface) *serverRouter {
|
||||||
|
return &serverRouter{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *serverRouter) updateRoutes(routesMap map[string]*route.Route) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *serverRouter) cleanUp() {}
|
||||||
120
client/internal/routemanager/server_nonandroid.go
Normal file
120
client/internal/routemanager/server_nonandroid.go
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
|
package routemanager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
|
)
|
||||||
|
|
||||||
|
type serverRouter struct {
|
||||||
|
mux sync.Mutex
|
||||||
|
ctx context.Context
|
||||||
|
routes map[string]*route.Route
|
||||||
|
firewall firewallManager
|
||||||
|
wgInterface *iface.WGIface
|
||||||
|
}
|
||||||
|
|
||||||
|
func newServerRouter(ctx context.Context, wgInterface *iface.WGIface) *serverRouter {
|
||||||
|
return &serverRouter{
|
||||||
|
ctx: ctx,
|
||||||
|
routes: make(map[string]*route.Route),
|
||||||
|
firewall: NewFirewall(ctx),
|
||||||
|
wgInterface: wgInterface,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *serverRouter) updateRoutes(routesMap map[string]*route.Route) error {
|
||||||
|
serverRoutesToRemove := make([]string, 0)
|
||||||
|
|
||||||
|
if len(routesMap) > 0 {
|
||||||
|
err := m.firewall.RestoreOrCreateContainers()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("couldn't initialize firewall containers, got err: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for routeID := range m.routes {
|
||||||
|
update, found := routesMap[routeID]
|
||||||
|
if !found || !update.IsEqual(m.routes[routeID]) {
|
||||||
|
serverRoutesToRemove = append(serverRoutesToRemove, routeID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, routeID := range serverRoutesToRemove {
|
||||||
|
oldRoute := m.routes[routeID]
|
||||||
|
err := m.removeFromServerNetwork(oldRoute)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("unable to remove route id: %s, network %s, from server, got: %v",
|
||||||
|
oldRoute.ID, oldRoute.Network, err)
|
||||||
|
}
|
||||||
|
delete(m.routes, routeID)
|
||||||
|
}
|
||||||
|
|
||||||
|
for id, newRoute := range routesMap {
|
||||||
|
_, found := m.routes[id]
|
||||||
|
if found {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
err := m.addToServerNetwork(newRoute)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("unable to add route %s from server, got: %v", newRoute.ID, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
m.routes[id] = newRoute
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(m.routes) > 0 {
|
||||||
|
err := enableIPForwarding()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *serverRouter) removeFromServerNetwork(route *route.Route) error {
|
||||||
|
select {
|
||||||
|
case <-m.ctx.Done():
|
||||||
|
log.Infof("not removing from server network because context is done")
|
||||||
|
return m.ctx.Err()
|
||||||
|
default:
|
||||||
|
m.mux.Lock()
|
||||||
|
defer m.mux.Unlock()
|
||||||
|
err := m.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
delete(m.routes, route.ID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *serverRouter) addToServerNetwork(route *route.Route) error {
|
||||||
|
select {
|
||||||
|
case <-m.ctx.Done():
|
||||||
|
log.Infof("not adding to server network because context is done")
|
||||||
|
return m.ctx.Err()
|
||||||
|
default:
|
||||||
|
m.mux.Lock()
|
||||||
|
defer m.mux.Unlock()
|
||||||
|
err := m.firewall.InsertRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
m.routes[route.ID] = route
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *serverRouter) cleanUp() {
|
||||||
|
m.firewall.CleanRoutingRules()
|
||||||
|
}
|
||||||
13
client/internal/routemanager/systemops_android.go
Normal file
13
client/internal/routemanager/systemops_android.go
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
package routemanager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
)
|
||||||
|
|
||||||
|
func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -1,10 +1,13 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
package routemanager
|
package routemanager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/vishvananda/netlink"
|
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
|
"github.com/vishvananda/netlink"
|
||||||
)
|
)
|
||||||
|
|
||||||
const ipv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward"
|
const ipv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward"
|
||||||
@@ -62,12 +65,3 @@ func enableIPForwarding() error {
|
|||||||
err := os.WriteFile(ipv4ForwardingPath, []byte("1"), 0644)
|
err := os.WriteFile(ipv4ForwardingPath, []byte("1"), 0644)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func isNetForwardHistoryEnabled() bool {
|
|
||||||
out, err := os.ReadFile(ipv4ForwardingPath)
|
|
||||||
if err != nil {
|
|
||||||
// todo
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
return string(out) == "1"
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,11 +1,14 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
package routemanager
|
package routemanager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/libp2p/go-netroute"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/libp2p/go-netroute"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
var errRouteNotFound = fmt.Errorf("route not found")
|
var errRouteNotFound = fmt.Errorf("route not found")
|
||||||
@@ -37,7 +37,7 @@ func TestAddRemoveRoutes(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", iface.DefaultMTU, nil, newNet)
|
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", iface.DefaultMTU, nil, nil, newNet)
|
||||||
require.NoError(t, err, "should create testing WGIface interface")
|
require.NoError(t, err, "should create testing WGIface interface")
|
||||||
defer wgInterface.Close()
|
defer wgInterface.Close()
|
||||||
|
|
||||||
@@ -4,10 +4,11 @@
|
|||||||
package routemanager
|
package routemanager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
func addToRouteTable(prefix netip.Prefix, addr string) error {
|
func addToRouteTable(prefix netip.Prefix, addr string) error {
|
||||||
@@ -34,8 +35,3 @@ func enableIPForwarding() error {
|
|||||||
log.Infof("enable IP forwarding is not implemented on %s", runtime.GOOS)
|
log.Infof("enable IP forwarding is not implemented on %s", runtime.GOOS)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func isNetForwardHistoryEnabled() bool {
|
|
||||||
log.Infof("check netforward history is not implemented on %s", runtime.GOOS)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|||||||
2
go.mod
2
go.mod
@@ -19,7 +19,7 @@ require (
|
|||||||
github.com/vishvananda/netlink v1.1.0
|
github.com/vishvananda/netlink v1.1.0
|
||||||
golang.org/x/crypto v0.7.0
|
golang.org/x/crypto v0.7.0
|
||||||
golang.org/x/sys v0.6.0
|
golang.org/x/sys v0.6.0
|
||||||
golang.zx2c4.com/wireguard v0.0.0-20230223181233-21636207a675
|
golang.zx2c4.com/wireguard v0.0.0-20230310135217-9e2f38602202
|
||||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211215182854-7a385b3431de
|
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211215182854-7a385b3431de
|
||||||
golang.zx2c4.com/wireguard/windows v0.5.1
|
golang.zx2c4.com/wireguard/windows v0.5.1
|
||||||
google.golang.org/grpc v1.52.3
|
google.golang.org/grpc v1.52.3
|
||||||
|
|||||||
2
go.sum
2
go.sum
@@ -887,6 +887,8 @@ golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224/go.mod h1:deeaetjYA+D
|
|||||||
golang.zx2c4.com/wireguard v0.0.0-20211129173154-2dd424e2d808/go.mod h1:TjUWrnD5ATh7bFvmm/ALEJZQ4ivKbETb6pmyj1vUoNI=
|
golang.zx2c4.com/wireguard v0.0.0-20211129173154-2dd424e2d808/go.mod h1:TjUWrnD5ATh7bFvmm/ALEJZQ4ivKbETb6pmyj1vUoNI=
|
||||||
golang.zx2c4.com/wireguard v0.0.0-20230223181233-21636207a675 h1:/J/RVnr7ng4fWPRH3xa4WtBJ1Jp+Auu4YNLmGiPv5QU=
|
golang.zx2c4.com/wireguard v0.0.0-20230223181233-21636207a675 h1:/J/RVnr7ng4fWPRH3xa4WtBJ1Jp+Auu4YNLmGiPv5QU=
|
||||||
golang.zx2c4.com/wireguard v0.0.0-20230223181233-21636207a675/go.mod h1:whfbyDBt09xhCYQWtO2+3UVjlaq6/9hDZrjg2ZE6SyA=
|
golang.zx2c4.com/wireguard v0.0.0-20230223181233-21636207a675/go.mod h1:whfbyDBt09xhCYQWtO2+3UVjlaq6/9hDZrjg2ZE6SyA=
|
||||||
|
golang.zx2c4.com/wireguard v0.0.0-20230310135217-9e2f38602202 h1:KcD4X7IcoRdQpr9NSQpQpn5S4rUMIQaCdF90FOEj6KY=
|
||||||
|
golang.zx2c4.com/wireguard v0.0.0-20230310135217-9e2f38602202/go.mod h1:qc3aHNhM1Rc4hW2az896MjLVcxHvLbJ6LZc9MI7RTMY=
|
||||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211215182854-7a385b3431de h1:qDZ+lyO5jC9RNJ7ANJA0GWXk3pSn0Fu5SlcAIlgw+6w=
|
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211215182854-7a385b3431de h1:qDZ+lyO5jC9RNJ7ANJA0GWXk3pSn0Fu5SlcAIlgw+6w=
|
||||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211215182854-7a385b3431de/go.mod h1:Q2XNgour4QSkFj0BWCkVlW0HWJwQgNMsMahpSlI0Eno=
|
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211215182854-7a385b3431de/go.mod h1:Q2XNgour4QSkFj0BWCkVlW0HWJwQgNMsMahpSlI0Eno=
|
||||||
golang.zx2c4.com/wireguard/windows v0.5.1 h1:OnYw96PF+CsIMrqWo5QP3Q59q5hY1rFErk/yN3cS+JQ=
|
golang.zx2c4.com/wireguard/windows v0.5.1 h1:OnYw96PF+CsIMrqWo5QP3Q59q5hY1rFErk/yN3cS+JQ=
|
||||||
|
|||||||
@@ -1,95 +1,416 @@
|
|||||||
package bind
|
package bind
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
"github.com/pion/stun"
|
"github.com/pion/stun"
|
||||||
"github.com/pion/transport/v2"
|
"github.com/pion/transport/v2"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.zx2c4.com/wireguard/conn"
|
"golang.org/x/net/ipv4"
|
||||||
|
"golang.org/x/net/ipv6"
|
||||||
|
wgConn "golang.zx2c4.com/wireguard/conn"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ICEBind is the userspace implementation of WireGuard's conn.Bind interface using ice.UDPMux of the pion/ice library
|
var (
|
||||||
|
_ wgConn.Bind = (*ICEBind)(nil)
|
||||||
|
)
|
||||||
|
|
||||||
|
// ICEBind implements Bind for all platforms except Windows.
|
||||||
type ICEBind struct {
|
type ICEBind struct {
|
||||||
// below fields, initialized on open
|
mu sync.Mutex // protects following fields
|
||||||
ipv4 net.PacketConn
|
ipv4 *net.UDPConn
|
||||||
udpMux *UniversalUDPMuxDefault
|
ipv6 *net.UDPConn
|
||||||
|
blackhole4 bool
|
||||||
|
blackhole6 bool
|
||||||
|
ipv4PC *ipv4.PacketConn
|
||||||
|
ipv6PC *ipv6.PacketConn
|
||||||
|
batchSize int
|
||||||
|
udpAddrPool sync.Pool
|
||||||
|
ipv4MsgsPool sync.Pool
|
||||||
|
ipv6MsgsPool sync.Pool
|
||||||
|
|
||||||
// below are fields initialized on creation
|
// NetBird related variables
|
||||||
transportNet transport.Net
|
transportNet transport.Net
|
||||||
mu sync.Mutex
|
udpMux *UniversalUDPMuxDefault
|
||||||
|
worker *worker
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewICEBind create a new instance of ICEBind with a given transportNet function.
|
|
||||||
// The transportNet can be nil.
|
|
||||||
func NewICEBind(transportNet transport.Net) *ICEBind {
|
func NewICEBind(transportNet transport.Net) *ICEBind {
|
||||||
return &ICEBind{
|
b := &ICEBind{
|
||||||
|
batchSize: wgConn.DefaultBatchSize,
|
||||||
|
|
||||||
|
udpAddrPool: sync.Pool{
|
||||||
|
New: func() any {
|
||||||
|
return &net.UDPAddr{
|
||||||
|
IP: make([]byte, 16),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
|
ipv4MsgsPool: sync.Pool{
|
||||||
|
New: func() any {
|
||||||
|
msgs := make([]ipv4.Message, wgConn.DefaultBatchSize)
|
||||||
|
for i := range msgs {
|
||||||
|
msgs[i].Buffers = make(net.Buffers, 1)
|
||||||
|
msgs[i].OOB = make([]byte, srcControlSize)
|
||||||
|
}
|
||||||
|
return &msgs
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
|
ipv6MsgsPool: sync.Pool{
|
||||||
|
New: func() any {
|
||||||
|
msgs := make([]ipv6.Message, wgConn.DefaultBatchSize)
|
||||||
|
for i := range msgs {
|
||||||
|
msgs[i].Buffers = make(net.Buffers, 1)
|
||||||
|
msgs[i].OOB = make([]byte, srcControlSize)
|
||||||
|
}
|
||||||
|
return &msgs
|
||||||
|
},
|
||||||
|
},
|
||||||
transportNet: transportNet,
|
transportNet: transportNet,
|
||||||
mu: sync.Mutex{},
|
}
|
||||||
|
b.worker = newWorker(b.handlePkgs)
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
type StdNetEndpoint struct {
|
||||||
|
// AddrPort is the endpoint destination.
|
||||||
|
netip.AddrPort
|
||||||
|
// src is the current sticky source address and interface index, if supported.
|
||||||
|
src struct {
|
||||||
|
netip.Addr
|
||||||
|
ifidx int32
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
|
var (
|
||||||
func (b *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
|
_ wgConn.Bind = (*ICEBind)(nil)
|
||||||
b.mu.Lock()
|
_ wgConn.Endpoint = &StdNetEndpoint{}
|
||||||
defer b.mu.Unlock()
|
)
|
||||||
if b.udpMux == nil {
|
|
||||||
return nil, fmt.Errorf("ICEBind has not been initialized yet")
|
func (*ICEBind) ParseEndpoint(s string) (wgConn.Endpoint, error) {
|
||||||
|
e, err := netip.ParseAddrPort(s)
|
||||||
|
return asEndpoint(e), err
|
||||||
}
|
}
|
||||||
|
|
||||||
return b.udpMux, nil
|
func (e *StdNetEndpoint) ClearSrc() {
|
||||||
|
e.src.ifidx = 0
|
||||||
|
e.src.Addr = netip.Addr{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Open creates a WireGuard socket and an instance of UDPMux that is used to glue up ICE and WireGuard for hole punching
|
func (e *StdNetEndpoint) DstIP() netip.Addr {
|
||||||
func (b *ICEBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
|
return e.AddrPort.Addr()
|
||||||
b.mu.Lock()
|
|
||||||
defer b.mu.Unlock()
|
|
||||||
|
|
||||||
if b.ipv4 != nil {
|
|
||||||
return nil, 0, conn.ErrBindAlreadyOpen
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *StdNetEndpoint) SrcIP() netip.Addr {
|
||||||
|
return e.src.Addr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *StdNetEndpoint) SrcIfidx() int32 {
|
||||||
|
return e.src.ifidx
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *StdNetEndpoint) DstToBytes() []byte {
|
||||||
|
b, _ := e.AddrPort.MarshalBinary()
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *StdNetEndpoint) DstToString() string {
|
||||||
|
return e.AddrPort.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *StdNetEndpoint) SrcToString() string {
|
||||||
|
return e.src.Addr.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func listenNet(network string, port int) (*net.UDPConn, int, error) {
|
||||||
|
conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port))
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retrieve port.
|
||||||
|
laddr := conn.LocalAddr()
|
||||||
|
uaddr, err := net.ResolveUDPAddr(
|
||||||
|
laddr.Network(),
|
||||||
|
laddr.String(),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
return conn.(*net.UDPConn), uaddr.Port, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ICEBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
b.ipv4, _, err = listenNet("udp4", int(uport))
|
var tries int
|
||||||
|
|
||||||
|
if s.ipv4 != nil || s.ipv6 != nil {
|
||||||
|
return nil, 0, wgConn.ErrBindAlreadyOpen
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attempt to open ipv4 and ipv6 listeners on the same port.
|
||||||
|
// If uport is 0, we can retry on failure.
|
||||||
|
again:
|
||||||
|
port := int(uport)
|
||||||
|
var v4conn, v6conn *net.UDPConn
|
||||||
|
|
||||||
|
v4conn, port, err = listenNet("udp4", port)
|
||||||
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
b.udpMux = NewUniversalUDPMuxDefault(UniversalUDPMuxParams{UDPConn: b.ipv4, Net: b.transportNet})
|
// Listen on the same port as we're using for ipv4.
|
||||||
|
v6conn, port, err = listenNet("udp6", port)
|
||||||
portAddr, err := netip.ParseAddrPort(b.ipv4.LocalAddr().String())
|
if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
|
||||||
if err != nil {
|
v4conn.Close()
|
||||||
|
tries++
|
||||||
|
goto again
|
||||||
|
}
|
||||||
|
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||||
|
v4conn.Close()
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
|
var fns []wgConn.ReceiveFunc
|
||||||
log.Infof("opened ICEBind on %s", b.ipv4.LocalAddr().String())
|
if v4conn != nil {
|
||||||
|
fns = append(fns, s.receiveIPv4)
|
||||||
return []conn.ReceiveFunc{
|
s.ipv4 = v4conn
|
||||||
b.makeReceiveIPv4(b.ipv4),
|
}
|
||||||
},
|
if v6conn != nil {
|
||||||
portAddr.Port(), nil
|
fns = append(fns, s.receiveIPv6)
|
||||||
|
s.ipv6 = v6conn
|
||||||
|
}
|
||||||
|
if len(fns) == 0 {
|
||||||
|
return nil, 0, syscall.EAFNOSUPPORT
|
||||||
}
|
}
|
||||||
|
|
||||||
func listenNet(network string, port int) (net.PacketConn, int, error) {
|
s.ipv4PC = ipv4.NewPacketConn(s.ipv4)
|
||||||
c, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
|
s.ipv6PC = ipv6.NewPacketConn(s.ipv6)
|
||||||
|
|
||||||
|
s.udpMux = NewUniversalUDPMuxDefault(UniversalUDPMuxParams{UDPConn: s.ipv4, Net: s.transportNet})
|
||||||
|
return fns, uint16(port), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ICEBind) receiveIPv4(buffs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
||||||
|
msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
|
||||||
|
defer s.ipv4MsgsPool.Put(msgs)
|
||||||
|
for i := range buffs {
|
||||||
|
(*msgs)[i].Buffers[0] = buffs[i]
|
||||||
|
}
|
||||||
|
numMsgs, err := s.ipv4PC.ReadBatch(*msgs, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
lAddr := c.LocalAddr()
|
s.worker.doWork((*msgs)[:numMsgs], sizes, eps)
|
||||||
uAddr, err := net.ResolveUDPAddr(
|
return numMsgs, nil
|
||||||
lAddr.Network(),
|
}
|
||||||
lAddr.String(),
|
|
||||||
|
func (s *ICEBind) receiveIPv6(buffs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
||||||
|
msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
|
||||||
|
defer s.ipv6MsgsPool.Put(msgs)
|
||||||
|
for i := range buffs {
|
||||||
|
(*msgs)[i].Buffers[0] = buffs[i]
|
||||||
|
}
|
||||||
|
numMsgs, err := s.ipv6PC.ReadBatch(*msgs, 0)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
for i := 0; i < numMsgs; i++ {
|
||||||
|
msg := &(*msgs)[i]
|
||||||
|
sizes[i] = msg.N
|
||||||
|
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
|
||||||
|
ep := asEndpoint(addrPort)
|
||||||
|
getSrcFromControl(msg.OOB, ep)
|
||||||
|
eps[i] = ep
|
||||||
|
}
|
||||||
|
return numMsgs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ICEBind) BatchSize() int {
|
||||||
|
return s.batchSize
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ICEBind) Close() error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
var err1, err2 error
|
||||||
|
if s.ipv4 != nil {
|
||||||
|
err1 = s.ipv4.Close()
|
||||||
|
s.ipv4 = nil
|
||||||
|
}
|
||||||
|
if s.ipv6 != nil {
|
||||||
|
err2 = s.ipv6.Close()
|
||||||
|
s.ipv6 = nil
|
||||||
|
}
|
||||||
|
s.blackhole4 = false
|
||||||
|
s.blackhole6 = false
|
||||||
|
if err1 != nil {
|
||||||
|
return err1
|
||||||
|
}
|
||||||
|
return err2
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ICEBind) Send(buffs [][]byte, endpoint wgConn.Endpoint) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
blackhole := s.blackhole4
|
||||||
|
conn := s.ipv4
|
||||||
|
is6 := false
|
||||||
|
if endpoint.DstIP().Is6() {
|
||||||
|
blackhole = s.blackhole6
|
||||||
|
conn = s.ipv6
|
||||||
|
is6 = true
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
if blackhole {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if conn == nil {
|
||||||
|
return syscall.EAFNOSUPPORT
|
||||||
|
}
|
||||||
|
if is6 {
|
||||||
|
return s.send6(s.ipv6PC, endpoint, buffs)
|
||||||
|
} else {
|
||||||
|
return s.send4(s.ipv4PC, endpoint, buffs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
|
||||||
|
func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
if s.udpMux == nil {
|
||||||
|
return nil, fmt.Errorf("ICEBind has not been initialized yet")
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.udpMux, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ICEBind) send4(conn *ipv4.PacketConn, ep wgConn.Endpoint, buffs [][]byte) error {
|
||||||
|
ua := s.udpAddrPool.Get().(*net.UDPAddr)
|
||||||
|
as4 := ep.DstIP().As4()
|
||||||
|
copy(ua.IP, as4[:])
|
||||||
|
ua.IP = ua.IP[:4]
|
||||||
|
ua.Port = int(ep.(*StdNetEndpoint).Port())
|
||||||
|
msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
|
||||||
|
for i, buff := range buffs {
|
||||||
|
(*msgs)[i].Buffers[0] = buff
|
||||||
|
(*msgs)[i].Addr = ua
|
||||||
|
setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint))
|
||||||
|
}
|
||||||
|
var (
|
||||||
|
n int
|
||||||
|
err error
|
||||||
|
start int
|
||||||
)
|
)
|
||||||
if err != nil {
|
for {
|
||||||
return nil, 0, err
|
n, err = conn.WriteBatch((*msgs)[start:len(buffs)], 0)
|
||||||
|
if err != nil || n == len((*msgs)[start:len(buffs)]) {
|
||||||
|
break
|
||||||
}
|
}
|
||||||
return c, uAddr.Port, nil
|
start += n
|
||||||
|
}
|
||||||
|
s.udpAddrPool.Put(ua)
|
||||||
|
s.ipv4MsgsPool.Put(msgs)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ICEBind) send6(conn *ipv6.PacketConn, ep wgConn.Endpoint, buffs [][]byte) error {
|
||||||
|
ua := s.udpAddrPool.Get().(*net.UDPAddr)
|
||||||
|
as16 := ep.DstIP().As16()
|
||||||
|
copy(ua.IP, as16[:])
|
||||||
|
ua.IP = ua.IP[:16]
|
||||||
|
ua.Port = int(ep.(*StdNetEndpoint).Port())
|
||||||
|
msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
|
||||||
|
for i, buff := range buffs {
|
||||||
|
(*msgs)[i].Buffers[0] = buff
|
||||||
|
(*msgs)[i].Addr = ua
|
||||||
|
setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint))
|
||||||
|
}
|
||||||
|
var (
|
||||||
|
n int
|
||||||
|
err error
|
||||||
|
start int
|
||||||
|
)
|
||||||
|
for {
|
||||||
|
n, err = conn.WriteBatch((*msgs)[start:len(buffs)], 0)
|
||||||
|
if err != nil || n == len((*msgs)[start:len(buffs)]) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
start += n
|
||||||
|
}
|
||||||
|
s.udpAddrPool.Put(ua)
|
||||||
|
s.ipv6MsgsPool.Put(msgs)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr) (bool, error) {
|
||||||
|
for _, buffer := range buffers {
|
||||||
|
if !stun.IsMessage(buffer) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
msg, err := parseSTUNMessage(buffer[:n])
|
||||||
|
if err != nil {
|
||||||
|
buffer = []byte{}
|
||||||
|
return true, err
|
||||||
|
}
|
||||||
|
muxErr := s.udpMux.HandleSTUNMessage(msg, addr)
|
||||||
|
if muxErr != nil {
|
||||||
|
log.Warnf("failed to handle packet")
|
||||||
|
}
|
||||||
|
buffer = []byte{}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ICEBind) handlePkgs(msg *ipv4.Message) (int, *StdNetEndpoint) {
|
||||||
|
// todo: handle err
|
||||||
|
size := 0
|
||||||
|
ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr)
|
||||||
|
if !ok {
|
||||||
|
size = msg.N
|
||||||
|
}
|
||||||
|
|
||||||
|
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
|
||||||
|
ep := asEndpoint(addrPort)
|
||||||
|
getSrcFromControl(msg.OOB, ep)
|
||||||
|
return size, ep
|
||||||
|
}
|
||||||
|
|
||||||
|
// endpointPool contains a re-usable set of mapping from netip.AddrPort to Endpoint.
|
||||||
|
// This exists to reduce allocations: Putting a netip.AddrPort in an Endpoint allocates,
|
||||||
|
// but Endpoints are immutable, so we can re-use them.
|
||||||
|
var endpointPool = sync.Pool{
|
||||||
|
New: func() any {
|
||||||
|
return make(map[netip.AddrPort]*StdNetEndpoint)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// asEndpoint returns an Endpoint containing ap.
|
||||||
|
func asEndpoint(ap netip.AddrPort) *StdNetEndpoint {
|
||||||
|
m := endpointPool.Get().(map[netip.AddrPort]*StdNetEndpoint)
|
||||||
|
defer endpointPool.Put(m)
|
||||||
|
e, ok := m[ap]
|
||||||
|
if !ok {
|
||||||
|
e = &StdNetEndpoint{AddrPort: ap}
|
||||||
|
m[ap] = e
|
||||||
|
}
|
||||||
|
return e
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseSTUNMessage(raw []byte) (*stun.Message, error) {
|
func parseSTUNMessage(raw []byte) (*stun.Message, error) {
|
||||||
@@ -102,107 +423,3 @@ func parseSTUNMessage(raw []byte) (*stun.Message, error) {
|
|||||||
|
|
||||||
return msg, nil
|
return msg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *ICEBind) makeReceiveIPv4(c net.PacketConn) conn.ReceiveFunc {
|
|
||||||
return func(buff []byte) (int, conn.Endpoint, error) {
|
|
||||||
n, endpoint, err := c.ReadFrom(buff)
|
|
||||||
if err != nil {
|
|
||||||
return 0, nil, err
|
|
||||||
}
|
|
||||||
e, err := netip.ParseAddrPort(endpoint.String())
|
|
||||||
if err != nil {
|
|
||||||
return 0, nil, err
|
|
||||||
}
|
|
||||||
if !stun.IsMessage(buff) {
|
|
||||||
// WireGuard traffic
|
|
||||||
return n, (conn.StdNetEndpoint)(netip.AddrPortFrom(e.Addr(), e.Port())), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
msg, err := parseSTUNMessage(buff[:n])
|
|
||||||
if err != nil {
|
|
||||||
return 0, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = b.udpMux.HandleSTUNMessage(msg, endpoint)
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed to handle packet")
|
|
||||||
}
|
|
||||||
|
|
||||||
// discard packets because they are STUN related
|
|
||||||
return 0, nil, nil //todo proper return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close closes the WireGuard socket and UDPMux
|
|
||||||
func (b *ICEBind) Close() error {
|
|
||||||
b.mu.Lock()
|
|
||||||
defer b.mu.Unlock()
|
|
||||||
|
|
||||||
var err1, err2 error
|
|
||||||
if b.ipv4 != nil {
|
|
||||||
c := b.ipv4
|
|
||||||
b.ipv4 = nil
|
|
||||||
err1 = c.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
if b.udpMux != nil {
|
|
||||||
m := b.udpMux
|
|
||||||
b.udpMux = nil
|
|
||||||
err2 = m.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
if err1 != nil {
|
|
||||||
return err1
|
|
||||||
}
|
|
||||||
|
|
||||||
return err2
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetMark sets the mark for each packet sent through this Bind.
|
|
||||||
// This mark is passed to the kernel as the socket option SO_MARK.
|
|
||||||
func (b *ICEBind) SetMark(mark uint32) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send bytes to the remote endpoint (peer)
|
|
||||||
func (b *ICEBind) Send(buff []byte, endpoint conn.Endpoint) error {
|
|
||||||
|
|
||||||
nend, ok := endpoint.(conn.StdNetEndpoint)
|
|
||||||
if !ok {
|
|
||||||
return conn.ErrWrongEndpointType
|
|
||||||
}
|
|
||||||
addrPort := netip.AddrPort(nend)
|
|
||||||
_, err := b.ipv4.WriteTo(buff, &net.UDPAddr{
|
|
||||||
IP: addrPort.Addr().AsSlice(),
|
|
||||||
Port: int(addrPort.Port()),
|
|
||||||
Zone: addrPort.Addr().Zone(),
|
|
||||||
})
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// ParseEndpoint creates a new endpoint from a string.
|
|
||||||
func (b *ICEBind) ParseEndpoint(s string) (ep conn.Endpoint, err error) {
|
|
||||||
e, err := netip.ParseAddrPort(s)
|
|
||||||
return asEndpoint(e), err
|
|
||||||
}
|
|
||||||
|
|
||||||
// endpointPool contains a re-usable set of mapping from netip.AddrPort to Endpoint.
|
|
||||||
// This exists to reduce allocations: Putting a netip.AddrPort in an Endpoint allocates,
|
|
||||||
// but Endpoints are immutable, so we can re-use them.
|
|
||||||
var endpointPool = sync.Pool{
|
|
||||||
New: func() any {
|
|
||||||
return make(map[netip.AddrPort]conn.Endpoint)
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// asEndpoint returns an Endpoint containing ap.
|
|
||||||
func asEndpoint(ap netip.AddrPort) conn.Endpoint {
|
|
||||||
m := endpointPool.Get().(map[netip.AddrPort]conn.Endpoint)
|
|
||||||
defer endpointPool.Put(m)
|
|
||||||
e, ok := m[ap]
|
|
||||||
if !ok {
|
|
||||||
e = conn.Endpoint(conn.StdNetEndpoint(ap))
|
|
||||||
m[ap] = e
|
|
||||||
}
|
|
||||||
return e
|
|
||||||
}
|
|
||||||
|
|||||||
36
iface/bind/controlfns.go
Normal file
36
iface/bind/controlfns.go
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package bind
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"syscall"
|
||||||
|
)
|
||||||
|
|
||||||
|
// controlFn is the callback function signature from net.ListenConfig.Control.
|
||||||
|
// It is used to apply platform specific configuration to the socket prior to
|
||||||
|
// bind.
|
||||||
|
type controlFn func(network, address string, c syscall.RawConn) error
|
||||||
|
|
||||||
|
// controlFns is a list of functions that are called from the listen config
|
||||||
|
// that can apply socket options.
|
||||||
|
var controlFns = []controlFn{}
|
||||||
|
|
||||||
|
// listenConfig returns a net.ListenConfig that applies the controlFns to the
|
||||||
|
// socket prior to bind. This is used to apply socket buffer sizing and packet
|
||||||
|
// information OOB configuration for sticky sockets.
|
||||||
|
func listenConfig() *net.ListenConfig {
|
||||||
|
return &net.ListenConfig{
|
||||||
|
Control: func(network, address string, c syscall.RawConn) error {
|
||||||
|
for _, fn := range controlFns {
|
||||||
|
if err := fn(network, address, c); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
41
iface/bind/controlfns_linux.go
Normal file
41
iface/bind/controlfns_linux.go
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package bind
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
controlFns = append(controlFns,
|
||||||
|
|
||||||
|
// Enable receiving of the packet information (IP_PKTINFO for IPv4,
|
||||||
|
// IPV6_PKTINFO for IPv6) that is used to implement sticky socket support.
|
||||||
|
func(network, address string, c syscall.RawConn) error {
|
||||||
|
var err error
|
||||||
|
switch network {
|
||||||
|
case "udp4":
|
||||||
|
c.Control(func(fd uintptr) {
|
||||||
|
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO, 1)
|
||||||
|
})
|
||||||
|
case "udp6":
|
||||||
|
c.Control(func(fd uintptr) {
|
||||||
|
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO, 1)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1)
|
||||||
|
})
|
||||||
|
default:
|
||||||
|
err = fmt.Errorf("unhandled network: %s: %w", network, unix.EINVAL)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
28
iface/bind/controlfns_unix.go
Normal file
28
iface/bind/controlfns_unix.go
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
//go:build !windows && !linux && !js
|
||||||
|
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package bind
|
||||||
|
|
||||||
|
import (
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
controlFns = append(controlFns,
|
||||||
|
func(network, address string, c syscall.RawConn) error {
|
||||||
|
var err error
|
||||||
|
if network == "udp6" {
|
||||||
|
c.Control(func(fd uintptr) {
|
||||||
|
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
12
iface/bind/mark_default.go
Normal file
12
iface/bind/mark_default.go
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
//go:build !linux && !openbsd && !freebsd
|
||||||
|
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package bind
|
||||||
|
|
||||||
|
func (s *ICEBind) SetMark(mark uint32) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
65
iface/bind/mark_unix.go
Normal file
65
iface/bind/mark_unix.go
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
//go:build linux || openbsd || freebsd
|
||||||
|
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package bind
|
||||||
|
|
||||||
|
import (
|
||||||
|
"runtime"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
var fwmarkIoctl int
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "linux", "android":
|
||||||
|
fwmarkIoctl = 36 /* unix.SO_MARK */
|
||||||
|
case "freebsd":
|
||||||
|
fwmarkIoctl = 0x1015 /* unix.SO_USER_COOKIE */
|
||||||
|
case "openbsd":
|
||||||
|
fwmarkIoctl = 0x1021 /* unix.SO_RTABLE */
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ICEBind) SetMark(mark uint32) error {
|
||||||
|
var operr error
|
||||||
|
if fwmarkIoctl == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if s.ipv4 != nil {
|
||||||
|
fd, err := s.ipv4.SyscallConn()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = fd.Control(func(fd uintptr) {
|
||||||
|
operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark))
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
err = operr
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if s.ipv6 != nil {
|
||||||
|
fd, err := s.ipv6.SyscallConn()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = fd.Control(func(fd uintptr) {
|
||||||
|
operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark))
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
err = operr
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
28
iface/bind/sticky_default.go
Normal file
28
iface/bind/sticky_default.go
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
//go:build !linux
|
||||||
|
// +build !linux
|
||||||
|
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package bind
|
||||||
|
|
||||||
|
import wgConn "golang.zx2c4.com/wireguard/conn"
|
||||||
|
|
||||||
|
// TODO: macOS, FreeBSD and other BSDs likely do support this feature set, but
|
||||||
|
// use alternatively named flags and need ports and require testing.
|
||||||
|
|
||||||
|
// getSrcFromControl parses the control for PKTINFO and if found updates ep with
|
||||||
|
// the source information found.
|
||||||
|
func getSrcFromControl(control []byte, ep *wgConn.StdNetEndpoint) {
|
||||||
|
}
|
||||||
|
|
||||||
|
// setSrcControl parses the control for PKTINFO and if found updates ep with
|
||||||
|
// the source information found.
|
||||||
|
func setSrcControl(control *[]byte, ep *wgConn.StdNetEndpoint) {
|
||||||
|
}
|
||||||
|
|
||||||
|
// srcControlSize returns the recommended buffer size for pooling sticky control
|
||||||
|
// data.
|
||||||
|
const srcControlSize = 0
|
||||||
111
iface/bind/sticky_linux.go
Normal file
111
iface/bind/sticky_linux.go
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package bind
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
// getSrcFromControl parses the control for PKTINFO and if found updates ep with
|
||||||
|
// the source information found.
|
||||||
|
func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
|
||||||
|
ep.ClearSrc()
|
||||||
|
|
||||||
|
var (
|
||||||
|
hdr unix.Cmsghdr
|
||||||
|
data []byte
|
||||||
|
rem []byte = control
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
for len(rem) > unix.SizeofCmsghdr {
|
||||||
|
hdr, data, rem, err = unix.ParseOneSocketControlMessage(control)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if hdr.Level == unix.IPPROTO_IP &&
|
||||||
|
hdr.Type == unix.IP_PKTINFO {
|
||||||
|
|
||||||
|
info := pktInfoFromBuf[unix.Inet4Pktinfo](data)
|
||||||
|
ep.src.Addr = netip.AddrFrom4(info.Spec_dst)
|
||||||
|
ep.src.ifidx = info.Ifindex
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if hdr.Level == unix.IPPROTO_IPV6 &&
|
||||||
|
hdr.Type == unix.IPV6_PKTINFO {
|
||||||
|
|
||||||
|
info := pktInfoFromBuf[unix.Inet6Pktinfo](data)
|
||||||
|
ep.src.Addr = netip.AddrFrom16(info.Addr)
|
||||||
|
ep.src.ifidx = int32(info.Ifindex)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// pktInfoFromBuf returns type T populated from the provided buf via copy(). It
|
||||||
|
// panics if buf is of insufficient size.
|
||||||
|
func pktInfoFromBuf[T unix.Inet4Pktinfo | unix.Inet6Pktinfo](buf []byte) (t T) {
|
||||||
|
size := int(unsafe.Sizeof(t))
|
||||||
|
if len(buf) < size {
|
||||||
|
panic("pktInfoFromBuf: buffer too small")
|
||||||
|
}
|
||||||
|
copy(unsafe.Slice((*byte)(unsafe.Pointer(&t)), size), buf)
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
|
||||||
|
// setSrcControl parses the control for PKTINFO and if found updates ep with
|
||||||
|
// the source information found.
|
||||||
|
func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
|
||||||
|
*control = (*control)[:cap(*control)]
|
||||||
|
if len(*control) < int(unsafe.Sizeof(unix.Cmsghdr{})) {
|
||||||
|
*control = (*control)[:0]
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if ep.src.ifidx == 0 && !ep.SrcIP().IsValid() {
|
||||||
|
*control = (*control)[:0]
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(*control) < srcControlSize {
|
||||||
|
*control = (*control)[:0]
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(*control)[0]))
|
||||||
|
if ep.SrcIP().Is4() {
|
||||||
|
hdr.Level = unix.IPPROTO_IP
|
||||||
|
hdr.Type = unix.IP_PKTINFO
|
||||||
|
hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo))
|
||||||
|
|
||||||
|
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&(*control)[unix.SizeofCmsghdr]))
|
||||||
|
info.Ifindex = ep.src.ifidx
|
||||||
|
if ep.SrcIP().IsValid() {
|
||||||
|
info.Spec_dst = ep.SrcIP().As4()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
hdr.Level = unix.IPPROTO_IPV6
|
||||||
|
hdr.Type = unix.IPV6_PKTINFO
|
||||||
|
hdr.Len = unix.SizeofCmsghdr + unix.SizeofInet6Pktinfo
|
||||||
|
|
||||||
|
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&(*control)[unix.SizeofCmsghdr]))
|
||||||
|
info.Ifindex = uint32(ep.src.ifidx)
|
||||||
|
if ep.SrcIP().IsValid() {
|
||||||
|
info.Addr = ep.SrcIP().As16()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
*control = (*control)[:hdr.Len]
|
||||||
|
}
|
||||||
|
|
||||||
|
var srcControlSize = unix.CmsgLen(unix.SizeofInet6Pktinfo)
|
||||||
207
iface/bind/sticky_linux_test.go
Normal file
207
iface/bind/sticky_linux_test.go
Normal file
@@ -0,0 +1,207 @@
|
|||||||
|
//go:build linux
|
||||||
|
// +build linux
|
||||||
|
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package bind
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"runtime"
|
||||||
|
"testing"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_setSrcControl(t *testing.T) {
|
||||||
|
t.Run("IPv4", func(t *testing.T) {
|
||||||
|
ep := &StdNetEndpoint{
|
||||||
|
AddrPort: netip.MustParseAddrPort("127.0.0.1:1234"),
|
||||||
|
}
|
||||||
|
ep.src.Addr = netip.MustParseAddr("127.0.0.1")
|
||||||
|
ep.src.ifidx = 5
|
||||||
|
|
||||||
|
control := make([]byte, srcControlSize)
|
||||||
|
|
||||||
|
setSrcControl(&control, ep)
|
||||||
|
|
||||||
|
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
|
||||||
|
if hdr.Level != unix.IPPROTO_IP {
|
||||||
|
t.Errorf("unexpected level: %d", hdr.Level)
|
||||||
|
}
|
||||||
|
if hdr.Type != unix.IP_PKTINFO {
|
||||||
|
t.Errorf("unexpected type: %d", hdr.Type)
|
||||||
|
}
|
||||||
|
if hdr.Len != uint64(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{})))) {
|
||||||
|
t.Errorf("unexpected length: %d", hdr.Len)
|
||||||
|
}
|
||||||
|
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
|
||||||
|
if info.Spec_dst[0] != 127 || info.Spec_dst[1] != 0 || info.Spec_dst[2] != 0 || info.Spec_dst[3] != 1 {
|
||||||
|
t.Errorf("unexpected address: %v", info.Spec_dst)
|
||||||
|
}
|
||||||
|
if info.Ifindex != 5 {
|
||||||
|
t.Errorf("unexpected ifindex: %d", info.Ifindex)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("IPv6", func(t *testing.T) {
|
||||||
|
ep := &StdNetEndpoint{
|
||||||
|
AddrPort: netip.MustParseAddrPort("[::1]:1234"),
|
||||||
|
}
|
||||||
|
ep.src.Addr = netip.MustParseAddr("::1")
|
||||||
|
ep.src.ifidx = 5
|
||||||
|
|
||||||
|
control := make([]byte, srcControlSize)
|
||||||
|
|
||||||
|
setSrcControl(&control, ep)
|
||||||
|
|
||||||
|
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
|
||||||
|
if hdr.Level != unix.IPPROTO_IPV6 {
|
||||||
|
t.Errorf("unexpected level: %d", hdr.Level)
|
||||||
|
}
|
||||||
|
if hdr.Type != unix.IPV6_PKTINFO {
|
||||||
|
t.Errorf("unexpected type: %d", hdr.Type)
|
||||||
|
}
|
||||||
|
if hdr.Len != uint64(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{})))) {
|
||||||
|
t.Errorf("unexpected length: %d", hdr.Len)
|
||||||
|
}
|
||||||
|
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
|
||||||
|
if info.Addr != ep.SrcIP().As16() {
|
||||||
|
t.Errorf("unexpected address: %v", info.Addr)
|
||||||
|
}
|
||||||
|
if info.Ifindex != 5 {
|
||||||
|
t.Errorf("unexpected ifindex: %d", info.Ifindex)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ClearOnNoSrc", func(t *testing.T) {
|
||||||
|
control := make([]byte, srcControlSize)
|
||||||
|
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
|
||||||
|
hdr.Level = 1
|
||||||
|
hdr.Type = 2
|
||||||
|
hdr.Len = 3
|
||||||
|
|
||||||
|
setSrcControl(&control, &StdNetEndpoint{})
|
||||||
|
|
||||||
|
if len(control) != 0 {
|
||||||
|
t.Errorf("unexpected control: %v", control)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_getSrcFromControl(t *testing.T) {
|
||||||
|
t.Run("IPv4", func(t *testing.T) {
|
||||||
|
control := make([]byte, srcControlSize)
|
||||||
|
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
|
||||||
|
hdr.Level = unix.IPPROTO_IP
|
||||||
|
hdr.Type = unix.IP_PKTINFO
|
||||||
|
hdr.Len = uint64(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{}))))
|
||||||
|
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
|
||||||
|
info.Spec_dst = [4]byte{127, 0, 0, 1}
|
||||||
|
info.Ifindex = 5
|
||||||
|
|
||||||
|
ep := &StdNetEndpoint{}
|
||||||
|
getSrcFromControl(control, ep)
|
||||||
|
|
||||||
|
if ep.src.Addr != netip.MustParseAddr("127.0.0.1") {
|
||||||
|
t.Errorf("unexpected address: %v", ep.src.Addr)
|
||||||
|
}
|
||||||
|
if ep.src.ifidx != 5 {
|
||||||
|
t.Errorf("unexpected ifindex: %d", ep.src.ifidx)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("IPv6", func(t *testing.T) {
|
||||||
|
control := make([]byte, srcControlSize)
|
||||||
|
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
|
||||||
|
hdr.Level = unix.IPPROTO_IPV6
|
||||||
|
hdr.Type = unix.IPV6_PKTINFO
|
||||||
|
hdr.Len = uint64(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{}))))
|
||||||
|
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
|
||||||
|
info.Addr = [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
|
||||||
|
info.Ifindex = 5
|
||||||
|
|
||||||
|
ep := &StdNetEndpoint{}
|
||||||
|
getSrcFromControl(control, ep)
|
||||||
|
|
||||||
|
if ep.SrcIP() != netip.MustParseAddr("::1") {
|
||||||
|
t.Errorf("unexpected address: %v", ep.SrcIP())
|
||||||
|
}
|
||||||
|
if ep.src.ifidx != 5 {
|
||||||
|
t.Errorf("unexpected ifindex: %d", ep.src.ifidx)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("ClearOnEmpty", func(t *testing.T) {
|
||||||
|
control := make([]byte, srcControlSize)
|
||||||
|
ep := &StdNetEndpoint{}
|
||||||
|
ep.src.Addr = netip.MustParseAddr("::1")
|
||||||
|
ep.src.ifidx = 5
|
||||||
|
|
||||||
|
getSrcFromControl(control, ep)
|
||||||
|
if ep.SrcIP().IsValid() {
|
||||||
|
t.Errorf("unexpected address: %v", ep.src.Addr)
|
||||||
|
}
|
||||||
|
if ep.src.ifidx != 0 {
|
||||||
|
t.Errorf("unexpected ifindex: %d", ep.src.ifidx)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_listenConfig(t *testing.T) {
|
||||||
|
t.Run("IPv4", func(t *testing.T) {
|
||||||
|
conn, err := listenConfig().ListenPacket(context.Background(), "udp4", ":0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
sc, err := conn.(*net.UDPConn).SyscallConn()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if runtime.GOOS == "linux" {
|
||||||
|
var i int
|
||||||
|
sc.Control(func(fd uintptr) {
|
||||||
|
i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if i != 1 {
|
||||||
|
t.Error("IP_PKTINFO not set!")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("IPv6", func(t *testing.T) {
|
||||||
|
conn, err := listenConfig().ListenPacket(context.Background(), "udp6", ":0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
sc, err := conn.(*net.UDPConn).SyscallConn()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if runtime.GOOS == "linux" {
|
||||||
|
var i int
|
||||||
|
sc.Control(func(fd uintptr) {
|
||||||
|
i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if i != 1 {
|
||||||
|
t.Error("IPV6_PKTINFO not set!")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
56
iface/bind/worker.go
Normal file
56
iface/bind/worker.go
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
package bind
|
||||||
|
|
||||||
|
import (
|
||||||
|
"runtime"
|
||||||
|
|
||||||
|
"golang.org/x/net/ipv4"
|
||||||
|
wgConn "golang.zx2c4.com/wireguard/conn"
|
||||||
|
)
|
||||||
|
|
||||||
|
// todo: add close function
|
||||||
|
type worker struct {
|
||||||
|
jobOffer chan int
|
||||||
|
numOfWorker int
|
||||||
|
|
||||||
|
jobFn func(msg *ipv4.Message) (int, *StdNetEndpoint)
|
||||||
|
|
||||||
|
messages []ipv4.Message
|
||||||
|
sizes []int
|
||||||
|
eps []wgConn.Endpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
func newWorker(jobFn func(msg *ipv4.Message) (int, *StdNetEndpoint)) *worker {
|
||||||
|
w := &worker{
|
||||||
|
jobOffer: make(chan int),
|
||||||
|
numOfWorker: runtime.NumCPU(),
|
||||||
|
jobFn: jobFn,
|
||||||
|
}
|
||||||
|
|
||||||
|
w.populateWorkers()
|
||||||
|
return w
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *worker) doWork(messages []ipv4.Message, sizes []int, eps []wgConn.Endpoint) {
|
||||||
|
w.messages = messages
|
||||||
|
w.sizes = sizes
|
||||||
|
w.eps = eps
|
||||||
|
|
||||||
|
for i := 0; i < len(messages); i++ {
|
||||||
|
w.jobOffer <- i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *worker) populateWorkers() {
|
||||||
|
for i := 0; i < w.numOfWorker; i++ {
|
||||||
|
go w.loop()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *worker) loop() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case msgPos := <-w.jobOffer:
|
||||||
|
w.sizes[msgPos], w.eps[msgPos] = w.jobFn(&w.messages[msgPos])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -7,7 +7,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// NewWGIFace Creates a new WireGuard interface instance
|
// NewWGIFace Creates a new WireGuard interface instance
|
||||||
func NewWGIFace(iFaceName string, address string, mtu int, tunAdapter TunAdapter, transportNet transport.Net) (*WGIface, error) {
|
func NewWGIFace(ifaceName string, address string, mtu int, routes []string, tunAdapter TunAdapter, transportNet transport.Net) (*WGIface, error) {
|
||||||
wgIFace := &WGIface{
|
wgIFace := &WGIface{
|
||||||
mu: sync.Mutex{},
|
mu: sync.Mutex{},
|
||||||
}
|
}
|
||||||
@@ -17,7 +17,7 @@ func NewWGIFace(iFaceName string, address string, mtu int, tunAdapter TunAdapter
|
|||||||
return wgIFace, err
|
return wgIFace, err
|
||||||
}
|
}
|
||||||
|
|
||||||
tun := newTunDevice(wgAddress, mtu, tunAdapter, transportNet)
|
tun := newTunDevice(wgAddress, mtu, routes, tunAdapter, transportNet)
|
||||||
wgIFace.tun = tun
|
wgIFace.tun = tun
|
||||||
|
|
||||||
wgIFace.configurer = newWGConfigurer(tun)
|
wgIFace.configurer = newWGConfigurer(tun)
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// NewWGIFace Creates a new WireGuard interface instance
|
// NewWGIFace Creates a new WireGuard interface instance
|
||||||
func NewWGIFace(iFaceName string, address string, mtu int, tunAdapter TunAdapter, transportNet transport.Net) (*WGIface, error) {
|
func NewWGIFace(iFaceName string, address string, mtu int, routes []string, tunAdapter TunAdapter, transportNet transport.Net) (*WGIface, error) {
|
||||||
wgIFace := &WGIface{
|
wgIFace := &WGIface{
|
||||||
mu: sync.Mutex{},
|
mu: sync.Mutex{},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ func TestWGIface_UpdateAddr(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
iface, err := NewWGIFace(ifaceName, addr, DefaultMTU, nil, newNet)
|
iface, err := NewWGIFace(ifaceName, addr, DefaultMTU, nil, nil, newNet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -103,7 +103,7 @@ func Test_CreateInterface(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, newNet)
|
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, nil, newNet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -136,7 +136,7 @@ func Test_Close(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, newNet)
|
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, nil, newNet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -168,7 +168,7 @@ func Test_ConfigureInterface(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, newNet)
|
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, nil, newNet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -219,7 +219,7 @@ func Test_UpdatePeer(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, newNet)
|
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, nil, newNet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -282,7 +282,7 @@ func Test_RemovePeer(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, newNet)
|
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, nil, newNet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -335,7 +335,7 @@ func Test_ConnectPeers(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
iface1, err := NewWGIFace(peer1ifaceName, peer1wgIP, DefaultMTU, nil, newNet)
|
iface1, err := NewWGIFace(peer1ifaceName, peer1wgIP, DefaultMTU, nil, nil, newNet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -356,7 +356,7 @@ func Test_ConnectPeers(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
iface2, err := NewWGIFace(peer2ifaceName, peer2wgIP, DefaultMTU, nil, newNet)
|
iface2, err := NewWGIFace(peer2ifaceName, peer2wgIP, DefaultMTU, nil, nil, newNet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,6 @@ package iface
|
|||||||
|
|
||||||
// TunAdapter is an interface for create tun device from externel service
|
// TunAdapter is an interface for create tun device from externel service
|
||||||
type TunAdapter interface {
|
type TunAdapter interface {
|
||||||
ConfigureInterface(address string, mtu int) (int, error)
|
ConfigureInterface(address string, mtu int, routes string) (int, error)
|
||||||
UpdateAddr(address string) error
|
UpdateAddr(address string) error
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,35 +1,34 @@
|
|||||||
package iface
|
package iface
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"strings"
|
||||||
|
|
||||||
"github.com/pion/transport/v2"
|
"github.com/pion/transport/v2"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/iface/bind"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
"golang.zx2c4.com/wireguard/device"
|
"golang.zx2c4.com/wireguard/device"
|
||||||
"golang.zx2c4.com/wireguard/ipc"
|
|
||||||
"golang.zx2c4.com/wireguard/tun"
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/iface/bind"
|
||||||
)
|
)
|
||||||
|
|
||||||
type tunDevice struct {
|
type tunDevice struct {
|
||||||
address WGAddress
|
address WGAddress
|
||||||
mtu int
|
mtu int
|
||||||
|
routes []string
|
||||||
tunAdapter TunAdapter
|
tunAdapter TunAdapter
|
||||||
|
|
||||||
fd int
|
fd int
|
||||||
name string
|
name string
|
||||||
device *device.Device
|
device *device.Device
|
||||||
uapi net.Listener
|
|
||||||
iceBind *bind.ICEBind
|
iceBind *bind.ICEBind
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTunDevice(address WGAddress, mtu int, tunAdapter TunAdapter, transportNet transport.Net) *tunDevice {
|
func newTunDevice(address WGAddress, mtu int, routes []string, tunAdapter TunAdapter, transportNet transport.Net) *tunDevice {
|
||||||
return &tunDevice{
|
return &tunDevice{
|
||||||
address: address,
|
address: address,
|
||||||
mtu: mtu,
|
mtu: mtu,
|
||||||
|
routes: routes,
|
||||||
tunAdapter: tunAdapter,
|
tunAdapter: tunAdapter,
|
||||||
iceBind: bind.NewICEBind(transportNet),
|
iceBind: bind.NewICEBind(transportNet),
|
||||||
}
|
}
|
||||||
@@ -37,7 +36,8 @@ func newTunDevice(address WGAddress, mtu int, tunAdapter TunAdapter, transportNe
|
|||||||
|
|
||||||
func (t *tunDevice) Create() error {
|
func (t *tunDevice) Create() error {
|
||||||
var err error
|
var err error
|
||||||
t.fd, err = t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu)
|
routesString := t.routesToString()
|
||||||
|
t.fd, err = t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, routesString)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to create Android interface: %s", err)
|
log.Errorf("failed to create Android interface: %s", err)
|
||||||
return err
|
return err
|
||||||
@@ -54,32 +54,8 @@ func (t *tunDevice) Create() error {
|
|||||||
t.device = device.NewDevice(tunDevice, t.iceBind, device.NewLogger(device.LogLevelSilent, "[wiretrustee] "))
|
t.device = device.NewDevice(tunDevice, t.iceBind, device.NewLogger(device.LogLevelSilent, "[wiretrustee] "))
|
||||||
t.device.DisableSomeRoamingForBrokenMobileSemantics()
|
t.device.DisableSomeRoamingForBrokenMobileSemantics()
|
||||||
|
|
||||||
log.Debugf("create uapi")
|
|
||||||
tunSock, err := ipc.UAPIOpen(name)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
t.uapi, err = ipc.UAPIListen(name, tunSock)
|
|
||||||
if err != nil {
|
|
||||||
tunSock.Close()
|
|
||||||
unix.Close(t.fd)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
uapiConn, err := t.uapi.Accept()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
go t.device.IpcHandle(uapiConn)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
err = t.device.Up()
|
err = t.device.Up()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tunSock.Close()
|
|
||||||
t.device.Close()
|
t.device.Close()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -105,13 +81,13 @@ func (t *tunDevice) UpdateAddr(addr WGAddress) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *tunDevice) Close() (err error) {
|
func (t *tunDevice) Close() (err error) {
|
||||||
if t.uapi != nil {
|
|
||||||
err = t.uapi.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
if t.device != nil {
|
if t.device != nil {
|
||||||
t.device.Close()
|
t.device.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *tunDevice) routesToString() string {
|
||||||
|
return strings.Join(t.routes, ";")
|
||||||
|
}
|
||||||
|
|||||||
@@ -172,6 +172,49 @@ func (c *GrpcClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetRoutes return with the routes
|
||||||
|
func (c *GrpcClient) GetRoutes() ([]*proto.Route, error) {
|
||||||
|
serverPubKey, err := c.GetServerPublicKey()
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed getting Management Service public key: %s", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancelStream := context.WithCancel(c.ctx)
|
||||||
|
defer cancelStream()
|
||||||
|
stream, err := c.connectToStream(ctx, *serverPubKey)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to open Management Service stream: %s", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = stream.CloseSend()
|
||||||
|
}()
|
||||||
|
|
||||||
|
update, err := stream.Recv()
|
||||||
|
if err == io.EOF {
|
||||||
|
log.Debugf("Management stream has been closed by server: %s", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("disconnected from Management Service sync stream: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
decryptedResp := &proto.SyncResponse{}
|
||||||
|
err = encryption.DecryptMessage(*serverPubKey, c.key, update.Body, decryptedResp)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed decrypting update message from Management Service: %s", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if decryptedResp.GetNetworkMap() == nil {
|
||||||
|
return nil, fmt.Errorf("invalid msg, required network map")
|
||||||
|
}
|
||||||
|
|
||||||
|
return decryptedResp.GetNetworkMap().GetRoutes(), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c *GrpcClient) connectToStream(ctx context.Context, serverPubKey wgtypes.Key) (proto.ManagementService_SyncClient, error) {
|
func (c *GrpcClient) connectToStream(ctx context.Context, serverPubKey wgtypes.Key) (proto.ManagementService_SyncClient, error) {
|
||||||
req := &proto.SyncRequest{}
|
req := &proto.SyncRequest{}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user