diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index 13228250d..50cb4e2af 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -33,6 +33,10 @@ jobs: steps: - name: Checkout code uses: actions/checkout@v3 + - name: Check for duplicate constants + if: matrix.os == 'ubuntu-latest' + run: | + ! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep . - name: Install Go uses: actions/setup-go@v4 with: diff --git a/client/internal/engine.go b/client/internal/engine.go index 3504554a2..1cce4c7a8 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -833,6 +833,7 @@ func (e *Engine) updateOfflinePeers(offlinePeers []*mgmProto.RemotePeerConfig) { FQDN: offlinePeer.GetFqdn(), ConnStatus: peer.StatusDisconnected, ConnStatusUpdate: time.Now(), + Mux: new(sync.RWMutex), } } e.statusRecorder.ReplaceOfflinePeers(replacement) diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index f3d07dcad..9e7ee6959 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -229,7 +229,6 @@ func (conn *Conn) reCreateAgent() error { } conn.agent, err = ice.NewAgent(agentConfig) - if err != nil { return err } @@ -285,6 +284,7 @@ func (conn *Conn) Open() error { IP: strings.Split(conn.config.WgConfig.AllowedIps, "/")[0], ConnStatusUpdate: time.Now(), ConnStatus: conn.status, + Mux: new(sync.RWMutex), } err := conn.statusRecorder.UpdatePeerState(peerState) if err != nil { @@ -344,6 +344,7 @@ func (conn *Conn) Open() error { PubKey: conn.config.Key, ConnStatus: conn.status, ConnStatusUpdate: time.Now(), + Mux: new(sync.RWMutex), } err = conn.statusRecorder.UpdatePeerState(peerState) if err != nil { @@ -468,6 +469,7 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Local.Port()), Direct: !isRelayCandidate(pair.Local), RosenpassEnabled: rosenpassEnabled, + Mux: new(sync.RWMutex), } if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay { peerState.Relayed = true @@ -558,6 +560,7 @@ func (conn *Conn) cleanup() error { PubKey: conn.config.Key, ConnStatus: conn.status, ConnStatusUpdate: time.Now(), + Mux: new(sync.RWMutex), } err := conn.statusRecorder.UpdatePeerState(peerState) if err != nil { diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index ca97c3ea4..ddea7d04e 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -14,6 +14,7 @@ import ( // State contains the latest state of a peer type State struct { + Mux *sync.RWMutex IP string PubKey string FQDN string @@ -30,7 +31,38 @@ type State struct { BytesRx int64 Latency time.Duration RosenpassEnabled bool - Routes map[string]struct{} + routes map[string]struct{} +} + +// AddRoute add a single route to routes map +func (s *State) AddRoute(network string) { + s.Mux.Lock() + if s.routes == nil { + s.routes = make(map[string]struct{}) + } + s.routes[network] = struct{}{} + s.Mux.Unlock() +} + +// SetRoutes set state routes +func (s *State) SetRoutes(routes map[string]struct{}) { + s.Mux.Lock() + s.routes = routes + s.Mux.Unlock() +} + +// DeleteRoute removes a route from the network amp +func (s *State) DeleteRoute(network string) { + s.Mux.Lock() + delete(s.routes, network) + s.Mux.Unlock() +} + +// GetRoutes return routes map +func (s *State) GetRoutes() map[string]struct{} { + s.Mux.RLock() + defer s.Mux.RUnlock() + return s.routes } // LocalPeerState contains the latest state of the local peer @@ -143,6 +175,7 @@ func (d *Status) AddPeer(peerPubKey string, fqdn string) error { PubKey: peerPubKey, ConnStatus: StatusDisconnected, FQDN: fqdn, + Mux: new(sync.RWMutex), } d.peerListChangedForNotification = true return nil @@ -189,8 +222,8 @@ func (d *Status) UpdatePeerState(receivedState State) error { peerState.IP = receivedState.IP } - if receivedState.Routes != nil { - peerState.Routes = receivedState.Routes + if receivedState.GetRoutes() != nil { + peerState.SetRoutes(receivedState.GetRoutes()) } skipNotification := shouldSkipNotify(receivedState, peerState) @@ -440,7 +473,6 @@ func (d *Status) IsLoginRequired() bool { s, ok := gstatus.FromError(d.managementError) if ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) { return true - } return false } diff --git a/client/internal/peer/status_test.go b/client/internal/peer/status_test.go index 9038371bd..a4a6e6081 100644 --- a/client/internal/peer/status_test.go +++ b/client/internal/peer/status_test.go @@ -3,6 +3,7 @@ package peer import ( "errors" "testing" + "sync" "github.com/stretchr/testify/assert" ) @@ -42,6 +43,7 @@ func TestUpdatePeerState(t *testing.T) { status := NewRecorder("https://mgm") peerState := State{ PubKey: key, + Mux: new(sync.RWMutex), } status.peers[key] = peerState @@ -62,6 +64,7 @@ func TestStatus_UpdatePeerFQDN(t *testing.T) { status := NewRecorder("https://mgm") peerState := State{ PubKey: key, + Mux: new(sync.RWMutex), } status.peers[key] = peerState @@ -80,6 +83,7 @@ func TestGetPeerStateChangeNotifierLogic(t *testing.T) { status := NewRecorder("https://mgm") peerState := State{ PubKey: key, + Mux: new(sync.RWMutex), } status.peers[key] = peerState @@ -104,6 +108,7 @@ func TestRemovePeer(t *testing.T) { status := NewRecorder("https://mgm") peerState := State{ PubKey: key, + Mux: new(sync.RWMutex), } status.peers[key] = peerState diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go index 370ad5cf4..d41ed422b 100644 --- a/client/internal/routemanager/client.go +++ b/client/internal/routemanager/client.go @@ -196,7 +196,7 @@ func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error { return fmt.Errorf("get peer state: %v", err) } - delete(state.Routes, c.network.String()) + state.DeleteRoute(c.network.String()) if err := c.statusRecorder.UpdatePeerState(state); err != nil { log.Warnf("Failed to update peer state: %v", err) } @@ -268,10 +268,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { if err != nil { log.Errorf("Failed to get peer state: %v", err) } else { - if state.Routes == nil { - state.Routes = map[string]struct{}{} - } - state.Routes[c.network.String()] = struct{}{} + state.AddRoute(c.network.String()) if err := c.statusRecorder.UpdatePeerState(state); err != nil { log.Warnf("Failed to update peer state: %v", err) } diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 36a37f02c..0dfc0f7e0 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -16,6 +16,7 @@ import ( "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/route" + nbnet "github.com/netbirdio/netbird/util/net" "github.com/netbirdio/netbird/version" ) @@ -68,6 +69,10 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, // Init sets up the routing func (m *DefaultManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { + if nbnet.CustomRoutingDisabled() { + return nil, nil, nil + } + if err := cleanupRouting(); err != nil { log.Warnf("Failed cleaning up routing: %v", err) } @@ -99,11 +104,15 @@ func (m *DefaultManager) Stop() { if m.serverRouter != nil { m.serverRouter.cleanUp() } - if err := cleanupRouting(); err != nil { - log.Errorf("Error cleaning up routing: %v", err) - } else { - log.Info("Routing cleanup complete") + + if !nbnet.CustomRoutingDisabled() { + if err := cleanupRouting(); err != nil { + log.Errorf("Error cleaning up routing: %v", err) + } else { + log.Info("Routing cleanup complete") + } } + m.ctx = nil } @@ -210,9 +219,11 @@ func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Rou } func isPrefixSupported(prefix netip.Prefix) bool { - switch runtime.GOOS { - case "linux", "windows", "darwin": - return true + if !nbnet.CustomRoutingDisabled() { + switch runtime.GOOS { + case "linux", "windows", "darwin": + return true + } } // If prefix is too small, lets assume it is a possible default prefix which is not yet supported diff --git a/client/internal/routemanager/systemops_linux.go b/client/internal/routemanager/systemops_linux.go index dd00626e1..d1302b39c 100644 --- a/client/internal/routemanager/systemops_linux.go +++ b/client/internal/routemanager/systemops_linux.go @@ -4,14 +4,14 @@ package routemanager import ( "bufio" - "context" "errors" "fmt" "net" "net/netip" "os" + "strconv" + "strings" "syscall" - "time" "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" @@ -32,19 +32,31 @@ const ( rtTablesPath = "/etc/iproute2/rt_tables" // ipv4ForwardingPath is the path to the file containing the IP forwarding setting. - ipv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward" + ipv4ForwardingPath = "net.ipv4.ip_forward" + + rpFilterPath = "net.ipv4.conf.all.rp_filter" + rpFilterInterfacePath = "net.ipv4.conf.%s.rp_filter" + srcValidMarkPath = "net.ipv4.conf.all.src_valid_mark" ) var ErrTableIDExists = errors.New("ID exists with different name") var routeManager = &RouteManager{} -var isLegacy = os.Getenv("NB_USE_LEGACY_ROUTING") == "true" + +// originalSysctl stores the original sysctl values before they are modified +var originalSysctl map[string]int + +// determines whether to use the legacy routing setup +var isLegacy = os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() + +// sysctlFailed is used as an indicator to emit a warning when default routes are configured +var sysctlFailed bool type ruleParams struct { + priority int fwmark int tableID int family int - priority int invert bool suppressPrefix int description string @@ -52,10 +64,10 @@ type ruleParams struct { func getSetupRules() []ruleParams { return []ruleParams{ - {nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V4, -1, true, -1, "rule v4 netbird"}, - {nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V6, -1, true, -1, "rule v6 netbird"}, - {-1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, -1, false, 0, "rule with suppress prefixlen v4"}, - {-1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, -1, false, 0, "rule with suppress prefixlen v6"}, + {100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, false, 0, "rule with suppress prefixlen v4"}, + {100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, false, 0, "rule with suppress prefixlen v6"}, + {110, nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V4, true, -1, "rule v4 netbird"}, + {110, nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V6, true, -1, "rule v6 netbird"}, } } @@ -69,8 +81,6 @@ func getSetupRules() []ruleParams { // Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table. // This table is where a default route or other specific routes received from the management server are configured, // enabling VPN connectivity. -// -// The rules are inserted in reverse order, as rules are added from the bottom up in the rule list. func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) { if isLegacy { log.Infof("Using legacy routing setup") @@ -81,6 +91,13 @@ func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.Before log.Errorf("Error adding routing table name: %v", err) } + originalValues, err := setupSysctl(wgIface) + if err != nil { + log.Errorf("Error setting up sysctl: %v", err) + sysctlFailed = true + } + originalSysctl = originalValues + defer func() { if err != nil { if cleanErr := cleanupRouting(); cleanErr != nil { @@ -123,11 +140,17 @@ func cleanupRouting() error { rules := getSetupRules() for _, rule := range rules { - if err := removeAllRules(rule); err != nil && !errors.Is(err, syscall.EOPNOTSUPP) { + if err := removeRule(rule); err != nil { result = multierror.Append(result, fmt.Errorf("%s: %w", rule.description, err)) } } + if err := cleanupSysctl(originalSysctl); err != nil { + result = multierror.Append(result, fmt.Errorf("cleanup sysctl: %w", err)) + } + originalSysctl = nil + sysctlFailed = false + return result.ErrorOrNil() } @@ -144,6 +167,10 @@ func addVPNRoute(prefix netip.Prefix, intf string) error { return genericAddVPNRoute(prefix, intf) } + if sysctlFailed && (prefix == defaultv4 || prefix == defaultv6) { + log.Warnf("Default route is configured but sysctl operations failed, VPN traffic may not be routed correctly, consider using NB_USE_LEGACY_ROUTING=true or setting net.ipv4.conf.*.rp_filter to 2 (loose) or 0 (off)") + } + // No need to check if routes exist as main table takes precedence over the VPN table via Rule 1 // TODO remove this once we have ipv6 support @@ -336,22 +363,8 @@ func flushRoutes(tableID, family int) error { } func enableIPForwarding() error { - bytes, err := os.ReadFile(ipv4ForwardingPath) - if err != nil { - return fmt.Errorf("read file %s: %w", ipv4ForwardingPath, err) - } - - // check if it is already enabled - // see more: https://github.com/netbirdio/netbird/issues/872 - if len(bytes) > 0 && bytes[0] == 49 { - return nil - } - - //nolint:gosec - if err := os.WriteFile(ipv4ForwardingPath, []byte("1"), 0644); err != nil { - return fmt.Errorf("write file %s: %w", ipv4ForwardingPath, err) - } - return nil + _, err := setSysctl(ipv4ForwardingPath, 1, false) + return err } // entryExists checks if the specified ID or name already exists in the rt_tables file @@ -429,7 +442,7 @@ func addRule(params ruleParams) error { rule.Invert = params.invert rule.SuppressPrefixlen = params.suppressPrefix - if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { + if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) { return fmt.Errorf("add routing rule: %w", err) } @@ -446,43 +459,13 @@ func removeRule(params ruleParams) error { rule.Priority = params.priority rule.SuppressPrefixlen = params.suppressPrefix - if err := netlink.RuleDel(rule); err != nil { + if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.ENOENT) && !errors.Is(err, syscall.EAFNOSUPPORT) { return fmt.Errorf("remove routing rule: %w", err) } return nil } -func removeAllRules(params ruleParams) error { - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - - done := make(chan error, 1) - go func() { - for { - if ctx.Err() != nil { - done <- ctx.Err() - return - } - if err := removeRule(params); err != nil { - if errors.Is(err, syscall.ENOENT) || errors.Is(err, syscall.EAFNOSUPPORT) { - done <- nil - return - } - done <- err - return - } - } - }() - - select { - case <-ctx.Done(): - return ctx.Err() - case err := <-done: - return err - } -} - // addNextHop adds the gateway and device to the route. func addNextHop(addr netip.Addr, intf string, route *netlink.Route) error { if addr.IsValid() { @@ -509,3 +492,83 @@ func getAddressFamily(prefix netip.Prefix) int { } return netlink.FAMILY_V6 } + +// setupSysctl configures sysctl settings for RP filtering and source validation. +func setupSysctl(wgIface *iface.WGIface) (map[string]int, error) { + keys := map[string]int{} + var result *multierror.Error + + oldVal, err := setSysctl(srcValidMarkPath, 1, false) + if err != nil { + result = multierror.Append(result, err) + } else { + keys[srcValidMarkPath] = oldVal + } + + oldVal, err = setSysctl(rpFilterPath, 2, true) + if err != nil { + result = multierror.Append(result, err) + } else { + keys[rpFilterPath] = oldVal + } + + interfaces, err := net.Interfaces() + if err != nil { + result = multierror.Append(result, fmt.Errorf("list interfaces: %w", err)) + } + + for _, intf := range interfaces { + if intf.Name == "lo" || wgIface != nil && intf.Name == wgIface.Name() { + continue + } + + i := fmt.Sprintf(rpFilterInterfacePath, intf.Name) + oldVal, err := setSysctl(i, 2, true) + if err != nil { + result = multierror.Append(result, err) + } else { + keys[i] = oldVal + } + } + + return keys, result.ErrorOrNil() +} + +// setSysctl sets a sysctl configuration, if onlyIfOne is true it will only set the new value if it's set to 1 +func setSysctl(key string, desiredValue int, onlyIfOne bool) (int, error) { + path := fmt.Sprintf("/proc/sys/%s", strings.ReplaceAll(key, ".", "/")) + currentValue, err := os.ReadFile(path) + if err != nil { + return -1, fmt.Errorf("read sysctl %s: %w", key, err) + } + + currentV, err := strconv.Atoi(strings.TrimSpace(string(currentValue))) + if err != nil && len(currentValue) > 0 { + return -1, fmt.Errorf("convert current desiredValue to int: %w", err) + } + + if currentV == desiredValue || onlyIfOne && currentV != 1 { + return currentV, nil + } + + //nolint:gosec + if err := os.WriteFile(path, []byte(strconv.Itoa(desiredValue)), 0644); err != nil { + return currentV, fmt.Errorf("write sysctl %s: %w", key, err) + } + log.Debugf("Set sysctl %s from %d to %d", key, currentV, desiredValue) + + return currentV, nil +} + +func cleanupSysctl(originalSettings map[string]int) error { + var result *multierror.Error + + for key, value := range originalSettings { + _, err := setSysctl(key, value, false) + if err != nil { + result = multierror.Append(result, err) + } + } + + return result.ErrorOrNil() +} diff --git a/client/internal/routemanager/systemops_test.go b/client/internal/routemanager/systemops_test.go index 97386f19a..9f906c06f 100644 --- a/client/internal/routemanager/systemops_test.go +++ b/client/internal/routemanager/systemops_test.go @@ -61,7 +61,7 @@ func TestAddRemoveRoutes(t *testing.T) { err = wgInterface.Create() require.NoError(t, err, "should create testing wireguard interface") - _, _, err = setupRouting(nil, nil) + _, _, err = setupRouting(nil, wgInterface) require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, cleanupRouting()) diff --git a/client/internal/routemanager/systemops_windows.go b/client/internal/routemanager/systemops_windows.go index 334ace453..ba211082f 100644 --- a/client/internal/routemanager/systemops_windows.go +++ b/client/internal/routemanager/systemops_windows.go @@ -73,7 +73,7 @@ func addRoutePowershell(prefix netip.Prefix, nexthop netip.Addr, intf, intfIdx s } script := fmt.Sprintf( - `%s -AddressFamily "%s" -DestinationPrefix "%s" -Confirm:$False -ErrorAction Stop`, + `%s -AddressFamily "%s" -DestinationPrefix "%s" -Confirm:$False -ErrorAction Stop -PolicyStore ActiveStore`, psCmd, addressFamily, destinationPrefix, ) diff --git a/client/internal/wgproxy/proxy_ebpf.go b/client/internal/wgproxy/proxy_ebpf.go index 2235c5d2b..22d327376 100644 --- a/client/internal/wgproxy/proxy_ebpf.go +++ b/client/internal/wgproxy/proxy_ebpf.go @@ -230,7 +230,7 @@ func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) { } // Set the fwmark on the socket. - err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, nbnet.NetbirdFwmark) + err = nbnet.SetSocketOpt(fd) if err != nil { return nil, fmt.Errorf("setting fwmark failed: %w", err) } diff --git a/client/server/server.go b/client/server/server.go index d1d9dbda4..d33bb5155 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -718,7 +718,7 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus { BytesRx: peerState.BytesRx, BytesTx: peerState.BytesTx, RosenpassEnabled: peerState.RosenpassEnabled, - Routes: maps.Keys(peerState.Routes), + Routes: maps.Keys(peerState.GetRoutes()), Latency: durationpb.New(peerState.Latency), } pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState) diff --git a/go.mod b/go.mod index bd56d7d71..48c696bba 100644 --- a/go.mod +++ b/go.mod @@ -60,7 +60,7 @@ require ( github.com/miekg/dns v1.1.43 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 - github.com/netbirdio/management-integrations/integrations v0.0.0-20240326083846-3682438fca98 + github.com/netbirdio/management-integrations/integrations v0.0.0-20240415094251-369eb33c9b01 github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 github.com/patrickmn/go-cache v2.1.0+incompatible diff --git a/go.sum b/go.sum index 62652dac5..53615e570 100644 --- a/go.sum +++ b/go.sum @@ -387,8 +387,8 @@ github.com/nadoo/ipset v0.5.0 h1:5GJUAuZ7ITQQQGne5J96AmFjRtI8Avlbk6CabzYWVUc= github.com/nadoo/ipset v0.5.0/go.mod h1:rYF5DQLRGGoQ8ZSWeK+6eX5amAuPqwFkWjhQlEITGJQ= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q= -github.com/netbirdio/management-integrations/integrations v0.0.0-20240326083846-3682438fca98 h1:i6AtenTLu/CqhTmj0g1K/GWkkpMJMhQM6Vjs46x25nA= -github.com/netbirdio/management-integrations/integrations v0.0.0-20240326083846-3682438fca98/go.mod h1:kxks50DrZnhW+oRTdHOkVOJbcTcyo766am8RBugo+Yc= +github.com/netbirdio/management-integrations/integrations v0.0.0-20240415094251-369eb33c9b01 h1:Fu9fq0ndfKVuFTEwbc8Etqui10BOkcMTv0UqcMy0RuY= +github.com/netbirdio/management-integrations/integrations v0.0.0-20240415094251-369eb33c9b01/go.mod h1:kxks50DrZnhW+oRTdHOkVOJbcTcyo766am8RBugo+Yc= github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g= github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 h1:xbWM9BU6mwZZLHxEjxIX/V8Hv3HurQt4mReIE4mY4DM= diff --git a/iface/wg_configurer_kernel.go b/iface/wg_configurer_kernel.go index 9fe987cee..67bfb716d 100644 --- a/iface/wg_configurer_kernel.go +++ b/iface/wg_configurer_kernel.go @@ -10,8 +10,6 @@ import ( log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - - nbnet "github.com/netbirdio/netbird/util/net" ) type wgKernelConfigurer struct { @@ -31,7 +29,7 @@ func (c *wgKernelConfigurer) configureInterface(privateKey string, port int) err if err != nil { return err } - fwmark := nbnet.NetbirdFwmark + fwmark := getFwmark() config := wgtypes.Config{ PrivateKey: &key, ReplacePeers: true, diff --git a/iface/wg_configurer_usp.go b/iface/wg_configurer_usp.go index 24dfadf14..c15bc1448 100644 --- a/iface/wg_configurer_usp.go +++ b/iface/wg_configurer_usp.go @@ -349,7 +349,7 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string { } func getFwmark() int { - if runtime.GOOS == "linux" { + if runtime.GOOS == "linux" && !nbnet.CustomRoutingDisabled() { return nbnet.NetbirdFwmark } return 0 diff --git a/management/cmd/management.go b/management/cmd/management.go index 23d9c195c..366935802 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -251,7 +251,7 @@ var ( ctx, cancel := context.WithCancel(cmd.Context()) defer cancel() - httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg) + httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg, integratedPeerValidator) if err != nil { return fmt.Errorf("failed creating HTTP API handler: %v", err) } diff --git a/management/server/account.go b/management/server/account.go index 806d09106..e6e77a58b 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -243,19 +243,19 @@ type UserPermissions struct { } type UserInfo struct { - ID string `json:"id"` - Email string `json:"email"` - Name string `json:"name"` - Role string `json:"role"` - AutoGroups []string `json:"auto_groups"` - Status string `json:"-"` - IsServiceUser bool `json:"is_service_user"` - IsBlocked bool `json:"is_blocked"` - NonDeletable bool `json:"non_deletable"` - LastLogin time.Time `json:"last_login"` - Issued string `json:"issued"` + ID string `json:"id"` + Email string `json:"email"` + Name string `json:"name"` + Role string `json:"role"` + AutoGroups []string `json:"auto_groups"` + Status string `json:"-"` + IsServiceUser bool `json:"is_service_user"` + IsBlocked bool `json:"is_blocked"` + NonDeletable bool `json:"non_deletable"` + LastLogin time.Time `json:"last_login"` + Issued string `json:"issued"` IntegrationReference integration_reference.IntegrationReference `json:"-"` - Permissions UserPermissions `json:"permissions"` + Permissions UserPermissions `json:"permissions"` } // getRoutesToSync returns the enabled routes for the peer ID and the routes @@ -1121,7 +1121,7 @@ func (am *DefaultAccountManager) DeleteAccount(accountID, userID string) error { return status.Errorf(status.PermissionDenied, "user is not allowed to delete account") } - if user.Id != account.CreatedBy { + if user.Role != UserRoleOwner { return status.Errorf(status.PermissionDenied, "user is not allowed to delete account. Only account owner can delete account") } for _, otherUser := range account.Users { @@ -1850,6 +1850,7 @@ func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(claims jwtclaims.Aut } func (am *DefaultAccountManager) onPeersInvalidated(accountID string) { + log.Debugf("validated peers has been invalidated for account %s", accountID) updatedAccount, err := am.Store.GetAccount(accountID) if err != nil { log.Errorf("failed to get account %s: %v", accountID, err) diff --git a/management/server/activity/codes.go b/management/server/activity/codes.go index e179fd14d..4ee57f181 100644 --- a/management/server/activity/codes.go +++ b/management/server/activity/codes.go @@ -11,133 +11,134 @@ type Code struct { Code string } +// Existing consts must not be changed, as this will break the compatibility with the existing data const ( // PeerAddedByUser indicates that a user added a new peer to the system - PeerAddedByUser Activity = iota + PeerAddedByUser Activity = 0 // PeerAddedWithSetupKey indicates that a new peer joined the system using a setup key - PeerAddedWithSetupKey + PeerAddedWithSetupKey Activity = 1 // UserJoined indicates that a new user joined the account - UserJoined + UserJoined Activity = 2 // UserInvited indicates that a new user was invited to join the account - UserInvited + UserInvited Activity = 3 // AccountCreated indicates that a new account has been created - AccountCreated + AccountCreated Activity = 4 // PeerRemovedByUser indicates that a user removed a peer from the system - PeerRemovedByUser + PeerRemovedByUser Activity = 5 // RuleAdded indicates that a user added a new rule - RuleAdded + RuleAdded Activity = 6 // RuleUpdated indicates that a user updated a rule - RuleUpdated + RuleUpdated Activity = 7 // RuleRemoved indicates that a user removed a rule - RuleRemoved + RuleRemoved Activity = 8 // PolicyAdded indicates that a user added a new policy - PolicyAdded + PolicyAdded Activity = 9 // PolicyUpdated indicates that a user updated a policy - PolicyUpdated + PolicyUpdated Activity = 10 // PolicyRemoved indicates that a user removed a policy - PolicyRemoved + PolicyRemoved Activity = 11 // SetupKeyCreated indicates that a user created a new setup key - SetupKeyCreated + SetupKeyCreated Activity = 12 // SetupKeyUpdated indicates that a user updated a setup key - SetupKeyUpdated + SetupKeyUpdated Activity = 13 // SetupKeyRevoked indicates that a user revoked a setup key - SetupKeyRevoked + SetupKeyRevoked Activity = 14 // SetupKeyOverused indicates that setup key usage exhausted - SetupKeyOverused + SetupKeyOverused Activity = 15 // GroupCreated indicates that a user created a group - GroupCreated + GroupCreated Activity = 16 // GroupUpdated indicates that a user updated a group - GroupUpdated + GroupUpdated Activity = 17 // GroupAddedToPeer indicates that a user added group to a peer - GroupAddedToPeer + GroupAddedToPeer Activity = 18 // GroupRemovedFromPeer indicates that a user removed peer group - GroupRemovedFromPeer + GroupRemovedFromPeer Activity = 19 // GroupAddedToUser indicates that a user added group to a user - GroupAddedToUser + GroupAddedToUser Activity = 20 // GroupRemovedFromUser indicates that a user removed a group from a user - GroupRemovedFromUser + GroupRemovedFromUser Activity = 21 // UserRoleUpdated indicates that a user changed the role of a user - UserRoleUpdated + UserRoleUpdated Activity = 22 // GroupAddedToSetupKey indicates that a user added group to a setup key - GroupAddedToSetupKey + GroupAddedToSetupKey Activity = 23 // GroupRemovedFromSetupKey indicates that a user removed a group from a setup key - GroupRemovedFromSetupKey + GroupRemovedFromSetupKey Activity = 24 // GroupAddedToDisabledManagementGroups indicates that a user added a group to the DNS setting Disabled management groups - GroupAddedToDisabledManagementGroups + GroupAddedToDisabledManagementGroups Activity = 25 // GroupRemovedFromDisabledManagementGroups indicates that a user removed a group from the DNS setting Disabled management groups - GroupRemovedFromDisabledManagementGroups + GroupRemovedFromDisabledManagementGroups Activity = 26 // RouteCreated indicates that a user created a route - RouteCreated + RouteCreated Activity = 27 // RouteRemoved indicates that a user deleted a route - RouteRemoved + RouteRemoved Activity = 28 // RouteUpdated indicates that a user updated a route - RouteUpdated + RouteUpdated Activity = 29 // PeerSSHEnabled indicates that a user enabled SSH server on a peer - PeerSSHEnabled + PeerSSHEnabled Activity = 30 // PeerSSHDisabled indicates that a user disabled SSH server on a peer - PeerSSHDisabled + PeerSSHDisabled Activity = 31 // PeerRenamed indicates that a user renamed a peer - PeerRenamed + PeerRenamed Activity = 32 // PeerLoginExpirationEnabled indicates that a user enabled login expiration of a peer - PeerLoginExpirationEnabled + PeerLoginExpirationEnabled Activity = 33 // PeerLoginExpirationDisabled indicates that a user disabled login expiration of a peer - PeerLoginExpirationDisabled + PeerLoginExpirationDisabled Activity = 34 // NameserverGroupCreated indicates that a user created a nameservers group - NameserverGroupCreated + NameserverGroupCreated Activity = 35 // NameserverGroupDeleted indicates that a user deleted a nameservers group - NameserverGroupDeleted + NameserverGroupDeleted Activity = 36 // NameserverGroupUpdated indicates that a user updated a nameservers group - NameserverGroupUpdated + NameserverGroupUpdated Activity = 37 // AccountPeerLoginExpirationEnabled indicates that a user enabled peer login expiration for the account - AccountPeerLoginExpirationEnabled + AccountPeerLoginExpirationEnabled Activity = 38 // AccountPeerLoginExpirationDisabled indicates that a user disabled peer login expiration for the account - AccountPeerLoginExpirationDisabled + AccountPeerLoginExpirationDisabled Activity = 39 // AccountPeerLoginExpirationDurationUpdated indicates that a user updated peer login expiration duration for the account - AccountPeerLoginExpirationDurationUpdated + AccountPeerLoginExpirationDurationUpdated Activity = 40 // PersonalAccessTokenCreated indicates that a user created a personal access token - PersonalAccessTokenCreated + PersonalAccessTokenCreated Activity = 41 // PersonalAccessTokenDeleted indicates that a user deleted a personal access token - PersonalAccessTokenDeleted + PersonalAccessTokenDeleted Activity = 42 // ServiceUserCreated indicates that a user created a service user - ServiceUserCreated + ServiceUserCreated Activity = 43 // ServiceUserDeleted indicates that a user deleted a service user - ServiceUserDeleted + ServiceUserDeleted Activity = 44 // UserBlocked indicates that a user blocked another user - UserBlocked + UserBlocked Activity = 45 // UserUnblocked indicates that a user unblocked another user - UserUnblocked + UserUnblocked Activity = 46 // UserDeleted indicates that a user deleted another user - UserDeleted + UserDeleted Activity = 47 // GroupDeleted indicates that a user deleted group - GroupDeleted + GroupDeleted Activity = 48 // UserLoggedInPeer indicates that user logged in their peer with an interactive SSO login - UserLoggedInPeer + UserLoggedInPeer Activity = 49 // PeerLoginExpired indicates that the user peer login has been expired and peer disconnected - PeerLoginExpired + PeerLoginExpired Activity = 50 // DashboardLogin indicates that the user logged in to the dashboard - DashboardLogin + DashboardLogin Activity = 51 // IntegrationCreated indicates that the user created an integration - IntegrationCreated + IntegrationCreated Activity = 52 // IntegrationUpdated indicates that the user updated an integration - IntegrationUpdated + IntegrationUpdated Activity = 53 // IntegrationDeleted indicates that the user deleted an integration - IntegrationDeleted + IntegrationDeleted Activity = 54 // AccountPeerApprovalEnabled indicates that the user enabled peer approval for the account - AccountPeerApprovalEnabled + AccountPeerApprovalEnabled Activity = 55 // AccountPeerApprovalDisabled indicates that the user disabled peer approval for the account - AccountPeerApprovalDisabled + AccountPeerApprovalDisabled Activity = 56 // PeerApproved indicates that the peer has been approved - PeerApproved + PeerApproved Activity = 57 // PeerApprovalRevoked indicates that the peer approval has been revoked - PeerApprovalRevoked + PeerApprovalRevoked Activity = 58 // TransferredOwnerRole indicates that the user transferred the owner role of the account - TransferredOwnerRole + TransferredOwnerRole Activity = 59 // PostureCheckCreated indicates that the user created a posture check - PostureCheckCreated + PostureCheckCreated Activity = 60 // PostureCheckUpdated indicates that the user updated a posture check - PostureCheckUpdated + PostureCheckUpdated Activity = 61 // PostureCheckDeleted indicates that the user deleted a posture check - PostureCheckDeleted + PostureCheckDeleted Activity = 62 ) var activityMap = map[Activity]Code{ diff --git a/management/server/http/handler.go b/management/server/http/handler.go index bdbeba346..4405d295c 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -12,6 +12,7 @@ import ( s "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/http/middleware" + "github.com/netbirdio/netbird/management/server/integrated_validator" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/telemetry" ) @@ -38,7 +39,7 @@ type emptyObject struct { } // APIHandler creates the Management service HTTP API handler registering all the available endpoints. -func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg) (http.Handler, error) { +func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg, integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) { claimsExtractor := jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), @@ -75,7 +76,7 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa AuthCfg: authCfg, } - if _, err := integrations.RegisterHandlers(ctx, prefix, api.Router, accountManager, claimsExtractor); err != nil { + if _, err := integrations.RegisterHandlers(ctx, prefix, api.Router, accountManager, claimsExtractor, integratedValidator); err != nil { return nil, fmt.Errorf("register integrations endpoints: %w", err) } diff --git a/management/server/peer.go b/management/server/peer.go index 3beceeb11..1448e3011 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -3,7 +3,6 @@ package server import ( "fmt" "net" - "slices" "strings" "time" @@ -13,7 +12,6 @@ import ( "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server/activity" nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" ) @@ -21,11 +19,6 @@ import ( type PeerSync struct { // WireGuardPubKey is a peers WireGuard public key WireGuardPubKey string - // Meta is the system information passed by peer, must be always present - Meta nbpeer.PeerSystemMeta - // UpdateAccountPeers indicate updating account peers, - // which occurs when the peer's metadata is updated - UpdateAccountPeers bool } // PeerLogin used as a data object between the gRPC API and AccountManager on Login request. @@ -558,20 +551,8 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync) (*nbpeer.Peer, *Network return nil, nil, status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more") } - peer, updated := updatePeerMeta(peer, sync.Meta, account) - if updated { - err = am.Store.SaveAccount(account) - if err != nil { - return nil, nil, err - } - - if sync.UpdateAccountPeers { - am.updateAccountPeers(account) - } - } - - requiresApproval, isStatusChanged := am.integratedPeerValidator.IsNotValidPeer(account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra) - if requiresApproval { + peerNotValid, isStatusChanged := am.integratedPeerValidator.IsNotValidPeer(account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra) + if peerNotValid { emptyMap := &NetworkMap{ Network: account.Network.Copy(), } @@ -582,11 +563,11 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync) (*nbpeer.Peer, *Network am.updateAccountPeers(account) } - approvedPeersMap, err := am.GetValidatedPeers(account) + validPeersMap, err := am.GetValidatedPeers(account) if err != nil { return nil, nil, err } - return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap), nil + return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain, validPeersMap), nil } // LoginPeer logs in or registers a peer. @@ -885,65 +866,7 @@ func (am *DefaultAccountManager) updateAccountPeers(account *Account) { } for _, peer := range peers { remotePeerNetworkMap := account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap) - update := toSyncResponse(am, nil, peer, nil, remotePeerNetworkMap, am.GetDNSDomain()) + update := toSyncResponse(nil, peer, nil, remotePeerNetworkMap, am.GetDNSDomain()) am.peersUpdateManager.SendUpdate(peer.ID, &UpdateMessage{Update: update}) } } - -// GetPeerAppliedPostureChecks returns posture checks that are applied to the peer. -func (am *DefaultAccountManager) GetPeerAppliedPostureChecks(peerKey string) ([]posture.Checks, error) { - account, err := am.Store.GetAccountByPeerPubKey(peerKey) - if err != nil { - log.Errorf("failed while getting peer %s: %v", peerKey, err) - return nil, err - } - - peer, err := account.FindPeerByPubKey(peerKey) - if err != nil { - return nil, status.Errorf(status.NotFound, "peer is not registered") - } - if peer == nil { - return nil, nil - } - - peerPostureChecks := make(map[string]posture.Checks) - for _, policy := range account.Policies { - if !policy.Enabled { - continue - } - - outerLoop: - for _, rule := range policy.Rules { - if !rule.Enabled { - continue - } - - for _, sourceGroup := range rule.Sources { - group, ok := account.Groups[sourceGroup] - if !ok { - continue - } - - // check if peer is in the rule source group - if slices.Contains(group.Peers, peer.ID) { - for _, sourcePostureCheckID := range policy.SourcePostureChecks { - for _, postureChecks := range account.PostureChecks { - if postureChecks.ID == sourcePostureCheckID { - peerPostureChecks[sourcePostureCheckID] = *postureChecks - } - } - } - - break outerLoop - } - } - } - } - - postureChecksList := make([]posture.Checks, 0, len(peerPostureChecks)) - for _, check := range peerPostureChecks { - postureChecksList = append(postureChecksList, check) - } - - return postureChecksList, nil -} diff --git a/util/grpc/dialer.go b/util/grpc/dialer.go index 96b2bc32b..63c56de17 100644 --- a/util/grpc/dialer.go +++ b/util/grpc/dialer.go @@ -3,6 +3,8 @@ package grpc import ( "context" "net" + "os/user" + "runtime" log "github.com/sirupsen/logrus" "google.golang.org/grpc" @@ -12,6 +14,20 @@ import ( func WithCustomDialer() grpc.DialOption { return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { + if runtime.GOOS == "linux" { + currentUser, err := user.Current() + if err != nil { + log.Fatalf("failed to get current user: %v", err) + } + + // the custom dialer requires root permissions which are not required for use cases run as non-root + if currentUser.Uid != "0" { + dialer := &net.Dialer{} + return dialer.DialContext(ctx, "tcp", addr) + } + } + + conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr) if err != nil { log.Errorf("Failed to dial: %s", err) diff --git a/util/net/dialer_generic.go b/util/net/dialer_generic.go index 4eda710ac..1e217da13 100644 --- a/util/net/dialer_generic.go +++ b/util/net/dialer_generic.go @@ -49,6 +49,10 @@ func RemoveDialerHooks() { // DialContext wraps the net.Dialer's DialContext method to use the custom connection func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + if CustomRoutingDisabled() { + return d.Dialer.DialContext(ctx, network, address) + } + var resolver *net.Resolver if d.Resolver != nil { resolver = d.Resolver @@ -123,6 +127,10 @@ func callDialerHooks(ctx context.Context, connID ConnectionID, address string, r } func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) { + if CustomRoutingDisabled() { + return net.DialUDP(network, laddr, raddr) + } + dialer := NewDialer() dialer.LocalAddr = laddr @@ -143,6 +151,10 @@ func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) { } func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) { + if CustomRoutingDisabled() { + return net.DialTCP(network, laddr, raddr) + } + dialer := NewDialer() dialer.LocalAddr = laddr diff --git a/util/net/listener_generic.go b/util/net/listener_generic.go index 451279e9d..7847a29c7 100644 --- a/util/net/listener_generic.go +++ b/util/net/listener_generic.go @@ -8,6 +8,7 @@ import ( "net" "sync" + "github.com/pion/transport/v3" log "github.com/sirupsen/logrus" ) @@ -52,6 +53,10 @@ func RemoveListenerHooks() { // ListenPacket listens on the network address and returns a PacketConn // which includes support for write hooks. func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) { + if CustomRoutingDisabled() { + return l.ListenConfig.ListenPacket(ctx, network, address) + } + pc, err := l.ListenConfig.ListenPacket(ctx, network, address) if err != nil { return nil, fmt.Errorf("listen packet: %w", err) @@ -144,7 +149,11 @@ func closeConn(id ConnectionID, conn net.PacketConn) error { // ListenUDP listens on the network address and returns a transport.UDPConn // which includes support for write and close hooks. -func ListenUDP(network string, laddr *net.UDPAddr) (*UDPConn, error) { +func ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) { + if CustomRoutingDisabled() { + return net.ListenUDP(network, laddr) + } + conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String()) if err != nil { return nil, fmt.Errorf("listen UDP: %w", err) diff --git a/util/net/net.go b/util/net/net.go index 9ea7ae803..3856911b1 100644 --- a/util/net/net.go +++ b/util/net/net.go @@ -1,10 +1,16 @@ package net -import "github.com/google/uuid" +import ( + "os" + + "github.com/google/uuid" +) const ( // NetbirdFwmark is the fwmark value used by Netbird via wireguard NetbirdFwmark = 0x1BD00 + + envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING" ) // ConnectionID provides a globally unique identifier for network connections. @@ -15,3 +21,7 @@ type ConnectionID string func GenerateConnID() ConnectionID { return ConnectionID(uuid.NewString()) } + +func CustomRoutingDisabled() bool { + return os.Getenv(envDisableCustomRouting) == "true" +} diff --git a/util/net/net_linux.go b/util/net/net_linux.go index 821417500..954545eb5 100644 --- a/util/net/net_linux.go +++ b/util/net/net_linux.go @@ -21,7 +21,7 @@ func SetRawSocketMark(conn syscall.RawConn) error { var setErr error err := conn.Control(func(fd uintptr) { - setErr = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, NetbirdFwmark) + setErr = SetSocketOpt(int(fd)) }) if err != nil { return fmt.Errorf("control: %w", err) @@ -33,3 +33,11 @@ func SetRawSocketMark(conn syscall.RawConn) error { return nil } + +func SetSocketOpt(fd int) error { + if CustomRoutingDisabled() { + return nil + } + + return syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, NetbirdFwmark) +}