From 32146e576d9e817a4fb84a87631521eceed5b503 Mon Sep 17 00:00:00 2001 From: Diego Romar Date: Fri, 21 Nov 2025 09:36:33 -0300 Subject: [PATCH] [android] allow selection/deselection of network resources on android peers (#4607) --- client/android/client.go | 115 ++++++- client/android/network_domains.go | 56 ++++ client/android/networks.go | 14 +- client/android/peer_notifier.go | 7 + client/android/peer_routes.go | 20 ++ client/android/platform_files.go | 10 + client/android/route_command.go | 67 ++++ client/iface/device/device_android.go | 52 ++- .../iface/device/device_netstack_android.go | 7 + client/iface/device/renewable_tun.go | 309 ++++++++++++++++++ client/iface/device_android.go | 1 + client/iface/iface_create.go | 4 + client/iface/iface_create_android.go | 7 + client/iface/iface_create_darwin.go | 4 + client/internal/connect.go | 2 + client/internal/engine.go | 14 +- client/internal/engine_test.go | 4 + client/internal/iface_common.go | 1 + 18 files changed, 666 insertions(+), 28 deletions(-) create mode 100644 client/android/network_domains.go create mode 100644 client/android/peer_routes.go create mode 100644 client/android/platform_files.go create mode 100644 client/android/route_command.go create mode 100644 client/iface/device/renewable_tun.go diff --git a/client/android/client.go b/client/android/client.go index 86fb1445d..2943702c6 100644 --- a/client/android/client.go +++ b/client/android/client.go @@ -4,10 +4,13 @@ package android import ( "context" + "fmt" "os" "slices" "sync" + "golang.org/x/exp/maps" + log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/iface/device" @@ -16,10 +19,13 @@ import ( "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/profilemanager" + "github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/net" "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/formatter" + "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/domain" ) // ConnectionListener export internal Listener for mobile @@ -62,17 +68,18 @@ type Client struct { deviceName string uiVersion string networkChangeListener listener.NetworkChangeListener + stateFile string connectClient *internal.ConnectClient } // NewClient instantiate a new Client -func NewClient(cfgFile string, androidSDKVersion int, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client { +func NewClient(platformFiles PlatformFiles, androidSDKVersion int, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client { execWorkaround(androidSDKVersion) net.SetAndroidProtectSocketFn(tunAdapter.ProtectSocket) return &Client{ - cfgFile: cfgFile, + cfgFile: platformFiles.ConfigurationFilePath(), deviceName: deviceName, uiVersion: uiVersion, tunAdapter: tunAdapter, @@ -80,6 +87,7 @@ func NewClient(cfgFile string, androidSDKVersion int, deviceName string, uiVersi recorder: peer.NewRecorder(""), ctxCancelLock: &sync.Mutex{}, networkChangeListener: networkChangeListener, + stateFile: platformFiles.StateFilePath(), } } @@ -115,7 +123,7 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead // todo do not throw error in case of cancelled context ctx = internal.CtxInitState(ctx) c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) - return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener) + return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, c.stateFile) } // RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot). @@ -142,7 +150,7 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener // todo do not throw error in case of cancelled context ctx = internal.CtxInitState(ctx) c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) - return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener) + return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, c.stateFile) } // Stop the internal client and free the resources @@ -156,6 +164,19 @@ func (c *Client) Stop() { c.ctxCancel() } +func (c *Client) RenewTun(fd int) error { + if c.connectClient == nil { + return fmt.Errorf("engine not running") + } + + e := c.connectClient.Engine() + if e == nil { + return fmt.Errorf("engine not initialized") + } + + return e.RenewTun(fd) +} + // SetTraceLogLevel configure the logger to trace level func (c *Client) SetTraceLogLevel() { log.SetLevel(log.TraceLevel) @@ -177,6 +198,7 @@ func (c *Client) PeersList() *PeerInfoArray { p.IP, p.FQDN, p.ConnStatus.String(), + PeerRoutes{routes: maps.Keys(p.GetRoutes())}, } peerInfos[n] = pi } @@ -201,31 +223,43 @@ func (c *Client) Networks() *NetworkArray { return nil } + routeSelector := routeManager.GetRouteSelector() + if routeSelector == nil { + log.Error("could not get route selector") + return nil + } + networkArray := &NetworkArray{ items: make([]Network, 0), } + resolvedDomains := c.recorder.GetResolvedDomainsStates() + for id, routes := range routeManager.GetClientRoutesWithNetID() { if len(routes) == 0 { continue } r := routes[0] + domains := c.getNetworkDomainsFromRoute(r, resolvedDomains) netStr := r.Network.String() + if r.IsDynamic() { netStr = r.Domains.SafeString() } - peer, err := c.recorder.GetPeer(routes[0].Peer) + routePeer, err := c.recorder.GetPeer(routes[0].Peer) if err != nil { log.Errorf("could not get peer info for %s: %v", routes[0].Peer, err) continue } network := Network{ - Name: string(id), - Network: netStr, - Peer: peer.FQDN, - Status: peer.ConnStatus.String(), + Name: string(id), + Network: netStr, + Peer: routePeer.FQDN, + Status: routePeer.ConnStatus.String(), + IsSelected: routeSelector.IsSelected(id), + Domains: domains, } networkArray.Add(network) } @@ -253,6 +287,69 @@ func (c *Client) RemoveConnectionListener() { c.recorder.RemoveConnectionListener() } +func (c *Client) toggleRoute(command routeCommand) error { + return command.toggleRoute() +} + +func (c *Client) getRouteManager() (routemanager.Manager, error) { + client := c.connectClient + if client == nil { + return nil, fmt.Errorf("not connected") + } + + engine := client.Engine() + if engine == nil { + return nil, fmt.Errorf("engine is not running") + } + + manager := engine.GetRouteManager() + if manager == nil { + return nil, fmt.Errorf("could not get route manager") + } + + return manager, nil +} + +func (c *Client) SelectRoute(route string) error { + manager, err := c.getRouteManager() + if err != nil { + return err + } + + return c.toggleRoute(selectRouteCommand{route: route, manager: manager}) +} + +func (c *Client) DeselectRoute(route string) error { + manager, err := c.getRouteManager() + if err != nil { + return err + } + + return c.toggleRoute(deselectRouteCommand{route: route, manager: manager}) +} + +// getNetworkDomainsFromRoute extracts domains from a route and enriches each domain +// with its resolved IP addresses from the provided resolvedDomains map. +func (c *Client) getNetworkDomainsFromRoute(route *route.Route, resolvedDomains map[domain.Domain]peer.ResolvedDomainInfo) NetworkDomains { + domains := NetworkDomains{} + + for _, d := range route.Domains { + networkDomain := NetworkDomain{ + Address: d.SafeString(), + } + + if info, exists := resolvedDomains[d]; exists { + for _, prefix := range info.Prefixes { + networkDomain.addResolvedIP(prefix.Addr().String()) + } + } + + domains.Add(&networkDomain) + } + + return domains +} + func exportEnvList(list *EnvList) { if list == nil { return diff --git a/client/android/network_domains.go b/client/android/network_domains.go new file mode 100644 index 000000000..a459bdc23 --- /dev/null +++ b/client/android/network_domains.go @@ -0,0 +1,56 @@ +//go:build android + +package android + +import "fmt" + +type ResolvedIPs struct { + resolvedIPs []string +} + +func (r *ResolvedIPs) Add(ipAddress string) { + r.resolvedIPs = append(r.resolvedIPs, ipAddress) +} + +func (r *ResolvedIPs) Get(i int) (string, error) { + if i < 0 || i >= len(r.resolvedIPs) { + return "", fmt.Errorf("%d is out of range", i) + } + return r.resolvedIPs[i], nil +} + +func (r *ResolvedIPs) Size() int { + return len(r.resolvedIPs) +} + +type NetworkDomain struct { + Address string + resolvedIPs ResolvedIPs +} + +func (d *NetworkDomain) addResolvedIP(resolvedIP string) { + d.resolvedIPs.Add(resolvedIP) +} + +func (d *NetworkDomain) GetResolvedIPs() *ResolvedIPs { + return &d.resolvedIPs +} + +type NetworkDomains struct { + domains []*NetworkDomain +} + +func (n *NetworkDomains) Add(domain *NetworkDomain) { + n.domains = append(n.domains, domain) +} + +func (n *NetworkDomains) Get(i int) (*NetworkDomain, error) { + if i < 0 || i >= len(n.domains) { + return nil, fmt.Errorf("%d is out of range", i) + } + return n.domains[i], nil +} + +func (n *NetworkDomains) Size() int { + return len(n.domains) +} diff --git a/client/android/networks.go b/client/android/networks.go index aa130420b..3c3a25939 100644 --- a/client/android/networks.go +++ b/client/android/networks.go @@ -3,10 +3,16 @@ package android type Network struct { - Name string - Network string - Peer string - Status string + Name string + Network string + Peer string + Status string + IsSelected bool + Domains NetworkDomains +} + +func (n Network) GetNetworkDomains() *NetworkDomains { + return &n.Domains } type NetworkArray struct { diff --git a/client/android/peer_notifier.go b/client/android/peer_notifier.go index 1f5564c72..b03947da1 100644 --- a/client/android/peer_notifier.go +++ b/client/android/peer_notifier.go @@ -1,3 +1,5 @@ +//go:build android + package android // PeerInfo describe information about the peers. It designed for the UI usage @@ -5,6 +7,11 @@ type PeerInfo struct { IP string FQDN string ConnStatus string // Todo replace to enum + Routes PeerRoutes +} + +func (p *PeerInfo) GetPeerRoutes() *PeerRoutes { + return &p.Routes } // PeerInfoArray is a wrapper of []PeerInfo diff --git a/client/android/peer_routes.go b/client/android/peer_routes.go new file mode 100644 index 000000000..bb46d609f --- /dev/null +++ b/client/android/peer_routes.go @@ -0,0 +1,20 @@ +//go:build android + +package android + +import "fmt" + +type PeerRoutes struct { + routes []string +} + +func (p *PeerRoutes) Get(i int) (string, error) { + if i < 0 || i >= len(p.routes) { + return "", fmt.Errorf("%d is out of range", i) + } + return p.routes[i], nil +} + +func (p *PeerRoutes) Size() int { + return len(p.routes) +} diff --git a/client/android/platform_files.go b/client/android/platform_files.go new file mode 100644 index 000000000..f0c369750 --- /dev/null +++ b/client/android/platform_files.go @@ -0,0 +1,10 @@ +//go:build android + +package android + +// PlatformFiles groups paths to files used internally by the engine that can't be created/modified +// at their default locations due to android OS restrictions. +type PlatformFiles interface { + ConfigurationFilePath() string + StateFilePath() string +} diff --git a/client/android/route_command.go b/client/android/route_command.go new file mode 100644 index 000000000..b47d5ca6c --- /dev/null +++ b/client/android/route_command.go @@ -0,0 +1,67 @@ +//go:build android + +package android + +import ( + "fmt" + + log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" + + "github.com/netbirdio/netbird/client/internal/routemanager" + "github.com/netbirdio/netbird/route" +) + +func executeRouteToggle(id string, manager routemanager.Manager, + operationName string, + routeOperation func(routes []route.NetID, allRoutes []route.NetID) error) error { + netID := route.NetID(id) + routes := []route.NetID{netID} + + log.Debugf("%s with id: %s", operationName, id) + + if err := routeOperation(routes, maps.Keys(manager.GetClientRoutesWithNetID())); err != nil { + log.Debugf("error when %s: %s", operationName, err) + return fmt.Errorf("error %s: %w", operationName, err) + } + + manager.TriggerSelection(manager.GetClientRoutes()) + + return nil +} + +type routeCommand interface { + toggleRoute() error +} + +type selectRouteCommand struct { + route string + manager routemanager.Manager +} + +func (s selectRouteCommand) toggleRoute() error { + routeSelector := s.manager.GetRouteSelector() + if routeSelector == nil { + return fmt.Errorf("no route selector available") + } + + routeOperation := func(routes []route.NetID, allRoutes []route.NetID) error { + return routeSelector.SelectRoutes(routes, true, allRoutes) + } + + return executeRouteToggle(s.route, s.manager, "selecting route", routeOperation) +} + +type deselectRouteCommand struct { + route string + manager routemanager.Manager +} + +func (d deselectRouteCommand) toggleRoute() error { + routeSelector := d.manager.GetRouteSelector() + if routeSelector == nil { + return fmt.Errorf("no route selector available") + } + + return executeRouteToggle(d.route, d.manager, "deselecting route", routeSelector.DeselectRoutes) +} diff --git a/client/iface/device/device_android.go b/client/iface/device/device_android.go index 48346fc0f..198343fbd 100644 --- a/client/iface/device/device_android.go +++ b/client/iface/device/device_android.go @@ -3,6 +3,7 @@ package device import ( + "fmt" "strings" log "github.com/sirupsen/logrus" @@ -19,11 +20,12 @@ import ( // WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform type WGTunDevice struct { - address wgaddr.Address - port int - key string - mtu uint16 - iceBind *bind.ICEBind + address wgaddr.Address + port int + key string + mtu uint16 + iceBind *bind.ICEBind + // todo: review if we can eliminate the TunAdapter tunAdapter TunAdapter disableDNS bool @@ -32,17 +34,19 @@ type WGTunDevice struct { filteredDevice *FilteredDevice udpMux *udpmux.UniversalUDPMuxDefault configurer WGConfigurer + renewableTun *RenewableTUN } func NewTunDevice(address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind, tunAdapter TunAdapter, disableDNS bool) *WGTunDevice { return &WGTunDevice{ - address: address, - port: port, - key: key, - mtu: mtu, - iceBind: iceBind, - tunAdapter: tunAdapter, - disableDNS: disableDNS, + address: address, + port: port, + key: key, + mtu: mtu, + iceBind: iceBind, + tunAdapter: tunAdapter, + disableDNS: disableDNS, + renewableTun: NewRenewableTUN(), } } @@ -65,14 +69,17 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string return nil, err } - tunDevice, name, err := tun.CreateUnmonitoredTUNFromFD(fd) + unmonitoredTUN, name, err := tun.CreateUnmonitoredTUNFromFD(fd) if err != nil { _ = unix.Close(fd) log.Errorf("failed to create Android interface: %s", err) return nil, err } + + t.renewableTun.AddDevice(unmonitoredTUN) + t.name = name - t.filteredDevice = newDeviceFilter(tunDevice) + t.filteredDevice = newDeviceFilter(t.renewableTun) log.Debugf("attaching to interface %v", name) t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[netbird] ")) @@ -104,6 +111,23 @@ func (t *WGTunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { return udpMux, nil } +func (t *WGTunDevice) RenewTun(fd int) error { + if t.device == nil { + return fmt.Errorf("device not initialized") + } + + unmonitoredTUN, _, err := tun.CreateUnmonitoredTUNFromFD(fd) + if err != nil { + _ = unix.Close(fd) + log.Errorf("failed to renew Android interface: %s", err) + return err + } + + t.renewableTun.AddDevice(unmonitoredTUN) + + return nil +} + func (t *WGTunDevice) UpdateAddr(addr wgaddr.Address) error { // todo implement return nil diff --git a/client/iface/device/device_netstack_android.go b/client/iface/device/device_netstack_android.go index 45ae8ba7d..f1a77d40a 100644 --- a/client/iface/device/device_netstack_android.go +++ b/client/iface/device/device_netstack_android.go @@ -2,6 +2,13 @@ package device +import "fmt" + func (t *TunNetstackDevice) Create(routes []string, dns string, searchDomains []string) (WGConfigurer, error) { return t.create() } + +func (t *TunNetstackDevice) RenewTun(fd int) error { + // Doesn't make sense in Android for Netstack. + return fmt.Errorf("this function has not been implemented in Netstack for Android") +} diff --git a/client/iface/device/renewable_tun.go b/client/iface/device/renewable_tun.go new file mode 100644 index 000000000..a501eebbb --- /dev/null +++ b/client/iface/device/renewable_tun.go @@ -0,0 +1,309 @@ +//go:build android + +package device + +import ( + "io" + "os" + "sync" + "sync/atomic" + "time" + + log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/tun" +) + +// closeAwareDevice wraps a tun.Device along with a flag +// indicating whether its Close method was called. +// +// It also redirects tun.Device's Events() to a separate goroutine +// and closes it when Close is called. +// +// The WaitGroup and CloseOnce fields are used to ensure that the +// goroutine is awaited and closed only once. +type closeAwareDevice struct { + isClosed atomic.Bool + tun.Device + closeEventCh chan struct{} + wg sync.WaitGroup + closeOnce sync.Once +} + +func newClosableDevice(tunDevice tun.Device) *closeAwareDevice { + return &closeAwareDevice{ + Device: tunDevice, + isClosed: atomic.Bool{}, + closeEventCh: make(chan struct{}), + } +} + +// redirectEvents redirects the Events() method of the underlying tun.Device +// to the given channel (RenewableTUN's events channel). +func (c *closeAwareDevice) redirectEvents(out chan tun.Event) { + c.wg.Add(1) + go func() { + defer c.wg.Done() + for { + select { + case ev, ok := <-c.Device.Events(): + if !ok { + return + } + + if ev == tun.EventDown { + continue + } + + select { + case out <- ev: + case <-c.closeEventCh: + return + } + case <-c.closeEventCh: + return + } + } + }() +} + +// Close calls the underlying Device's Close method +// after setting isClosed to true. +func (c *closeAwareDevice) Close() (err error) { + c.closeOnce.Do(func() { + c.isClosed.Store(true) + close(c.closeEventCh) + err = c.Device.Close() + c.wg.Wait() + }) + + return err +} + +func (c *closeAwareDevice) IsClosed() bool { + return c.isClosed.Load() +} + +type RenewableTUN struct { + devices []*closeAwareDevice + mu sync.Mutex + cond *sync.Cond + events chan tun.Event + closed atomic.Bool +} + +func NewRenewableTUN() *RenewableTUN { + r := &RenewableTUN{ + devices: make([]*closeAwareDevice, 0), + mu: sync.Mutex{}, + events: make(chan tun.Event, 16), + } + r.cond = sync.NewCond(&r.mu) + return r +} + +func (r *RenewableTUN) File() *os.File { + for { + dev := r.peekLast() + if dev == nil { + if !r.waitForDevice() { + return nil + } + continue + } + + file := dev.File() + + if dev.IsClosed() { + time.Sleep(1 * time.Millisecond) + continue + } + + return file + } +} + +// Read reads from an underlying tun.Device kept in the r.devices slice. +// If no device is available, it waits for one to be added via AddDevice(). +// +// On error, it retries reading from the newest device instead of returning the error +// if the device is closed; if not, it propagates the error. +func (r *RenewableTUN) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) { + for { + dev := r.peekLast() + if dev == nil { + // wait until AddDevice() signals a new device via cond.Broadcast() + if !r.waitForDevice() { // returns false if the renewable TUN itself is closed + return 0, io.EOF + } + continue + } + + n, err = dev.Read(bufs, sizes, offset) + if err == nil { + return n, nil + } + + // swap in progress; retry on the newest instead of returning the error + if dev.IsClosed() { + time.Sleep(1 * time.Millisecond) + continue + } + return n, err // propagate non-swap error + } +} + +// Write writes to underlying tun.Device kept in the r.devices slice. +// If no device is available, it waits for one to be added via AddDevice(). +// +// On error, it retries writing to the newest device instead of returning the error +// if the device is closed; if not, it propagates the error. +func (r *RenewableTUN) Write(bufs [][]byte, offset int) (int, error) { + for { + dev := r.peekLast() + if dev == nil { + if !r.waitForDevice() { + return 0, io.EOF + } + continue + } + + n, err := dev.Write(bufs, offset) + if err == nil { + return n, nil + } + + if dev.IsClosed() { + time.Sleep(1 * time.Millisecond) + continue + } + + return n, err + } +} + +func (r *RenewableTUN) MTU() (int, error) { + for { + dev := r.peekLast() + if dev == nil { + if !r.waitForDevice() { + return 0, io.EOF + } + continue + } + mtu, err := dev.MTU() + if err == nil { + return mtu, nil + } + if dev.IsClosed() { + time.Sleep(1 * time.Millisecond) + continue + } + return 0, err + } +} + +func (r *RenewableTUN) Name() (string, error) { + for { + dev := r.peekLast() + if dev == nil { + if !r.waitForDevice() { + return "", io.EOF + } + continue + } + name, err := dev.Name() + if err == nil { + return name, nil + } + if dev.IsClosed() { + time.Sleep(1 * time.Millisecond) + continue + } + return "", err + } +} + +// Events returns a channel that is fed events from the underlying tun.Device's events channel +// once it is added. +func (r *RenewableTUN) Events() <-chan tun.Event { + return r.events +} + +func (r *RenewableTUN) Close() error { + // Attempts to set the RenewableTUN closed flag to true. + // If it's already true, returns immediately. + if !r.closed.CompareAndSwap(false, true) { + return nil // already closed: idempotent + } + r.mu.Lock() + devices := r.devices + r.devices = nil + r.cond.Broadcast() + r.mu.Unlock() + + var lastErr error + + log.Debugf("closing %d devices", len(devices)) + for _, device := range devices { + if err := device.Close(); err != nil { + log.Debugf("error closing a device: %v", err) + lastErr = err + } + } + + close(r.events) + return lastErr +} + +func (r *RenewableTUN) BatchSize() int { + return 1 +} + +func (r *RenewableTUN) AddDevice(device tun.Device) { + r.mu.Lock() + if r.closed.Load() { + r.mu.Unlock() + _ = device.Close() + return + } + + var toClose *closeAwareDevice + if len(r.devices) > 0 { + toClose = r.devices[len(r.devices)-1] + } + + cad := newClosableDevice(device) + cad.redirectEvents(r.events) + + r.devices = []*closeAwareDevice{cad} + r.cond.Broadcast() + + r.mu.Unlock() + + if toClose != nil { + if err := toClose.Close(); err != nil { + log.Debugf("error closing last device: %v", err) + } + } +} + +func (r *RenewableTUN) waitForDevice() bool { + r.mu.Lock() + defer r.mu.Unlock() + + for len(r.devices) == 0 && !r.closed.Load() { + r.cond.Wait() + } + return !r.closed.Load() +} + +func (r *RenewableTUN) peekLast() *closeAwareDevice { + r.mu.Lock() + defer r.mu.Unlock() + + if len(r.devices) == 0 { + return nil + } + + return r.devices[len(r.devices)-1] +} diff --git a/client/iface/device_android.go b/client/iface/device_android.go index cdfcea48d..3899bf426 100644 --- a/client/iface/device_android.go +++ b/client/iface/device_android.go @@ -21,5 +21,6 @@ type WGTunDevice interface { FilteredDevice() *device.FilteredDevice Device() *wgdevice.Device GetNet() *netstack.Net + RenewTun(fd int) error GetICEBind() device.EndpointManager } diff --git a/client/iface/iface_create.go b/client/iface/iface_create.go index 5e17c6d41..13ae9393c 100644 --- a/client/iface/iface_create.go +++ b/client/iface/iface_create.go @@ -24,3 +24,7 @@ func (w *WGIface) Create() error { func (w *WGIface) CreateOnAndroid([]string, string, []string) error { return fmt.Errorf("this function has not implemented on non mobile") } + +func (w *WGIface) RenewTun(fd int) error { + return fmt.Errorf("this function has not been implemented on non-android") +} diff --git a/client/iface/iface_create_android.go b/client/iface/iface_create_android.go index 373a9c95a..d2d9eb70e 100644 --- a/client/iface/iface_create_android.go +++ b/client/iface/iface_create_android.go @@ -6,6 +6,7 @@ import ( // CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up. // Will reuse an existing one. +// todo: review does this function really necessary or can we merge it with iOS func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []string) error { w.mu.Lock() defer w.mu.Unlock() @@ -22,3 +23,9 @@ func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []s func (w *WGIface) Create() error { return fmt.Errorf("this function has not implemented on this platform") } + +func (w *WGIface) RenewTun(fd int) error { + w.mu.Lock() + defer w.mu.Unlock() + return w.tun.RenewTun(fd) +} diff --git a/client/iface/iface_create_darwin.go b/client/iface/iface_create_darwin.go index 1d91bce54..0b7cd36ef 100644 --- a/client/iface/iface_create_darwin.go +++ b/client/iface/iface_create_darwin.go @@ -39,3 +39,7 @@ func (w *WGIface) Create() error { func (w *WGIface) CreateOnAndroid([]string, string, []string) error { return fmt.Errorf("this function has not implemented on this platform") } + +func (w *WGIface) RenewTun(fd int) error { + return fmt.Errorf("this function has not been implemented on this platform") +} diff --git a/client/internal/connect.go b/client/internal/connect.go index 6ad5f264b..5a5f4f63c 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -74,6 +74,7 @@ func (c *ConnectClient) RunOnAndroid( networkChangeListener listener.NetworkChangeListener, dnsAddresses []netip.AddrPort, dnsReadyListener dns.ReadyListener, + stateFilePath string, ) error { // in case of non Android os these variables will be nil mobileDependency := MobileDependency{ @@ -82,6 +83,7 @@ func (c *ConnectClient) RunOnAndroid( NetworkChangeListener: networkChangeListener, HostDNSAddresses: dnsAddresses, DnsReadyListener: dnsReadyListener, + StateFilePath: stateFilePath, } return c.run(mobileDependency, nil) } diff --git a/client/internal/engine.go b/client/internal/engine.go index 94c948398..76265dd77 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -255,7 +255,7 @@ func NewEngine( sm := profilemanager.NewServiceManager("") path := sm.GetStatePath() - if runtime.GOOS == "ios" { + if runtime.GOOS == "ios" || runtime.GOOS == "android" { if !fileExists(mobileDep.StateFilePath) { err := createFile(mobileDep.StateFilePath) if err != nil { @@ -1831,6 +1831,18 @@ func (e *Engine) GetWgAddr() netip.Addr { return e.wgInterface.Address().IP } +func (e *Engine) RenewTun(fd int) error { + e.syncMsgMux.Lock() + wgInterface := e.wgInterface + e.syncMsgMux.Unlock() + + if wgInterface == nil { + return fmt.Errorf("wireguard interface not initialized") + } + + return wgInterface.RenewTun(fd) +} + // updateDNSForwarder start or stop the DNS forwarder based on the domains and the feature flag func (e *Engine) updateDNSForwarder( enabled bool, diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index d6ab7391e..9252ce13e 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -110,6 +110,10 @@ type MockWGIface struct { LastActivitiesFunc func() map[string]monotime.Time } +func (m *MockWGIface) RenewTun(_ int) error { + return nil +} + func (m *MockWGIface) RemoveEndpointAddress(_ string) error { return nil } diff --git a/client/internal/iface_common.go b/client/internal/iface_common.go index 98fe01912..90b06cbd1 100644 --- a/client/internal/iface_common.go +++ b/client/internal/iface_common.go @@ -20,6 +20,7 @@ import ( type wgIfaceBase interface { Create() error CreateOnAndroid(routeRange []string, ip string, domains []string) error + RenewTun(fd int) error IsUserspaceBind() bool Name() string Address() wgaddr.Address