diff --git a/client/internal/engine.go b/client/internal/engine.go index ec513391a..fbd8ee6d8 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -140,7 +140,7 @@ type Engine struct { ctx context.Context cancel context.CancelFunc - wgInterface *iface.WGIface + wgInterface iface.IWGIface wgProxyFactory *wgproxy.Factory udpMux *bind.UniversalUDPMuxDefault diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 79b3cd498..b8a85b071 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -58,6 +58,12 @@ var ( } ) +func TestMain(m *testing.M) { + _ = util.InitLog("debug", "console") + code := m.Run() + os.Exit(code) +} + func TestEngine_SSH(t *testing.T) { // todo resolve test execution on freebsd if runtime.GOOS == "windows" || runtime.GOOS == "freebsd" { @@ -74,14 +80,22 @@ func TestEngine_SSH(t *testing.T) { defer cancel() relayMgr := relayClient.NewManager(ctx, "", key.PublicKey().String()) - engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{ - WgIfaceName: "utun101", - WgAddr: "100.64.0.1/24", - WgPrivateKey: key, - WgPort: 33100, - ServerSSHAllowed: true, - }, - MobileDependency{}, peer.NewRecorder("https://mgm"), nil) + engine := NewEngine( + ctx, cancel, + &signal.MockClient{}, + &mgmt.MockClient{}, + relayMgr, + &EngineConfig{ + WgIfaceName: "utun101", + WgAddr: "100.64.0.1/24", + WgPrivateKey: key, + WgPort: 33100, + ServerSSHAllowed: true, + }, + MobileDependency{}, + peer.NewRecorder("https://mgm"), + nil, + ) engine.dnsServer = &dns.MockServer{ UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil }, @@ -211,20 +225,27 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { defer cancel() relayMgr := relayClient.NewManager(ctx, "", key.PublicKey().String()) - engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{ - WgIfaceName: "utun102", - WgAddr: "100.64.0.1/24", - WgPrivateKey: key, - WgPort: 33100, - }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil) - newNet, err := stdnet.NewNet() - if err != nil { - t.Fatal(err) - } - engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil, nil) - if err != nil { - t.Fatal(err) + engine := NewEngine( + ctx, cancel, + &signal.MockClient{}, + &mgmt.MockClient{}, + relayMgr, + &EngineConfig{ + WgIfaceName: "utun102", + WgAddr: "100.64.0.1/24", + WgPrivateKey: key, + WgPort: 33100, + }, + MobileDependency{}, + peer.NewRecorder("https://mgm"), + nil) + + wgIface := &iface.MockWGIface{ + RemovePeerFunc: func(peerKey string) error { + return nil + }, } + engine.wgInterface = wgIface engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, nil) engine.dnsServer = &dns.MockServer{ UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil }, diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 3901709ef..759b9e999 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -35,7 +35,7 @@ const ( type WgConfig struct { WgListenPort int RemoteKey string - WgInterface *iface.WGIface + WgInterface iface.IWGIface AllowedIps string PreSharedKey *wgtypes.Key } @@ -91,6 +91,7 @@ type Conn struct { statusRelay ConnStatus statusICE ConnStatus currentConnType ConnPriority + opened bool // this flag is used to prevent close in case of not opened connection workerICE *WorkerICE workerRelay *WorkerRelay @@ -167,6 +168,9 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu // be used. func (conn *Conn) Open() { conn.log.Debugf("open connection to peer") + conn.mu.Lock() + defer conn.mu.Unlock() + conn.opened = true peerState := State{ PubKey: conn.config.Key, @@ -197,6 +201,11 @@ func (conn *Conn) Close() { conn.ctxCancel() + if !conn.opened { + log.Infof("IGNORE close connection to peer") + return + } + if conn.wgProxyRelay != nil { err := conn.wgProxyRelay.CloseConn() if err != nil { @@ -302,7 +311,7 @@ func (conn *Conn) GetKey() string { } func (conn *Conn) reconnectLoop() { - ticker := time.NewTicker(conn.config.Timeout) // todo use the interval from config + ticker := time.NewTicker(conn.config.Timeout) if !conn.workerRelay.IsController() { ticker.Stop() } else { diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go index 3c230df21..57ed9ed1e 100644 --- a/client/internal/routemanager/client.go +++ b/client/internal/routemanager/client.go @@ -42,7 +42,7 @@ type clientNetwork struct { ctx context.Context cancel context.CancelFunc statusRecorder *peer.Status - wgInterface *iface.WGIface + wgInterface iface.IWGIface routes map[route.ID]*route.Route routeUpdate chan routesUpdate peerStateUpdate chan struct{} @@ -52,7 +52,7 @@ type clientNetwork struct { updateSerial uint64 } -func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration, wgInterface *iface.WGIface, statusRecorder *peer.Status, rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *clientNetwork { +func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration, wgInterface iface.IWGIface, statusRecorder *peer.Status, rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *clientNetwork { ctx, cancel := context.WithCancel(ctx) client := &clientNetwork{ diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 0673ea6c3..2e1378414 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -48,7 +48,7 @@ type DefaultManager struct { serverRouter serverRouter sysOps *systemops.SysOps statusRecorder *peer.Status - wgInterface *iface.WGIface + wgInterface iface.IWGIface pubKey string notifier *notifier routeRefCounter *refcounter.RouteRefCounter @@ -60,7 +60,7 @@ func NewManager( ctx context.Context, pubKey string, dnsRouteInterval time.Duration, - wgInterface *iface.WGIface, + wgInterface iface.IWGIface, statusRecorder *peer.Status, initialRoutes []*route.Route, ) *DefaultManager { diff --git a/client/internal/routemanager/server_nonandroid.go b/client/internal/routemanager/server_nonandroid.go index 8470934c2..43a266cd2 100644 --- a/client/internal/routemanager/server_nonandroid.go +++ b/client/internal/routemanager/server_nonandroid.go @@ -22,11 +22,11 @@ type defaultServerRouter struct { ctx context.Context routes map[route.ID]*route.Route firewall firewall.Manager - wgInterface *iface.WGIface + wgInterface iface.IWGIface statusRecorder *peer.Status } -func newServerRouter(ctx context.Context, wgInterface *iface.WGIface, firewall firewall.Manager, statusRecorder *peer.Status) (serverRouter, error) { +func newServerRouter(ctx context.Context, wgInterface iface.IWGIface, firewall firewall.Manager, statusRecorder *peer.Status) (serverRouter, error) { return &defaultServerRouter{ ctx: ctx, routes: make(map[route.ID]*route.Route), diff --git a/client/internal/routemanager/sysctl/sysctl_linux.go b/client/internal/routemanager/sysctl/sysctl_linux.go index 3f2937c89..3e176860f 100644 --- a/client/internal/routemanager/sysctl/sysctl_linux.go +++ b/client/internal/routemanager/sysctl/sysctl_linux.go @@ -22,7 +22,7 @@ const ( ) // Setup configures sysctl settings for RP filtering and source validation. -func Setup(wgIface *iface.WGIface) (map[string]int, error) { +func Setup(wgIface iface.IWGIface) (map[string]int, error) { keys := map[string]int{} var result *multierror.Error diff --git a/client/internal/routemanager/systemops/systemops.go b/client/internal/routemanager/systemops/systemops.go index 9ee51538b..fa7ab0290 100644 --- a/client/internal/routemanager/systemops/systemops.go +++ b/client/internal/routemanager/systemops/systemops.go @@ -17,10 +17,10 @@ type ExclusionCounter = refcounter.Counter[any, Nexthop] type SysOps struct { refCounter *ExclusionCounter - wgInterface *iface.WGIface + wgInterface iface.IWGIface } -func NewSysOps(wgInterface *iface.WGIface) *SysOps { +func NewSysOps(wgInterface iface.IWGIface) *SysOps { return &SysOps{ wgInterface: wgInterface, } diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go index 0d1c16ca1..0c152d233 100644 --- a/client/internal/routemanager/systemops/systemops_generic.go +++ b/client/internal/routemanager/systemops/systemops_generic.go @@ -122,7 +122,7 @@ func (r *SysOps) addRouteForCurrentDefaultGateway(prefix netip.Prefix) error { // addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface. // If the next hop or interface is pointing to the VPN interface, it will return the initial values. -func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIface, initialNextHop Nexthop) (Nexthop, error) { +func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf iface.IWGIface, initialNextHop Nexthop) (Nexthop, error) { addr := prefix.Addr() switch { case addr.IsLoopback(), diff --git a/iface/iface.go b/iface/iface.go index 928077a3d..448a47763 100644 --- a/iface/iface.go +++ b/iface/iface.go @@ -17,6 +17,26 @@ const ( DefaultWgPort = 51820 ) +type IWGIface interface { + Create() error + CreateOnAndroid(routeRange []string, ip string, domains []string) error + IsUserspaceBind() bool + Name() string + Address() WGAddress + ToInterface() *net.Interface + Up() (*bind.UniversalUDPMuxDefault, error) + UpdateAddr(newAddr string) error + UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error + RemovePeer(peerKey string) error + AddAllowedIP(peerKey string, allowedIP string) error + RemoveAllowedIP(peerKey string, allowedIP string) error + Close() error + SetFilter(filter PacketFilter) error + GetFilter() PacketFilter + GetDevice() *DeviceWrapper + GetStats(peerKey string) (WGStats, error) +} + // WGIface represents a interface instance type WGIface struct { tun wgTunDevice diff --git a/iface/iface_moc.go b/iface/iface_moc.go new file mode 100644 index 000000000..179c5fae4 --- /dev/null +++ b/iface/iface_moc.go @@ -0,0 +1,98 @@ +package iface + +import ( + "net" + "time" + + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/netbirdio/netbird/iface/bind" +) + +type MockWGIface struct { + IsUserspaceBindFunc func() bool + NameFunc func() string + AddressFunc func() WGAddress + ToInterfaceFunc func() *net.Interface + UpFunc func() (*bind.UniversalUDPMuxDefault, error) + UpdateAddrFunc func(newAddr string) error + UpdatePeerFunc func(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error + RemovePeerFunc func(peerKey string) error + AddAllowedIPFunc func(peerKey string, allowedIP string) error + RemoveAllowedIPFunc func(peerKey string, allowedIP string) error + CloseFunc func() error + SetFilterFunc func(filter PacketFilter) error + GetFilterFunc func() PacketFilter + GetDeviceFunc func() *DeviceWrapper + GetStatsFunc func(peerKey string) (WGStats, error) +} + +func (m *MockWGIface) Create() error { + //TODO implement me + panic("implement me") +} + +func (m *MockWGIface) CreateOnAndroid(routeRange []string, ip string, domains []string) error { + //TODO implement me + panic("implement me") +} + +func (m *MockWGIface) IsUserspaceBind() bool { + return m.IsUserspaceBindFunc() +} + +func (m *MockWGIface) Name() string { + return m.NameFunc() +} + +func (m *MockWGIface) Address() WGAddress { + return m.AddressFunc() +} + +func (m *MockWGIface) ToInterface() *net.Interface { + return m.ToInterfaceFunc() +} + +func (m *MockWGIface) Up() (*bind.UniversalUDPMuxDefault, error) { + return m.UpFunc() +} + +func (m *MockWGIface) UpdateAddr(newAddr string) error { + return m.UpdateAddrFunc(newAddr) +} + +func (m *MockWGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { + return m.UpdatePeerFunc(peerKey, allowedIps, keepAlive, endpoint, preSharedKey) +} + +func (m *MockWGIface) RemovePeer(peerKey string) error { + return m.RemovePeerFunc(peerKey) +} + +func (m *MockWGIface) AddAllowedIP(peerKey string, allowedIP string) error { + return m.AddAllowedIPFunc(peerKey, allowedIP) +} + +func (m *MockWGIface) RemoveAllowedIP(peerKey string, allowedIP string) error { + return m.RemoveAllowedIPFunc(peerKey, allowedIP) +} + +func (m *MockWGIface) Close() error { + return m.CloseFunc() +} + +func (m *MockWGIface) SetFilter(filter PacketFilter) error { + return m.SetFilterFunc(filter) +} + +func (m *MockWGIface) GetFilter() PacketFilter { + return m.GetFilterFunc() +} + +func (m *MockWGIface) GetDevice() *DeviceWrapper { + return m.GetDeviceFunc() +} + +func (m *MockWGIface) GetStats(peerKey string) (WGStats, error) { + return m.GetStatsFunc(peerKey) +}