From af50eb350f56827461171a35079b9197f83437da Mon Sep 17 00:00:00 2001 From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Mon, 25 Mar 2024 14:25:26 +0100 Subject: [PATCH 01/28] Change log level for JWT override message of single account mode (#1747) --- management/server/account.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/management/server/account.go b/management/server/account.go index 8b326d93a..8588cf343 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1588,7 +1588,7 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat // We override incoming domain claims to group users under a single account. claims.Domain = am.singleAccountModeDomain claims.DomainCategory = PrivateCategory - log.Infof("overriding JWT Domain and DomainCategory claims since single account mode is enabled") + log.Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled") } newAcc, err := am.getAccountWithAuthorizationClaims(claims) From 68b377a28caf09e77613fa1b1a4d23bc6b3c7f03 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Tue, 26 Mar 2024 15:33:01 +0100 Subject: [PATCH 02/28] Collect chassis.serial (#1748) --- client/system/info_linux.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/system/info_linux.go b/client/system/info_linux.go index ca3be9d1c..652bc1115 100644 --- a/client/system/info_linux.go +++ b/client/system/info_linux.go @@ -120,5 +120,5 @@ func _getReleaseInfo() string { func sysInfo() (serialNumber string, productName string, manufacturer string) { var si sysinfo.SysInfo si.GetSysInfo() - return si.Product.Version, si.Product.Name, si.Product.Vendor + return si.Chassis.Serial, si.Product.Name, si.Product.Vendor } From ea2d060f93275695423bed1ab88cfd89e62d8c63 Mon Sep 17 00:00:00 2001 From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Wed, 27 Mar 2024 16:11:45 +0100 Subject: [PATCH 03/28] Add limited dashboard view (#1738) --- management/server/account.go | 13 +- management/server/group.go | 38 ++++- management/server/http/accounts_handler.go | 2 + .../server/http/accounts_handler_test.go | 9 +- management/server/http/api/openapi.yml | 17 +- management/server/http/api/types.gen.go | 50 +++++- management/server/http/groups_handler.go | 22 ++- management/server/http/groups_handler_test.go | 8 +- management/server/http/users_handler.go | 3 + management/server/http/users_handler_test.go | 2 +- management/server/mock_server/account_mock.go | 27 ++- management/server/peer.go | 9 + management/server/peer_test.go | 155 +++++++++++++++++- management/server/setupkey.go | 8 + management/server/setupkey_test.go | 31 ++++ management/server/user.go | 36 +++- management/server/user_test.go | 77 +++++++++ 17 files changed, 466 insertions(+), 41 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 8588cf343..d9030007d 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -85,7 +85,8 @@ type AccountManager interface { GetAllPATs(accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) UpdatePeerSSHKey(peerID string, sshKey string) error GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error) - GetGroup(accountId, groupID string) (*Group, error) + GetGroup(accountId, groupID, userID string) (*Group, error) + GetAllGroups(accountID, userID string) ([]*Group, error) GetGroupByName(groupName, accountID string) (*Group, error) SaveGroup(accountID, userID string, group *Group) error DeleteGroup(accountId, userId, groupID string) error @@ -162,6 +163,9 @@ type Settings struct { // Applies to all peers that have Peer.LoginExpirationEnabled set to true. PeerLoginExpiration time.Duration + // RegularUsersViewBlocked allows to block regular users from viewing even their own peers and some UI elements + RegularUsersViewBlocked bool + // GroupsPropagationEnabled allows to propagate auto groups from the user to the peer GroupsPropagationEnabled bool @@ -188,6 +192,7 @@ func (s *Settings) Copy() *Settings { JWTGroupsClaimName: s.JWTGroupsClaimName, GroupsPropagationEnabled: s.GroupsPropagationEnabled, JWTAllowGroups: s.JWTAllowGroups, + RegularUsersViewBlocked: s.RegularUsersViewBlocked, } if s.Extra != nil { settings.Extra = s.Extra.Copy() @@ -226,6 +231,10 @@ type Account struct { Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"` } +type UserPermissions struct { + DashboardView string `json:"dashboard_view"` +} + type UserInfo struct { ID string `json:"id"` Email string `json:"email"` @@ -239,6 +248,7 @@ type UserInfo struct { LastLogin time.Time `json:"last_login"` Issued string `json:"issued"` IntegrationReference IntegrationReference `json:"-"` + Permissions UserPermissions `json:"permissions"` } // getRoutesToSync returns the enabled routes for the peer ID and the routes @@ -1885,6 +1895,7 @@ func newAccountWithId(accountID, userID, domain string) *Account { PeerLoginExpirationEnabled: true, PeerLoginExpiration: DefaultPeerLoginExpiration, GroupsPropagationEnabled: true, + RegularUsersViewBlocked: true, }, } diff --git a/management/server/group.go b/management/server/group.go index 43d48e622..59f05a354 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -63,7 +63,7 @@ func (g *Group) Copy() *Group { } // GetGroup object of the peers -func (am *DefaultAccountManager) GetGroup(accountID, groupID string) (*Group, error) { +func (am *DefaultAccountManager) GetGroup(accountID, groupID, userID string) (*Group, error) { unlock := am.Store.AcquireAccountLock(accountID) defer unlock() @@ -72,6 +72,15 @@ func (am *DefaultAccountManager) GetGroup(accountID, groupID string) (*Group, er return nil, err } + user, err := account.FindUser(userID) + if err != nil { + return nil, err + } + + if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked { + return nil, status.Errorf(status.PermissionDenied, "groups are blocked for users") + } + group, ok := account.Groups[groupID] if ok { return group, nil @@ -80,6 +89,33 @@ func (am *DefaultAccountManager) GetGroup(accountID, groupID string) (*Group, er return nil, status.Errorf(status.NotFound, "group with ID %s not found", groupID) } +// GetAllGroups returns all groups in an account +func (am *DefaultAccountManager) GetAllGroups(accountID string, userID string) ([]*Group, error) { + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() + + account, err := am.Store.GetAccount(accountID) + if err != nil { + return nil, err + } + + user, err := account.FindUser(userID) + if err != nil { + return nil, err + } + + if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked { + return nil, status.Errorf(status.PermissionDenied, "groups are blocked for users") + } + + groups := make([]*Group, 0, len(account.Groups)) + for _, item := range account.Groups { + groups = append(groups, item) + } + + return groups, nil +} + // GetGroupByName filters all groups in an account by name and returns the one with the most peers func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*Group, error) { unlock := am.Store.AcquireAccountLock(accountID) diff --git a/management/server/http/accounts_handler.go b/management/server/http/accounts_handler.go index 71088cfaf..d3c9954d3 100644 --- a/management/server/http/accounts_handler.go +++ b/management/server/http/accounts_handler.go @@ -76,6 +76,7 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) settings := &server.Settings{ PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled, PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)), + RegularUsersViewBlocked: req.Settings.RegularUsersViewBlocked, } if req.Settings.Extra != nil { @@ -143,6 +144,7 @@ func toAccountResponse(account *server.Account) *api.Account { JwtGroupsEnabled: &account.Settings.JWTGroupsEnabled, JwtGroupsClaimName: &account.Settings.JWTGroupsClaimName, JwtAllowGroups: &jwtAllowGroups, + RegularUsersViewBlocked: account.Settings.RegularUsersViewBlocked, } if account.Settings.Extra != nil { diff --git a/management/server/http/accounts_handler_test.go b/management/server/http/accounts_handler_test.go index fd2c4bfcd..9d174d0be 100644 --- a/management/server/http/accounts_handler_test.go +++ b/management/server/http/accounts_handler_test.go @@ -69,6 +69,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { Settings: &server.Settings{ PeerLoginExpirationEnabled: false, PeerLoginExpiration: time.Hour, + RegularUsersViewBlocked: true, }, }, adminUser) @@ -96,6 +97,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { JwtGroupsClaimName: sr(""), JwtGroupsEnabled: br(false), JwtAllowGroups: &[]string{}, + RegularUsersViewBlocked: true, }, expectedArray: true, expectedID: accountID, @@ -114,6 +116,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { JwtGroupsClaimName: sr(""), JwtGroupsEnabled: br(false), JwtAllowGroups: &[]string{}, + RegularUsersViewBlocked: false, }, expectedArray: false, expectedID: accountID, @@ -123,7 +126,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { expectedBody: true, requestType: http.MethodPut, requestPath: "/api/accounts/" + accountID, - requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": false,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"roles\",\"jwt_allow_groups\":[\"test\"]}}"), + requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": false,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"roles\",\"jwt_allow_groups\":[\"test\"],\"regular_users_view_blocked\":true}}"), expectedStatus: http.StatusOK, expectedSettings: api.AccountSettings{ PeerLoginExpiration: 15552000, @@ -132,6 +135,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { JwtGroupsClaimName: sr("roles"), JwtGroupsEnabled: br(true), JwtAllowGroups: &[]string{"test"}, + RegularUsersViewBlocked: true, }, expectedArray: false, expectedID: accountID, @@ -141,7 +145,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { expectedBody: true, requestType: http.MethodPut, requestPath: "/api/accounts/" + accountID, - requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 554400,\"peer_login_expiration_enabled\": true,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"groups\",\"groups_propagation_enabled\":true}}"), + requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 554400,\"peer_login_expiration_enabled\": true,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"groups\",\"groups_propagation_enabled\":true,\"regular_users_view_blocked\":true}}"), expectedStatus: http.StatusOK, expectedSettings: api.AccountSettings{ PeerLoginExpiration: 554400, @@ -150,6 +154,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { JwtGroupsClaimName: sr("groups"), JwtGroupsEnabled: br(true), JwtAllowGroups: &[]string{}, + RegularUsersViewBlocked: true, }, expectedArray: false, expectedID: accountID, diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index 7ec2310af..281089396 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -54,6 +54,10 @@ components: description: Period of time after which peer login expires (seconds). type: integer example: 43200 + regular_users_view_blocked: + description: Allows blocking regular users from viewing parts of the system. + type: boolean + example: true groups_propagation_enabled: description: Allows propagate the new user auto groups to peers that belongs to the user type: boolean @@ -77,6 +81,7 @@ components: required: - peer_login_expiration_enabled - peer_login_expiration + - regular_users_view_blocked AccountExtraSettings: type: object properties: @@ -144,6 +149,8 @@ components: description: How user was issued by API or Integration type: string example: api + permissions: + $ref: '#/components/schemas/UserPermissions' required: - id - email @@ -152,6 +159,14 @@ components: - auto_groups - status - is_blocked + UserPermissions: + type: object + properties: + dashboard_view: + description: User's permission to view the dashboard + type: string + enum: [ "limited", "blocked", "full" ] + example: limited UserRequest: type: object properties: @@ -589,8 +604,6 @@ components: type: string enum: ["api", "integration", "jwt"] example: api - type: string - example: api required: - id - name diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index a4c492bb8..78cd83a27 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -69,6 +69,20 @@ const ( GeoLocationCheckActionDeny GeoLocationCheckAction = "deny" ) +// Defines values for GroupIssued. +const ( + GroupIssuedApi GroupIssued = "api" + GroupIssuedIntegration GroupIssued = "integration" + GroupIssuedJwt GroupIssued = "jwt" +) + +// Defines values for GroupMinimumIssued. +const ( + GroupMinimumIssuedApi GroupMinimumIssued = "api" + GroupMinimumIssuedIntegration GroupMinimumIssued = "integration" + GroupMinimumIssuedJwt GroupMinimumIssued = "jwt" +) + // Defines values for NameserverNsType. const ( NameserverNsTypeUdp NameserverNsType = "udp" @@ -129,6 +143,13 @@ const ( UserStatusInvited UserStatus = "invited" ) +// Defines values for UserPermissionsDashboardView. +const ( + UserPermissionsDashboardViewBlocked UserPermissionsDashboardView = "blocked" + UserPermissionsDashboardViewFull UserPermissionsDashboardView = "full" + UserPermissionsDashboardViewLimited UserPermissionsDashboardView = "limited" +) + // AccessiblePeer defines model for AccessiblePeer. type AccessiblePeer struct { // DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud @@ -186,6 +207,9 @@ type AccountSettings struct { // PeerLoginExpirationEnabled Enables or disables peer login expiration globally. After peer's login has expired the user has to log in (authenticate). Applies only to peers that were added by a user (interactive SSO login). PeerLoginExpirationEnabled bool `json:"peer_login_expiration_enabled"` + + // RegularUsersViewBlocked Allows blocking regular users from viewing parts of the system. + RegularUsersViewBlocked bool `json:"regular_users_view_blocked"` } // Checks List of objects that perform the actual checks @@ -283,8 +307,8 @@ type Group struct { // Id Group ID Id string `json:"id"` - // Issued How group was issued by API or from JWT token - Issued *string `json:"issued,omitempty"` + // Issued How the group was issued (api, integration, jwt) + Issued *GroupIssued `json:"issued,omitempty"` // Name Group Name identifier Name string `json:"name"` @@ -296,13 +320,16 @@ type Group struct { PeersCount int `json:"peers_count"` } +// GroupIssued How the group was issued (api, integration, jwt) +type GroupIssued string + // GroupMinimum defines model for GroupMinimum. type GroupMinimum struct { // Id Group ID Id string `json:"id"` - // Issued How group was issued by API or from JWT token - Issued *string `json:"issued,omitempty"` + // Issued How the group was issued (api, integration, jwt) + Issued *GroupMinimumIssued `json:"issued,omitempty"` // Name Group Name identifier Name string `json:"name"` @@ -311,6 +338,9 @@ type GroupMinimum struct { PeersCount int `json:"peers_count"` } +// GroupMinimumIssued How the group was issued (api, integration, jwt) +type GroupMinimumIssued string + // GroupRequest defines model for GroupRequest. type GroupRequest struct { // Name Group name identifier @@ -1072,7 +1102,8 @@ type User struct { LastLogin *time.Time `json:"last_login,omitempty"` // Name User's name from idp provider - Name string `json:"name"` + Name string `json:"name"` + Permissions *UserPermissions `json:"permissions,omitempty"` // Role User's NetBird account role Role string `json:"role"` @@ -1102,6 +1133,15 @@ type UserCreateRequest struct { Role string `json:"role"` } +// UserPermissions defines model for UserPermissions. +type UserPermissions struct { + // DashboardView User's permission to view the dashboard + DashboardView *UserPermissionsDashboardView `json:"dashboard_view,omitempty"` +} + +// UserPermissionsDashboardView User's permission to view the dashboard +type UserPermissionsDashboardView string + // UserRequest defines model for UserRequest. type UserRequest struct { // AutoGroups Group IDs to auto-assign to peers registered by this user diff --git a/management/server/http/groups_handler.go b/management/server/http/groups_handler.go index b37f4fd2f..56d06595f 100644 --- a/management/server/http/groups_handler.go +++ b/management/server/http/groups_handler.go @@ -35,19 +35,25 @@ func NewGroupsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Gr // GetAllGroups list for the account func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, _, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(claims) if err != nil { log.Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) return } - var groups []*api.Group - for _, g := range account.Groups { - groups = append(groups, toGroupResponse(account, g)) + groups, err := h.accountManager.GetAllGroups(account.Id, user.Id) + if err != nil { + util.WriteError(err, w) + return } - util.WriteJSONObject(w, groups) + groupsResponse := make([]*api.Group, 0, len(groups)) + for _, group := range groups { + groupsResponse = append(groupsResponse, toGroupResponse(account, group)) + } + + util.WriteJSONObject(w, groupsResponse) } // UpdateGroup handles update to a group identified by a given ID @@ -207,7 +213,7 @@ func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) { // GetGroup returns a group func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, _, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(claims) if err != nil { util.WriteError(err, w) return @@ -221,7 +227,7 @@ func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) { return } - group, err := h.accountManager.GetGroup(account.Id, groupID) + group, err := h.accountManager.GetGroup(account.Id, groupID, user.Id) if err != nil { util.WriteError(err, w) return @@ -239,7 +245,7 @@ func toGroupResponse(account *server.Account, group *server.Group) *api.Group { gr := api.Group{ Id: group.ID, Name: group.Name, - Issued: &group.Issued, + Issued: (*api.GroupIssued)(&group.Issued), } for _, pid := range group.Peers { diff --git a/management/server/http/groups_handler_test.go b/management/server/http/groups_handler_test.go index 5b47b1208..303efc9d7 100644 --- a/management/server/http/groups_handler_test.go +++ b/management/server/http/groups_handler_test.go @@ -37,7 +37,7 @@ func initGroupTestData(user *server.User, groups ...*server.Group) *GroupsHandle } return nil }, - GetGroupFunc: func(_, groupID string) (*server.Group, error) { + GetGroupFunc: func(_, groupID, _ string) (*server.Group, error) { if groupID != "idofthegroup" { return nil, status.Errorf(status.NotFound, "not found") } @@ -187,7 +187,7 @@ func TestWriteGroup(t *testing.T) { expectedGroup: &api.Group{ Id: "id-was-set", Name: "Default POSTed Group", - Issued: &groupIssuedAPI, + Issued: (*api.GroupIssued)(&groupIssuedAPI), }, }, { @@ -209,7 +209,7 @@ func TestWriteGroup(t *testing.T) { expectedGroup: &api.Group{ Id: "id-existed", Name: "Default POSTed Group", - Issued: &groupIssuedAPI, + Issued: (*api.GroupIssued)(&groupIssuedAPI), }, }, { @@ -240,7 +240,7 @@ func TestWriteGroup(t *testing.T) { expectedGroup: &api.Group{ Id: "id-jwt-group", Name: "changed", - Issued: &groupIssuedJWT, + Issued: (*api.GroupIssued)(&groupIssuedJWT), }, }, } diff --git a/management/server/http/users_handler.go b/management/server/http/users_handler.go index 5d92b65e5..ed8a3f543 100644 --- a/management/server/http/users_handler.go +++ b/management/server/http/users_handler.go @@ -288,5 +288,8 @@ func toUserResponse(user *server.UserInfo, currenUserID string) *api.User { IsBlocked: user.IsBlocked, LastLogin: &user.LastLogin, Issued: &user.Issued, + Permissions: &api.UserPermissions{ + DashboardView: (*api.UserPermissionsDashboardView)(&user.Permissions.DashboardView), + }, } } diff --git a/management/server/http/users_handler_test.go b/management/server/http/users_handler_test.go index ff886ca9f..91f19d8d8 100644 --- a/management/server/http/users_handler_test.go +++ b/management/server/http/users_handler_test.go @@ -105,7 +105,7 @@ func initUsersTestData() *UsersHandler { return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", userID) } - info, err := update.Copy().ToUserInfo(nil) + info, err := update.Copy().ToUserInfo(nil, &server.Settings{RegularUsersViewBlocked: false}) if err != nil { return nil, err } diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index f518372ed..9463498cf 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -31,7 +31,8 @@ type MockAccountManager struct { GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error) GetPeerNetworkFunc func(peerKey string) (*server.Network, error) AddPeerFunc func(setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *server.NetworkMap, error) - GetGroupFunc func(accountID, groupID string) (*server.Group, error) + GetGroupFunc func(accountID, groupID, userID string) (*server.Group, error) + GetAllGroupsFunc func(accountID, userID string) ([]*server.Group, error) GetGroupByNameFunc func(accountID, groupName string) (*server.Group, error) SaveGroupFunc func(accountID, userID string, group *server.Group) error DeleteGroupFunc func(accountID, userId, groupID string) error @@ -92,6 +93,22 @@ type MockAccountManager struct { GetIdpManagerFunc func() idp.Manager } +// GetGroup mock implementation of GetGroup from server.AccountManager interface +func (am *MockAccountManager) GetGroup(accountId, groupID, userID string) (*server.Group, error) { + if am.GetGroupFunc != nil { + return am.GetGroupFunc(accountId, groupID, userID) + } + return nil, status.Errorf(codes.Unimplemented, "method GetGroup is not implemented") +} + +// GetAllGroups mock implementation of GetAllGroups from server.AccountManager interface +func (am *MockAccountManager) GetAllGroups(accountID, userID string) ([]*server.Group, error) { + if am.GetAllGroupsFunc != nil { + return am.GetAllGroupsFunc(accountID, userID) + } + return nil, status.Errorf(codes.Unimplemented, "method GetAllGroups is not implemented") +} + // GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface func (am *MockAccountManager) GetUsersFromAccount(accountID string, userID string) ([]*server.UserInfo, error) { if am.GetUsersFromAccountFunc != nil { @@ -243,14 +260,6 @@ func (am *MockAccountManager) AddPeer( return nil, nil, status.Errorf(codes.Unimplemented, "method AddPeer is not implemented") } -// GetGroup mock implementation of GetGroup from server.AccountManager interface -func (am *MockAccountManager) GetGroup(accountID, groupID string) (*server.Group, error) { - if am.GetGroupFunc != nil { - return am.GetGroupFunc(accountID, groupID) - } - return nil, status.Errorf(codes.Unimplemented, "method GetGroup is not implemented") -} - // GetGroupByName mock implementation of GetGroupByName from server.AccountManager interface func (am *MockAccountManager) GetGroupByName(accountID, groupName string) (*server.Group, error) { if am.GetGroupFunc != nil { diff --git a/management/server/peer.go b/management/server/peer.go index 53b86e9b3..7de1b6542 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -54,6 +54,11 @@ func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*nbpeer.P peers := make([]*nbpeer.Peer, 0) peersMap := make(map[string]*nbpeer.Peer) + + if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked { + return peers, nil + } + for _, peer := range account.Peers { if !(user.HasAdminPower() || user.IsServiceUser) && user.Id != peer.UserID { // only display peers that belong to the current user if the current user is not an admin @@ -738,6 +743,10 @@ func (am *DefaultAccountManager) GetPeer(accountID, peerID, userID string) (*nbp return nil, err } + if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked { + return nil, status.Errorf(status.Internal, "user %s has no access to his own peer %s under account %s", userID, peerID, accountID) + } + peer := account.GetPeer(peerID) if peer == nil { return nil, status.Errorf(status.NotFound, "peer with %s not found under account %s", peerID, accountID) diff --git a/management/server/peer_test.go b/management/server/peer_test.go index ee84ea47d..7f6d440bb 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -4,9 +4,8 @@ import ( "testing" "time" - "github.com/stretchr/testify/assert" - "github.com/rs/xid" + "github.com/stretchr/testify/assert" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" nbpeer "github.com/netbirdio/netbird/management/server/peer" @@ -392,6 +391,8 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) { Id: someUser, Role: UserRoleUser, } + account.Settings.RegularUsersViewBlocked = false + err = manager.Store.SaveAccount(account) if err != nil { t.Fatal(err) @@ -480,3 +481,153 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) { } assert.NotNil(t, peer) } + +func TestDefaultAccountManager_GetPeers(t *testing.T) { + testCases := []struct { + name string + role UserRole + limitedViewSettings bool + isServiceUser bool + expectedPeerCount int + }{ + { + name: "Regular user, no limited view settings, not a service user", + role: UserRoleUser, + limitedViewSettings: false, + isServiceUser: false, + expectedPeerCount: 1, + }, + { + name: "Service user, no limited view settings", + role: UserRoleUser, + limitedViewSettings: false, + isServiceUser: true, + expectedPeerCount: 2, + }, + { + name: "Regular user, limited view settings", + role: UserRoleUser, + limitedViewSettings: true, + isServiceUser: false, + expectedPeerCount: 0, + }, + { + name: "Service user, limited view settings", + role: UserRoleUser, + limitedViewSettings: true, + isServiceUser: true, + expectedPeerCount: 2, + }, + { + name: "Admin, no limited view settings, not a service user", + role: UserRoleAdmin, + limitedViewSettings: false, + isServiceUser: false, + expectedPeerCount: 2, + }, + { + name: "Admin service user, no limited view settings", + role: UserRoleAdmin, + limitedViewSettings: false, + isServiceUser: true, + expectedPeerCount: 2, + }, + { + name: "Admin, limited view settings", + role: UserRoleAdmin, + limitedViewSettings: true, + isServiceUser: false, + expectedPeerCount: 2, + }, + { + name: "Admin Service user, limited view settings", + role: UserRoleAdmin, + limitedViewSettings: true, + isServiceUser: true, + expectedPeerCount: 2, + }, + { + name: "Owner, no limited view settings", + role: UserRoleOwner, + limitedViewSettings: true, + isServiceUser: false, + expectedPeerCount: 2, + }, + { + name: "Owner, limited view settings", + role: UserRoleOwner, + limitedViewSettings: true, + isServiceUser: false, + expectedPeerCount: 2, + }, + } + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + return + } + + // account with an admin and a regular user + accountID := "test_account" + adminUser := "account_creator" + someUser := "some_user" + account := newAccountWithId(accountID, adminUser, "") + account.Users[someUser] = &User{ + Id: someUser, + Role: testCase.role, + IsServiceUser: testCase.isServiceUser, + } + account.Policies = []*Policy{} + account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings + + err = manager.Store.SaveAccount(account) + if err != nil { + t.Fatal(err) + return + } + + peerKey1, err := wgtypes.GeneratePrivateKey() + if err != nil { + t.Fatal(err) + return + } + + peerKey2, err := wgtypes.GeneratePrivateKey() + if err != nil { + t.Fatal(err) + return + } + + _, _, err = manager.AddPeer("", someUser, &nbpeer.Peer{ + Key: peerKey1.PublicKey().String(), + Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"}, + }) + if err != nil { + t.Errorf("expecting peer to be added, got failure %v", err) + return + } + + _, _, err = manager.AddPeer("", adminUser, &nbpeer.Peer{ + Key: peerKey2.PublicKey().String(), + Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"}, + }) + if err != nil { + t.Errorf("expecting peer to be added, got failure %v", err) + return + } + + peers, err := manager.GetPeers(accountID, someUser) + if err != nil { + t.Fatal(err) + return + } + assert.NotNil(t, peers) + + assert.Len(t, peers, testCase.expectedPeerCount) + + }) + } + +} diff --git a/management/server/setupkey.go b/management/server/setupkey.go index 972665527..ff6fb3204 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -339,6 +339,10 @@ func (am *DefaultAccountManager) ListSetupKeys(accountID, userID string) ([]*Set return nil, err } + if !user.HasAdminPower() && !user.IsServiceUser { + return nil, status.Errorf(status.Unauthorized, "only users with admin power can view policies") + } + keys := make([]*SetupKey, 0, len(account.SetupKeys)) for _, key := range account.SetupKeys { var k *SetupKey @@ -368,6 +372,10 @@ func (am *DefaultAccountManager) GetSetupKey(accountID, userID, keyID string) (* return nil, err } + if !user.HasAdminPower() && !user.IsServiceUser { + return nil, status.Errorf(status.Unauthorized, "only users with admin power can view policies") + } + var foundKey *SetupKey for _, key := range account.SetupKeys { if key.Id == keyID { diff --git a/management/server/setupkey_test.go b/management/server/setupkey_test.go index c22df2094..b714652f1 100644 --- a/management/server/setupkey_test.go +++ b/management/server/setupkey_test.go @@ -166,6 +166,37 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { } +func TestGetSetupKeys(t *testing.T) { + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + } + + userID := "testingUser" + account, err := manager.GetOrCreateAccountByUser(userID, "") + if err != nil { + t.Fatal(err) + } + + err = manager.SaveGroup(account.Id, userID, &Group{ + ID: "group_1", + Name: "group_name_1", + Peers: []string{}, + }) + if err != nil { + t.Fatal(err) + } + + err = manager.SaveGroup(account.Id, userID, &Group{ + ID: "group_2", + Name: "group_name_2", + Peers: []string{}, + }) + if err != nil { + t.Fatal(err) + } +} + func TestGenerateDefaultSetupKey(t *testing.T) { expectedName := "Default key" expectedRevoke := false diff --git a/management/server/user.go b/management/server/user.go index f1516139b..15517db41 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -113,12 +113,20 @@ func (u *User) HasAdminPower() bool { } // ToUserInfo converts a User object to a UserInfo object. -func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) { +func (u *User) ToUserInfo(userData *idp.UserData, settings *Settings) (*UserInfo, error) { autoGroups := u.AutoGroups if autoGroups == nil { autoGroups = []string{} } + dashboardViewPermissions := "full" + if !u.HasAdminPower() { + dashboardViewPermissions = "limited" + if settings.RegularUsersViewBlocked { + dashboardViewPermissions = "blocked" + } + } + if userData == nil { return &UserInfo{ ID: u.Id, @@ -131,6 +139,9 @@ func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) { IsBlocked: u.Blocked, LastLogin: u.LastLogin, Issued: u.Issued, + Permissions: UserPermissions{ + DashboardView: dashboardViewPermissions, + }, }, nil } if userData.ID != u.Id { @@ -153,6 +164,9 @@ func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) { IsBlocked: u.Blocked, LastLogin: u.LastLogin, Issued: u.Issued, + Permissions: UserPermissions{ + DashboardView: dashboardViewPermissions, + }, }, nil } @@ -358,7 +372,7 @@ func (am *DefaultAccountManager) inviteNewUser(accountID, userID string, invite am.StoreEvent(userID, newUser.Id, accountID, activity.UserInvited, nil) - return newUser.ToUserInfo(idpUser) + return newUser.ToUserInfo(idpUser, account.Settings) } // GetUser looks up a user by provided authorization claims. @@ -905,9 +919,9 @@ func (am *DefaultAccountManager) SaveOrAddUser(accountID, initiatorUserID string if err != nil { return nil, err } - return newUser.ToUserInfo(userData) + return newUser.ToUserInfo(userData, account.Settings) } - return newUser.ToUserInfo(nil) + return newUser.ToUserInfo(nil, account.Settings) } // GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist @@ -998,7 +1012,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) ( // if user is not an admin then show only current user and do not show other users continue } - info, err := accountUser.ToUserInfo(nil) + info, err := accountUser.ToUserInfo(nil, account.Settings) if err != nil { return nil, err } @@ -1015,7 +1029,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) ( var info *UserInfo if queriedUser, contains := findUserInIDPUserdata(localUser.Id, queriedUsers); contains { - info, err = localUser.ToUserInfo(queriedUser) + info, err = localUser.ToUserInfo(queriedUser, account.Settings) if err != nil { return nil, err } @@ -1024,6 +1038,15 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) ( if localUser.IsServiceUser { name = localUser.ServiceUserName } + + dashboardViewPermissions := "full" + if !localUser.HasAdminPower() { + dashboardViewPermissions = "limited" + if account.Settings.RegularUsersViewBlocked { + dashboardViewPermissions = "blocked" + } + } + info = &UserInfo{ ID: localUser.Id, Email: "", @@ -1033,6 +1056,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) ( Status: string(UserStatusActive), IsServiceUser: localUser.IsServiceUser, NonDeletable: localUser.NonDeletable, + Permissions: UserPermissions{DashboardView: dashboardViewPermissions}, } } userInfos = append(userInfos, info) diff --git a/management/server/user_test.go b/management/server/user_test.go index 50cd726ef..e34aa406d 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -709,6 +709,83 @@ func TestDefaultAccountManager_ListUsers(t *testing.T) { assert.Equal(t, 2, regular) } +func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) { + testCases := []struct { + name string + role UserRole + limitedViewSettings bool + expectedDashboardPermissions string + }{ + { + name: "Regular user, no limited view settings", + role: UserRoleUser, + limitedViewSettings: false, + expectedDashboardPermissions: "limited", + }, + { + name: "Admin user, no limited view settings", + role: UserRoleAdmin, + limitedViewSettings: false, + expectedDashboardPermissions: "full", + }, + { + name: "Owner, no limited view settings", + role: UserRoleOwner, + limitedViewSettings: false, + expectedDashboardPermissions: "full", + }, + { + name: "Regular user, limited view settings", + role: UserRoleUser, + limitedViewSettings: true, + expectedDashboardPermissions: "blocked", + }, + { + name: "Admin user, limited view settings", + role: UserRoleAdmin, + limitedViewSettings: true, + expectedDashboardPermissions: "full", + }, + { + name: "Owner, limited view settings", + role: UserRoleOwner, + limitedViewSettings: true, + expectedDashboardPermissions: "full", + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + store := newStore(t) + account := newAccountWithId(mockAccountID, mockUserID, "") + account.Users["normal_user1"] = NewUser("normal_user1", testCase.role, false, false, "", []string{}, UserIssuedAPI) + account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings + delete(account.Users, mockUserID) + + err := store.SaveAccount(account) + if err != nil { + t.Fatalf("Error when saving account: %s", err) + } + + am := DefaultAccountManager{ + Store: store, + eventStore: &activity.InMemoryEventStore{}, + } + + users, err := am.ListUsers(mockAccountID) + if err != nil { + t.Fatalf("Error when checking user role: %s", err) + } + + assert.Equal(t, 1, len(users)) + + userInfo, _ := users[0].ToUserInfo(nil, account.Settings) + assert.Equal(t, testCase.expectedDashboardPermissions, userInfo.Permissions.DashboardView) + }) + } + +} + func TestDefaultAccountManager_ExternalCache(t *testing.T) { store := newStore(t) account := newAccountWithId(mockAccountID, mockUserID, "") From 2d76b058fcee1bebc54745a91e2b447da2a2c9db Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Wed, 27 Mar 2024 18:48:48 +0100 Subject: [PATCH 04/28] Feature/peer validator (#1553) Follow up management-integrations changes move groups to separated packages to avoid circle dependencies save location information in Login action --- client/cmd/testutil.go | 4 +- client/internal/engine_test.go | 4 +- client/server/server_test.go | 4 +- go.mod | 4 +- go.sum | 7 +- management/client/client_test.go | 11 +- management/cmd/management.go | 8 +- management/server/account.go | 88 +++++++---- management/server/account/account.go | 8 +- management/server/account_test.go | 103 +++++++++---- management/server/dns_test.go | 7 +- management/server/ephemeral.go | 2 +- management/server/file_store.go | 3 +- management/server/file_store_test.go | 8 +- management/server/group.go | 77 +++------- management/server/group/group.go | 46 ++++++ management/server/group_test.go | 35 ++--- management/server/grpcserver.go | 1 + management/server/http/api/openapi.yml | 1 + management/server/http/api/types.gen.go | 6 +- management/server/http/groups_handler.go | 22 +-- management/server/http/groups_handler_test.go | 27 ++-- management/server/http/handler.go | 1 - management/server/http/peers_handler.go | 141 ++++++++++++------ .../server/http/policies_handler_test.go | 3 +- .../server/http/setupkeys_handler_test.go | 11 +- management/server/http/util/util.go | 2 + management/server/integrated_validator.go | 80 ++++++++++ .../server/integrated_validator/interface.go | 19 +++ .../integration_reference.go | 23 +++ management/server/management_proto_test.go | 5 +- management/server/management_test.go | 57 +++++-- management/server/metrics/selfhosted_test.go | 5 +- management/server/mock_server/account_mock.go | 47 ++++-- management/server/nameserver.go | 3 +- management/server/nameserver_test.go | 7 +- management/server/peer.go | 121 ++++++++++++--- management/server/peer_test.go | 5 +- management/server/policy.go | 19 ++- management/server/policy_test.go | 51 ++++--- management/server/route_test.go | 9 +- management/server/setupkey_test.go | 11 +- management/server/sqlite_store.go | 15 +- management/server/user.go | 20 +-- management/server/user_test.go | 10 +- 45 files changed, 790 insertions(+), 351 deletions(-) create mode 100644 management/server/group/group.go create mode 100644 management/server/integrated_validator.go create mode 100644 management/server/integrated_validator/interface.go create mode 100644 management/server/integration_reference/integration_reference.go diff --git a/client/cmd/testutil.go b/client/cmd/testutil.go index 2cfc93415..2f92e1c03 100644 --- a/client/cmd/testutil.go +++ b/client/cmd/testutil.go @@ -13,6 +13,7 @@ import ( "google.golang.org/grpc" + "github.com/netbirdio/management-integrations/integrations" clientProto "github.com/netbirdio/netbird/client/proto" client "github.com/netbirdio/netbird/client/server" mgmtProto "github.com/netbirdio/netbird/management/proto" @@ -78,7 +79,8 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste if err != nil { return nil, nil } - accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false) + iv, _ := integrations.NewIntegratedValidator(eventStore) + accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv) if err != nil { t.Fatal(err) } diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 952b3c90c..309b2e7c6 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -21,6 +21,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/keepalive" + "github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/routemanager" @@ -1050,7 +1051,8 @@ func startManagement(dataDir string) (*grpc.Server, string, error) { if err != nil { return nil, "", err } - accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false) + ia, _ := integrations.NewIntegratedValidator(eventStore) + accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia) if err != nil { return nil, "", err } diff --git a/client/server/server_test.go b/client/server/server_test.go index 7f8310c90..4e4a09145 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -2,6 +2,7 @@ package server import ( "context" + "github.com/netbirdio/management-integrations/integrations" "net" "testing" "time" @@ -114,7 +115,8 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve if err != nil { return nil, "", err } - accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false) + ia, _ := integrations.NewIntegratedValidator(eventStore) + accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia) if err != nil { return nil, "", err } diff --git a/go.mod b/go.mod index 67ec9c42e..5566f8559 100644 --- a/go.mod +++ b/go.mod @@ -46,6 +46,7 @@ require ( github.com/golang/mock v1.6.0 github.com/google/go-cmp v0.5.9 github.com/google/gopacket v1.1.19 + github.com/google/martian/v3 v3.0.0 github.com/google/nftables v0.0.0-20220808154552-2eca00135732 github.com/gopacket/gopacket v1.1.1 github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357 @@ -59,8 +60,7 @@ require ( github.com/miekg/dns v1.1.43 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 - github.com/netbirdio/management-integrations/additions v0.0.0-20240212121739-8ea8c89a4552 - github.com/netbirdio/management-integrations/integrations v0.0.0-20240212121739-8ea8c89a4552 + github.com/netbirdio/management-integrations/integrations v0.0.0-20240326083846-3682438fca98 github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 github.com/patrickmn/go-cache v2.1.0+incompatible diff --git a/go.sum b/go.sum index c36b8aff3..6da405341 100644 --- a/go.sum +++ b/go.sum @@ -255,6 +255,7 @@ github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/ github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= +github.com/google/martian/v3 v3.0.0 h1:pMen7vLs8nvgEYhywH3KDWJIJTeEr2ULsVWHWYHQyBs= github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= github.com/google/nftables v0.0.0-20220808154552-2eca00135732 h1:csc7dT82JiSLvq4aMyQMIQDL7986NH6Wxf/QrvOj55A= github.com/google/nftables v0.0.0-20220808154552-2eca00135732/go.mod h1:b97ulCCFipUC+kSin+zygkvUVpx0vyIAwxXFdY3PlNc= @@ -382,10 +383,8 @@ github.com/nadoo/ipset v0.5.0 h1:5GJUAuZ7ITQQQGne5J96AmFjRtI8Avlbk6CabzYWVUc= github.com/nadoo/ipset v0.5.0/go.mod h1:rYF5DQLRGGoQ8ZSWeK+6eX5amAuPqwFkWjhQlEITGJQ= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q= -github.com/netbirdio/management-integrations/additions v0.0.0-20240212121739-8ea8c89a4552 h1:yzcQKizAK9YufCHMMCIsr467Dw/OU/4xyHbWizGb1E4= -github.com/netbirdio/management-integrations/additions v0.0.0-20240212121739-8ea8c89a4552/go.mod h1:31FhBNvQ+riHEIu6LSTmqr8IeuSIsGfQffqV4LFmbwA= -github.com/netbirdio/management-integrations/integrations v0.0.0-20240212121739-8ea8c89a4552 h1:OFlzVZtkXCoJsfDKrMigFpuad8ZXTm8epq6x27K0irA= -github.com/netbirdio/management-integrations/integrations v0.0.0-20240212121739-8ea8c89a4552/go.mod h1:B0nMS3es77gOvPYhc0K91fAzTkQLi/jRq5TffUN3klM= +github.com/netbirdio/management-integrations/integrations v0.0.0-20240326083846-3682438fca98 h1:i6AtenTLu/CqhTmj0g1K/GWkkpMJMhQM6Vjs46x25nA= +github.com/netbirdio/management-integrations/integrations v0.0.0-20240326083846-3682438fca98/go.mod h1:kxks50DrZnhW+oRTdHOkVOJbcTcyo766am8RBugo+Yc= github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g= github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 h1:xbWM9BU6mwZZLHxEjxIX/V8Hv3HurQt4mReIE4mY4DM= diff --git a/management/client/client_test.go b/management/client/client_test.go index f30ae0cfd..30f91c73b 100644 --- a/management/client/client_test.go +++ b/management/client/client_test.go @@ -3,6 +3,7 @@ package client import ( "context" "net" + "os" "path/filepath" "sync" "testing" @@ -15,6 +16,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" + "github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/netbird/encryption" mgmtProto "github.com/netbirdio/netbird/management/proto" mgmt "github.com/netbirdio/netbird/management/server" @@ -30,6 +32,12 @@ import ( const ValidKey = "A2C8E62B-38F5-4553-B31E-DD66C696CEBB" +func TestMain(m *testing.M) { + _ = util.InitLog("debug", "console") + code := m.Run() + os.Exit(code) +} + func startManagement(t *testing.T) (*grpc.Server, net.Listener) { t.Helper() level, _ := log.ParseLevel("debug") @@ -60,7 +68,8 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { peersUpdateManager := mgmt.NewPeersUpdateManager(nil) eventStore := &activity.InMemoryEventStore{} - accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false) + ia, _ := integrations.NewIntegratedValidator(eventStore) + accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia) if err != nil { t.Fatal(err) } diff --git a/management/cmd/management.go b/management/cmd/management.go index e8bcdc97d..23d9c195c 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -31,6 +31,7 @@ import ( "google.golang.org/grpc/keepalive" "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip" + "github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/netbird/encryption" @@ -172,8 +173,12 @@ var ( log.Infof("geo location service has been initialized from %s", config.Datadir) } + integratedPeerValidator, err := integrations.NewIntegratedValidator(eventStore) + if err != nil { + return fmt.Errorf("failed to initialize integrated peer validator: %v", err) + } accountManager, err := server.BuildManager(store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain, - dnsDomain, eventStore, geo, userDeleteFromIDPEnabled) + dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator) if err != nil { return fmt.Errorf("failed to build default manager: %v", err) } @@ -323,6 +328,7 @@ var ( SetupCloseHandler() <-stopCh + integratedPeerValidator.Stop() if geo != nil { _ = geo.Stop() } diff --git a/management/server/account.go b/management/server/account.go index d9030007d..c145c1bd7 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -21,14 +21,15 @@ import ( "github.com/rs/xid" log "github.com/sirupsen/logrus" - "github.com/netbirdio/management-integrations/additions" - "github.com/netbirdio/netbird/base62" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/geolocation" + nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/idp" + "github.com/netbirdio/netbird/management/server/integrated_validator" + "github.com/netbirdio/netbird/management/server/integration_reference" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" @@ -85,12 +86,12 @@ type AccountManager interface { GetAllPATs(accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) UpdatePeerSSHKey(peerID string, sshKey string) error GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error) - GetGroup(accountId, groupID, userID string) (*Group, error) - GetAllGroups(accountID, userID string) ([]*Group, error) - GetGroupByName(groupName, accountID string) (*Group, error) - SaveGroup(accountID, userID string, group *Group) error + GetGroup(accountId, groupID, userID string) (*nbgroup.Group, error) + GetAllGroups(accountID, userID string) ([]*nbgroup.Group, error) + GetGroupByName(groupName, accountID string) (*nbgroup.Group, error) + SaveGroup(accountID, userID string, group *nbgroup.Group) error DeleteGroup(accountId, userId, groupID string) error - ListGroups(accountId string) ([]*Group, error) + ListGroups(accountId string) ([]*nbgroup.Group, error) GroupAddPeer(accountId, groupID, peerID string) error GroupDeletePeer(accountId, groupID, peerID string) error GetPolicy(accountID, policyID, userID string) (*Policy, error) @@ -124,6 +125,9 @@ type AccountManager interface { DeletePostureChecks(accountID, postureChecksID, userID string) error ListPostureChecks(accountID, userID string) ([]*posture.Checks, error) GetIdpManager() idp.Manager + UpdateIntegratedValidatorGroups(accountID string, userID string, groups []string) error + GroupValidation(accountId string, groups []string) (bool, error) + GetValidatedPeers(account *Account) (map[string]struct{}, error) } type DefaultAccountManager struct { @@ -152,6 +156,8 @@ type DefaultAccountManager struct { // userDeleteFromIDPEnabled allows to delete user from IDP when user is deleted from account userDeleteFromIDPEnabled bool + + integratedPeerValidator integrated_validator.IntegratedValidator } // Settings represents Account settings structure that can be modified via API and Dashboard @@ -218,8 +224,8 @@ type Account struct { PeersG []nbpeer.Peer `json:"-" gorm:"foreignKey:AccountID;references:id"` Users map[string]*User `gorm:"-"` UsersG []User `json:"-" gorm:"foreignKey:AccountID;references:id"` - Groups map[string]*Group `gorm:"-"` - GroupsG []Group `json:"-" gorm:"foreignKey:AccountID;references:id"` + Groups map[string]*nbgroup.Group `gorm:"-"` + GroupsG []nbgroup.Group `json:"-" gorm:"foreignKey:AccountID;references:id"` Policies []*Policy `gorm:"foreignKey:AccountID;references:id"` Routes map[string]*route.Route `gorm:"-"` RoutesG []route.Route `json:"-" gorm:"foreignKey:AccountID;references:id"` @@ -247,7 +253,7 @@ type UserInfo struct { NonDeletable bool `json:"non_deletable"` LastLogin time.Time `json:"last_login"` Issued string `json:"issued"` - IntegrationReference IntegrationReference `json:"-"` + IntegrationReference integration_reference.IntegrationReference `json:"-"` Permissions UserPermissions `json:"permissions"` } @@ -372,25 +378,26 @@ func (a *Account) GetRoutesByPrefix(prefix netip.Prefix) []*route.Route { } // GetGroup returns a group by ID if exists, nil otherwise -func (a *Account) GetGroup(groupID string) *Group { +func (a *Account) GetGroup(groupID string) *nbgroup.Group { return a.Groups[groupID] } // GetPeerNetworkMap returns a group by ID if exists, nil otherwise -func (a *Account) GetPeerNetworkMap(peerID, dnsDomain string) *NetworkMap { +func (a *Account) GetPeerNetworkMap(peerID, dnsDomain string, validatedPeersMap map[string]struct{}) *NetworkMap { peer := a.Peers[peerID] if peer == nil { return &NetworkMap{ Network: a.Network.Copy(), } } - validatedPeers := additions.ValidatePeers([]*nbpeer.Peer{peer}) - if len(validatedPeers) == 0 { + + if _, ok := validatedPeersMap[peerID]; !ok { return &NetworkMap{ Network: a.Network.Copy(), } } - aclPeers, firewallRules := a.getPeerConnectionResources(peerID) + + aclPeers, firewallRules := a.getPeerConnectionResources(peerID, validatedPeersMap) // exclude expired peers var peersToConnect []*nbpeer.Peer var expiredPeers []*nbpeer.Peer @@ -564,7 +571,7 @@ func (a *Account) FindUser(userID string) (*User, error) { } // FindGroupByName looks for a given group in the Account by name or returns error if the group wasn't found. -func (a *Account) FindGroupByName(groupName string) (*Group, error) { +func (a *Account) FindGroupByName(groupName string) (*nbgroup.Group, error) { for _, group := range a.Groups { if group.Name == groupName { return group, nil @@ -583,6 +590,20 @@ func (a *Account) FindSetupKey(setupKey string) (*SetupKey, error) { return key, nil } +// GetPeerGroupsList return with the list of groups ID. +func (a *Account) GetPeerGroupsList(peerID string) []string { + var grps []string + for groupID, group := range a.Groups { + for _, id := range group.Peers { + if id == peerID { + grps = append(grps, groupID) + break + } + } + } + return grps +} + func (a *Account) getUserGroups(userID string) ([]string, error) { user, err := a.FindUser(userID) if err != nil { @@ -660,7 +681,7 @@ func (a *Account) Copy() *Account { setupKeys[id] = key.Copy() } - groups := map[string]*Group{} + groups := map[string]*nbgroup.Group{} for id, group := range a.Groups { groups[id] = group.Copy() } @@ -713,7 +734,7 @@ func (a *Account) Copy() *Account { } } -func (a *Account) GetGroupAll() (*Group, error) { +func (a *Account) GetGroupAll() (*nbgroup.Group, error) { for _, g := range a.Groups { if g.Name == "All" { return g, nil @@ -734,7 +755,7 @@ func (a *Account) SetJWTGroups(userID string, groupsNames []string) bool { return false } - existedGroupsByName := make(map[string]*Group) + existedGroupsByName := make(map[string]*nbgroup.Group) for _, group := range a.Groups { existedGroupsByName[group.Name] = group } @@ -743,7 +764,7 @@ func (a *Account) SetJWTGroups(userID string, groupsNames []string) bool { removed := 0 jwtAutoGroups := make(map[string]struct{}) for i, id := range user.AutoGroups { - if group, ok := a.Groups[id]; ok && group.Issued == GroupIssuedJWT { + if group, ok := a.Groups[id]; ok && group.Issued == nbgroup.GroupIssuedJWT { jwtAutoGroups[group.Name] = struct{}{} user.AutoGroups = append(user.AutoGroups[:i-removed], user.AutoGroups[i-removed+1:]...) removed++ @@ -756,15 +777,15 @@ func (a *Account) SetJWTGroups(userID string, groupsNames []string) bool { for _, name := range groupsNames { group, ok := existedGroupsByName[name] if !ok { - group = &Group{ + group = &nbgroup.Group{ ID: xid.New().String(), Name: name, - Issued: GroupIssuedJWT, + Issued: nbgroup.GroupIssuedJWT, } a.Groups[group.ID] = group } // only JWT groups will be synced - if group.Issued == GroupIssuedJWT { + if group.Issued == nbgroup.GroupIssuedJWT { user.AutoGroups = append(user.AutoGroups, group.ID) if _, ok := jwtAutoGroups[name]; !ok { modified = true @@ -837,6 +858,7 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) { func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager, singleAccountModeDomain string, dnsDomain string, eventStore activity.Store, geo *geolocation.Geolocation, userDeleteFromIDPEnabled bool, + integratedPeerValidator integrated_validator.IntegratedValidator, ) (*DefaultAccountManager, error) { am := &DefaultAccountManager{ Store: store, @@ -850,6 +872,7 @@ func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManage eventStore: eventStore, peerLoginExpiry: NewDefaultScheduler(), userDeleteFromIDPEnabled: userDeleteFromIDPEnabled, + integratedPeerValidator: integratedPeerValidator, } allAccounts := store.GetAllAccounts() // enable single account mode only if configured by user and number of existing accounts is not grater than 1 @@ -906,6 +929,8 @@ func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManage }() } + am.integratedPeerValidator.SetPeerInvalidationListener(am.onPeersInvalidated) + return am, nil } @@ -948,7 +973,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(accountID, userID string, return nil, status.Errorf(status.PermissionDenied, "user is not allowed to update account") } - err = additions.ValidateExtraSettings(newSettings.Extra, account.Settings.Extra, account.Peers, userID, accountID, am.eventStore) + err = am.integratedPeerValidator.ValidateExtraSettings(newSettings.Extra, account.Settings.Extra, account.Peers, userID, accountID) if err != nil { return nil, err } @@ -1823,18 +1848,27 @@ func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(claims jwtclaims.Aut return nil } +func (am *DefaultAccountManager) onPeersInvalidated(accountID string) { + updatedAccount, err := am.Store.GetAccount(accountID) + if err != nil { + log.Errorf("failed to get account %s: %v", accountID, err) + return + } + am.updateAccountPeers(updatedAccount) +} + // addAllGroup to account object if it doesn't exist func addAllGroup(account *Account) error { if len(account.Groups) == 0 { - allGroup := &Group{ + allGroup := &nbgroup.Group{ ID: xid.New().String(), Name: "All", - Issued: GroupIssuedAPI, + Issued: nbgroup.GroupIssuedAPI, } for _, peer := range account.Peers { allGroup.Peers = append(allGroup.Peers, peer.ID) } - account.Groups = map[string]*Group{allGroup.ID: allGroup} + account.Groups = map[string]*nbgroup.Group{allGroup.ID: allGroup} id := xid.New().String() diff --git a/management/server/account/account.go b/management/server/account/account.go index b8b71a6de..40f032fbe 100644 --- a/management/server/account/account.go +++ b/management/server/account/account.go @@ -3,11 +3,17 @@ package account type ExtraSettings struct { // PeerApprovalEnabled enables or disables the need for peers bo be approved by an administrator PeerApprovalEnabled bool + + // IntegratedValidatorGroups list of group IDs to be used with integrated approval configurations + IntegratedValidatorGroups []string `gorm:"serializer:json"` } // Copy copies the ExtraSettings struct func (e *ExtraSettings) Copy() *ExtraSettings { + var cpGroup []string + return &ExtraSettings{ - PeerApprovalEnabled: e.PeerApprovalEnabled, + PeerApprovalEnabled: e.PeerApprovalEnabled, + IntegratedValidatorGroups: append(cpGroup, e.IntegratedValidatorGroups...), } } diff --git a/management/server/account_test.go b/management/server/account_test.go index 2b0c44196..a0eff239b 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -12,20 +12,57 @@ import ( "time" "github.com/golang-jwt/jwt" - - nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/management/server/activity" - nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/posture" - "github.com/netbirdio/netbird/route" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/server/account" + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/jwtclaims" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/route" ) +type MocIntegratedValidator struct { +} + +func (a MocIntegratedValidator) ValidateExtraSettings(newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error { + return nil +} + +func (a MocIntegratedValidator) ValidatePeer(update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) { + return update, nil +} +func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*group.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) { + validatedPeers := make(map[string]struct{}) + for _, peer := range peers { + validatedPeers[peer.ID] = struct{}{} + } + return validatedPeers, nil +} + +func (MocIntegratedValidator) PreparePeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer { + return peer +} + +func (MocIntegratedValidator) IsNotValidPeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool) { + return false, false +} + +func (MocIntegratedValidator) PeerDeleted(_, _ string) error { + return nil +} + +func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string)) { + +} + +func (MocIntegratedValidator) Stop() { +} + func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *Account, userID string) { t.Helper() peer := &nbpeer.Peer{ @@ -367,7 +404,12 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { account.Groups[all.ID].Peers = append(account.Groups[all.ID].Peers, peer.ID) } - networkMap := account.GetPeerNetworkMap(testCase.peerID, "netbird.io") + validatedPeers := map[string]struct{}{} + for p := range account.Peers { + validatedPeers[p] = struct{}{} + } + + networkMap := account.GetPeerNetworkMap(testCase.peerID, "netbird.io", validatedPeers) assert.Len(t, networkMap.Peers, len(testCase.expectedPeers)) assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers)) } @@ -667,7 +709,7 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { require.NoError(t, err, "get account by token failed") require.Len(t, account.Groups, 3, "groups should be added to the account") - groupsByNames := map[string]*Group{} + groupsByNames := map[string]*group.Group{} for _, g := range account.Groups { groupsByNames[g.Name] = g } @@ -675,12 +717,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { g1, ok := groupsByNames["group1"] require.True(t, ok, "group1 should be added to the account") require.Equal(t, g1.Name, "group1", "group1 name should match") - require.Equal(t, g1.Issued, GroupIssuedJWT, "group1 issued should match") + require.Equal(t, g1.Issued, group.GroupIssuedJWT, "group1 issued should match") g2, ok := groupsByNames["group2"] require.True(t, ok, "group2 should be added to the account") require.Equal(t, g2.Name, "group2", "group2 name should match") - require.Equal(t, g2.Issued, GroupIssuedJWT, "group2 issued should match") + require.Equal(t, g2.Issued, group.GroupIssuedJWT, "group2 issued should match") }) } @@ -800,7 +842,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) { t.Fatalf("expected to create an account for a user %s", userId) } - if account.Domain != domain { + if account != nil && account.Domain != domain { t.Errorf("setting account domain failed, expected %s, got %s", domain, account.Domain) } @@ -815,7 +857,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) { t.Fatalf("expected to get an account for a user %s", userId) } - if account.Domain != domain { + if account != nil && account.Domain != domain { t.Errorf("updating domain. expected %s got %s", domain, account.Domain) } } @@ -835,13 +877,12 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) { } if account == nil { t.Fatalf("expected to create an account for a user %s", userId) + return } - accountId := account.Id - - _, err = manager.GetAccountByUserOrAccountID("", accountId, "") + _, err = manager.GetAccountByUserOrAccountID("", account.Id, "") if err != nil { - t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", accountId) + t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", account.Id) } _, err = manager.GetAccountByUserOrAccountID("", "", "") @@ -1124,7 +1165,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) { updMsg := manager.peersUpdateManager.CreateChannel(peer1.ID) defer manager.peersUpdateManager.CloseChannel(peer1.ID) - group := Group{ + group := group.Group{ ID: "group-id", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID, peer3.ID}, @@ -1417,7 +1458,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) { Peers: map[string]*nbpeer.Peer{ "peer-1": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-2": {Key: "peer-2", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-3": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, }, - Groups: map[string]*Group{"group1": {ID: "group1", Peers: []string{"peer-1", "peer-2"}}}, + Groups: map[string]*group.Group{"group1": {ID: "group1", Peers: []string{"peer-1", "peer-2"}}}, Routes: map[string]*route.Route{ "route-1": { ID: "route-1", @@ -1518,7 +1559,7 @@ func TestAccount_Copy(t *testing.T) { }, }, }, - Groups: map[string]*Group{ + Groups: map[string]*group.Group{ "group1": { ID: "group1", Peers: []string{"peer1"}, @@ -2112,8 +2153,8 @@ func TestAccount_SetJWTGroups(t *testing.T) { "peer4": {ID: "peer4", Key: "key4", UserID: "user2"}, "peer5": {ID: "peer5", Key: "key5", UserID: "user2"}, }, - Groups: map[string]*Group{ - "group1": {ID: "group1", Name: "group1", Issued: GroupIssuedAPI, Peers: []string{}}, + Groups: map[string]*group.Group{ + "group1": {ID: "group1", Name: "group1", Issued: group.GroupIssuedAPI, Peers: []string{}}, }, Settings: &Settings{GroupsPropagationEnabled: true}, Users: map[string]*User{ @@ -2160,10 +2201,10 @@ func TestAccount_UserGroupsAddToPeers(t *testing.T) { "peer4": {ID: "peer4", Key: "key4", UserID: "user2"}, "peer5": {ID: "peer5", Key: "key5", UserID: "user2"}, }, - Groups: map[string]*Group{ - "group1": {ID: "group1", Name: "group1", Issued: GroupIssuedAPI, Peers: []string{}}, - "group2": {ID: "group2", Name: "group2", Issued: GroupIssuedAPI, Peers: []string{}}, - "group3": {ID: "group3", Name: "group3", Issued: GroupIssuedAPI, Peers: []string{}}, + Groups: map[string]*group.Group{ + "group1": {ID: "group1", Name: "group1", Issued: group.GroupIssuedAPI, Peers: []string{}}, + "group2": {ID: "group2", Name: "group2", Issued: group.GroupIssuedAPI, Peers: []string{}}, + "group3": {ID: "group3", Name: "group3", Issued: group.GroupIssuedAPI, Peers: []string{}}, }, Users: map[string]*User{"user1": {Id: "user1"}, "user2": {Id: "user2"}}, } @@ -2196,10 +2237,10 @@ func TestAccount_UserGroupsRemoveFromPeers(t *testing.T) { "peer4": {ID: "peer4", Key: "key4", UserID: "user2"}, "peer5": {ID: "peer5", Key: "key5", UserID: "user2"}, }, - Groups: map[string]*Group{ - "group1": {ID: "group1", Name: "group1", Issued: GroupIssuedAPI, Peers: []string{"peer1", "peer2", "peer3"}}, - "group2": {ID: "group2", Name: "group2", Issued: GroupIssuedAPI, Peers: []string{"peer1", "peer2", "peer3", "peer4", "peer5"}}, - "group3": {ID: "group3", Name: "group3", Issued: GroupIssuedAPI, Peers: []string{"peer4", "peer5"}}, + Groups: map[string]*group.Group{ + "group1": {ID: "group1", Name: "group1", Issued: group.GroupIssuedAPI, Peers: []string{"peer1", "peer2", "peer3"}}, + "group2": {ID: "group2", Name: "group2", Issued: group.GroupIssuedAPI, Peers: []string{"peer1", "peer2", "peer3", "peer4", "peer5"}}, + "group3": {ID: "group3", Name: "group3", Issued: group.GroupIssuedAPI, Peers: []string{"peer4", "peer5"}}, }, Users: map[string]*User{"user1": {Id: "user1"}, "user2": {Id: "user2"}}, } @@ -2223,7 +2264,7 @@ func createManager(t *testing.T) (*DefaultAccountManager, error) { return nil, err } eventStore := &activity.InMemoryEventStore{} - return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false) + return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}) } func createStore(t *testing.T) (Store, error) { diff --git a/management/server/dns_test.go b/management/server/dns_test.go index aac35308c..18f942e68 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -8,6 +8,7 @@ import ( "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" ) @@ -193,7 +194,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) { return nil, err } eventStore := &activity.InMemoryEventStore{} - return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false) + return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}) } func createDNSStore(t *testing.T) (Store, error) { @@ -278,13 +279,13 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro return nil, err } - newGroup1 := &Group{ + newGroup1 := &group.Group{ ID: dnsGroup1ID, Peers: []string{peer1.ID}, Name: dnsGroup1ID, } - newGroup2 := &Group{ + newGroup2 := &group.Group{ ID: dnsGroup2ID, Name: dnsGroup2ID, } diff --git a/management/server/ephemeral.go b/management/server/ephemeral.go index 9d70a05d1..4fffa024d 100644 --- a/management/server/ephemeral.go +++ b/management/server/ephemeral.go @@ -165,7 +165,7 @@ func (e *EphemeralManager) cleanup() { log.Debugf("delete ephemeral peer: %s", id) err := e.accountManager.DeletePeer(p.account.Id, id, activity.SystemInitiator) if err != nil { - log.Tracef("failed to delete ephemeral peer: %s", err) + log.Errorf("failed to delete ephemeral peer: %s", err) } } } diff --git a/management/server/file_store.go b/management/server/file_store.go index 0228285cb..2de852bee 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -10,6 +10,7 @@ import ( "github.com/rs/xid" log "github.com/sirupsen/logrus" + nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/telemetry" @@ -170,7 +171,7 @@ func restore(file string) (*FileStore, error) { // Set API as issuer for groups which has not this field for _, group := range account.Groups { if group.Issued == "" { - group.Issued = GroupIssuedAPI + group.Issued = nbgroup.GroupIssuedAPI } } diff --git a/management/server/file_store_test.go b/management/server/file_store_test.go index d8575a3bf..d53298d8f 100644 --- a/management/server/file_store_test.go +++ b/management/server/file_store_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/util" ) @@ -188,7 +189,7 @@ func TestStore(t *testing.T) { Name: "peer name", Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } - account.Groups["all"] = &Group{ + account.Groups["all"] = &group.Group{ ID: "all", Name: "all", Peers: []string{"testpeer"}, @@ -320,7 +321,7 @@ func TestRestoreGroups_Migration(t *testing.T) { // create default group account := store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"] - account.Groups = map[string]*Group{ + account.Groups = map[string]*group.Group{ "cfefqs706sqkneg59g3g": { ID: "cfefqs706sqkneg59g3g", Name: "All", @@ -336,7 +337,7 @@ func TestRestoreGroups_Migration(t *testing.T) { account = store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"] require.Contains(t, account.Groups, "cfefqs706sqkneg59g3g", "failed to restore a FileStore file - missing Account Groups") - require.Equal(t, GroupIssuedAPI, account.Groups["cfefqs706sqkneg59g3g"].Issued, "default group should has API issued mark") + require.Equal(t, group.GroupIssuedAPI, account.Groups["cfefqs706sqkneg59g3g"].Issued, "default group should has API issued mark") } func TestGetAccountByPrivateDomain(t *testing.T) { @@ -384,6 +385,7 @@ func TestFileStore_GetAccount(t *testing.T) { expected := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"] if expected == nil { t.Fatalf("expected account doesn't exist") + return } account, err := store.GetAccount(expected.Id) diff --git a/management/server/group.go b/management/server/group.go index 59f05a354..0fc952cdb 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -7,6 +7,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/activity" + nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/status" ) @@ -19,51 +20,8 @@ func (e *GroupLinkError) Error() string { return fmt.Sprintf("group has been linked to %s: %s", e.Resource, e.Name) } -const ( - GroupIssuedAPI = "api" - GroupIssuedJWT = "jwt" - GroupIssuedIntegration = "integration" -) - -// Group of the peers for ACL -type Group struct { - // ID of the group - ID string - - // AccountID is a reference to Account that this object belongs - AccountID string `json:"-" gorm:"index"` - - // Name visible in the UI - Name string - - // Issued defines how this group was created (enum of "api", "integration" or "jwt") - Issued string - - // Peers list of the group - Peers []string `gorm:"serializer:json"` - - IntegrationReference IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"` -} - -// EventMeta returns activity event meta related to the group -func (g *Group) EventMeta() map[string]any { - return map[string]any{"name": g.Name} -} - -func (g *Group) Copy() *Group { - group := &Group{ - ID: g.ID, - Name: g.Name, - Issued: g.Issued, - Peers: make([]string, len(g.Peers)), - IntegrationReference: g.IntegrationReference, - } - copy(group.Peers, g.Peers) - return group -} - // GetGroup object of the peers -func (am *DefaultAccountManager) GetGroup(accountID, groupID, userID string) (*Group, error) { +func (am *DefaultAccountManager) GetGroup(accountID, groupID, userID string) (*nbgroup.Group, error) { unlock := am.Store.AcquireAccountLock(accountID) defer unlock() @@ -90,7 +48,7 @@ func (am *DefaultAccountManager) GetGroup(accountID, groupID, userID string) (*G } // GetAllGroups returns all groups in an account -func (am *DefaultAccountManager) GetAllGroups(accountID string, userID string) ([]*Group, error) { +func (am *DefaultAccountManager) GetAllGroups(accountID string, userID string) ([]*nbgroup.Group, error) { unlock := am.Store.AcquireAccountLock(accountID) defer unlock() @@ -108,7 +66,7 @@ func (am *DefaultAccountManager) GetAllGroups(accountID string, userID string) ( return nil, status.Errorf(status.PermissionDenied, "groups are blocked for users") } - groups := make([]*Group, 0, len(account.Groups)) + groups := make([]*nbgroup.Group, 0, len(account.Groups)) for _, item := range account.Groups { groups = append(groups, item) } @@ -117,7 +75,7 @@ func (am *DefaultAccountManager) GetAllGroups(accountID string, userID string) ( } // GetGroupByName filters all groups in an account by name and returns the one with the most peers -func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*Group, error) { +func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*nbgroup.Group, error) { unlock := am.Store.AcquireAccountLock(accountID) defer unlock() @@ -126,7 +84,7 @@ func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*G return nil, err } - matchingGroups := make([]*Group, 0) + matchingGroups := make([]*nbgroup.Group, 0) for _, group := range account.Groups { if group.Name == groupName { matchingGroups = append(matchingGroups, group) @@ -138,7 +96,7 @@ func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*G } maxPeers := -1 - var groupWithMostPeers *Group + var groupWithMostPeers *nbgroup.Group for i, group := range matchingGroups { if len(group.Peers) > maxPeers { maxPeers = len(group.Peers) @@ -150,7 +108,7 @@ func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*G } // SaveGroup object of the peers -func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *Group) error { +func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *nbgroup.Group) error { unlock := am.Store.AcquireAccountLock(accountID) defer unlock() @@ -159,11 +117,11 @@ func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *G return err } - if newGroup.ID == "" && newGroup.Issued != GroupIssuedAPI { + if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI { return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued) } - if newGroup.ID == "" && newGroup.Issued == GroupIssuedAPI { + if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI { existingGroup, err := account.FindGroupByName(newGroup.Name) if err != nil { @@ -270,7 +228,7 @@ func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string) } // disable a deleting integration group if the initiator is not an admin service user - if g.Issued == GroupIssuedIntegration { + if g.Issued == nbgroup.GroupIssuedIntegration { executingUser := account.Users[userId] if executingUser == nil { return status.Errorf(status.NotFound, "user not found") @@ -340,6 +298,15 @@ func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string) } } + // check integrated peer validator groups + if account.Settings.Extra != nil { + for _, integratedPeerValidatorGroups := range account.Settings.Extra.IntegratedValidatorGroups { + if groupID == integratedPeerValidatorGroups { + return &GroupLinkError{"integrated validator", g.Name} + } + } + } + delete(account.Groups, groupID) account.Network.IncSerial() @@ -355,7 +322,7 @@ func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string) } // ListGroups objects of the peers -func (am *DefaultAccountManager) ListGroups(accountID string) ([]*Group, error) { +func (am *DefaultAccountManager) ListGroups(accountID string) ([]*nbgroup.Group, error) { unlock := am.Store.AcquireAccountLock(accountID) defer unlock() @@ -364,7 +331,7 @@ func (am *DefaultAccountManager) ListGroups(accountID string) ([]*Group, error) return nil, err } - groups := make([]*Group, 0, len(account.Groups)) + groups := make([]*nbgroup.Group, 0, len(account.Groups)) for _, item := range account.Groups { groups = append(groups, item) } diff --git a/management/server/group/group.go b/management/server/group/group.go new file mode 100644 index 000000000..79dfd995c --- /dev/null +++ b/management/server/group/group.go @@ -0,0 +1,46 @@ +package group + +import "github.com/netbirdio/netbird/management/server/integration_reference" + +const ( + GroupIssuedAPI = "api" + GroupIssuedJWT = "jwt" + GroupIssuedIntegration = "integration" +) + +// Group of the peers for ACL +type Group struct { + // ID of the group + ID string + + // AccountID is a reference to Account that this object belongs + AccountID string `json:"-" gorm:"index"` + + // Name visible in the UI + Name string + + // Issued defines how this group was created (enum of "api", "integration" or "jwt") + Issued string + + // Peers list of the group + Peers []string `gorm:"serializer:json"` + + IntegrationReference integration_reference.IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"` +} + +// EventMeta returns activity event meta related to the group +func (g *Group) EventMeta() map[string]any { + return map[string]any{"name": g.Name} +} + +func (g *Group) Copy() *Group { + group := &Group{ + ID: g.ID, + Name: g.Name, + Issued: g.Issued, + Peers: make([]string, len(g.Peers)), + IntegrationReference: g.IntegrationReference, + } + copy(group.Peers, g.Peers) + return group +} diff --git a/management/server/group_test.go b/management/server/group_test.go index 3a2195c88..35e9b2170 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -5,6 +5,7 @@ import ( "testing" nbdns "github.com/netbirdio/netbird/dns" + nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/route" ) @@ -24,22 +25,22 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) { t.Error("failed to init testing account") } for _, group := range account.Groups { - group.Issued = GroupIssuedIntegration + group.Issued = nbgroup.GroupIssuedIntegration err = am.SaveGroup(account.Id, groupAdminUserID, group) if err != nil { - t.Errorf("should allow to create %s groups", GroupIssuedIntegration) + t.Errorf("should allow to create %s groups", nbgroup.GroupIssuedIntegration) } } for _, group := range account.Groups { - group.Issued = GroupIssuedJWT + group.Issued = nbgroup.GroupIssuedJWT err = am.SaveGroup(account.Id, groupAdminUserID, group) if err != nil { - t.Errorf("should allow to create %s groups", GroupIssuedJWT) + t.Errorf("should allow to create %s groups", nbgroup.GroupIssuedJWT) } } for _, group := range account.Groups { - group.Issued = GroupIssuedAPI + group.Issued = nbgroup.GroupIssuedAPI group.ID = "" err = am.SaveGroup(account.Id, groupAdminUserID, group) if err == nil { @@ -129,51 +130,51 @@ func initTestGroupAccount(am *DefaultAccountManager) (*Account, error) { accountID := "testingAcc" domain := "example.com" - groupForRoute := &Group{ + groupForRoute := &nbgroup.Group{ ID: "grp-for-route", AccountID: "account-id", Name: "Group for route", - Issued: GroupIssuedAPI, + Issued: nbgroup.GroupIssuedAPI, Peers: make([]string, 0), } - groupForNameServerGroups := &Group{ + groupForNameServerGroups := &nbgroup.Group{ ID: "grp-for-name-server-grp", AccountID: "account-id", Name: "Group for name server groups", - Issued: GroupIssuedAPI, + Issued: nbgroup.GroupIssuedAPI, Peers: make([]string, 0), } - groupForPolicies := &Group{ + groupForPolicies := &nbgroup.Group{ ID: "grp-for-policies", AccountID: "account-id", Name: "Group for policies", - Issued: GroupIssuedAPI, + Issued: nbgroup.GroupIssuedAPI, Peers: make([]string, 0), } - groupForSetupKeys := &Group{ + groupForSetupKeys := &nbgroup.Group{ ID: "grp-for-keys", AccountID: "account-id", Name: "Group for setup keys", - Issued: GroupIssuedAPI, + Issued: nbgroup.GroupIssuedAPI, Peers: make([]string, 0), } - groupForUsers := &Group{ + groupForUsers := &nbgroup.Group{ ID: "grp-for-users", AccountID: "account-id", Name: "Group for users", - Issued: GroupIssuedAPI, + Issued: nbgroup.GroupIssuedAPI, Peers: make([]string, 0), } - groupForIntegration := &Group{ + groupForIntegration := &nbgroup.Group{ ID: "grp-for-integration", AccountID: "account-id", Name: "Group for users integration", - Issued: GroupIssuedIntegration, + Issued: nbgroup.GroupIssuedIntegration, Peers: make([]string, 0), } diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 341d202b6..340adcfc6 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -361,6 +361,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p Meta: extractPeerMeta(loginReq), UserID: userID, SetupKey: loginReq.GetSetupKey(), + ConnectionIP: realIP, }) if err != nil { diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index 281089396..f8f581bd4 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -355,6 +355,7 @@ components: - user_id - version - ui_version + - approval_required AccessiblePeer: allOf: - $ref: '#/components/schemas/PeerMinimum' diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index 78cd83a27..0bed93b3c 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -470,7 +470,7 @@ type Peer struct { AccessiblePeers []AccessiblePeer `json:"accessible_peers"` // ApprovalRequired (Cloud only) Indicates whether peer needs approval - ApprovalRequired *bool `json:"approval_required,omitempty"` + ApprovalRequired bool `json:"approval_required"` // CityName Commonly used English name of the city CityName CityName `json:"city_name"` @@ -539,7 +539,7 @@ type Peer struct { // PeerBase defines model for PeerBase. type PeerBase struct { // ApprovalRequired (Cloud only) Indicates whether peer needs approval - ApprovalRequired *bool `json:"approval_required,omitempty"` + ApprovalRequired bool `json:"approval_required"` // CityName Commonly used English name of the city CityName CityName `json:"city_name"` @@ -611,7 +611,7 @@ type PeerBatch struct { AccessiblePeersCount int `json:"accessible_peers_count"` // ApprovalRequired (Cloud only) Indicates whether peer needs approval - ApprovalRequired *bool `json:"approval_required,omitempty"` + ApprovalRequired bool `json:"approval_required"` // CityName Commonly used English name of the city CityName CityName `json:"city_name"` diff --git a/management/server/http/groups_handler.go b/management/server/http/groups_handler.go index 56d06595f..47bcf2f32 100644 --- a/management/server/http/groups_handler.go +++ b/management/server/http/groups_handler.go @@ -4,15 +4,15 @@ import ( "encoding/json" "net/http" - "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/http/util" - "github.com/netbirdio/netbird/management/server/status" - - "github.com/netbirdio/netbird/management/server" - "github.com/netbirdio/netbird/management/server/jwtclaims" - "github.com/gorilla/mux" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/management/server" + nbgroup "github.com/netbirdio/netbird/management/server/group" + "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/util" + "github.com/netbirdio/netbird/management/server/jwtclaims" + "github.com/netbirdio/netbird/management/server/status" ) // GroupsHandler is a handler that returns groups of the account @@ -110,7 +110,7 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) { } else { peers = *req.Peers } - group := server.Group{ + group := nbgroup.Group{ ID: groupID, Name: req.Name, Peers: peers, @@ -154,10 +154,10 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) { } else { peers = *req.Peers } - group := server.Group{ + group := nbgroup.Group{ Name: req.Name, Peers: peers, - Issued: server.GroupIssuedAPI, + Issued: nbgroup.GroupIssuedAPI, } err = h.accountManager.SaveGroup(account.Id, user.Id, &group) @@ -240,7 +240,7 @@ func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) { } } -func toGroupResponse(account *server.Account, group *server.Group) *api.Group { +func toGroupResponse(account *server.Account, group *nbgroup.Group) *api.Group { cache := make(map[string]api.PeerMinimum) gr := api.Group{ Id: group.ID, diff --git a/management/server/http/groups_handler_test.go b/management/server/http/groups_handler_test.go index 303efc9d7..3d74b848c 100644 --- a/management/server/http/groups_handler_test.go +++ b/management/server/http/groups_handler_test.go @@ -15,6 +15,7 @@ import ( "github.com/magiconair/properties/assert" "github.com/netbirdio/netbird/management/server" + nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -28,30 +29,30 @@ var TestPeers = map[string]*nbpeer.Peer{ "B": {Key: "B", ID: "peer-B-ID", IP: net.ParseIP("200.200.200.200")}, } -func initGroupTestData(user *server.User, groups ...*server.Group) *GroupsHandler { +func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler { return &GroupsHandler{ accountManager: &mock_server.MockAccountManager{ - SaveGroupFunc: func(accountID, userID string, group *server.Group) error { + SaveGroupFunc: func(accountID, userID string, group *nbgroup.Group) error { if !strings.HasPrefix(group.ID, "id-") { group.ID = "id-was-set" } return nil }, - GetGroupFunc: func(_, groupID, _ string) (*server.Group, error) { + GetGroupFunc: func(_, groupID, _ string) (*nbgroup.Group, error) { if groupID != "idofthegroup" { return nil, status.Errorf(status.NotFound, "not found") } if groupID == "id-jwt-group" { - return &server.Group{ + return &nbgroup.Group{ ID: "id-jwt-group", Name: "Default Group", - Issued: server.GroupIssuedJWT, + Issued: nbgroup.GroupIssuedJWT, }, nil } - return &server.Group{ + return &nbgroup.Group{ ID: "idofthegroup", Name: "Group", - Issued: server.GroupIssuedAPI, + Issued: nbgroup.GroupIssuedAPI, }, nil }, GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { @@ -62,10 +63,10 @@ func initGroupTestData(user *server.User, groups ...*server.Group) *GroupsHandle Users: map[string]*server.User{ user.Id: user, }, - Groups: map[string]*server.Group{ - "id-jwt-group": {ID: "id-jwt-group", Name: "From JWT", Issued: server.GroupIssuedJWT}, - "id-existed": {ID: "id-existed", Peers: []string{"A", "B"}, Issued: server.GroupIssuedAPI}, - "id-all": {ID: "id-all", Name: "All", Issued: server.GroupIssuedAPI}, + Groups: map[string]*nbgroup.Group{ + "id-jwt-group": {ID: "id-jwt-group", Name: "From JWT", Issued: nbgroup.GroupIssuedJWT}, + "id-existed": {ID: "id-existed", Peers: []string{"A", "B"}, Issued: nbgroup.GroupIssuedAPI}, + "id-all": {ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI}, }, }, user, nil }, @@ -118,7 +119,7 @@ func TestGetGroup(t *testing.T) { }, } - group := &server.Group{ + group := &nbgroup.Group{ ID: "idofthegroup", Name: "Group", } @@ -153,7 +154,7 @@ func TestGetGroup(t *testing.T) { t.Fatalf("I don't know what I expected; %v", err) } - got := &server.Group{} + got := &nbgroup.Group{} if err = json.Unmarshal(content, &got); err != nil { t.Fatalf("Sent content is not in correct json format; %v", err) } diff --git a/management/server/http/handler.go b/management/server/http/handler.go index d035ae0b7..bdbeba346 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -9,7 +9,6 @@ import ( "github.com/rs/cors" "github.com/netbirdio/management-integrations/integrations" - s "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/http/middleware" diff --git a/management/server/http/peers_handler.go b/management/server/http/peers_handler.go index d4d2558e8..77b4578f8 100644 --- a/management/server/http/peers_handler.go +++ b/management/server/http/peers_handler.go @@ -6,8 +6,10 @@ import ( "net/http" "github.com/gorilla/mux" + log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server" + nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -61,10 +63,18 @@ func (h *PeersHandler) getPeer(account *server.Account, peerID, userID string, w groupsInfo := toGroupsInfo(account.Groups, peer.ID) - netMap := account.GetPeerNetworkMap(peerID, h.accountManager.GetDNSDomain()) + validPeers, err := h.accountManager.GetValidatedPeers(account) + if err != nil { + log.Errorf("failed to list appreoved peers: %v", err) + util.WriteError(fmt.Errorf("internal error"), w) + return + } + + netMap := account.GetPeerNetworkMap(peerID, h.accountManager.GetDNSDomain(), validPeers) accessiblePeers := toAccessiblePeers(netMap, dnsDomain) - util.WriteJSONObject(w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, accessiblePeers)) + _, valid := validPeers[peer.ID] + util.WriteJSONObject(w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, accessiblePeers, valid)) } func (h *PeersHandler) updatePeer(account *server.Account, user *server.User, peerID string, w http.ResponseWriter, r *http.Request) { @@ -75,11 +85,18 @@ func (h *PeersHandler) updatePeer(account *server.Account, user *server.User, pe return } - update := &nbpeer.Peer{ID: peerID, SSHEnabled: req.SshEnabled, Name: req.Name, - LoginExpirationEnabled: req.LoginExpirationEnabled} + update := &nbpeer.Peer{ + ID: peerID, + SSHEnabled: req.SshEnabled, + Name: req.Name, + LoginExpirationEnabled: req.LoginExpirationEnabled, + } if req.ApprovalRequired != nil { - update.Status = &nbpeer.PeerStatus{RequiresApproval: *req.ApprovalRequired} + // todo: looks like that we reset all status property, is it right? + update.Status = &nbpeer.PeerStatus{ + RequiresApproval: *req.ApprovalRequired, + } } peer, err := h.accountManager.UpdatePeer(account.Id, user.Id, update) @@ -91,15 +108,24 @@ func (h *PeersHandler) updatePeer(account *server.Account, user *server.User, pe groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID) - netMap := account.GetPeerNetworkMap(peerID, h.accountManager.GetDNSDomain()) + validPeers, err := h.accountManager.GetValidatedPeers(account) + if err != nil { + log.Errorf("failed to list appreoved peers: %v", err) + util.WriteError(fmt.Errorf("internal error"), w) + return + } + netMap := account.GetPeerNetworkMap(peerID, h.accountManager.GetDNSDomain(), validPeers) accessiblePeers := toAccessiblePeers(netMap, dnsDomain) - util.WriteJSONObject(w, toSinglePeerResponse(peer, groupMinimumInfo, dnsDomain, accessiblePeers)) + _, valid := validPeers[peer.ID] + + util.WriteJSONObject(w, toSinglePeerResponse(peer, groupMinimumInfo, dnsDomain, accessiblePeers, valid)) } func (h *PeersHandler) deletePeer(accountID, userID string, peerID string, w http.ResponseWriter) { err := h.accountManager.DeletePeer(accountID, peerID, userID) if err != nil { + log.Errorf("failed to delete peer: %v", err) util.WriteError(err, w) return } @@ -138,46 +164,68 @@ func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) { // GetAllPeers returns a list of all peers associated with a provided account func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) { - switch r.Method { - case http.MethodGet: - claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) - if err != nil { - util.WriteError(err, w) - return - } - - peers, err := h.accountManager.GetPeers(account.Id, user.Id) - if err != nil { - util.WriteError(err, w) - return - } - - dnsDomain := h.accountManager.GetDNSDomain() - - respBody := make([]*api.PeerBatch, 0, len(peers)) - for _, peer := range peers { - peerToReturn, err := h.checkPeerStatus(peer) - if err != nil { - util.WriteError(err, w) - return - } - groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID) - - accessiblePeerNumbers := h.accessiblePeersNumber(account, peer.ID) - - respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, accessiblePeerNumbers)) - } - util.WriteJSONObject(w, respBody) - return - default: + if r.Method != http.MethodGet { util.WriteError(status.Errorf(status.NotFound, "unknown METHOD"), w) + return } + + claims := h.claimsExtractor.FromRequestContext(r) + account, user, err := h.accountManager.GetAccountFromToken(claims) + if err != nil { + util.WriteError(err, w) + return + } + + peers, err := h.accountManager.GetPeers(account.Id, user.Id) + if err != nil { + util.WriteError(err, w) + return + } + + dnsDomain := h.accountManager.GetDNSDomain() + + respBody := make([]*api.PeerBatch, 0, len(peers)) + for _, peer := range peers { + peerToReturn, err := h.checkPeerStatus(peer) + if err != nil { + util.WriteError(err, w) + return + } + groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID) + + accessiblePeerNumbers, _ := h.accessiblePeersNumber(account, peer.ID) + + respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, accessiblePeerNumbers)) + } + + validPeersMap, err := h.accountManager.GetValidatedPeers(account) + if err != nil { + log.Errorf("failed to list appreoved peers: %v", err) + util.WriteError(fmt.Errorf("internal error"), w) + return + } + h.setApprovalRequiredFlag(respBody, validPeersMap) + + util.WriteJSONObject(w, respBody) } -func (h *PeersHandler) accessiblePeersNumber(account *server.Account, peerID string) int { - netMap := account.GetPeerNetworkMap(peerID, h.accountManager.GetDNSDomain()) - return len(netMap.Peers) + len(netMap.OfflinePeers) +func (h *PeersHandler) accessiblePeersNumber(account *server.Account, peerID string) (int, error) { + validatedPeersMap, err := h.accountManager.GetValidatedPeers(account) + if err != nil { + return 0, err + } + + netMap := account.GetPeerNetworkMap(peerID, h.accountManager.GetDNSDomain(), validatedPeersMap) + return len(netMap.Peers) + len(netMap.OfflinePeers), nil +} + +func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approvedPeersMap map[string]struct{}) { + for _, peer := range respBody { + _, ok := approvedPeersMap[peer.Id] + if !ok { + peer.ApprovalRequired = true + } + } } func toAccessiblePeers(netMap *server.NetworkMap, dnsDomain string) []api.AccessiblePeer { @@ -206,7 +254,7 @@ func toAccessiblePeers(netMap *server.NetworkMap, dnsDomain string) []api.Access return accessiblePeers } -func toGroupsInfo(groups map[string]*server.Group, peerID string) []api.GroupMinimum { +func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMinimum { var groupsInfo []api.GroupMinimum groupsChecked := make(map[string]struct{}) for _, group := range groups { @@ -230,7 +278,7 @@ func toGroupsInfo(groups map[string]*server.Group, peerID string) []api.GroupMin return groupsInfo } -func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, accessiblePeer []api.AccessiblePeer) *api.Peer { +func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, accessiblePeer []api.AccessiblePeer, approved bool) *api.Peer { osVersion := peer.Meta.OSVersion if osVersion == "" { osVersion = peer.Meta.Core @@ -257,7 +305,7 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD LastLogin: peer.LastLogin, LoginExpired: peer.Status.LoginExpired, AccessiblePeers: accessiblePeer, - ApprovalRequired: &peer.Status.RequiresApproval, + ApprovalRequired: !approved, CountryCode: peer.Location.CountryCode, CityName: peer.Location.CityName, } @@ -290,7 +338,6 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn LastLogin: peer.LastLogin, LoginExpired: peer.Status.LoginExpired, AccessiblePeersCount: accessiblePeersCount, - ApprovalRequired: &peer.Status.RequiresApproval, CountryCode: peer.Location.CountryCode, CityName: peer.Location.CityName, } diff --git a/management/server/http/policies_handler_test.go b/management/server/http/policies_handler_test.go index e6b858036..74e682854 100644 --- a/management/server/http/policies_handler_test.go +++ b/management/server/http/policies_handler_test.go @@ -9,6 +9,7 @@ import ( "strings" "testing" + nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/status" @@ -51,7 +52,7 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies { Policies: []*server.Policy{ {ID: "id-existed"}, }, - Groups: map[string]*server.Group{ + Groups: map[string]*nbgroup.Group{ "F": {ID: "F"}, "G": {ID: "G"}, }, diff --git a/management/server/http/setupkeys_handler_test.go b/management/server/http/setupkeys_handler_test.go index 7b68479ed..ebbd5954f 100644 --- a/management/server/http/setupkeys_handler_test.go +++ b/management/server/http/setupkeys_handler_test.go @@ -13,13 +13,12 @@ import ( "github.com/gorilla/mux" "github.com/stretchr/testify/assert" - "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/status" - - "github.com/netbirdio/netbird/management/server/jwtclaims" - "github.com/netbirdio/netbird/management/server" + nbgroup "github.com/netbirdio/netbird/management/server/group" + "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" + "github.com/netbirdio/netbird/management/server/status" ) const ( @@ -44,7 +43,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup SetupKeys: map[string]*server.SetupKey{ defaultKey.Key: defaultKey, }, - Groups: map[string]*server.Group{ + Groups: map[string]*nbgroup.Group{ "group-1": {ID: "group-1", Peers: []string{"A", "B"}}, "id-all": {ID: "id-all", Name: "All"}, }, diff --git a/management/server/http/util/util.go b/management/server/http/util/util.go index 4e2c3d0b3..2bb279c76 100644 --- a/management/server/http/util/util.go +++ b/management/server/http/util/util.go @@ -99,6 +99,8 @@ func WriteError(err error, w http.ResponseWriter) { httpStatus = http.StatusUnprocessableEntity case status.Unauthorized: httpStatus = http.StatusUnauthorized + case status.BadRequest: + httpStatus = http.StatusBadRequest default: } msg = strings.ToLower(err.Error()) diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go new file mode 100644 index 000000000..cd770a801 --- /dev/null +++ b/management/server/integrated_validator.go @@ -0,0 +1,80 @@ +package server + +import ( + "errors" + + "github.com/google/martian/v3/log" + + "github.com/netbirdio/netbird/management/server/account" +) + +// UpdateIntegratedValidatorGroups updates the integrated validator groups for a specified account. +// It retrieves the account associated with the provided userID, then updates the integrated validator groups +// with the provided list of group ids. The updated account is then saved. +// +// Parameters: +// - accountID: The ID of the account for which integrated validator groups are to be updated. +// - userID: The ID of the user whose account is being updated. +// - groups: A slice of strings representing the ids of integrated validator groups to be updated. +// +// Returns: +// - error: An error if any occurred during the process, otherwise returns nil +func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(accountID string, userID string, groups []string) error { + ok, err := am.GroupValidation(accountID, groups) + if err != nil { + log.Debugf("error validating groups: %s", err.Error()) + return err + } + + if !ok { + log.Debugf("invalid groups") + return errors.New("invalid groups") + } + + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() + + a, err := am.Store.GetAccountByUser(userID) + if err != nil { + return err + } + + var extra *account.ExtraSettings + + if a.Settings.Extra != nil { + extra = a.Settings.Extra + } else { + extra = &account.ExtraSettings{} + a.Settings.Extra = extra + } + extra.IntegratedValidatorGroups = groups + return am.Store.SaveAccount(a) +} + +func (am *DefaultAccountManager) GroupValidation(accountId string, groups []string) (bool, error) { + if len(groups) == 0 { + return true, nil + } + accountsGroups, err := am.ListGroups(accountId) + if err != nil { + return false, err + } + for _, group := range groups { + var found bool + for _, accountGroup := range accountsGroups { + if accountGroup.ID == group { + found = true + break + } + } + if !found { + return false, nil + } + } + + return true, nil +} + +func (am *DefaultAccountManager) GetValidatedPeers(account *Account) (map[string]struct{}, error) { + return am.integratedPeerValidator.GetValidatedPeers(account.Id, account.Groups, account.Peers, account.Settings.Extra) +} diff --git a/management/server/integrated_validator/interface.go b/management/server/integrated_validator/interface.go new file mode 100644 index 000000000..e87755b87 --- /dev/null +++ b/management/server/integrated_validator/interface.go @@ -0,0 +1,19 @@ +package integrated_validator + +import ( + "github.com/netbirdio/netbird/management/server/account" + nbgroup "github.com/netbirdio/netbird/management/server/group" + nbpeer "github.com/netbirdio/netbird/management/server/peer" +) + +// IntegratedValidator interface exists to avoid the circle dependencies +type IntegratedValidator interface { + ValidateExtraSettings(newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error + ValidatePeer(update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) + PreparePeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer + IsNotValidPeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool) + GetValidatedPeers(accountID string, groups map[string]*nbgroup.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) + PeerDeleted(accountID, peerID string) error + SetPeerInvalidationListener(fn func(accountID string)) + Stop() +} diff --git a/management/server/integration_reference/integration_reference.go b/management/server/integration_reference/integration_reference.go new file mode 100644 index 000000000..254b4e62f --- /dev/null +++ b/management/server/integration_reference/integration_reference.go @@ -0,0 +1,23 @@ +package integration_reference + +import ( + "fmt" + "strings" +) + +// IntegrationReference holds the reference to a particular integration +type IntegrationReference struct { + ID int + IntegrationType string +} + +func (ir IntegrationReference) String() string { + return fmt.Sprintf("%s:%d", ir.IntegrationType, ir.ID) +} + +func (ir IntegrationReference) CacheKey(path ...string) string { + if len(path) == 0 { + return ir.String() + } + return fmt.Sprintf("%s:%s", ir.String(), strings.Join(path, ":")) +} diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index 6ea902003..98ad0de0c 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -9,8 +9,6 @@ import ( "testing" "time" - "github.com/netbirdio/netbird/management/server/activity" - "github.com/stretchr/testify/require" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc" @@ -19,6 +17,7 @@ import ( "github.com/netbirdio/netbird/encryption" mgmtProto "github.com/netbirdio/netbird/management/proto" + "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/util" ) @@ -413,7 +412,7 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error) peersUpdateManager := NewPeersUpdateManager(nil) eventStore := &activity.InMemoryEventStore{} accountManager, err := BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", - eventStore, nil, false) + eventStore, nil, false, MocIntegratedValidator{}) if err != nil { return nil, "", err } diff --git a/management/server/management_test.go b/management/server/management_test.go index fb3f74cb9..13db5ae95 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -10,24 +10,22 @@ import ( sync2 "sync" "time" - "github.com/netbirdio/netbird/management/server/activity" - - "google.golang.org/grpc/credentials/insecure" - - "github.com/netbirdio/netbird/management/server" - pb "github.com/golang/protobuf/proto" //nolint - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/encryption" - . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" + log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/keepalive" + "github.com/netbirdio/netbird/encryption" mgmtProto "github.com/netbirdio/netbird/management/proto" + "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/account" + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/group" + nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/util" ) @@ -448,6 +446,43 @@ var _ = Describe("Management service", func() { }) }) +type MocIntegratedValidator struct { +} + +func (a MocIntegratedValidator) ValidateExtraSettings(newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error { + return nil +} + +func (a MocIntegratedValidator) ValidatePeer(update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) { + return update, nil +} + +func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*group.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) { + validatedPeers := make(map[string]struct{}) + for p := range peers { + validatedPeers[p] = struct{}{} + } + return validatedPeers, nil +} + +func (MocIntegratedValidator) PreparePeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer { + return peer +} + +func (MocIntegratedValidator) IsNotValidPeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool) { + return false, false +} + +func (MocIntegratedValidator) PeerDeleted(_, _ string) error { + return nil +} + +func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string)) { + +} + +func (MocIntegratedValidator) Stop() {} + func loginPeerWithValidSetupKey(serverPubKey wgtypes.Key, key wgtypes.Key, client mgmtProto.ManagementServiceClient) *mgmtProto.LoginResponse { defer GinkgoRecover() @@ -504,7 +539,7 @@ func startServer(config *server.Config) (*grpc.Server, net.Listener) { peersUpdateManager := server.NewPeersUpdateManager(nil) eventStore := &activity.InMemoryEventStore{} accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", - eventStore, nil, false) + eventStore, nil, false, MocIntegratedValidator{}) if err != nil { log.Fatalf("failed creating a manager: %v", err) } diff --git a/management/server/metrics/selfhosted_test.go b/management/server/metrics/selfhosted_test.go index 7be3c818d..c479867d2 100644 --- a/management/server/metrics/selfhosted_test.go +++ b/management/server/metrics/selfhosted_test.go @@ -5,6 +5,7 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/route" @@ -32,7 +33,7 @@ func (mockDatasource) GetAllAccounts() []*server.Account { UsedTimes: 1, }, }, - Groups: map[string]*server.Group{ + Groups: map[string]*group.Group{ "1": {}, "2": {}, }, @@ -117,7 +118,7 @@ func (mockDatasource) GetAllAccounts() []*server.Account { UsedTimes: 1, }, }, - Groups: map[string]*server.Group{ + Groups: map[string]*group.Group{ "1": {}, "2": {}, }, diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 9463498cf..8e7c47a28 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -10,6 +10,7 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" @@ -31,12 +32,12 @@ type MockAccountManager struct { GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error) GetPeerNetworkFunc func(peerKey string) (*server.Network, error) AddPeerFunc func(setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *server.NetworkMap, error) - GetGroupFunc func(accountID, groupID, userID string) (*server.Group, error) - GetAllGroupsFunc func(accountID, userID string) ([]*server.Group, error) - GetGroupByNameFunc func(accountID, groupName string) (*server.Group, error) - SaveGroupFunc func(accountID, userID string, group *server.Group) error + GetGroupFunc func(accountID, groupID, userID string) (*group.Group, error) + GetAllGroupsFunc func(accountID, userID string) ([]*group.Group, error) + GetGroupByNameFunc func(accountID, groupName string) (*group.Group, error) + SaveGroupFunc func(accountID, userID string, group *group.Group) error DeleteGroupFunc func(accountID, userId, groupID string) error - ListGroupsFunc func(accountID string) ([]*server.Group, error) + ListGroupsFunc func(accountID string) ([]*group.Group, error) GroupAddPeerFunc func(accountID, groupID, peerID string) error GroupDeletePeerFunc func(accountID, groupID, peerID string) error DeleteRuleFunc func(accountID, ruleID, userID string) error @@ -91,10 +92,20 @@ type MockAccountManager struct { DeletePostureChecksFunc func(accountID, postureChecksID, userID string) error ListPostureChecksFunc func(accountID, userID string) ([]*posture.Checks, error) GetIdpManagerFunc func() idp.Manager + UpdateIntegratedValidatorGroupsFunc func(accountID string, userID string, groups []string) error + GroupValidationFunc func(accountId string, groups []string) (bool, error) +} + +func (am *MockAccountManager) GetValidatedPeers(account *server.Account) (map[string]struct{}, error) { + approvedPeers := make(map[string]struct{}) + for id := range account.Peers { + approvedPeers[id] = struct{}{} + } + return approvedPeers, nil } // GetGroup mock implementation of GetGroup from server.AccountManager interface -func (am *MockAccountManager) GetGroup(accountId, groupID, userID string) (*server.Group, error) { +func (am *MockAccountManager) GetGroup(accountId, groupID, userID string) (*group.Group, error) { if am.GetGroupFunc != nil { return am.GetGroupFunc(accountId, groupID, userID) } @@ -102,7 +113,7 @@ func (am *MockAccountManager) GetGroup(accountId, groupID, userID string) (*serv } // GetAllGroups mock implementation of GetAllGroups from server.AccountManager interface -func (am *MockAccountManager) GetAllGroups(accountID, userID string) ([]*server.Group, error) { +func (am *MockAccountManager) GetAllGroups(accountID, userID string) ([]*group.Group, error) { if am.GetAllGroupsFunc != nil { return am.GetAllGroupsFunc(accountID, userID) } @@ -261,7 +272,7 @@ func (am *MockAccountManager) AddPeer( } // GetGroupByName mock implementation of GetGroupByName from server.AccountManager interface -func (am *MockAccountManager) GetGroupByName(accountID, groupName string) (*server.Group, error) { +func (am *MockAccountManager) GetGroupByName(accountID, groupName string) (*group.Group, error) { if am.GetGroupFunc != nil { return am.GetGroupByNameFunc(accountID, groupName) } @@ -269,7 +280,7 @@ func (am *MockAccountManager) GetGroupByName(accountID, groupName string) (*serv } // SaveGroup mock implementation of SaveGroup from server.AccountManager interface -func (am *MockAccountManager) SaveGroup(accountID, userID string, group *server.Group) error { +func (am *MockAccountManager) SaveGroup(accountID, userID string, group *group.Group) error { if am.SaveGroupFunc != nil { return am.SaveGroupFunc(accountID, userID, group) } @@ -285,7 +296,7 @@ func (am *MockAccountManager) DeleteGroup(accountId, userId, groupID string) err } // ListGroups mock implementation of ListGroups from server.AccountManager interface -func (am *MockAccountManager) ListGroups(accountID string) ([]*server.Group, error) { +func (am *MockAccountManager) ListGroups(accountID string) ([]*group.Group, error) { if am.ListGroupsFunc != nil { return am.ListGroupsFunc(accountID) } @@ -694,3 +705,19 @@ func (am *MockAccountManager) GetIdpManager() idp.Manager { } return nil } + +// UpdateIntegratedValidatedGroups mocks UpdateIntegratedApprovalGroups of the AccountManager interface +func (am *MockAccountManager) UpdateIntegratedValidatorGroups(accountID string, userID string, groups []string) error { + if am.UpdateIntegratedValidatorGroupsFunc != nil { + return am.UpdateIntegratedValidatorGroupsFunc(accountID, userID, groups) + } + return status.Errorf(codes.Unimplemented, "method UpdateIntegratedValidatorGroups is not implemented") +} + +// GroupValidation mocks GroupValidation of the AccountManager interface +func (am *MockAccountManager) GroupValidation(accountId string, groups []string) (bool, error) { + if am.GroupValidationFunc != nil { + return am.GroupValidationFunc(accountId, groups) + } + return false, status.Errorf(codes.Unimplemented, "method GroupValidation is not implemented") +} diff --git a/management/server/nameserver.go b/management/server/nameserver.go index e521805c8..fa7793602 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -10,6 +10,7 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/activity" + nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/status" ) @@ -261,7 +262,7 @@ func validateNSList(list []nbdns.NameServer) error { return nil } -func validateGroups(list []string, groups map[string]*Group) error { +func validateGroups(list []string, groups map[string]*nbgroup.Group) error { if len(list) == 0 { return status.Errorf(status.InvalidArgument, "the list of group IDs should not be empty") } diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index d04ac1a20..b10f9387a 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -8,6 +8,7 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/activity" + nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" ) @@ -759,7 +760,7 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) { return nil, err } eventStore := &activity.InMemoryEventStore{} - return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false) + return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}) } func createNSStore(t *testing.T) (Store, error) { @@ -831,12 +832,12 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error account.NameServerGroups[existingNSGroup.ID] = &existingNSGroup - newGroup1 := &Group{ + newGroup1 := &nbgroup.Group{ ID: group1ID, Name: group1ID, } - newGroup2 := &Group{ + newGroup2 := &nbgroup.Group{ ID: group2ID, Name: group2ID, } diff --git a/management/server/peer.go b/management/server/peer.go index 7de1b6542..fda8e49e9 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -7,16 +7,12 @@ import ( "time" "github.com/rs/xid" - - "github.com/netbirdio/management-integrations/additions" - - "github.com/netbirdio/netbird/management/server/activity" - nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/status" - log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/proto" + "github.com/netbirdio/netbird/management/server/activity" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/status" ) // PeerSync used as a data object between the gRPC API and AccountManager on Sync request. @@ -37,6 +33,8 @@ type PeerLogin struct { UserID string // SetupKey references to a server.SetupKey to log in. Can be empty when UserID is used or auth is not required. SetupKey string + // ConnectionIP is the real IP of the peer + ConnectionIP net.IP } // GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if @@ -52,6 +50,10 @@ func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*nbpeer.P return nil, err } + approvedPeersMap, err := am.GetValidatedPeers(account) + if err != nil { + return nil, err + } peers := make([]*nbpeer.Peer, 0) peersMap := make(map[string]*nbpeer.Peer) @@ -71,7 +73,7 @@ func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*nbpeer.P // fetch all the peers that have access to the user's peers for _, peer := range peers { - aclPeers, _ := account.getPeerConnectionResources(peer.ID) + aclPeers, _ := account.getPeerConnectionResources(peer.ID, approvedPeersMap) for _, p := range aclPeers { peersMap[p.ID] = p } @@ -167,7 +169,7 @@ func (am *DefaultAccountManager) UpdatePeer(accountID, userID string, update *nb return nil, status.Errorf(status.NotFound, "peer %s not found", update.ID) } - update, err = additions.ValidatePeersUpdateRequest(update, peer, userID, accountID, am.eventStore, am.GetDNSDomain()) + update, err = am.integratedPeerValidator.ValidatePeer(update, peer, userID, accountID, am.GetDNSDomain(), account.GetPeerGroupsList(peer.ID), account.Settings.Extra) if err != nil { return nil, err } @@ -244,6 +246,12 @@ func (am *DefaultAccountManager) deletePeers(account *Account, peerIDs []string, // the 2nd loop performs the actual modification for _, peer := range peers { + + err := am.integratedPeerValidator.PeerDeleted(account.Id, peer.ID) + if err != nil { + return err + } + account.DeletePeer(peer.ID) am.peersUpdateManager.SendUpdate(peer.ID, &UpdateMessage{ @@ -304,7 +312,17 @@ func (am *DefaultAccountManager) GetNetworkMap(peerID string) (*NetworkMap, erro if peer == nil { return nil, status.Errorf(status.NotFound, "peer with ID %s not found", peerID) } - return account.GetPeerNetworkMap(peer.ID, am.dnsDomain), nil + + groups := make(map[string][]string) + for groupID, group := range account.Groups { + groups[groupID] = group.Peers + } + + validatedPeers, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, account.Groups, account.Peers, account.Settings.Extra) + if err != nil { + return nil, err + } + return account.GetPeerNetworkMap(peer.ID, am.dnsDomain, validatedPeers), nil } // GetPeerNetwork returns the Network for a given peer @@ -433,10 +451,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P CreatedAt: registrationTime, LoginExpirationEnabled: addedByUser, Ephemeral: ephemeral, - } - - if account.Settings.Extra != nil { - newPeer = additions.PreparePeer(newPeer, account.Settings.Extra) + Location: peer.Location, } // add peer to 'All' group @@ -467,6 +482,8 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P } } + newPeer = am.integratedPeerValidator.PreparePeer(account.Id, newPeer, account.GetPeerGroupsList(newPeer.ID), account.Settings.Extra) + if addedByUser { user, err := account.FindUser(userID) if err != nil { @@ -492,7 +509,11 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P am.updateAccountPeers(account) - networkMap := account.GetPeerNetworkMap(newPeer.ID, am.dnsDomain) + approvedPeersMap, err := am.GetValidatedPeers(account) + if err != nil { + return nil, nil, err + } + networkMap := account.GetPeerNetworkMap(newPeer.ID, am.dnsDomain, approvedPeersMap) return newPeer, networkMap, nil } @@ -529,23 +550,53 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync) (*nbpeer.Peer, *Network if peerLoginExpired(peer, account) { return nil, nil, status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more") } - return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain), nil + + requiresApproval, isStatusChanged := am.integratedPeerValidator.IsNotValidPeer(account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra) + if requiresApproval { + emptyMap := &NetworkMap{ + Network: account.Network.Copy(), + } + return peer, emptyMap, nil + } + + if isStatusChanged { + am.updateAccountPeers(account) + } + + approvedPeersMap, err := am.GetValidatedPeers(account) + if err != nil { + return nil, nil, err + } + return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap), nil } // LoginPeer logs in or registers a peer. // If peer doesn't exist the function checks whether a setup key or a user is present and registers a new peer if so. func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *NetworkMap, error) { account, err := am.Store.GetAccountByPeerPubKey(login.WireGuardPubKey) - if err != nil { if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound { // we couldn't find this peer by its public key which can mean that peer hasn't been registered yet. // Try registering it. - return am.AddPeer(login.SetupKey, login.UserID, &nbpeer.Peer{ + newPeer := &nbpeer.Peer{ Key: login.WireGuardPubKey, Meta: login.Meta, SSHKey: login.SSHKey, - }) + } + if am.geo != nil && login.ConnectionIP != nil { + location, err := am.geo.Lookup(login.ConnectionIP) + if err != nil { + log.Warnf("failed to get location for new peer realip: [%s]: %v", login.ConnectionIP.String(), err) + } else { + newPeer.Location.ConnectionIP = login.ConnectionIP + newPeer.Location.CountryCode = location.Country.ISOCode + newPeer.Location.CityName = location.City.Names.En + newPeer.Location.GeoNameID = location.City.GeonameID + + } + } + + return am.AddPeer(login.SetupKey, login.UserID, newPeer) } log.Errorf("failed while logging in peer %s: %v", login.WireGuardPubKey, err) return nil, nil, status.Errorf(status.Internal, "failed while logging in peer") @@ -595,6 +646,7 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw am.StoreEvent(login.UserID, peer.ID, account.Id, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain())) } + isRequiresApproval, isStatusChanged := am.integratedPeerValidator.IsNotValidPeer(account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra) peer, updated := updatePeerMeta(peer, login.Meta, account) if updated { shouldStoreAccount = true @@ -612,10 +664,23 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw } } - if updateRemotePeers { + if updateRemotePeers || isStatusChanged { am.updateAccountPeers(account) } - return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain), nil + + if isRequiresApproval { + emptyMap := &NetworkMap{ + Network: account.Network.Copy(), + } + return peer, emptyMap, nil + } + + approvedPeersMap, err := am.GetValidatedPeers(account) + if err != nil { + return nil, nil, err + } + + return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap), nil } func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, account *Account) error { @@ -764,8 +829,13 @@ func (am *DefaultAccountManager) GetPeer(accountID, peerID, userID string) (*nbp return nil, err } + approvedPeersMap, err := am.GetValidatedPeers(account) + if err != nil { + return nil, err + } + for _, p := range userPeers { - aclPeers, _ := account.getPeerConnectionResources(p.ID) + aclPeers, _ := account.getPeerConnectionResources(p.ID, approvedPeersMap) for _, aclPeer := range aclPeers { if aclPeer.ID == peerID { return peer, nil @@ -789,8 +859,13 @@ func updatePeerMeta(peer *nbpeer.Peer, meta nbpeer.PeerSystemMeta, account *Acco func (am *DefaultAccountManager) updateAccountPeers(account *Account) { peers := account.GetPeers() + approvedPeersMap, err := am.GetValidatedPeers(account) + if err != nil { + log.Errorf("failed send out updates to peers, failed to validate peer: %v", err) + return + } for _, peer := range peers { - remotePeerNetworkMap := account.GetPeerNetworkMap(peer.ID, am.dnsDomain) + remotePeerNetworkMap := account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap) update := toSyncResponse(nil, peer, nil, remotePeerNetworkMap, am.GetDNSDomain()) am.peersUpdateManager.SendUpdate(peer.ID, &UpdateMessage{Update: update}) } diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 7f6d440bb..6063cc2a7 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -8,6 +8,7 @@ import ( "github.com/stretchr/testify/assert" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" ) @@ -199,8 +200,8 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { return } var ( - group1 Group - group2 Group + group1 nbgroup.Group + group2 nbgroup.Group policy Policy ) diff --git a/management/server/policy.go b/management/server/policy.go index 8265dabb5..e162d2b3b 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -5,11 +5,11 @@ import ( "strconv" "strings" - "github.com/netbirdio/management-integrations/additions" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server/activity" + nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" @@ -211,7 +211,8 @@ type FirewallRule struct { // getPeerConnectionResources for a given peer // // This function returns the list of peers and firewall rules that are applicable to a given peer. -func (a *Account) getPeerConnectionResources(peerID string) ([]*nbpeer.Peer, []*FirewallRule) { +func (a *Account) getPeerConnectionResources(peerID string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) { + generateResources, getAccumulatedResources := a.connResourcesGenerator() for _, policy := range a.Policies { if !policy.Enabled { @@ -223,10 +224,8 @@ func (a *Account) getPeerConnectionResources(peerID string) ([]*nbpeer.Peer, []* continue } - sourcePeers, peerInSources := getAllPeersFromGroups(a, rule.Sources, peerID, policy.SourcePostureChecks) - destinationPeers, peerInDestinations := getAllPeersFromGroups(a, rule.Destinations, peerID, nil) - sourcePeers = additions.ValidatePeers(sourcePeers) - destinationPeers = additions.ValidatePeers(destinationPeers) + sourcePeers, peerInSources := getAllPeersFromGroups(a, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap) + destinationPeers, peerInDestinations := getAllPeersFromGroups(a, rule.Destinations, peerID, nil, validatedPeersMap) if rule.Bidirectional { if peerInSources { @@ -264,7 +263,7 @@ func (a *Account) connResourcesGenerator() (func(*PolicyRule, []*nbpeer.Peer, in all, err := a.GetGroupAll() if err != nil { log.Errorf("failed to get group all: %v", err) - all = &Group{} + all = &nbgroup.Group{} } return func(rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) { @@ -491,7 +490,7 @@ func toProtocolFirewallRules(update []*FirewallRule) []*proto.FirewallRule { // // Important: Posture checks are applicable only to source group peers, // for destination group peers, call this method with an empty list of sourcePostureChecksIDs -func getAllPeersFromGroups(account *Account, groups []string, peerID string, sourcePostureChecksIDs []string) ([]*nbpeer.Peer, bool) { +func getAllPeersFromGroups(account *Account, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) { peerInGroups := false filteredPeers := make([]*nbpeer.Peer, 0, len(groups)) for _, g := range groups { @@ -512,6 +511,10 @@ func getAllPeersFromGroups(account *Account, groups []string, peerID string, sou continue } + if _, ok := validatedPeersMap[peer.ID]; !ok { + continue + } + if peer.ID == peerID { peerInGroups = true continue diff --git a/management/server/policy_test.go b/management/server/policy_test.go index 681bab1da..1ea3bb379 100644 --- a/management/server/policy_test.go +++ b/management/server/policy_test.go @@ -8,6 +8,7 @@ import ( "github.com/stretchr/testify/assert" "golang.org/x/exp/slices" + nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" ) @@ -56,7 +57,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) { Status: &nbpeer.PeerStatus{}, }, }, - Groups: map[string]*Group{ + Groups: map[string]*nbgroup.Group{ "GroupAll": { ID: "GroupAll", Name: "All", @@ -135,16 +136,21 @@ func TestAccount_getPeersByPolicy(t *testing.T) { }, } + validatedPeers := make(map[string]struct{}) + for p := range account.Peers { + validatedPeers[p] = struct{}{} + } + t.Run("check that all peers get map", func(t *testing.T) { for _, p := range account.Peers { - peers, firewallRules := account.getPeerConnectionResources(p.ID) + peers, firewallRules := account.getPeerConnectionResources(p.ID, validatedPeers) assert.GreaterOrEqual(t, len(peers), 2, "minimum number peers should present") assert.GreaterOrEqual(t, len(firewallRules), 2, "minimum number of firewall rules should present") } }) t.Run("check first peer map details", func(t *testing.T) { - peers, firewallRules := account.getPeerConnectionResources("peerB") + peers, firewallRules := account.getPeerConnectionResources("peerB", validatedPeers) assert.Len(t, peers, 7) assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerC"]) @@ -299,7 +305,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { Status: &nbpeer.PeerStatus{}, }, }, - Groups: map[string]*Group{ + Groups: map[string]*nbgroup.Group{ "GroupAll": { ID: "GroupAll", Name: "All", @@ -374,8 +380,13 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { }, } + approvedPeers := make(map[string]struct{}) + for p := range account.Peers { + approvedPeers[p] = struct{}{} + } + t.Run("check first peer map", func(t *testing.T) { - peers, firewallRules := account.getPeerConnectionResources("peerB") + peers, firewallRules := account.getPeerConnectionResources("peerB", approvedPeers) assert.Contains(t, peers, account.Peers["peerC"]) epectedFirewallRules := []*FirewallRule{ @@ -403,7 +414,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { }) t.Run("check second peer map", func(t *testing.T) { - peers, firewallRules := account.getPeerConnectionResources("peerC") + peers, firewallRules := account.getPeerConnectionResources("peerC", approvedPeers) assert.Contains(t, peers, account.Peers["peerB"]) epectedFirewallRules := []*FirewallRule{ @@ -433,7 +444,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { account.Policies[1].Rules[0].Bidirectional = false t.Run("check first peer map directional only", func(t *testing.T) { - peers, firewallRules := account.getPeerConnectionResources("peerB") + peers, firewallRules := account.getPeerConnectionResources("peerB", approvedPeers) assert.Contains(t, peers, account.Peers["peerC"]) epectedFirewallRules := []*FirewallRule{ @@ -454,7 +465,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { }) t.Run("check second peer map directional only", func(t *testing.T) { - peers, firewallRules := account.getPeerConnectionResources("peerC") + peers, firewallRules := account.getPeerConnectionResources("peerC", approvedPeers) assert.Contains(t, peers, account.Peers["peerB"]) epectedFirewallRules := []*FirewallRule{ @@ -569,7 +580,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { }, }, }, - Groups: map[string]*Group{ + Groups: map[string]*nbgroup.Group{ "GroupAll": { ID: "GroupAll", Name: "All", @@ -644,10 +655,14 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { }, }) + approvedPeers := make(map[string]struct{}) + for p := range account.Peers { + approvedPeers[p] = struct{}{} + } t.Run("verify peer's network map with default group peer list", func(t *testing.T) { // peerB doesn't fulfill the NB posture check but is included in the destination group Swarm, // will establish a connection with all source peers satisfying the NB posture check. - peers, firewallRules := account.getPeerConnectionResources("peerB") + peers, firewallRules := account.getPeerConnectionResources("peerB", approvedPeers) assert.Len(t, peers, 4) assert.Len(t, firewallRules, 4) assert.Contains(t, peers, account.Peers["peerA"]) @@ -657,7 +672,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerC satisfy the NB posture check, should establish connection to all destination group peer's // We expect a single permissive firewall rule which all outgoing connections - peers, firewallRules = account.getPeerConnectionResources("peerC") + peers, firewallRules = account.getPeerConnectionResources("peerC", approvedPeers) assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers)) assert.Len(t, firewallRules, 1) expectedFirewallRules := []*FirewallRule{ @@ -673,7 +688,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerE doesn't fulfill the NB posture check and exists in only destination group Swarm, // all source group peers satisfying the NB posture check should establish connection - peers, firewallRules = account.getPeerConnectionResources("peerE") + peers, firewallRules = account.getPeerConnectionResources("peerE", approvedPeers) assert.Len(t, peers, 4) assert.Len(t, firewallRules, 4) assert.Contains(t, peers, account.Peers["peerA"]) @@ -683,7 +698,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerI doesn't fulfill the OS version posture check and exists in only destination group Swarm, // all source group peers satisfying the NB posture check should establish connection - peers, firewallRules = account.getPeerConnectionResources("peerI") + peers, firewallRules = account.getPeerConnectionResources("peerI", approvedPeers) assert.Len(t, peers, 4) assert.Len(t, firewallRules, 4) assert.Contains(t, peers, account.Peers["peerA"]) @@ -698,19 +713,19 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerB doesn't satisfy the NB posture check, and doesn't exist in destination group peer's // no connection should be established to any peer of destination group - peers, firewallRules := account.getPeerConnectionResources("peerB") + peers, firewallRules := account.getPeerConnectionResources("peerB", approvedPeers) assert.Len(t, peers, 0) assert.Len(t, firewallRules, 0) // peerI doesn't satisfy the OS version posture check, and doesn't exist in destination group peer's // no connection should be established to any peer of destination group - peers, firewallRules = account.getPeerConnectionResources("peerI") + peers, firewallRules = account.getPeerConnectionResources("peerI", approvedPeers) assert.Len(t, peers, 0) assert.Len(t, firewallRules, 0) // peerC satisfy the NB posture check, should establish connection to all destination group peer's // We expect a single permissive firewall rule which all outgoing connections - peers, firewallRules = account.getPeerConnectionResources("peerC") + peers, firewallRules = account.getPeerConnectionResources("peerC", approvedPeers) assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers)) assert.Len(t, firewallRules, len(account.Groups["GroupSwarm"].Peers)) @@ -725,14 +740,14 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerE doesn't fulfill the NB posture check and exists in only destination group Swarm, // all source group peers satisfying the NB posture check should establish connection - peers, firewallRules = account.getPeerConnectionResources("peerE") + peers, firewallRules = account.getPeerConnectionResources("peerE", approvedPeers) assert.Len(t, peers, 3) assert.Len(t, firewallRules, 3) assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerC"]) assert.Contains(t, peers, account.Peers["peerD"]) - peers, firewallRules = account.getPeerConnectionResources("peerA") + peers, firewallRules = account.getPeerConnectionResources("peerA", approvedPeers) assert.Len(t, peers, 5) // assert peers from Group Swarm assert.Contains(t, peers, account.Peers["peerD"]) diff --git a/management/server/route_test.go b/management/server/route_test.go index 5a56eaa8b..9f8ea08c9 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/management/server/activity" + nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/route" ) @@ -858,7 +859,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) { groups, err := am.ListGroups(account.Id) require.NoError(t, err) - var groupHA1, groupHA2 *Group + var groupHA1, groupHA2 *nbgroup.Group for _, group := range groups { switch group.Name { case routeGroupHA1: @@ -967,7 +968,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { require.Len(t, peer2Routes.Routes, 1, "we should receive one route") require.True(t, peer1Routes.Routes[0].IsEqual(peer2Routes.Routes[0]), "routes should be the same for peers in the same group") - newGroup := &Group{ + newGroup := &nbgroup.Group{ ID: xid.New().String(), Name: "peer1 group", Peers: []string{peer1ID}, @@ -1014,7 +1015,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) { return nil, err } eventStore := &activity.InMemoryEventStore{} - return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false) + return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}) } func createRouterStore(t *testing.T) (Store, error) { @@ -1195,7 +1196,7 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er return nil, err } - newGroup := []*Group{ + newGroup := []*nbgroup.Group{ { ID: routeGroup1, Name: routeGroup1, diff --git a/management/server/setupkey_test.go b/management/server/setupkey_test.go index b714652f1..43edabbd6 100644 --- a/management/server/setupkey_test.go +++ b/management/server/setupkey_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/netbirdio/netbird/management/server/activity" + nbgroup "github.com/netbirdio/netbird/management/server/group" ) func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { @@ -24,7 +25,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { t.Fatal(err) } - err = manager.SaveGroup(account.Id, userID, &Group{ + err = manager.SaveGroup(account.Id, userID, &nbgroup.Group{ ID: "group_1", Name: "group_name_1", Peers: []string{}, @@ -82,7 +83,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { t.Fatal(err) } - err = manager.SaveGroup(account.Id, userID, &Group{ + err = manager.SaveGroup(account.Id, userID, &nbgroup.Group{ ID: "group_1", Name: "group_name_1", Peers: []string{}, @@ -91,7 +92,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { t.Fatal(err) } - err = manager.SaveGroup(account.Id, userID, &Group{ + err = manager.SaveGroup(account.Id, userID, &nbgroup.Group{ ID: "group_2", Name: "group_name_2", Peers: []string{}, @@ -178,7 +179,7 @@ func TestGetSetupKeys(t *testing.T) { t.Fatal(err) } - err = manager.SaveGroup(account.Id, userID, &Group{ + err = manager.SaveGroup(account.Id, userID, &nbgroup.Group{ ID: "group_1", Name: "group_name_1", Peers: []string{}, @@ -187,7 +188,7 @@ func TestGetSetupKeys(t *testing.T) { t.Fatal(err) } - err = manager.SaveGroup(account.Id, userID, &Group{ + err = manager.SaveGroup(account.Id, userID, &nbgroup.Group{ ID: "group_2", Name: "group_name_2", Peers: []string{}, diff --git a/management/server/sqlite_store.go b/management/server/sqlite_store.go index f6a6f92a7..e6a9c8467 100644 --- a/management/server/sqlite_store.go +++ b/management/server/sqlite_store.go @@ -17,6 +17,7 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/account" + nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" @@ -64,7 +65,7 @@ func NewSqliteStore(dataDir string, metrics telemetry.AppMetrics) (*SqliteStore, sql.SetMaxOpenConns(conns) // TODO: make it configurable err = db.AutoMigrate( - &SetupKey{}, &nbpeer.Peer{}, &User{}, &PersonalAccessToken{}, &Group{}, + &SetupKey{}, &nbpeer.Peer{}, &User{}, &PersonalAccessToken{}, &nbgroup.Group{}, &Account{}, &Policy{}, &PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{}, &installation{}, &account.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{}, ) @@ -99,17 +100,17 @@ func NewSqliteStoreFromFileStore(filestore *FileStore, dataDir string, metrics t // AcquireGlobalLock acquires global lock across all the accounts and returns a function that releases the lock func (s *SqliteStore) AcquireGlobalLock() (unlock func()) { - log.Debugf("acquiring global lock") + log.Tracef("acquiring global lock") start := time.Now() s.globalAccountLock.Lock() unlock = func() { s.globalAccountLock.Unlock() - log.Debugf("released global lock in %v", time.Since(start)) + log.Tracef("released global lock in %v", time.Since(start)) } took := time.Since(start) - log.Debugf("took %v to acquire global lock", took) + log.Tracef("took %v to acquire global lock", took) if s.metrics != nil { s.metrics.StoreMetrics().CountGlobalLockAcquisitionDuration(took) } @@ -118,7 +119,7 @@ func (s *SqliteStore) AcquireGlobalLock() (unlock func()) { } func (s *SqliteStore) AcquireAccountLock(accountID string) (unlock func()) { - log.Debugf("acquiring lock for account %s", accountID) + log.Tracef("acquiring lock for account %s", accountID) start := time.Now() value, _ := s.accountLocks.LoadOrStore(accountID, &sync.Mutex{}) @@ -127,7 +128,7 @@ func (s *SqliteStore) AcquireAccountLock(accountID string) (unlock func()) { unlock = func() { mtx.Unlock() - log.Debugf("released lock for account %s in %v", accountID, time.Since(start)) + log.Tracef("released lock for account %s in %v", accountID, time.Since(start)) } return unlock @@ -434,7 +435,7 @@ func (s *SqliteStore) GetAccount(accountID string) (*Account, error) { } account.UsersG = nil - account.Groups = make(map[string]*Group, len(account.GroupsG)) + account.Groups = make(map[string]*nbgroup.Group, len(account.GroupsG)) for _, group := range account.GroupsG { account.Groups[group.ID] = group.Copy() } diff --git a/management/server/user.go b/management/server/user.go index 15517db41..b955c4058 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -10,6 +10,7 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/idp" + "github.com/netbirdio/netbird/management/server/integration_reference" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" @@ -49,23 +50,6 @@ type UserStatus string // UserRole is the role of a User type UserRole string -// IntegrationReference holds the reference to a particular integration -type IntegrationReference struct { - ID int - IntegrationType string -} - -func (ir IntegrationReference) String() string { - return fmt.Sprintf("%s:%d", ir.IntegrationType, ir.ID) -} - -func (ir IntegrationReference) CacheKey(path ...string) string { - if len(path) == 0 { - return ir.String() - } - return fmt.Sprintf("%s:%s", ir.String(), strings.Join(path, ":")) -} - // User represents a user of the system type User struct { Id string `gorm:"primaryKey"` @@ -91,7 +75,7 @@ type User struct { // Issued of the user Issued string `gorm:"default:api"` - IntegrationReference IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"` + IntegrationReference integration_reference.IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"` } // IsBlocked returns true if the user is blocked, false otherwise diff --git a/management/server/user_test.go b/management/server/user_test.go index e34aa406d..c92f87e6c 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -16,6 +16,7 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/idp" + "github.com/netbirdio/netbird/management/server/integration_reference" "github.com/netbirdio/netbird/management/server/jwtclaims" ) @@ -276,7 +277,7 @@ func TestUser_Copy(t *testing.T) { LastLogin: time.Now().UTC(), CreatedAt: time.Now().UTC(), Issued: "test", - IntegrationReference: IntegrationReference{ + IntegrationReference: integration_reference.IntegrationReference{ ID: 0, IntegrationType: "test", }, @@ -603,8 +604,9 @@ func TestUser_DeleteUser_regularUser(t *testing.T) { } am := DefaultAccountManager{ - Store: store, - eventStore: &activity.InMemoryEventStore{}, + Store: store, + eventStore: &activity.InMemoryEventStore{}, + integratedPeerValidator: MocIntegratedValidator{}, } testCases := []struct { @@ -793,7 +795,7 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) { Id: "externalUser", Role: UserRoleUser, Issued: UserIssuedIntegration, - IntegrationReference: IntegrationReference{ + IntegrationReference: integration_reference.IntegrationReference{ ID: 1, IntegrationType: "external", }, From bd7a65d7984ad802e105a6451ffd31560ee8312e Mon Sep 17 00:00:00 2001 From: Jeremy Wu Date: Thu, 28 Mar 2024 16:56:41 +0800 Subject: [PATCH 05/28] support to configure extra blacklist of iface in "up" command (#1734) Support to configure extra blacklist of iface in "up" command --- client/cmd/root.go | 2 + client/cmd/up.go | 14 +- client/internal/config.go | 18 +- client/internal/config_test.go | 22 +- client/proto/daemon.pb.go | 401 +++++++++++++++++---------------- client/proto/daemon.proto | 2 + client/server/server.go | 8 +- 7 files changed, 259 insertions(+), 208 deletions(-) diff --git a/client/cmd/root.go b/client/cmd/root.go index c3ff0a3c8..9c4ad99de 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -34,6 +34,7 @@ const ( wireguardPortFlag = "wireguard-port" disableAutoConnectFlag = "disable-auto-connect" serverSSHAllowedFlag = "allow-server-ssh" + extraIFaceBlackListFlag = "extra-iface-blacklist" ) var ( @@ -63,6 +64,7 @@ var ( wireguardPort uint16 serviceName string autoConnectDisabled bool + extraIFaceBlackList []string rootCmd = &cobra.Command{ Use: "netbird", Short: "", diff --git a/client/cmd/up.go b/client/cmd/up.go index f44f29a47..c2c3c7c90 100644 --- a/client/cmd/up.go +++ b/client/cmd/up.go @@ -40,6 +40,7 @@ func init() { upCmd.PersistentFlags().BoolVarP(&foregroundMode, "foreground-mode", "F", false, "start service in foreground") upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name") upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port") + upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening") } func upFunc(cmd *cobra.Command, args []string) error { @@ -83,11 +84,12 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error { } ic := internal.ConfigInput{ - ManagementURL: managementURL, - AdminURL: adminURL, - ConfigPath: configPath, - NATExternalIPs: natExternalIPs, - CustomDNSAddress: customDNSAddressConverted, + ManagementURL: managementURL, + AdminURL: adminURL, + ConfigPath: configPath, + NATExternalIPs: natExternalIPs, + CustomDNSAddress: customDNSAddressConverted, + ExtraIFaceBlackList: extraIFaceBlackList, } if cmd.Flag(enableRosenpassFlag).Changed { @@ -149,7 +151,6 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error { } func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error { - customDNSAddressConverted, err := parseCustomDNSAddress(cmd.Flag(dnsResolverAddress).Changed) if err != nil { return err @@ -190,6 +191,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error { CustomDNSAddress: customDNSAddressConverted, IsLinuxDesktopClient: isLinuxRunningDesktop(), Hostname: hostName, + ExtraIFaceBlacklist: extraIFaceBlackList, } if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) { diff --git a/client/internal/config.go b/client/internal/config.go index 2f6958235..5b3c61cbd 100644 --- a/client/internal/config.go +++ b/client/internal/config.go @@ -30,8 +30,10 @@ const ( DefaultAdminURL = "https://app.netbird.io:443" ) -var defaultInterfaceBlacklist = []string{iface.WgInterfaceDefault, "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts", - "Tailscale", "tailscale", "docker", "veth", "br-", "lo"} +var defaultInterfaceBlacklist = []string{ + iface.WgInterfaceDefault, "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts", + "Tailscale", "tailscale", "docker", "veth", "br-", "lo", +} // ConfigInput carries configuration changes to the client type ConfigInput struct { @@ -47,6 +49,7 @@ type ConfigInput struct { InterfaceName *string WireguardPort *int DisableAutoConnect *bool + ExtraIFaceBlackList []string } // Config Configuration type @@ -220,7 +223,8 @@ func createNewConfig(input ConfigInput) (*Config, error) { config.AdminURL = newURL } - config.IFaceBlackList = defaultInterfaceBlacklist + // nolint:gocritic + config.IFaceBlackList = append(defaultInterfaceBlacklist, input.ExtraIFaceBlackList...) return config, nil } @@ -320,6 +324,13 @@ func update(input ConfigInput) (*Config, error) { refresh = true } + if len(input.ExtraIFaceBlackList) > 0 { + for _, iFace := range util.SliceDiff(input.ExtraIFaceBlackList, config.IFaceBlackList) { + config.IFaceBlackList = append(config.IFaceBlackList, iFace) + refresh = true + } + } + if refresh { // since we have new management URL, we need to update config file if err := util.WriteJson(input.ConfigPath, config); err != nil { @@ -384,7 +395,6 @@ func configFileIsExists(path string) bool { // If it can switch, then it updates the config and returns a new one. Otherwise, it returns the provided config. // The check is performed only for the NetBird's managed version. func UpdateOldManagementURL(ctx context.Context, config *Config, configPath string) (*Config, error) { - defaultManagementURL, err := parseURL("Management URL", DefaultManagementURL) if err != nil { return nil, err diff --git a/client/internal/config_test.go b/client/internal/config_test.go index 7453c8fdf..978d0b3df 100644 --- a/client/internal/config_test.go +++ b/client/internal/config_test.go @@ -18,7 +18,6 @@ func TestGetConfig(t *testing.T) { config, err := UpdateOrCreateConfig(ConfigInput{ ConfigPath: filepath.Join(t.TempDir(), "config.json"), }) - if err != nil { return } @@ -86,6 +85,26 @@ func TestGetConfig(t *testing.T) { assert.Equal(t, readConf.(*Config).ManagementURL.String(), newManagementURL) } +func TestExtraIFaceBlackList(t *testing.T) { + extraIFaceBlackList := []string{"eth1"} + path := filepath.Join(t.TempDir(), "config.json") + config, err := UpdateOrCreateConfig(ConfigInput{ + ConfigPath: path, + ExtraIFaceBlackList: extraIFaceBlackList, + }) + if err != nil { + return + } + + assert.Contains(t, config.IFaceBlackList, "eth1") + readConf, err := util.ReadJson(path, config) + if err != nil { + return + } + + assert.Contains(t, readConf.(*Config).IFaceBlackList, "eth1") +} + func TestHiddenPreSharedKey(t *testing.T) { hidden := "**********" samplePreSharedKey := "mysecretpresharedkey" @@ -111,7 +130,6 @@ func TestHiddenPreSharedKey(t *testing.T) { ConfigPath: cfgFile, PreSharedKey: tt.preSharedKey, }) - if err != nil { t.Fatalf("failed to get cfg: %s", err) } diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index 81998b115..4b8502268 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -44,17 +44,18 @@ type LoginRequest struct { // cleanNATExternalIPs clean map list of external IPs. // This is needed because the generated code // omits initialized empty slices due to omitempty tags - CleanNATExternalIPs bool `protobuf:"varint,6,opt,name=cleanNATExternalIPs,proto3" json:"cleanNATExternalIPs,omitempty"` - CustomDNSAddress []byte `protobuf:"bytes,7,opt,name=customDNSAddress,proto3" json:"customDNSAddress,omitempty"` - IsLinuxDesktopClient bool `protobuf:"varint,8,opt,name=isLinuxDesktopClient,proto3" json:"isLinuxDesktopClient,omitempty"` - Hostname string `protobuf:"bytes,9,opt,name=hostname,proto3" json:"hostname,omitempty"` - RosenpassEnabled *bool `protobuf:"varint,10,opt,name=rosenpassEnabled,proto3,oneof" json:"rosenpassEnabled,omitempty"` - InterfaceName *string `protobuf:"bytes,11,opt,name=interfaceName,proto3,oneof" json:"interfaceName,omitempty"` - WireguardPort *int64 `protobuf:"varint,12,opt,name=wireguardPort,proto3,oneof" json:"wireguardPort,omitempty"` - OptionalPreSharedKey *string `protobuf:"bytes,13,opt,name=optionalPreSharedKey,proto3,oneof" json:"optionalPreSharedKey,omitempty"` - DisableAutoConnect *bool `protobuf:"varint,14,opt,name=disableAutoConnect,proto3,oneof" json:"disableAutoConnect,omitempty"` - ServerSSHAllowed *bool `protobuf:"varint,15,opt,name=serverSSHAllowed,proto3,oneof" json:"serverSSHAllowed,omitempty"` - RosenpassPermissive *bool `protobuf:"varint,16,opt,name=rosenpassPermissive,proto3,oneof" json:"rosenpassPermissive,omitempty"` + CleanNATExternalIPs bool `protobuf:"varint,6,opt,name=cleanNATExternalIPs,proto3" json:"cleanNATExternalIPs,omitempty"` + CustomDNSAddress []byte `protobuf:"bytes,7,opt,name=customDNSAddress,proto3" json:"customDNSAddress,omitempty"` + IsLinuxDesktopClient bool `protobuf:"varint,8,opt,name=isLinuxDesktopClient,proto3" json:"isLinuxDesktopClient,omitempty"` + Hostname string `protobuf:"bytes,9,opt,name=hostname,proto3" json:"hostname,omitempty"` + RosenpassEnabled *bool `protobuf:"varint,10,opt,name=rosenpassEnabled,proto3,oneof" json:"rosenpassEnabled,omitempty"` + InterfaceName *string `protobuf:"bytes,11,opt,name=interfaceName,proto3,oneof" json:"interfaceName,omitempty"` + WireguardPort *int64 `protobuf:"varint,12,opt,name=wireguardPort,proto3,oneof" json:"wireguardPort,omitempty"` + OptionalPreSharedKey *string `protobuf:"bytes,13,opt,name=optionalPreSharedKey,proto3,oneof" json:"optionalPreSharedKey,omitempty"` + DisableAutoConnect *bool `protobuf:"varint,14,opt,name=disableAutoConnect,proto3,oneof" json:"disableAutoConnect,omitempty"` + ServerSSHAllowed *bool `protobuf:"varint,15,opt,name=serverSSHAllowed,proto3,oneof" json:"serverSSHAllowed,omitempty"` + RosenpassPermissive *bool `protobuf:"varint,16,opt,name=rosenpassPermissive,proto3,oneof" json:"rosenpassPermissive,omitempty"` + ExtraIFaceBlacklist []string `protobuf:"bytes,17,rep,name=extraIFaceBlacklist,proto3" json:"extraIFaceBlacklist,omitempty"` } func (x *LoginRequest) Reset() { @@ -202,6 +203,13 @@ func (x *LoginRequest) GetRosenpassPermissive() bool { return false } +func (x *LoginRequest) GetExtraIFaceBlacklist() []string { + if x != nil { + return x.ExtraIFaceBlacklist + } + return nil +} + type LoginResponse struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -1385,7 +1393,7 @@ var file_daemon_proto_rawDesc = []byte{ 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x1e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x64, 0x75, 0x72, 0x61, 0x74, - 0x69, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xdd, 0x06, 0x0a, 0x0c, 0x4c, 0x6f, + 0x69, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x8f, 0x07, 0x0a, 0x0c, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x73, 0x65, 0x74, 0x75, 0x70, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x73, 0x65, 0x74, 0x75, 0x70, 0x4b, 0x65, 0x79, 0x12, 0x26, 0x0a, 0x0c, 0x70, 0x72, 0x65, 0x53, 0x68, 0x61, @@ -1430,192 +1438,195 @@ var file_daemon_proto_rawDesc = []byte{ 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x18, 0x10, 0x20, 0x01, 0x28, 0x08, 0x48, 0x06, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, - 0x88, 0x01, 0x01, 0x42, 0x13, 0x0a, 0x11, 0x5f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, - 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x42, 0x10, 0x0a, 0x0e, 0x5f, 0x69, 0x6e, 0x74, - 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x42, 0x10, 0x0a, 0x0e, 0x5f, 0x77, - 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x42, 0x17, 0x0a, 0x15, - 0x5f, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x61, 0x6c, 0x50, 0x72, 0x65, 0x53, 0x68, 0x61, 0x72, - 0x65, 0x64, 0x4b, 0x65, 0x79, 0x42, 0x15, 0x0a, 0x13, 0x5f, 0x64, 0x69, 0x73, 0x61, 0x62, 0x6c, - 0x65, 0x41, 0x75, 0x74, 0x6f, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x42, 0x13, 0x0a, 0x11, - 0x5f, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x53, 0x53, 0x48, 0x41, 0x6c, 0x6c, 0x6f, 0x77, 0x65, - 0x64, 0x42, 0x16, 0x0a, 0x14, 0x5f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, - 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x22, 0xb5, 0x01, 0x0a, 0x0d, 0x4c, 0x6f, - 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x24, 0x0a, 0x0d, 0x6e, - 0x65, 0x65, 0x64, 0x73, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x08, 0x52, 0x0d, 0x6e, 0x65, 0x65, 0x64, 0x73, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, - 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x75, 0x73, 0x65, 0x72, 0x43, 0x6f, 0x64, 0x65, 0x18, 0x02, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x08, 0x75, 0x73, 0x65, 0x72, 0x43, 0x6f, 0x64, 0x65, 0x12, 0x28, 0x0a, - 0x0f, 0x76, 0x65, 0x72, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x55, 0x52, 0x49, - 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0f, 0x76, 0x65, 0x72, 0x69, 0x66, 0x69, 0x63, 0x61, - 0x74, 0x69, 0x6f, 0x6e, 0x55, 0x52, 0x49, 0x12, 0x38, 0x0a, 0x17, 0x76, 0x65, 0x72, 0x69, 0x66, - 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x55, 0x52, 0x49, 0x43, 0x6f, 0x6d, 0x70, 0x6c, 0x65, - 0x74, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x17, 0x76, 0x65, 0x72, 0x69, 0x66, 0x69, - 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x55, 0x52, 0x49, 0x43, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, - 0x65, 0x22, 0x4d, 0x0a, 0x13, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, - 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x75, 0x73, 0x65, 0x72, - 0x43, 0x6f, 0x64, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x75, 0x73, 0x65, 0x72, - 0x43, 0x6f, 0x64, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x68, 0x6f, 0x73, 0x74, 0x6e, 0x61, 0x6d, 0x65, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x68, 0x6f, 0x73, 0x74, 0x6e, 0x61, 0x6d, 0x65, - 0x22, 0x16, 0x0a, 0x14, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, - 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x0b, 0x0a, 0x09, 0x55, 0x70, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x0c, 0x0a, 0x0a, 0x55, 0x70, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x22, 0x3d, 0x0a, 0x0d, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x12, 0x2c, 0x0a, 0x11, 0x67, 0x65, 0x74, 0x46, 0x75, 0x6c, 0x6c, 0x50, - 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, - 0x11, 0x67, 0x65, 0x74, 0x46, 0x75, 0x6c, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, - 0x75, 0x73, 0x22, 0x82, 0x01, 0x0a, 0x0e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, - 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x32, 0x0a, - 0x0a, 0x66, 0x75, 0x6c, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, - 0x0b, 0x32, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x46, 0x75, 0x6c, 0x6c, 0x53, - 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x0a, 0x66, 0x75, 0x6c, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, - 0x73, 0x12, 0x24, 0x0a, 0x0d, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x56, 0x65, 0x72, 0x73, 0x69, - 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, - 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x0d, 0x0a, 0x0b, 0x44, 0x6f, 0x77, 0x6e, 0x52, - 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x0e, 0x0a, 0x0c, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, - 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x12, 0x0a, 0x10, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, - 0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0xb3, 0x01, 0x0a, 0x11, 0x47, - 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x12, 0x24, 0x0a, 0x0d, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x55, 0x72, - 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x55, 0x72, 0x6c, 0x12, 0x1e, 0x0a, 0x0a, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, - 0x46, 0x69, 0x6c, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x63, 0x6f, 0x6e, 0x66, - 0x69, 0x67, 0x46, 0x69, 0x6c, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x6c, 0x6f, 0x67, 0x46, 0x69, 0x6c, - 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6c, 0x6f, 0x67, 0x46, 0x69, 0x6c, 0x65, - 0x12, 0x22, 0x0a, 0x0c, 0x70, 0x72, 0x65, 0x53, 0x68, 0x61, 0x72, 0x65, 0x64, 0x4b, 0x65, 0x79, - 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x70, 0x72, 0x65, 0x53, 0x68, 0x61, 0x72, 0x65, - 0x64, 0x4b, 0x65, 0x79, 0x12, 0x1a, 0x0a, 0x08, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x55, 0x52, 0x4c, - 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x55, 0x52, 0x4c, - 0x22, 0xce, 0x05, 0x0a, 0x09, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x0e, - 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, - 0x0a, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, - 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x1e, 0x0a, 0x0a, 0x63, 0x6f, 0x6e, 0x6e, 0x53, 0x74, - 0x61, 0x74, 0x75, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x63, 0x6f, 0x6e, 0x6e, - 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x46, 0x0a, 0x10, 0x63, 0x6f, 0x6e, 0x6e, 0x53, 0x74, - 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, - 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, - 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x10, 0x63, 0x6f, - 0x6e, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x12, 0x18, - 0x0a, 0x07, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, - 0x07, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x65, 0x64, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x69, 0x72, 0x65, - 0x63, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, - 0x12, 0x34, 0x0a, 0x15, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, - 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x15, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, - 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12, 0x36, 0x0a, 0x16, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, + 0x88, 0x01, 0x01, 0x12, 0x30, 0x0a, 0x13, 0x65, 0x78, 0x74, 0x72, 0x61, 0x49, 0x46, 0x61, 0x63, + 0x65, 0x42, 0x6c, 0x61, 0x63, 0x6b, 0x6c, 0x69, 0x73, 0x74, 0x18, 0x11, 0x20, 0x03, 0x28, 0x09, + 0x52, 0x13, 0x65, 0x78, 0x74, 0x72, 0x61, 0x49, 0x46, 0x61, 0x63, 0x65, 0x42, 0x6c, 0x61, 0x63, + 0x6b, 0x6c, 0x69, 0x73, 0x74, 0x42, 0x13, 0x0a, 0x11, 0x5f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, + 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x42, 0x10, 0x0a, 0x0e, 0x5f, 0x69, + 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x42, 0x10, 0x0a, 0x0e, + 0x5f, 0x77, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x42, 0x17, + 0x0a, 0x15, 0x5f, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x61, 0x6c, 0x50, 0x72, 0x65, 0x53, 0x68, + 0x61, 0x72, 0x65, 0x64, 0x4b, 0x65, 0x79, 0x42, 0x15, 0x0a, 0x13, 0x5f, 0x64, 0x69, 0x73, 0x61, + 0x62, 0x6c, 0x65, 0x41, 0x75, 0x74, 0x6f, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x42, 0x13, + 0x0a, 0x11, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x53, 0x53, 0x48, 0x41, 0x6c, 0x6c, 0x6f, + 0x77, 0x65, 0x64, 0x42, 0x16, 0x0a, 0x14, 0x5f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, + 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x22, 0xb5, 0x01, 0x0a, 0x0d, + 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x24, 0x0a, + 0x0d, 0x6e, 0x65, 0x65, 0x64, 0x73, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x6e, 0x65, 0x65, 0x64, 0x73, 0x53, 0x53, 0x4f, 0x4c, 0x6f, + 0x67, 0x69, 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x75, 0x73, 0x65, 0x72, 0x43, 0x6f, 0x64, 0x65, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x75, 0x73, 0x65, 0x72, 0x43, 0x6f, 0x64, 0x65, 0x12, + 0x28, 0x0a, 0x0f, 0x76, 0x65, 0x72, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x55, + 0x52, 0x49, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0f, 0x76, 0x65, 0x72, 0x69, 0x66, 0x69, + 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x55, 0x52, 0x49, 0x12, 0x38, 0x0a, 0x17, 0x76, 0x65, 0x72, + 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x55, 0x52, 0x49, 0x43, 0x6f, 0x6d, 0x70, + 0x6c, 0x65, 0x74, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x17, 0x76, 0x65, 0x72, 0x69, + 0x66, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x55, 0x52, 0x49, 0x43, 0x6f, 0x6d, 0x70, 0x6c, + 0x65, 0x74, 0x65, 0x22, 0x4d, 0x0a, 0x13, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, + 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x75, 0x73, + 0x65, 0x72, 0x43, 0x6f, 0x64, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x75, 0x73, + 0x65, 0x72, 0x43, 0x6f, 0x64, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x68, 0x6f, 0x73, 0x74, 0x6e, 0x61, + 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x68, 0x6f, 0x73, 0x74, 0x6e, 0x61, + 0x6d, 0x65, 0x22, 0x16, 0x0a, 0x14, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, + 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x0b, 0x0a, 0x09, 0x55, 0x70, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x0c, 0x0a, 0x0a, 0x55, 0x70, 0x52, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x3d, 0x0a, 0x0d, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x2c, 0x0a, 0x11, 0x67, 0x65, 0x74, 0x46, 0x75, 0x6c, + 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x08, 0x52, 0x11, 0x67, 0x65, 0x74, 0x46, 0x75, 0x6c, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, + 0x61, 0x74, 0x75, 0x73, 0x22, 0x82, 0x01, 0x0a, 0x0e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, + 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, + 0x32, 0x0a, 0x0a, 0x66, 0x75, 0x6c, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x46, 0x75, 0x6c, + 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x0a, 0x66, 0x75, 0x6c, 0x6c, 0x53, 0x74, 0x61, + 0x74, 0x75, 0x73, 0x12, 0x24, 0x0a, 0x0d, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x56, 0x65, 0x72, + 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x64, 0x61, 0x65, 0x6d, + 0x6f, 0x6e, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x0d, 0x0a, 0x0b, 0x44, 0x6f, 0x77, + 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x0e, 0x0a, 0x0c, 0x44, 0x6f, 0x77, 0x6e, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x12, 0x0a, 0x10, 0x47, 0x65, 0x74, 0x43, + 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0xb3, 0x01, 0x0a, + 0x11, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, + 0x73, 0x65, 0x12, 0x24, 0x0a, 0x0d, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x55, 0x72, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x55, 0x72, 0x6c, 0x12, 0x1e, 0x0a, 0x0a, 0x63, 0x6f, 0x6e, 0x66, + 0x69, 0x67, 0x46, 0x69, 0x6c, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x63, 0x6f, + 0x6e, 0x66, 0x69, 0x67, 0x46, 0x69, 0x6c, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x6c, 0x6f, 0x67, 0x46, + 0x69, 0x6c, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6c, 0x6f, 0x67, 0x46, 0x69, + 0x6c, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x70, 0x72, 0x65, 0x53, 0x68, 0x61, 0x72, 0x65, 0x64, 0x4b, + 0x65, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x70, 0x72, 0x65, 0x53, 0x68, 0x61, + 0x72, 0x65, 0x64, 0x4b, 0x65, 0x79, 0x12, 0x1a, 0x0a, 0x08, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x55, + 0x52, 0x4c, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x55, + 0x52, 0x4c, 0x22, 0xce, 0x05, 0x0a, 0x09, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, + 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, + 0x12, 0x16, 0x0a, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x1e, 0x0a, 0x0a, 0x63, 0x6f, 0x6e, 0x6e, + 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x63, 0x6f, + 0x6e, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x46, 0x0a, 0x10, 0x63, 0x6f, 0x6e, 0x6e, + 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x18, 0x04, 0x20, 0x01, + 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x10, + 0x63, 0x6f, 0x6e, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, + 0x12, 0x18, 0x0a, 0x07, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, + 0x08, 0x52, 0x07, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x65, 0x64, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x69, + 0x72, 0x65, 0x63, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x06, 0x64, 0x69, 0x72, 0x65, + 0x63, 0x74, 0x12, 0x34, 0x0a, 0x15, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, 0x43, 0x61, + 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x18, 0x07, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x15, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, + 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12, 0x36, 0x0a, 0x16, 0x72, 0x65, 0x6d, 0x6f, + 0x74, 0x65, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, + 0x70, 0x65, 0x18, 0x08, 0x20, 0x01, 0x28, 0x09, 0x52, 0x16, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, - 0x18, 0x08, 0x20, 0x01, 0x28, 0x09, 0x52, 0x16, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49, 0x63, - 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, - 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, - 0x64, 0x6e, 0x12, 0x3c, 0x0a, 0x19, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, 0x43, 0x61, - 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, - 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x19, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, 0x43, + 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, + 0x66, 0x71, 0x64, 0x6e, 0x12, 0x3c, 0x0a, 0x19, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, + 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, + 0x74, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x19, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, + 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, + 0x6e, 0x74, 0x12, 0x3e, 0x0a, 0x1a, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, - 0x12, 0x3e, 0x0a, 0x1a, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, - 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x0b, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x1a, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49, 0x63, 0x65, 0x43, - 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, - 0x12, 0x52, 0x0a, 0x16, 0x6c, 0x61, 0x73, 0x74, 0x57, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, - 0x64, 0x48, 0x61, 0x6e, 0x64, 0x73, 0x68, 0x61, 0x6b, 0x65, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x0b, - 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, - 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x16, 0x6c, 0x61, - 0x73, 0x74, 0x57, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x48, 0x61, 0x6e, 0x64, 0x73, - 0x68, 0x61, 0x6b, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x52, 0x78, 0x18, - 0x0d, 0x20, 0x01, 0x28, 0x03, 0x52, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x52, 0x78, 0x12, 0x18, - 0x0a, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x54, 0x78, 0x18, 0x0e, 0x20, 0x01, 0x28, 0x03, 0x52, - 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x54, 0x78, 0x12, 0x2a, 0x0a, 0x10, 0x72, 0x6f, 0x73, 0x65, - 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x0f, 0x20, 0x01, - 0x28, 0x08, 0x52, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, - 0x62, 0x6c, 0x65, 0x64, 0x12, 0x16, 0x0a, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x10, - 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x33, 0x0a, 0x07, - 0x6c, 0x61, 0x74, 0x65, 0x6e, 0x63, 0x79, 0x18, 0x11, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, - 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, - 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x07, 0x6c, 0x61, 0x74, 0x65, 0x6e, 0x63, - 0x79, 0x22, 0xec, 0x01, 0x0a, 0x0e, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, - 0x74, 0x61, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x28, 0x0a, 0x0f, - 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x18, - 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0f, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x49, 0x6e, 0x74, - 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x04, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x12, 0x2a, 0x0a, 0x10, 0x72, 0x6f, - 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x05, + 0x18, 0x0b, 0x20, 0x01, 0x28, 0x09, 0x52, 0x1a, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49, 0x63, + 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, + 0x6e, 0x74, 0x12, 0x52, 0x0a, 0x16, 0x6c, 0x61, 0x73, 0x74, 0x57, 0x69, 0x72, 0x65, 0x67, 0x75, + 0x61, 0x72, 0x64, 0x48, 0x61, 0x6e, 0x64, 0x73, 0x68, 0x61, 0x6b, 0x65, 0x18, 0x0c, 0x20, 0x01, + 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x16, + 0x6c, 0x61, 0x73, 0x74, 0x57, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x48, 0x61, 0x6e, + 0x64, 0x73, 0x68, 0x61, 0x6b, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x52, + 0x78, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x03, 0x52, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x52, 0x78, + 0x12, 0x18, 0x0a, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x54, 0x78, 0x18, 0x0e, 0x20, 0x01, 0x28, + 0x03, 0x52, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x54, 0x78, 0x12, 0x2a, 0x0a, 0x10, 0x72, 0x6f, + 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, - 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x30, 0x0a, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, - 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x18, 0x06, 0x20, - 0x01, 0x28, 0x08, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, - 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x72, 0x6f, 0x75, 0x74, - 0x65, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, - 0x22, 0x53, 0x0a, 0x0b, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, - 0x10, 0x0a, 0x03, 0x55, 0x52, 0x4c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, - 0x4c, 0x12, 0x1c, 0x0a, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x12, - 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, - 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0x57, 0x0a, 0x0f, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, 0x4c, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x4c, 0x12, 0x1c, 0x0a, 0x09, 0x63, 0x6f, - 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x63, - 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, - 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0x52, - 0x0a, 0x0a, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, - 0x55, 0x52, 0x49, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x49, 0x12, 0x1c, - 0x0a, 0x09, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, - 0x08, 0x52, 0x09, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x14, 0x0a, 0x05, - 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, - 0x6f, 0x72, 0x22, 0x72, 0x0a, 0x0c, 0x4e, 0x53, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, 0x74, 0x61, - 0x74, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, - 0x03, 0x28, 0x09, 0x52, 0x07, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, - 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, - 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, - 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, - 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0xd2, 0x02, 0x0a, 0x0a, 0x46, 0x75, 0x6c, 0x6c, 0x53, - 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x41, 0x0a, 0x0f, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, - 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0f, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x35, 0x0a, 0x0b, 0x73, 0x69, 0x67, 0x6e, - 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x13, 0x2e, - 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, - 0x74, 0x65, 0x52, 0x0b, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, - 0x3e, 0x0a, 0x0e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, - 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, - 0x2e, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, - 0x0e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, - 0x27, 0x0a, 0x05, 0x70, 0x65, 0x65, 0x72, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x11, - 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, - 0x65, 0x52, 0x05, 0x70, 0x65, 0x65, 0x72, 0x73, 0x12, 0x2a, 0x0a, 0x06, 0x72, 0x65, 0x6c, 0x61, - 0x79, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, - 0x6e, 0x2e, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x06, 0x72, 0x65, - 0x6c, 0x61, 0x79, 0x73, 0x12, 0x35, 0x0a, 0x0b, 0x64, 0x6e, 0x73, 0x5f, 0x73, 0x65, 0x72, 0x76, - 0x65, 0x72, 0x73, 0x18, 0x06, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, - 0x6f, 0x6e, 0x2e, 0x4e, 0x53, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, - 0x0a, 0x64, 0x6e, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x32, 0xf7, 0x02, 0x0a, 0x0d, - 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x36, 0x0a, - 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, - 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, 0x64, - 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, - 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, - 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, - 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, - 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55, 0x70, 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, - 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64, 0x61, - 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, - 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64, 0x61, - 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, - 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, - 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x04, - 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, - 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, - 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, - 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, - 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, - 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, - 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x16, 0x0a, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, + 0x18, 0x10, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x33, + 0x0a, 0x07, 0x6c, 0x61, 0x74, 0x65, 0x6e, 0x63, 0x79, 0x18, 0x11, 0x20, 0x01, 0x28, 0x0b, 0x32, + 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, + 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x07, 0x6c, 0x61, 0x74, 0x65, + 0x6e, 0x63, 0x79, 0x22, 0xec, 0x01, 0x0a, 0x0e, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, + 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x28, + 0x0a, 0x0f, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, + 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0f, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x49, + 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, + 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x12, 0x2a, 0x0a, 0x10, + 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, + 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, + 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x30, 0x0a, 0x13, 0x72, 0x6f, 0x73, 0x65, + 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x18, + 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, + 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x72, 0x6f, + 0x75, 0x74, 0x65, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x72, 0x6f, 0x75, 0x74, + 0x65, 0x73, 0x22, 0x53, 0x0a, 0x0b, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, + 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, 0x4c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, + 0x55, 0x52, 0x4c, 0x12, 0x1c, 0x0a, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, + 0x64, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0x57, 0x0a, 0x0f, 0x4d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, + 0x4c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x4c, 0x12, 0x1c, 0x0a, 0x09, + 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, + 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, + 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, + 0x22, 0x52, 0x0a, 0x0a, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, + 0x0a, 0x03, 0x55, 0x52, 0x49, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x49, + 0x12, 0x1c, 0x0a, 0x09, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x08, 0x52, 0x09, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x14, + 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, + 0x72, 0x72, 0x6f, 0x72, 0x22, 0x72, 0x0a, 0x0c, 0x4e, 0x53, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, + 0x74, 0x61, 0x74, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, + 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, + 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, + 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x65, 0x6e, 0x61, 0x62, + 0x6c, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, + 0x65, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0xd2, 0x02, 0x0a, 0x0a, 0x46, 0x75, 0x6c, + 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x41, 0x0a, 0x0f, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, + 0x32, 0x17, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0f, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x35, 0x0a, 0x0b, 0x73, 0x69, + 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, + 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, + 0x74, 0x61, 0x74, 0x65, 0x52, 0x0b, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, + 0x65, 0x12, 0x3e, 0x0a, 0x0e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, + 0x61, 0x74, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, + 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, + 0x65, 0x52, 0x0e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, + 0x65, 0x12, 0x27, 0x0a, 0x05, 0x70, 0x65, 0x65, 0x72, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x0b, + 0x32, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, + 0x61, 0x74, 0x65, 0x52, 0x05, 0x70, 0x65, 0x65, 0x72, 0x73, 0x12, 0x2a, 0x0a, 0x06, 0x72, 0x65, + 0x6c, 0x61, 0x79, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x64, 0x61, 0x65, + 0x6d, 0x6f, 0x6e, 0x2e, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x06, + 0x72, 0x65, 0x6c, 0x61, 0x79, 0x73, 0x12, 0x35, 0x0a, 0x0b, 0x64, 0x6e, 0x73, 0x5f, 0x73, 0x65, + 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x06, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x64, 0x61, + 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4e, 0x53, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, 0x74, 0x61, 0x74, + 0x65, 0x52, 0x0a, 0x64, 0x6e, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x32, 0xf7, 0x02, + 0x0a, 0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, + 0x36, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, + 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, + 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, + 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, + 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, + 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, + 0x73, 0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55, 0x70, 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, + 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, + 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, + 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, + 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, + 0x0a, 0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, + 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, + 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, + 0x12, 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, + 0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, + 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index 8f9148d68..5f8878a11 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -70,6 +70,8 @@ message LoginRequest { optional bool serverSSHAllowed = 15; optional bool rosenpassPermissive = 16; + + repeated string extraIFaceBlacklist = 17; } message LoginResponse { diff --git a/client/server/server.go b/client/server/server.go index 481ef0f7c..d1d9dbda4 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -152,7 +152,8 @@ func (s *Server) Start() error { // mechanism to keep the client connected even when the connection is lost. // we cancel retry if the client receive a stop or down command, or if disable auto connect is configured. func (s *Server) connectWithRetryRuns(ctx context.Context, config *internal.Config, statusRecorder *peer.Status, - mgmProbe *internal.Probe, signalProbe *internal.Probe, relayProbe *internal.Probe, wgProbe *internal.Probe) { + mgmProbe *internal.Probe, signalProbe *internal.Probe, relayProbe *internal.Probe, wgProbe *internal.Probe, +) { backOff := getConnectWithBackoff(ctx) retryStarted := false @@ -351,6 +352,11 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro s.latestConfigInput.WireguardPort = &port } + if len(msg.ExtraIFaceBlacklist) > 0 { + inputConfig.ExtraIFaceBlackList = msg.ExtraIFaceBlacklist + s.latestConfigInput.ExtraIFaceBlackList = msg.ExtraIFaceBlacklist + } + s.mutex.Unlock() if msg.OptionalPreSharedKey != nil { From 22beac1b1b510ac117d84f95af4ddb08d7a048b9 Mon Sep 17 00:00:00 2001 From: Misha Bragin Date: Thu, 28 Mar 2024 12:33:56 +0100 Subject: [PATCH 06/28] Fix invalid token due to the cache race (#1763) --- management/server/grpcserver.go | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 340adcfc6..4df24711e 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -343,10 +343,18 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p userID := "" // JWT token is not always provided, it is fine for userID to be empty cuz it might be that peer is already registered, // or it uses a setup key to register. + if loginReq.GetJwtToken() != "" { - userID, err = s.validateToken(loginReq.GetJwtToken()) + for i := 0; i < 3; i++ { + userID, err = s.validateToken(loginReq.GetJwtToken()) + if err == nil { + break + } + log.Warnf("failed validating JWT token sent from peer %s with error %v. "+ + "Trying again as it may be due to the IdP cache issue", peerKey, err) + time.Sleep(200 * time.Millisecond) + } if err != nil { - log.Warnf("failed validating JWT token sent from peer %s", peerKey) return nil, err } } From 4fff93a1f228fe15a802b5a0862aa1180f912c5d Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Thu, 28 Mar 2024 13:06:54 +0100 Subject: [PATCH 07/28] Ignore unsupported address families (#1766) --- client/internal/routemanager/systemops_linux.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/client/internal/routemanager/systemops_linux.go b/client/internal/routemanager/systemops_linux.go index 192509992..3510f9553 100644 --- a/client/internal/routemanager/systemops_linux.go +++ b/client/internal/routemanager/systemops_linux.go @@ -162,7 +162,7 @@ func addRoute(prefix *netip.Prefix, addr, intf *string, tableID, family int) err return fmt.Errorf("add gateway and device: %w", err) } - if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) { + if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) { return fmt.Errorf("netlink add route: %w", err) } @@ -185,7 +185,7 @@ func addUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error { Dst: ipNet, } - if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) { + if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) { return fmt.Errorf("netlink add unreachable route: %w", err) } @@ -205,7 +205,7 @@ func removeUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error { Dst: ipNet, } - if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) { + if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) && !errors.Is(err, syscall.EAFNOSUPPORT) { return fmt.Errorf("netlink remove unreachable route: %w", err) } @@ -231,7 +231,7 @@ func removeRoute(prefix *netip.Prefix, addr, intf *string, tableID, family int) return fmt.Errorf("add gateway and device: %w", err) } - if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) { + if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) && !errors.Is(err, syscall.EAFNOSUPPORT) { return fmt.Errorf("netlink remove route: %w", err) } @@ -255,7 +255,7 @@ func flushRoutes(tableID, family int) error { routes[i].Dst = &net.IPNet{IP: net.IPv6zero, Mask: net.CIDRMask(0, 128)} } } - if err := netlink.RouteDel(&routes[i]); err != nil { + if err := netlink.RouteDel(&routes[i]); err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { result = multierror.Append(result, fmt.Errorf("failed to delete route %v from table %d: %w", routes[i], tableID, err)) } } @@ -385,7 +385,7 @@ func addRule(params ruleParams) error { rule.Invert = params.invert rule.SuppressPrefixlen = params.suppressPrefix - if err := netlink.RuleAdd(rule); err != nil { + if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { return fmt.Errorf("add routing rule: %w", err) } @@ -402,7 +402,7 @@ func removeRule(params ruleParams) error { rule.Priority = params.priority rule.SuppressPrefixlen = params.suppressPrefix - if err := netlink.RuleDel(rule); err != nil { + if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { return fmt.Errorf("remove routing rule: %w", err) } From fd23d0c28ff069a5e60ed9c10139f4babf15a1b4 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Thu, 28 Mar 2024 18:12:25 +0100 Subject: [PATCH 08/28] Don't block on failed routing setup (#1768) --- client/internal/engine.go | 3 +-- client/internal/routemanager/systemops_linux.go | 8 ++++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/client/internal/engine.go b/client/internal/engine.go index 7f7b5ef55..046a6c944 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -261,8 +261,7 @@ func (e *Engine) Start() error { e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, initialRoutes) if err := e.routeManager.Init(); err != nil { - e.close() - return fmt.Errorf("init route manager: %w", err) + log.Errorf("Failed to initialize route manager: %s", err) } e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener) diff --git a/client/internal/routemanager/systemops_linux.go b/client/internal/routemanager/systemops_linux.go index 3510f9553..83af5008a 100644 --- a/client/internal/routemanager/systemops_linux.go +++ b/client/internal/routemanager/systemops_linux.go @@ -45,10 +45,10 @@ type ruleParams struct { func getSetupRules() []ruleParams { return []ruleParams{ - {nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V4, -1, true, -1, "add rule v4 netbird"}, - {nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V6, -1, true, -1, "add rule v6 netbird"}, - {-1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, -1, false, 0, "add rule with suppress prefixlen v4"}, - {-1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, -1, false, 0, "add rule with suppress prefixlen v6"}, + {nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V4, -1, true, -1, "rule v4 netbird"}, + {nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V6, -1, true, -1, "rule v6 netbird"}, + {-1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, -1, false, 0, "rule with suppress prefixlen v4"}, + {-1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, -1, false, 0, "rule with suppress prefixlen v6"}, } } From 40d56e5d29608d55aa628a5172fceac1116a3aa9 Mon Sep 17 00:00:00 2001 From: Misha Bragin Date: Thu, 28 Mar 2024 18:43:32 +0100 Subject: [PATCH 09/28] Update network security image (#1765) --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 9b52f5b5f..d0b07feff 100644 --- a/README.md +++ b/README.md @@ -44,7 +44,8 @@ ### Open-Source Network Security in a Single Platform -![download (2)](https://github.com/netbirdio/netbird/assets/700848/16210ac2-7265-44c1-8d4e-8fae85534dac) +![image](https://github.com/netbirdio/netbird/assets/700848/c0d7bae4-3301-499a-bb4e-5e4a225bf35f) + ### Key features From 9c2dc05df1a735adc4ab8ec19fc6371c528817c2 Mon Sep 17 00:00:00 2001 From: Misha Bragin Date: Sun, 31 Mar 2024 19:39:52 +0200 Subject: [PATCH 10/28] Eval/higher timeouts (#1776) --- client/internal/peer/conn.go | 28 ++++++++++++++++------------ client/internal/peer/env_config.go | 27 ++++++++++++++++++++++----- 2 files changed, 38 insertions(+), 17 deletions(-) diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index c180e8f03..ce8cc4b97 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -26,6 +26,8 @@ import ( const ( iceKeepAliveDefault = 4 * time.Second iceDisconnectedTimeoutDefault = 6 * time.Second + // iceRelayAcceptanceMinWaitDefault is the same as in the Pion ICE package + iceRelayAcceptanceMinWaitDefault = 2 * time.Second defaultWgKeepAlive = 25 * time.Second ) @@ -196,20 +198,22 @@ func (conn *Conn) reCreateAgent() error { iceKeepAlive := iceKeepAlive() iceDisconnectedTimeout := iceDisconnectedTimeout() + iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait() agentConfig := &ice.AgentConfig{ - MulticastDNSMode: ice.MulticastDNSModeDisabled, - NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}, - Urls: conn.config.StunTurn, - CandidateTypes: conn.candidateTypes(), - FailedTimeout: &failedTimeout, - InterfaceFilter: stdnet.InterfaceFilter(conn.config.InterfaceBlackList), - UDPMux: conn.config.UDPMux, - UDPMuxSrflx: conn.config.UDPMuxSrflx, - NAT1To1IPs: conn.config.NATExternalIPs, - Net: transportNet, - DisconnectedTimeout: &iceDisconnectedTimeout, - KeepaliveInterval: &iceKeepAlive, + MulticastDNSMode: ice.MulticastDNSModeDisabled, + NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}, + Urls: conn.config.StunTurn, + CandidateTypes: conn.candidateTypes(), + FailedTimeout: &failedTimeout, + InterfaceFilter: stdnet.InterfaceFilter(conn.config.InterfaceBlackList), + UDPMux: conn.config.UDPMux, + UDPMuxSrflx: conn.config.UDPMuxSrflx, + NAT1To1IPs: conn.config.NATExternalIPs, + Net: transportNet, + DisconnectedTimeout: &iceDisconnectedTimeout, + KeepaliveInterval: &iceKeepAlive, + RelayAcceptanceMinWait: &iceRelayAcceptanceMinWait, } if conn.config.DisableIPv6Discovery { diff --git a/client/internal/peer/env_config.go b/client/internal/peer/env_config.go index 540bc413e..87b626df7 100644 --- a/client/internal/peer/env_config.go +++ b/client/internal/peer/env_config.go @@ -10,9 +10,10 @@ import ( ) const ( - envICEKeepAliveIntervalSec = "NB_ICE_KEEP_ALIVE_INTERVAL_SEC" - envICEDisconnectedTimeoutSec = "NB_ICE_DISCONNECTED_TIMEOUT_SEC" - envICEForceRelayConn = "NB_ICE_FORCE_RELAY_CONN" + envICEKeepAliveIntervalSec = "NB_ICE_KEEP_ALIVE_INTERVAL_SEC" + envICEDisconnectedTimeoutSec = "NB_ICE_DISCONNECTED_TIMEOUT_SEC" + envICERelayAcceptanceMinWaitSec = "NB_ICE_RELAY_ACCEPTANCE_MIN_WAIT_SEC" + envICEForceRelayConn = "NB_ICE_FORCE_RELAY_CONN" ) func iceKeepAlive() time.Duration { @@ -21,7 +22,7 @@ func iceKeepAlive() time.Duration { return iceKeepAliveDefault } - log.Debugf("setting ICE keep alive interval to %s seconds", keepAliveEnv) + log.Infof("setting ICE keep alive interval to %s seconds", keepAliveEnv) keepAliveEnvSec, err := strconv.Atoi(keepAliveEnv) if err != nil { log.Warnf("invalid value %s set for %s, using default %v", keepAliveEnv, envICEKeepAliveIntervalSec, iceKeepAliveDefault) @@ -37,7 +38,7 @@ func iceDisconnectedTimeout() time.Duration { return iceDisconnectedTimeoutDefault } - log.Debugf("setting ICE disconnected timeout to %s seconds", disconnectedTimeoutEnv) + log.Infof("setting ICE disconnected timeout to %s seconds", disconnectedTimeoutEnv) disconnectedTimeoutSec, err := strconv.Atoi(disconnectedTimeoutEnv) if err != nil { log.Warnf("invalid value %s set for %s, using default %v", disconnectedTimeoutEnv, envICEDisconnectedTimeoutSec, iceDisconnectedTimeoutDefault) @@ -47,6 +48,22 @@ func iceDisconnectedTimeout() time.Duration { return time.Duration(disconnectedTimeoutSec) * time.Second } +func iceRelayAcceptanceMinWait() time.Duration { + iceRelayAcceptanceMinWaitEnv := os.Getenv(envICERelayAcceptanceMinWaitSec) + if iceRelayAcceptanceMinWaitEnv == "" { + return iceRelayAcceptanceMinWaitDefault + } + + log.Infof("setting ICE relay acceptance min wait to %s seconds", iceRelayAcceptanceMinWaitEnv) + disconnectedTimeoutSec, err := strconv.Atoi(iceRelayAcceptanceMinWaitEnv) + if err != nil { + log.Warnf("invalid value %s set for %s, using default %v", iceRelayAcceptanceMinWaitEnv, envICERelayAcceptanceMinWaitSec, iceRelayAcceptanceMinWaitDefault) + return iceRelayAcceptanceMinWaitDefault + } + + return time.Duration(disconnectedTimeoutSec) * time.Second +} + func hasICEForceRelayConn() bool { disconnectedTimeoutEnv := os.Getenv(envICEForceRelayConn) return strings.ToLower(disconnectedTimeoutEnv) == "true" From 23a14737974e3849fa86408d136cc46db8a885d0 Mon Sep 17 00:00:00 2001 From: Vilian Gerdzhikov Date: Tue, 2 Apr 2024 11:08:58 +0300 Subject: [PATCH 11/28] Fix grammar in readme (#1778) --- README.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index d0b07feff..d2a2bd6b9 100644 --- a/README.md +++ b/README.md @@ -40,7 +40,7 @@ **Connect.** NetBird creates a WireGuard-based overlay network that automatically connects your machines over an encrypted tunnel, leaving behind the hassle of opening ports, complex firewall rules, VPN gateways, and so forth. -**Secure.** NetBird enables secure remote access by applying granular access policies, while allowing you to manage them intuitively from a single place. Works universally on any infrastructure. +**Secure.** NetBird enables secure remote access by applying granular access policies while allowing you to manage them intuitively from a single place. Works universally on any infrastructure. ### Open-Source Network Security in a Single Platform @@ -77,7 +77,7 @@ Follow the [Advanced guide with a custom identity provider](https://docs.netbird - **Public domain** name pointing to the VM. **Software requirements:** -- Docker installed on the VM with the docker compose plugin ([Docker installation guide](https://docs.docker.com/engine/install/)) or docker with docker-compose in version 2 or higher. +- Docker installed on the VM with the docker-compose plugin ([Docker installation guide](https://docs.docker.com/engine/install/)) or docker with docker-compose in version 2 or higher. - [jq](https://jqlang.github.io/jq/) installed. In most distributions Usually available in the official repositories and can be installed with `sudo apt install jq` or `sudo yum install jq` - [curl](https://curl.se/) installed. @@ -94,9 +94,9 @@ export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbird - Every machine in the network runs [NetBird Agent (or Client)](client/) that manages WireGuard. - Every agent connects to [Management Service](management/) that holds network state, manages peer IPs, and distributes network updates to agents (peers). - NetBird agent uses WebRTC ICE implemented in [pion/ice library](https://github.com/pion/ice) to discover connection candidates when establishing a peer-to-peer connection between machines. -- Connection candidates are discovered with a help of [STUN](https://en.wikipedia.org/wiki/STUN) servers. +- Connection candidates are discovered with the help of [STUN](https://en.wikipedia.org/wiki/STUN) servers. - Agents negotiate a connection through [Signal Service](signal/) passing p2p encrypted messages with candidates. -- Sometimes the NAT traversal is unsuccessful due to strict NATs (e.g. mobile carrier-grade NAT) and p2p connection isn't possible. When this occurs the system falls back to a relay server called [TURN](https://en.wikipedia.org/wiki/Traversal_Using_Relays_around_NAT), and a secure WireGuard tunnel is established via the TURN server. +- Sometimes the NAT traversal is unsuccessful due to strict NATs (e.g. mobile carrier-grade NAT) and a p2p connection isn't possible. When this occurs the system falls back to a relay server called [TURN](https://en.wikipedia.org/wiki/Traversal_Using_Relays_around_NAT), and a secure WireGuard tunnel is established via the TURN server. [Coturn](https://github.com/coturn/coturn) is the one that has been successfully used for STUN and TURN in NetBird setups. @@ -120,7 +120,7 @@ In November 2022, NetBird joined the [StartUpSecure program](https://www.forschu ![CISPA_Logo_BLACK_EN_RZ_RGB (1)](https://user-images.githubusercontent.com/700848/203091324-c6d311a0-22b5-4b05-a288-91cbc6cdcc46.png) ### Testimonials -We use open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE (WebRTC)](https://github.com/pion/ice), [Coturn](https://github.com/coturn/coturn), and [Rosenpass](https://rosenpass.eu). We very much appreciate the work these guys are doing and we'd greatly appreciate if you could support them in any way (e.g. giving a star or a contribution). +We use open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE (WebRTC)](https://github.com/pion/ice), [Coturn](https://github.com/coturn/coturn), and [Rosenpass](https://rosenpass.eu). We very much appreciate the work these guys are doing and we'd greatly appreciate if you could support them in any way (e.g., by giving a star or a contribution). ### Legal _WireGuard_ and the _WireGuard_ logo are [registered trademarks](https://www.wireguard.com/trademark-policy/) of Jason A. Donenfeld. From 9af532fe719e3851db756087730f278ea5559751 Mon Sep 17 00:00:00 2001 From: rqi14 Date: Tue, 2 Apr 2024 19:43:57 +0800 Subject: [PATCH 12/28] Get scope from endpoint url instead of hardcoding (#1770) --- management/server/idp/azure.go | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/management/server/idp/azure.go b/management/server/idp/azure.go index 706e4d330..2f21b3b54 100644 --- a/management/server/idp/azure.go +++ b/management/server/idp/azure.go @@ -115,7 +115,15 @@ func (ac *AzureCredentials) requestJWTToken() (*http.Response, error) { data.Set("client_id", ac.clientConfig.ClientID) data.Set("client_secret", ac.clientConfig.ClientSecret) data.Set("grant_type", ac.clientConfig.GrantType) - data.Set("scope", "https://graph.microsoft.com/.default") + parsedURL, err := url.Parse(ac.clientConfig.GraphAPIEndpoint) + if err != nil { + return nil, err + } + + // get base url and add "/.default" as scope + baseURL := parsedURL.Scheme + "://" + parsedURL.Host + scopeURL := baseURL + "/.default" + data.Set("scope", scopeURL) payload := strings.NewReader(data.Encode()) req, err := http.NewRequest(http.MethodPost, ac.clientConfig.TokenEndpoint, payload) From 79382951905d9fe399ed5b07fff1ddcda8de7ee2 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Wed, 3 Apr 2024 11:11:46 +0200 Subject: [PATCH 13/28] Feature/exit nodes - Windows and macOS support (#1726) --- .github/workflows/golang-test-darwin.yml | 3 + .github/workflows/golang-test-windows.yml | 2 +- client/internal/engine.go | 26 +- client/internal/peer/conn.go | 33 ++ client/internal/routemanager/client.go | 4 +- client/internal/routemanager/manager.go | 41 +- client/internal/routemanager/manager_test.go | 30 +- client/internal/routemanager/mock.go | 5 +- client/internal/routemanager/routemanager.go | 119 +++++ .../routemanager/server_nonandroid.go | 8 +- .../routemanager/systemops_android.go | 24 +- .../routemanager/systemops_bsd_nonios.go | 13 - .../internal/routemanager/systemops_darwin.go | 61 +++ .../routemanager/systemops_darwin_test.go | 100 +++++ client/internal/routemanager/systemops_ios.go | 24 +- .../internal/routemanager/systemops_linux.go | 44 +- .../routemanager/systemops_linux_test.go | 386 +++-------------- .../routemanager/systemops_nonandroid.go | 148 ------- .../routemanager/systemops_nonandroid_test.go | 282 ------------ .../routemanager/systemops_nonlinux.go | 408 +++++++++++++++++- .../routemanager/systemops_nonlinux_test.go | 242 ++++++++++- .../routemanager/systemops_unix_test.go | 234 ++++++++++ .../routemanager/systemops_windows.go | 81 +++- .../routemanager/systemops_windows_test.go | 289 +++++++++++++ client/internal/routemanager/sytemops_test.go | 101 +++++ client/internal/wgproxy/portlookup.go | 6 +- client/internal/wgproxy/proxy_ebpf.go | 6 +- go.mod | 2 +- go.sum | 6 +- util/grpc/{dialer_linux.go => dialer.go} | 10 +- util/grpc/dialer_generic.go | 9 - util/net/dialer.go | 64 +++ util/net/dialer_generic.go | 118 ++++- util/net/dialer_linux.go | 58 +-- util/net/dialer_nonlinux.go | 6 + util/net/listener.go | 21 + util/net/listener_generic.go | 153 ++++++- util/net/listener_linux.go | 24 +- util/net/listener_mobile.go | 11 + util/net/listener_nonlinux.go | 6 + util/net/net.go | 11 + 41 files changed, 2254 insertions(+), 965 deletions(-) create mode 100644 client/internal/routemanager/routemanager.go delete mode 100644 client/internal/routemanager/systemops_bsd_nonios.go create mode 100644 client/internal/routemanager/systemops_darwin.go create mode 100644 client/internal/routemanager/systemops_darwin_test.go delete mode 100644 client/internal/routemanager/systemops_nonandroid.go delete mode 100644 client/internal/routemanager/systemops_nonandroid_test.go create mode 100644 client/internal/routemanager/systemops_unix_test.go create mode 100644 client/internal/routemanager/systemops_windows_test.go create mode 100644 client/internal/routemanager/sytemops_test.go rename util/grpc/{dialer_linux.go => dialer.go} (56%) delete mode 100644 util/grpc/dialer_generic.go create mode 100644 util/net/dialer.go create mode 100644 util/net/dialer_nonlinux.go create mode 100644 util/net/listener.go create mode 100644 util/net/listener_mobile.go create mode 100644 util/net/listener_nonlinux.go diff --git a/.github/workflows/golang-test-darwin.yml b/.github/workflows/golang-test-darwin.yml index f8afd3d6e..d7007c860 100644 --- a/.github/workflows/golang-test-darwin.yml +++ b/.github/workflows/golang-test-darwin.yml @@ -32,6 +32,9 @@ jobs: restore-keys: | macos-go- + - name: Install libpcap + run: brew install libpcap + - name: Install modules run: go mod tidy diff --git a/.github/workflows/golang-test-windows.yml b/.github/workflows/golang-test-windows.yml index 6027d3626..2d63acbcd 100644 --- a/.github/workflows/golang-test-windows.yml +++ b/.github/workflows/golang-test-windows.yml @@ -46,7 +46,7 @@ jobs: - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=C:\Users\runneradmin\AppData\Local\go-build - name: test - run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 5m -p 1 ./... > test-out.txt 2>&1" + run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 10m -p 1 ./... > test-out.txt 2>&1" - name: test output if: ${{ always() }} run: Get-Content test-out.txt diff --git a/client/internal/engine.go b/client/internal/engine.go index 046a6c944..d6238c4b3 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -93,6 +93,10 @@ type Engine struct { mgmClient mgm.Client // peerConns is a map that holds all the peers that are known to this peer peerConns map[string]*peer.Conn + + beforePeerHook peer.BeforeAddPeerHookFunc + afterPeerHook peer.AfterRemovePeerHookFunc + // rpManager is a Rosenpass manager rpManager *rosenpass.Manager @@ -260,9 +264,14 @@ func (e *Engine) Start() error { e.dnsServer = dnsServer e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, initialRoutes) - if err := e.routeManager.Init(); err != nil { + beforePeerHook, afterPeerHook, err := e.routeManager.Init() + if err != nil { log.Errorf("Failed to initialize route manager: %s", err) + } else { + e.beforePeerHook = beforePeerHook + e.afterPeerHook = afterPeerHook } + e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener) err = e.wgInterfaceCreate() @@ -808,10 +817,15 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error { if _, ok := e.peerConns[peerKey]; !ok { conn, err := e.createPeerConn(peerKey, strings.Join(peerIPs, ",")) if err != nil { - return err + return fmt.Errorf("create peer connection: %w", err) } e.peerConns[peerKey] = conn + if e.beforePeerHook != nil && e.afterPeerHook != nil { + conn.AddBeforeAddPeerHook(e.beforePeerHook) + conn.AddAfterRemovePeerHook(e.afterPeerHook) + } + err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn) if err != nil { log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err) @@ -1105,6 +1119,10 @@ func (e *Engine) close() { e.dnsServer.Stop() } + if e.routeManager != nil { + e.routeManager.Stop() + } + log.Debugf("removing Netbird interface %s", e.config.WgIfaceName) if e.wgInterface != nil { if err := e.wgInterface.Close(); err != nil { @@ -1119,10 +1137,6 @@ func (e *Engine) close() { } } - if e.routeManager != nil { - e.routeManager.Stop() - } - if e.firewall != nil { err := e.firewall.Reset() if err != nil { diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index ce8cc4b97..f3d07dcad 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -20,6 +20,7 @@ import ( "github.com/netbirdio/netbird/iface/bind" signal "github.com/netbirdio/netbird/signal/client" sProto "github.com/netbirdio/netbird/signal/proto" + nbnet "github.com/netbirdio/netbird/util/net" "github.com/netbirdio/netbird/version" ) @@ -100,6 +101,9 @@ type IceCredentials struct { Pwd string } +type BeforeAddPeerHookFunc func(connID nbnet.ConnectionID, IP net.IP) error +type AfterRemovePeerHookFunc func(connID nbnet.ConnectionID) error + type Conn struct { config ConnConfig mu sync.Mutex @@ -138,6 +142,10 @@ type Conn struct { remoteEndpoint *net.UDPAddr remoteConn *ice.Conn + + connID nbnet.ConnectionID + beforeAddPeerHooks []BeforeAddPeerHookFunc + afterRemovePeerHooks []AfterRemovePeerHookFunc } // meta holds meta information about a connection @@ -393,6 +401,14 @@ func isRelayCandidate(candidate ice.Candidate) bool { return candidate.Type() == ice.CandidateTypeRelay } +func (conn *Conn) AddBeforeAddPeerHook(hook BeforeAddPeerHookFunc) { + conn.beforeAddPeerHooks = append(conn.beforeAddPeerHooks, hook) +} + +func (conn *Conn) AddAfterRemovePeerHook(hook AfterRemovePeerHookFunc) { + conn.afterRemovePeerHooks = append(conn.afterRemovePeerHooks, hook) +} + // configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, remoteRosenpassPubKey []byte, remoteRosenpassAddr string) (net.Addr, error) { conn.mu.Lock() @@ -419,6 +435,14 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String()) conn.remoteEndpoint = endpointUdpAddr + log.Debugf("Conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP) + + conn.connID = nbnet.GenerateConnID() + for _, hook := range conn.beforeAddPeerHooks { + if err := hook(conn.connID, endpointUdpAddr.IP); err != nil { + log.Errorf("Before add peer hook failed: %v", err) + } + } err = conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, endpointUdpAddr, conn.config.WgConfig.PreSharedKey) if err != nil { @@ -510,6 +534,15 @@ func (conn *Conn) cleanup() error { // todo: is it problem if we try to remove a peer what is never existed? err3 = conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey) + if conn.connID != "" { + for _, hook := range conn.afterRemovePeerHooks { + if err := hook(conn.connID); err != nil { + log.Errorf("After remove peer hook failed: %v", err) + } + } + } + conn.connID = "" + if conn.notifyDisconnected != nil { conn.notifyDisconnected() conn.notifyDisconnected = nil diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go index b2dff7f08..38cf4bf65 100644 --- a/client/internal/routemanager/client.go +++ b/client/internal/routemanager/client.go @@ -193,7 +193,7 @@ func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error { func (c *clientNetwork) removeRouteFromPeerAndSystem() error { if c.chosenRoute != nil { - if err := removeFromRouteTableIfNonSystem(c.network, c.wgInterface.Address().IP.String(), c.wgInterface.Name()); err != nil { + if err := removeVPNRoute(c.network, c.wgInterface.Name()); err != nil { return fmt.Errorf("remove route %s from system, err: %v", c.network, err) } @@ -234,7 +234,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { } } else { // otherwise add the route to the system - if err := addToRouteTableIfNoExists(c.network, c.wgInterface.Address().IP.String(), c.wgInterface.Name()); err != nil { + if err := addVPNRoute(c.network, c.wgInterface.Name()); err != nil { return fmt.Errorf("route %s couldn't be added for peer %s, err: %v", c.network.String(), c.wgInterface.Address().IP.String(), err) } diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 6a0d954da..36a37f02c 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -3,7 +3,9 @@ package routemanager import ( "context" "fmt" + "net" "net/netip" + "net/url" "runtime" "sync" @@ -24,7 +26,7 @@ var defaultv6 = netip.PrefixFrom(netip.IPv6Unspecified(), 0) // Manager is a route manager interface type Manager interface { - Init() error + Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error SetRouteChangeListener(listener listener.NetworkChangeListener) InitialRouteRange() []string @@ -65,16 +67,21 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, } // Init sets up the routing -func (m *DefaultManager) Init() error { +func (m *DefaultManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { if err := cleanupRouting(); err != nil { log.Warnf("Failed cleaning up routing: %v", err) } - if err := setupRouting(); err != nil { - return fmt.Errorf("setup routing: %w", err) + mgmtAddress := m.statusRecorder.GetManagementState().URL + signalAddress := m.statusRecorder.GetSignalState().URL + ips := resolveURLsToIPs([]string{mgmtAddress, signalAddress}) + + beforePeerHook, afterPeerHook, err := setupRouting(ips, m.wgInterface) + if err != nil { + return nil, nil, fmt.Errorf("setup routing: %w", err) } log.Info("Routing setup complete") - return nil + return beforePeerHook, afterPeerHook, nil } func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error { @@ -203,16 +210,36 @@ func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Rou } func isPrefixSupported(prefix netip.Prefix) bool { - if runtime.GOOS == "linux" { + switch runtime.GOOS { + case "linux", "windows", "darwin": return true } // If prefix is too small, lets assume it is a possible default prefix which is not yet supported // we skip this prefix management - if prefix.Bits() < minRangeBits { + if prefix.Bits() <= minRangeBits { log.Warnf("This agent version: %s, doesn't support default routes, received %s, skipping this prefix", version.NetbirdVersion(), prefix) return false } return true } + +// resolveURLsToIPs takes a slice of URLs, resolves them to IP addresses and returns a slice of IPs. +func resolveURLsToIPs(urls []string) []net.IP { + var ips []net.IP + for _, rawurl := range urls { + u, err := url.Parse(rawurl) + if err != nil { + log.Errorf("Failed to parse url %s: %v", rawurl, err) + continue + } + ipAddrs, err := net.LookupIP(u.Hostname()) + if err != nil { + log.Errorf("Failed to resolve host %s: %v", u.Hostname(), err) + continue + } + ips = append(ips, ipAddrs...) + } + return ips +} diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index 9d92bf90d..03e77e09b 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -28,14 +28,14 @@ const remotePeerKey2 = "remote1" func TestManagerUpdateRoutes(t *testing.T) { testCases := []struct { - name string - inputInitRoutes []*route.Route - inputRoutes []*route.Route - inputSerial uint64 - removeSrvRouter bool - serverRoutesExpected int - clientNetworkWatchersExpected int - clientNetworkWatchersExpectedLinux int + name string + inputInitRoutes []*route.Route + inputRoutes []*route.Route + inputSerial uint64 + removeSrvRouter bool + serverRoutesExpected int + clientNetworkWatchersExpected int + clientNetworkWatchersExpectedAllowed int }{ { name: "Should create 2 client networks", @@ -201,9 +201,9 @@ func TestManagerUpdateRoutes(t *testing.T) { Enabled: true, }, }, - inputSerial: 1, - clientNetworkWatchersExpected: 0, - clientNetworkWatchersExpectedLinux: 1, + inputSerial: 1, + clientNetworkWatchersExpected: 0, + clientNetworkWatchersExpectedAllowed: 1, }, { name: "Remove 1 Client Route", @@ -417,7 +417,9 @@ func TestManagerUpdateRoutes(t *testing.T) { statusRecorder := peer.NewRecorder("https://mgm") ctx := context.TODO() routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder, nil) - err = routeManager.Init() + + _, _, err = routeManager.Init() + require.NoError(t, err, "should init route manager") defer routeManager.Stop() @@ -434,8 +436,8 @@ func TestManagerUpdateRoutes(t *testing.T) { require.NoError(t, err, "should update routes") expectedWatchers := testCase.clientNetworkWatchersExpected - if runtime.GOOS == "linux" && testCase.clientNetworkWatchersExpectedLinux != 0 { - expectedWatchers = testCase.clientNetworkWatchersExpectedLinux + if (runtime.GOOS == "linux" || runtime.GOOS == "windows" || runtime.GOOS == "darwin") && testCase.clientNetworkWatchersExpectedAllowed != 0 { + expectedWatchers = testCase.clientNetworkWatchersExpectedAllowed } require.Len(t, routeManager.clientNetworks, expectedWatchers, "client networks size should match") diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index e812b3a85..dd2c28e59 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -6,6 +6,7 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/listener" + "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/route" ) @@ -16,8 +17,8 @@ type MockManager struct { StopFunc func() } -func (m *MockManager) Init() error { - return nil +func (m *MockManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { + return nil, nil, nil } // InitialRouteRange mock implementation of InitialRouteRange from Manager interface diff --git a/client/internal/routemanager/routemanager.go b/client/internal/routemanager/routemanager.go new file mode 100644 index 000000000..fe8d7b4ef --- /dev/null +++ b/client/internal/routemanager/routemanager.go @@ -0,0 +1,119 @@ +//go:build !android + +package routemanager + +import ( + "fmt" + "net/netip" + "sync" + + "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" + + nbnet "github.com/netbirdio/netbird/util/net" +) + +type ref struct { + count int + nexthop netip.Addr + intf string +} + +type RouteManager struct { + // refCountMap keeps track of the reference ref for prefixes + refCountMap map[netip.Prefix]ref + // prefixMap keeps track of the prefixes associated with a connection ID for removal + prefixMap map[nbnet.ConnectionID][]netip.Prefix + addRoute AddRouteFunc + removeRoute RemoveRouteFunc + mutex sync.Mutex +} + +type AddRouteFunc func(prefix netip.Prefix) (nexthop netip.Addr, intf string, err error) +type RemoveRouteFunc func(prefix netip.Prefix, nexthop netip.Addr, intf string) error + +func NewRouteManager(addRoute AddRouteFunc, removeRoute RemoveRouteFunc) *RouteManager { + // TODO: read initial routing table into refCountMap + return &RouteManager{ + refCountMap: map[netip.Prefix]ref{}, + prefixMap: map[nbnet.ConnectionID][]netip.Prefix{}, + addRoute: addRoute, + removeRoute: removeRoute, + } +} + +func (rm *RouteManager) AddRouteRef(connID nbnet.ConnectionID, prefix netip.Prefix) error { + rm.mutex.Lock() + defer rm.mutex.Unlock() + + ref := rm.refCountMap[prefix] + log.Debugf("Increasing route ref count %d for prefix %s", ref.count, prefix) + + // Add route to the system, only if it's a new prefix + if ref.count == 0 { + log.Debugf("Adding route for prefix %s", prefix) + nexthop, intf, err := rm.addRoute(prefix) + if err != nil { + return fmt.Errorf("failed to add route for prefix %s: %w", prefix, err) + } + ref.nexthop = nexthop + ref.intf = intf + } + + ref.count++ + rm.refCountMap[prefix] = ref + rm.prefixMap[connID] = append(rm.prefixMap[connID], prefix) + + return nil +} + +func (rm *RouteManager) RemoveRouteRef(connID nbnet.ConnectionID) error { + rm.mutex.Lock() + defer rm.mutex.Unlock() + + prefixes, ok := rm.prefixMap[connID] + if !ok { + log.Debugf("No prefixes found for connection ID %s", connID) + return nil + } + + var result *multierror.Error + for _, prefix := range prefixes { + ref := rm.refCountMap[prefix] + log.Debugf("Decreasing route ref count %d for prefix %s", ref.count, prefix) + if ref.count == 1 { + log.Debugf("Removing route for prefix %s", prefix) + // TODO: don't fail if the route is not found + if err := rm.removeRoute(prefix, ref.nexthop, ref.intf); err != nil { + result = multierror.Append(result, fmt.Errorf("remove route for prefix %s: %w", prefix, err)) + continue + } + delete(rm.refCountMap, prefix) + } else { + ref.count-- + rm.refCountMap[prefix] = ref + } + } + delete(rm.prefixMap, connID) + + return result.ErrorOrNil() +} + +// Flush removes all references and routes from the system +func (rm *RouteManager) Flush() error { + rm.mutex.Lock() + defer rm.mutex.Unlock() + + var result *multierror.Error + for prefix := range rm.refCountMap { + log.Debugf("Removing route for prefix %s", prefix) + ref := rm.refCountMap[prefix] + if err := rm.removeRoute(prefix, ref.nexthop, ref.intf); err != nil { + result = multierror.Append(result, fmt.Errorf("remove route for prefix %s: %w", prefix, err)) + } + } + rm.refCountMap = map[netip.Prefix]ref{} + rm.prefixMap = map[nbnet.ConnectionID][]netip.Prefix{} + + return result.ErrorOrNil() +} diff --git a/client/internal/routemanager/server_nonandroid.go b/client/internal/routemanager/server_nonandroid.go index 00df735fb..af82dc913 100644 --- a/client/internal/routemanager/server_nonandroid.go +++ b/client/internal/routemanager/server_nonandroid.go @@ -155,11 +155,13 @@ func (m *defaultServerRouter) cleanUp() { log.Errorf("Failed to remove cleanup route: %v", err) } - state := m.statusRecorder.GetLocalPeerState() - state.Routes = nil - m.statusRecorder.UpdateLocalPeerState(state) } + + state := m.statusRecorder.GetLocalPeerState() + state.Routes = nil + m.statusRecorder.UpdateLocalPeerState(state) } + func routeToRouterPair(source string, route *route.Route) (firewall.RouterPair, error) { parsed, err := netip.ParsePrefix(source) if err != nil { diff --git a/client/internal/routemanager/systemops_android.go b/client/internal/routemanager/systemops_android.go index 291826780..34d2d270f 100644 --- a/client/internal/routemanager/systemops_android.go +++ b/client/internal/routemanager/systemops_android.go @@ -1,13 +1,33 @@ package routemanager import ( + "net" "net/netip" + "runtime" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/iface" ) -func addToRouteTableIfNoExists(prefix netip.Prefix, addr, intf string) error { +func setupRouting([]net.IP, *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { + return nil, nil, nil +} + +func cleanupRouting() error { return nil } -func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr, intf string) error { +func enableIPForwarding() error { + log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS) + return nil +} + +func addVPNRoute(netip.Prefix, string) error { + return nil +} + +func removeVPNRoute(netip.Prefix, string) error { return nil } diff --git a/client/internal/routemanager/systemops_bsd_nonios.go b/client/internal/routemanager/systemops_bsd_nonios.go deleted file mode 100644 index f60c7afc3..000000000 --- a/client/internal/routemanager/systemops_bsd_nonios.go +++ /dev/null @@ -1,13 +0,0 @@ -//go:build (darwin || dragonfly || freebsd || netbsd || openbsd) && !ios - -package routemanager - -import "net/netip" - -func addToRouteTableIfNoExists(prefix netip.Prefix, addr string, intf string) error { - return genericAddToRouteTableIfNoExists(prefix, addr, intf) -} - -func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string, intf string) error { - return genericRemoveFromRouteTableIfNonSystem(prefix, addr, intf) -} diff --git a/client/internal/routemanager/systemops_darwin.go b/client/internal/routemanager/systemops_darwin.go new file mode 100644 index 000000000..f34964a83 --- /dev/null +++ b/client/internal/routemanager/systemops_darwin.go @@ -0,0 +1,61 @@ +//go:build darwin && !ios + +package routemanager + +import ( + "fmt" + "net" + "net/netip" + "os/exec" + "strings" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/iface" +) + +var routeManager *RouteManager + +func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { + return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface) +} + +func cleanupRouting() error { + return cleanupRoutingWithRouteManager(routeManager) +} + +func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { + return routeCmd("add", prefix, nexthop, intf) +} + +func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { + return routeCmd("delete", prefix, nexthop, intf) +} + +func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf string) error { + inet := "-inet" + if prefix.Addr().Is6() { + inet = "-inet6" + // Special case for IPv6 split default route, pointing to the wg interface fails + // TODO: Remove once we have IPv6 support on the interface + if prefix.Bits() == 1 { + intf = "lo0" + } + } + + args := []string{"-n", action, inet, prefix.String()} + if nexthop.IsValid() { + args = append(args, nexthop.Unmap().String()) + } else if intf != "" { + args = append(args, "-interface", intf) + } + + out, err := exec.Command("route", args...).CombinedOutput() + log.Tracef("route %s: %s", strings.Join(args, " "), out) + + if err != nil { + return fmt.Errorf("failed to %s route for %s: %w", action, prefix, err) + } + return nil +} diff --git a/client/internal/routemanager/systemops_darwin_test.go b/client/internal/routemanager/systemops_darwin_test.go new file mode 100644 index 000000000..5c5aaa24f --- /dev/null +++ b/client/internal/routemanager/systemops_darwin_test.go @@ -0,0 +1,100 @@ +//go:build !ios + +package routemanager + +import ( + "fmt" + "net" + "os/exec" + "regexp" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var expectedVPNint = "utun100" +var expectedExternalInt = "lo0" +var expectedInternalInt = "lo0" + +func init() { + testCases = append(testCases, []testCase{ + { + name: "To more specific route without custom dialer via vpn", + destination: "10.10.0.2:53", + expectedInterface: expectedVPNint, + dialer: &net.Dialer{}, + expectedPacket: createPacketExpectation("100.64.0.1", 12345, "10.10.0.2", 53), + }, + }...) +} + +func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string { + t.Helper() + + err := exec.Command("ifconfig", intf, "alias", ipAddressCIDR).Run() + require.NoError(t, err, "Failed to create loopback alias") + + t.Cleanup(func() { + err := exec.Command("ifconfig", intf, ipAddressCIDR, "-alias").Run() + assert.NoError(t, err, "Failed to remove loopback alias") + }) + + return "lo0" +} + +func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, _ string) { + t.Helper() + + var originalNexthop net.IP + if dstCIDR == "0.0.0.0/0" { + var err error + originalNexthop, err = fetchOriginalGateway() + if err != nil { + t.Logf("Failed to fetch original gateway: %v", err) + } + + if output, err := exec.Command("route", "delete", "-net", dstCIDR).CombinedOutput(); err != nil { + t.Logf("Failed to delete route: %v, output: %s", err, output) + } + } + + t.Cleanup(func() { + if originalNexthop != nil { + err := exec.Command("route", "add", "-net", dstCIDR, originalNexthop.String()).Run() + assert.NoError(t, err, "Failed to restore original route") + } + }) + + err := exec.Command("route", "add", "-net", dstCIDR, gw.String()).Run() + require.NoError(t, err, "Failed to add route") + + t.Cleanup(func() { + err := exec.Command("route", "delete", "-net", dstCIDR).Run() + assert.NoError(t, err, "Failed to remove route") + }) +} + +func fetchOriginalGateway() (net.IP, error) { + output, err := exec.Command("route", "-n", "get", "default").CombinedOutput() + if err != nil { + return nil, err + } + + matches := regexp.MustCompile(`gateway: (\S+)`).FindStringSubmatch(string(output)) + if len(matches) == 0 { + return nil, fmt.Errorf("gateway not found") + } + + return net.ParseIP(matches[1]), nil +} + +func setupDummyInterfacesAndRoutes(t *testing.T) { + t.Helper() + + defaultDummy := createAndSetupDummyInterface(t, expectedExternalInt, "192.168.0.1/24") + addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), defaultDummy) + + otherDummy := createAndSetupDummyInterface(t, expectedInternalInt, "192.168.1.1/24") + addDummyRoute(t, "10.0.0.0/8", net.IPv4(192, 168, 1, 1), otherDummy) +} diff --git a/client/internal/routemanager/systemops_ios.go b/client/internal/routemanager/systemops_ios.go index 291826780..34d2d270f 100644 --- a/client/internal/routemanager/systemops_ios.go +++ b/client/internal/routemanager/systemops_ios.go @@ -1,13 +1,33 @@ package routemanager import ( + "net" "net/netip" + "runtime" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/iface" ) -func addToRouteTableIfNoExists(prefix netip.Prefix, addr, intf string) error { +func setupRouting([]net.IP, *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { + return nil, nil, nil +} + +func cleanupRouting() error { return nil } -func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr, intf string) error { +func enableIPForwarding() error { + log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS) + return nil +} + +func addVPNRoute(netip.Prefix, string) error { + return nil +} + +func removeVPNRoute(netip.Prefix, string) error { return nil } diff --git a/client/internal/routemanager/systemops_linux.go b/client/internal/routemanager/systemops_linux.go index 83af5008a..d21a3bfbf 100644 --- a/client/internal/routemanager/systemops_linux.go +++ b/client/internal/routemanager/systemops_linux.go @@ -15,6 +15,8 @@ import ( log "github.com/sirupsen/logrus" "github.com/vishvananda/netlink" + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/iface" nbnet "github.com/netbirdio/netbird/util/net" ) @@ -64,7 +66,7 @@ func getSetupRules() []ruleParams { // enabling VPN connectivity. // // The rules are inserted in reverse order, as rules are added from the bottom up in the rule list. -func setupRouting() (err error) { +func setupRouting([]net.IP, *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) { if err = addRoutingTableName(); err != nil { log.Errorf("Error adding routing table name: %v", err) } @@ -80,11 +82,11 @@ func setupRouting() (err error) { rules := getSetupRules() for _, rule := range rules { if err := addRule(rule); err != nil { - return fmt.Errorf("%s: %w", rule.description, err) + return nil, nil, fmt.Errorf("%s: %w", rule.description, err) } } - return nil + return nil, nil, nil } // cleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'. @@ -110,7 +112,7 @@ func cleanupRouting() error { return result.ErrorOrNil() } -func addToRouteTableIfNoExists(prefix netip.Prefix, _ string, intf string) error { +func addVPNRoute(prefix netip.Prefix, intf string) error { // No need to check if routes exist as main table takes precedence over the VPN table via Rule 2 // TODO remove this once we have ipv6 support @@ -125,7 +127,7 @@ func addToRouteTableIfNoExists(prefix netip.Prefix, _ string, intf string) error return nil } -func removeFromRouteTableIfNonSystem(prefix netip.Prefix, _ string, intf string) error { +func removeVPNRoute(prefix netip.Prefix, intf string) error { // TODO remove this once we have ipv6 support if prefix == defaultv4 { if err := removeUnreachableRoute(&defaultv6, NetbirdVPNTableID, netlink.FAMILY_V6); err != nil { @@ -138,10 +140,6 @@ func removeFromRouteTableIfNonSystem(prefix netip.Prefix, _ string, intf string) return nil } -func getRoutesFromTable() ([]netip.Prefix, error) { - return getRoutes(NetbirdVPNTableID, netlink.FAMILY_V4) -} - // addRoute adds a route to a specific routing table identified by tableID. func addRoute(prefix *netip.Prefix, addr, intf *string, tableID, family int) error { route := &netlink.Route{ @@ -263,34 +261,6 @@ func flushRoutes(tableID, family int) error { return result.ErrorOrNil() } -// getRoutes fetches routes from a specific routing table identified by tableID. -func getRoutes(tableID, family int) ([]netip.Prefix, error) { - var prefixList []netip.Prefix - - routes, err := netlink.RouteListFiltered(family, &netlink.Route{Table: tableID}, netlink.RT_FILTER_TABLE) - if err != nil { - return nil, fmt.Errorf("list routes from table %d: %v", tableID, err) - } - - for _, route := range routes { - if route.Dst != nil { - addr, ok := netip.AddrFromSlice(route.Dst.IP) - if !ok { - return nil, fmt.Errorf("parse route destination IP: %v", route.Dst.IP) - } - - ones, _ := route.Dst.Mask.Size() - - prefix := netip.PrefixFrom(addr, ones) - if prefix.IsValid() { - prefixList = append(prefixList, prefix) - } - } - } - - return prefixList, nil -} - func enableIPForwarding() error { bytes, err := os.ReadFile(ipv4ForwardingPath) if err != nil { diff --git a/client/internal/routemanager/systemops_linux_test.go b/client/internal/routemanager/systemops_linux_test.go index 96e43d20f..50a02401a 100644 --- a/client/internal/routemanager/systemops_linux_test.go +++ b/client/internal/routemanager/systemops_linux_test.go @@ -6,34 +6,40 @@ import ( "errors" "fmt" "net" - "net/netip" "os" "strings" "syscall" "testing" - "time" - "github.com/gopacket/gopacket" - "github.com/gopacket/gopacket/layers" - "github.com/gopacket/gopacket/pcap" - "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vishvananda/netlink" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - - "github.com/netbirdio/netbird/client/internal/stdnet" - "github.com/netbirdio/netbird/iface" - nbnet "github.com/netbirdio/netbird/util/net" ) -type PacketExpectation struct { - SrcIP net.IP - DstIP net.IP - SrcPort int - DstPort int - UDP bool - TCP bool +var expectedVPNint = "wgtest0" +var expectedLoopbackInt = "lo" +var expectedExternalInt = "dummyext0" +var expectedInternalInt = "dummyint0" + +var errRouteNotFound = fmt.Errorf("route not found") + +func init() { + testCases = append(testCases, []testCase{ + { + name: "To more specific route without custom dialer via physical interface", + destination: "10.10.0.2:53", + expectedInterface: expectedInternalInt, + dialer: &net.Dialer{}, + expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.10.0.2", 53), + }, + { + name: "To more specific route (local) without custom dialer via physical interface", + destination: "127.0.10.1:53", + expectedInterface: expectedLoopbackInt, + dialer: &net.Dialer{}, + expectedPacket: createPacketExpectation("127.0.0.1", 12345, "127.0.10.1", 53), + }, + }...) } func TestEntryExists(t *testing.T) { @@ -92,157 +98,7 @@ func TestEntryExists(t *testing.T) { } } -func TestRoutingWithTables(t *testing.T) { - testCases := []struct { - name string - destination string - captureInterface string - dialer *net.Dialer - packetExpectation PacketExpectation - }{ - { - name: "To external host without fwmark via vpn", - destination: "192.0.2.1:53", - captureInterface: "wgtest0", - dialer: &net.Dialer{}, - packetExpectation: createPacketExpectation("100.64.0.1", 12345, "192.0.2.1", 53), - }, - { - name: "To external host with fwmark via physical interface", - destination: "192.0.2.1:53", - captureInterface: "dummyext0", - dialer: nbnet.NewDialer(), - packetExpectation: createPacketExpectation("192.168.0.1", 12345, "192.0.2.1", 53), - }, - - { - name: "To duplicate internal route with fwmark via physical interface", - destination: "10.0.0.1:53", - captureInterface: "dummyint0", - dialer: nbnet.NewDialer(), - packetExpectation: createPacketExpectation("192.168.1.1", 12345, "10.0.0.1", 53), - }, - { - name: "To duplicate internal route without fwmark via physical interface", // local route takes precedence - destination: "10.0.0.1:53", - captureInterface: "dummyint0", - dialer: &net.Dialer{}, - packetExpectation: createPacketExpectation("192.168.1.1", 12345, "10.0.0.1", 53), - }, - - { - name: "To unique vpn route with fwmark via physical interface", - destination: "172.16.0.1:53", - captureInterface: "dummyext0", - dialer: nbnet.NewDialer(), - packetExpectation: createPacketExpectation("192.168.0.1", 12345, "172.16.0.1", 53), - }, - { - name: "To unique vpn route without fwmark via vpn", - destination: "172.16.0.1:53", - captureInterface: "wgtest0", - dialer: &net.Dialer{}, - packetExpectation: createPacketExpectation("100.64.0.1", 12345, "172.16.0.1", 53), - }, - - { - name: "To more specific route without fwmark via vpn interface", - destination: "10.10.0.1:53", - captureInterface: "dummyint0", - dialer: &net.Dialer{}, - packetExpectation: createPacketExpectation("192.168.1.1", 12345, "10.10.0.1", 53), - }, - - { - name: "To more specific route (local) without fwmark via physical interface", - destination: "127.0.10.1:53", - captureInterface: "lo", - dialer: &net.Dialer{}, - packetExpectation: createPacketExpectation("127.0.0.1", 12345, "127.0.10.1", 53), - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - wgIface, _, _ := setupTestEnv(t) - - // default route exists in main table and vpn table - err := addToRouteTableIfNoExists(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Address().IP.String(), wgIface.Name()) - require.NoError(t, err, "addToRouteTableIfNoExists should not return err") - - // 10.0.0.0/8 route exists in main table and vpn table - err = addToRouteTableIfNoExists(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Address().IP.String(), wgIface.Name()) - require.NoError(t, err, "addToRouteTableIfNoExists should not return err") - - // 10.10.0.0/24 more specific route exists in vpn table - err = addToRouteTableIfNoExists(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Address().IP.String(), wgIface.Name()) - require.NoError(t, err, "addToRouteTableIfNoExists should not return err") - - // 127.0.10.0/24 more specific route exists in vpn table - err = addToRouteTableIfNoExists(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Address().IP.String(), wgIface.Name()) - require.NoError(t, err, "addToRouteTableIfNoExists should not return err") - - // unique route in vpn table - err = addToRouteTableIfNoExists(netip.MustParsePrefix("172.16.0.0/16"), wgIface.Address().IP.String(), wgIface.Name()) - require.NoError(t, err, "addToRouteTableIfNoExists should not return err") - - filter := createBPFFilter(tc.destination) - handle := startPacketCapture(t, tc.captureInterface, filter) - - sendTestPacket(t, tc.destination, tc.packetExpectation.SrcPort, tc.dialer) - - packetSource := gopacket.NewPacketSource(handle, handle.LinkType()) - packet, err := packetSource.NextPacket() - require.NoError(t, err) - - verifyPacket(t, packet, tc.packetExpectation) - }) - } -} - -func verifyPacket(t *testing.T, packet gopacket.Packet, exp PacketExpectation) { - t.Helper() - - ipLayer := packet.Layer(layers.LayerTypeIPv4) - require.NotNil(t, ipLayer, "Expected IPv4 layer not found in packet") - - ip, ok := ipLayer.(*layers.IPv4) - require.True(t, ok, "Failed to cast to IPv4 layer") - - // Convert both source and destination IP addresses to 16-byte representation - expectedSrcIP := exp.SrcIP.To16() - actualSrcIP := ip.SrcIP.To16() - assert.Equal(t, expectedSrcIP, actualSrcIP, "Source IP mismatch") - - expectedDstIP := exp.DstIP.To16() - actualDstIP := ip.DstIP.To16() - assert.Equal(t, expectedDstIP, actualDstIP, "Destination IP mismatch") - - if exp.UDP { - udpLayer := packet.Layer(layers.LayerTypeUDP) - require.NotNil(t, udpLayer, "Expected UDP layer not found in packet") - - udp, ok := udpLayer.(*layers.UDP) - require.True(t, ok, "Failed to cast to UDP layer") - - assert.Equal(t, layers.UDPPort(exp.SrcPort), udp.SrcPort, "UDP source port mismatch") - assert.Equal(t, layers.UDPPort(exp.DstPort), udp.DstPort, "UDP destination port mismatch") - } - - if exp.TCP { - tcpLayer := packet.Layer(layers.LayerTypeTCP) - require.NotNil(t, tcpLayer, "Expected TCP layer not found in packet") - - tcp, ok := tcpLayer.(*layers.TCP) - require.True(t, ok, "Failed to cast to TCP layer") - - assert.Equal(t, layers.TCPPort(exp.SrcPort), tcp.SrcPort, "TCP source port mismatch") - assert.Equal(t, layers.TCPPort(exp.DstPort), tcp.DstPort, "TCP destination port mismatch") - } - -} - -func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) *netlink.Dummy { +func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) string { t.Helper() dummy := &netlink.Dummy{LinkAttrs: netlink.LinkAttrs{Name: interfaceName}} @@ -264,35 +120,52 @@ func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR str require.NoError(t, err) } - return dummy + t.Cleanup(func() { + err := netlink.LinkDel(dummy) + assert.NoError(t, err) + }) + + return dummy.Name } -func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, linkIndex int) { +func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, intf string) { t.Helper() _, dstIPNet, err := net.ParseCIDR(dstCIDR) require.NoError(t, err) + // Handle existing routes with metric 0 + var originalNexthop net.IP + var originalLinkIndex int if dstIPNet.String() == "0.0.0.0/0" { - gw, linkIndex, err := fetchOriginalGateway(netlink.FAMILY_V4) - if err != nil { + var err error + originalNexthop, originalLinkIndex, err = fetchOriginalGateway(netlink.FAMILY_V4) + if err != nil && !errors.Is(err, errRouteNotFound) { t.Logf("Failed to fetch original gateway: %v", err) } - // Handle existing routes with metric 0 - err = netlink.RouteDel(&netlink.Route{Dst: dstIPNet, Priority: 0}) - if err == nil { - t.Cleanup(func() { - err := netlink.RouteAdd(&netlink.Route{Dst: dstIPNet, Gw: gw, LinkIndex: linkIndex, Priority: 0}) - if err != nil && !errors.Is(err, syscall.EEXIST) { - t.Fatalf("Failed to add route: %v", err) - } - }) - } else if !errors.Is(err, syscall.ESRCH) { - t.Logf("Failed to delete route: %v", err) + if originalNexthop != nil { + err = netlink.RouteDel(&netlink.Route{Dst: dstIPNet, Priority: 0}) + switch { + case err != nil && !errors.Is(err, syscall.ESRCH): + t.Logf("Failed to delete route: %v", err) + case err == nil: + t.Cleanup(func() { + err := netlink.RouteAdd(&netlink.Route{Dst: dstIPNet, Gw: originalNexthop, LinkIndex: originalLinkIndex, Priority: 0}) + if err != nil && !errors.Is(err, syscall.EEXIST) { + t.Fatalf("Failed to add route: %v", err) + } + }) + default: + t.Logf("Failed to delete route: %v", err) + } } } + link, err := netlink.LinkByName(intf) + require.NoError(t, err) + linkIndex := link.Attrs().Index + route := &netlink.Route{ Dst: dstIPNet, Gw: gw, @@ -307,9 +180,9 @@ func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, linkIndex int) { if err != nil && !errors.Is(err, syscall.EEXIST) { t.Fatalf("Failed to add route: %v", err) } + require.NoError(t, err) } -// fetchOriginalGateway returns the original gateway IP address and the interface index. func fetchOriginalGateway(family int) (net.IP, int, error) { routes, err := netlink.RouteList(nil, family) if err != nil { @@ -317,153 +190,20 @@ func fetchOriginalGateway(family int) (net.IP, int, error) { } for _, route := range routes { - if route.Dst == nil { + if route.Dst == nil && route.Priority == 0 { return route.Gw, route.LinkIndex, nil } } - return nil, 0, fmt.Errorf("default route not found") + return nil, 0, errRouteNotFound } -func setupDummyInterfacesAndRoutes(t *testing.T) (string, string) { +func setupDummyInterfacesAndRoutes(t *testing.T) { t.Helper() defaultDummy := createAndSetupDummyInterface(t, "dummyext0", "192.168.0.1/24") - addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), defaultDummy.Attrs().Index) + addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), defaultDummy) otherDummy := createAndSetupDummyInterface(t, "dummyint0", "192.168.1.1/24") - addDummyRoute(t, "10.0.0.0/8", nil, otherDummy.Attrs().Index) - - t.Cleanup(func() { - err := netlink.LinkDel(defaultDummy) - assert.NoError(t, err) - err = netlink.LinkDel(otherDummy) - assert.NoError(t, err) - }) - - return defaultDummy.Name, otherDummy.Name -} - -func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listenPort int) *iface.WGIface { - t.Helper() - - peerPrivateKey, err := wgtypes.GeneratePrivateKey() - require.NoError(t, err) - - newNet, err := stdnet.NewNet(nil) - require.NoError(t, err) - - wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil) - require.NoError(t, err, "should create testing WireGuard interface") - - err = wgInterface.Create() - require.NoError(t, err, "should create testing WireGuard interface") - - t.Cleanup(func() { - wgInterface.Close() - }) - - return wgInterface -} - -func setupTestEnv(t *testing.T) (*iface.WGIface, string, string) { - t.Helper() - - defaultDummy, otherDummy := setupDummyInterfacesAndRoutes(t) - - wgIface := createWGInterface(t, "wgtest0", "100.64.0.1/24", 51820) - t.Cleanup(func() { - assert.NoError(t, wgIface.Close()) - }) - - err := setupRouting() - require.NoError(t, err, "setupRouting should not return err") - t.Cleanup(func() { - assert.NoError(t, cleanupRouting()) - }) - - return wgIface, defaultDummy, otherDummy -} - -func startPacketCapture(t *testing.T, intf, filter string) *pcap.Handle { - t.Helper() - - inactive, err := pcap.NewInactiveHandle(intf) - require.NoError(t, err, "Failed to create inactive pcap handle") - defer inactive.CleanUp() - - err = inactive.SetSnapLen(1600) - require.NoError(t, err, "Failed to set snap length on inactive handle") - - err = inactive.SetTimeout(time.Second * 10) - require.NoError(t, err, "Failed to set timeout on inactive handle") - - err = inactive.SetImmediateMode(true) - require.NoError(t, err, "Failed to set immediate mode on inactive handle") - - handle, err := inactive.Activate() - require.NoError(t, err, "Failed to activate pcap handle") - t.Cleanup(handle.Close) - - err = handle.SetBPFFilter(filter) - require.NoError(t, err, "Failed to set BPF filter") - - return handle -} - -func sendTestPacket(t *testing.T, destination string, sourcePort int, dialer *net.Dialer) { - t.Helper() - - if dialer == nil { - dialer = &net.Dialer{} - } - - if sourcePort != 0 { - localUDPAddr := &net.UDPAddr{ - IP: net.IPv4zero, - Port: sourcePort, - } - dialer.LocalAddr = localUDPAddr - } - - msg := new(dns.Msg) - msg.Id = dns.Id() - msg.RecursionDesired = true - msg.Question = []dns.Question{ - {Name: "example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, - } - - conn, err := dialer.Dial("udp", destination) - require.NoError(t, err, "Failed to dial UDP") - defer conn.Close() - - data, err := msg.Pack() - require.NoError(t, err, "Failed to pack DNS message") - - _, err = conn.Write(data) - if err != nil { - if strings.Contains(err.Error(), "required key not available") { - t.Logf("Ignoring WireGuard key error: %v", err) - return - } - t.Fatalf("Failed to send DNS query: %v", err) - } -} - -func createBPFFilter(destination string) string { - host, port, err := net.SplitHostPort(destination) - if err != nil { - return fmt.Sprintf("udp and dst host %s and dst port %s", host, port) - } - return "udp" -} - -func createPacketExpectation(srcIP string, srcPort int, dstIP string, dstPort int) PacketExpectation { - return PacketExpectation{ - SrcIP: net.ParseIP(srcIP), - DstIP: net.ParseIP(dstIP), - SrcPort: srcPort, - DstPort: dstPort, - UDP: true, - } + addDummyRoute(t, "10.0.0.0/8", net.IPv4(192, 168, 1, 1), otherDummy) } diff --git a/client/internal/routemanager/systemops_nonandroid.go b/client/internal/routemanager/systemops_nonandroid.go deleted file mode 100644 index 65f670ace..000000000 --- a/client/internal/routemanager/systemops_nonandroid.go +++ /dev/null @@ -1,148 +0,0 @@ -//go:build !android - -//nolint:unused -package routemanager - -import ( - "errors" - "fmt" - "net" - "net/netip" - "os/exec" - "runtime" - - "github.com/libp2p/go-netroute" - log "github.com/sirupsen/logrus" -) - -var errRouteNotFound = fmt.Errorf("route not found") - -func genericAddRouteForCurrentDefaultGateway(prefix netip.Prefix) error { - defaultGateway, err := getExistingRIBRouteGateway(defaultv4) - if err != nil && !errors.Is(err, errRouteNotFound) { - return fmt.Errorf("get existing route gateway: %s", err) - } - - addr := netip.MustParseAddr(defaultGateway.String()) - - if !prefix.Contains(addr) { - log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", addr, prefix) - return nil - } - - gatewayPrefix := netip.PrefixFrom(addr, 32) - - ok, err := existsInRouteTable(gatewayPrefix) - if err != nil { - return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err) - } - - if ok { - log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix) - return nil - } - - gatewayHop, err := getExistingRIBRouteGateway(gatewayPrefix) - if err != nil && !errors.Is(err, errRouteNotFound) { - return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err) - } - log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop) - return genericAddToRouteTable(gatewayPrefix, gatewayHop.String(), "") -} - -func genericAddToRouteTableIfNoExists(prefix netip.Prefix, addr string, intf string) error { - ok, err := existsInRouteTable(prefix) - if err != nil { - return fmt.Errorf("exists in route table: %w", err) - } - if ok { - log.Warnf("Skipping adding a new route for network %s because it already exists", prefix) - return nil - } - - ok, err = isSubRange(prefix) - if err != nil { - return fmt.Errorf("sub range: %w", err) - } - - if ok { - err := genericAddRouteForCurrentDefaultGateway(prefix) - if err != nil { - log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err) - } - } - - return genericAddToRouteTable(prefix, addr, intf) -} - -func genericRemoveFromRouteTableIfNonSystem(prefix netip.Prefix, addr string, intf string) error { - return genericRemoveFromRouteTable(prefix, addr, intf) -} - -func genericAddToRouteTable(prefix netip.Prefix, addr, _ string) error { - cmd := exec.Command("route", "add", prefix.String(), addr) - out, err := cmd.Output() - if err != nil { - return fmt.Errorf("add route: %w", err) - } - log.Debugf(string(out)) - return nil -} - -func genericRemoveFromRouteTable(prefix netip.Prefix, addr, _ string) error { - args := []string{"delete", prefix.String()} - if runtime.GOOS == "darwin" { - args = append(args, addr) - } - cmd := exec.Command("route", args...) - out, err := cmd.Output() - if err != nil { - return fmt.Errorf("remove route: %w", err) - } - log.Debugf(string(out)) - return nil -} - -func getExistingRIBRouteGateway(prefix netip.Prefix) (net.IP, error) { - r, err := netroute.New() - if err != nil { - return nil, fmt.Errorf("new netroute: %w", err) - } - _, gateway, preferredSrc, err := r.Route(prefix.Addr().AsSlice()) - if err != nil { - log.Errorf("Getting routes returned an error: %v", err) - return nil, errRouteNotFound - } - - if gateway == nil { - return preferredSrc, nil - } - - return gateway, nil -} - -func existsInRouteTable(prefix netip.Prefix) (bool, error) { - routes, err := getRoutesFromTable() - if err != nil { - return false, fmt.Errorf("get routes from table: %w", err) - } - for _, tableRoute := range routes { - if tableRoute == prefix { - return true, nil - } - } - return false, nil -} - -func isSubRange(prefix netip.Prefix) (bool, error) { - routes, err := getRoutesFromTable() - if err != nil { - return false, fmt.Errorf("get routes from table: %w", err) - } - for _, tableRoute := range routes { - if isPrefixSupported(tableRoute) && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() { - return true, nil - } - } - return false, nil -} diff --git a/client/internal/routemanager/systemops_nonandroid_test.go b/client/internal/routemanager/systemops_nonandroid_test.go deleted file mode 100644 index aae5e5faa..000000000 --- a/client/internal/routemanager/systemops_nonandroid_test.go +++ /dev/null @@ -1,282 +0,0 @@ -//go:build !android - -package routemanager - -import ( - "bytes" - "fmt" - "net" - "net/netip" - "os" - "os/exec" - "runtime" - "strings" - "testing" - - "github.com/pion/transport/v3/stdnet" - log "github.com/sirupsen/logrus" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - - "github.com/netbirdio/netbird/iface" -) - -func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIface, invert bool) { - t.Helper() - - if runtime.GOOS == "linux" { - outIntf, err := getOutgoingInterfaceLinux(prefix.Addr().String()) - require.NoError(t, err, "getOutgoingInterfaceLinux should not return error") - if invert { - require.NotEqual(t, wgIface.Name(), outIntf, "outgoing interface should not be the wireguard interface") - } else { - require.Equal(t, wgIface.Name(), outIntf, "outgoing interface should be the wireguard interface") - } - return - } - - prefixGateway, err := getExistingRIBRouteGateway(prefix) - require.NoError(t, err, "getExistingRIBRouteGateway should not return err") - if invert { - assert.NotEqual(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should not point to wireguard interface IP") - } else { - assert.Equal(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP") - } -} - -func getOutgoingInterfaceLinux(destination string) (string, error) { - cmd := exec.Command("ip", "route", "get", destination) - output, err := cmd.Output() - if err != nil { - return "", fmt.Errorf("executing ip route get: %w", err) - } - - return parseOutgoingInterface(string(output)), nil -} - -func parseOutgoingInterface(routeGetOutput string) string { - fields := strings.Fields(routeGetOutput) - for i, field := range fields { - if field == "dev" && i+1 < len(fields) { - return fields[i+1] - } - } - return "" -} - -func TestAddRemoveRoutes(t *testing.T) { - testCases := []struct { - name string - prefix netip.Prefix - shouldRouteToWireguard bool - shouldBeRemoved bool - }{ - { - name: "Should Add And Remove Route 100.66.120.0/24", - prefix: netip.MustParsePrefix("100.66.120.0/24"), - shouldRouteToWireguard: true, - shouldBeRemoved: true, - }, - { - name: "Should Not Add Or Remove Route 127.0.0.1/32", - prefix: netip.MustParsePrefix("127.0.0.1/32"), - shouldRouteToWireguard: false, - shouldBeRemoved: false, - }, - } - - for n, testCase := range testCases { - t.Run(testCase.name, func(t *testing.T) { - peerPrivateKey, _ := wgtypes.GeneratePrivateKey() - newNet, err := stdnet.NewNet() - if err != nil { - t.Fatal(err) - } - wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil) - require.NoError(t, err, "should create testing WGIface interface") - defer wgInterface.Close() - - err = wgInterface.Create() - require.NoError(t, err, "should create testing wireguard interface") - - require.NoError(t, setupRouting()) - t.Cleanup(func() { - assert.NoError(t, cleanupRouting()) - }) - - err = addToRouteTableIfNoExists(testCase.prefix, wgInterface.Address().IP.String(), wgInterface.Name()) - require.NoError(t, err, "addToRouteTableIfNoExists should not return err") - - if testCase.shouldRouteToWireguard { - assertWGOutInterface(t, testCase.prefix, wgInterface, false) - } else { - assertWGOutInterface(t, testCase.prefix, wgInterface, true) - } - exists, err := existsInRouteTable(testCase.prefix) - require.NoError(t, err, "existsInRouteTable should not return err") - if exists && testCase.shouldRouteToWireguard { - err = removeFromRouteTableIfNonSystem(testCase.prefix, wgInterface.Address().IP.String(), wgInterface.Name()) - require.NoError(t, err, "removeFromRouteTableIfNonSystem should not return err") - - prefixGateway, err := getExistingRIBRouteGateway(testCase.prefix) - require.NoError(t, err, "getExistingRIBRouteGateway should not return err") - - internetGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0")) - require.NoError(t, err) - - if testCase.shouldBeRemoved { - require.Equal(t, internetGateway, prefixGateway, "route should be pointing to default internet gateway") - } else { - require.NotEqual(t, internetGateway, prefixGateway, "route should be pointing to a different gateway than the internet gateway") - } - } - }) - } -} - -func TestGetExistingRIBRouteGateway(t *testing.T) { - gateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0")) - if err != nil { - t.Fatal("shouldn't return error when fetching the gateway: ", err) - } - if gateway == nil { - t.Fatal("should return a gateway") - } - addresses, err := net.InterfaceAddrs() - if err != nil { - t.Fatal("shouldn't return error when fetching interface addresses: ", err) - } - - var testingIP string - var testingPrefix netip.Prefix - for _, address := range addresses { - if address.Network() != "ip+net" { - continue - } - prefix := netip.MustParsePrefix(address.String()) - if !prefix.Addr().IsLoopback() && prefix.Addr().Is4() { - testingIP = prefix.Addr().String() - testingPrefix = prefix.Masked() - break - } - } - - localIP, err := getExistingRIBRouteGateway(testingPrefix) - if err != nil { - t.Fatal("shouldn't return error: ", err) - } - if localIP == nil { - t.Fatal("should return a gateway for local network") - } - if localIP.String() == gateway.String() { - t.Fatal("local ip should not match with gateway IP") - } - if localIP.String() != testingIP { - t.Fatalf("local ip should match with testing IP: want %s got %s", testingIP, localIP.String()) - } -} - -func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) { - defaultGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0")) - t.Log("defaultGateway: ", defaultGateway) - if err != nil { - t.Fatal("shouldn't return error when fetching the gateway: ", err) - } - testCases := []struct { - name string - prefix netip.Prefix - preExistingPrefix netip.Prefix - shouldAddRoute bool - }{ - { - name: "Should Add And Remove random Route", - prefix: netip.MustParsePrefix("99.99.99.99/32"), - shouldAddRoute: true, - }, - { - name: "Should Not Add Route if overlaps with default gateway", - prefix: netip.MustParsePrefix(defaultGateway.String() + "/31"), - shouldAddRoute: false, - }, - { - name: "Should Add Route if bigger network exists", - prefix: netip.MustParsePrefix("100.100.100.0/24"), - preExistingPrefix: netip.MustParsePrefix("100.100.0.0/16"), - shouldAddRoute: true, - }, - { - name: "Should Add Route if smaller network exists", - prefix: netip.MustParsePrefix("100.100.0.0/16"), - preExistingPrefix: netip.MustParsePrefix("100.100.100.0/24"), - shouldAddRoute: true, - }, - { - name: "Should Not Add Route if same network exists", - prefix: netip.MustParsePrefix("100.100.0.0/16"), - preExistingPrefix: netip.MustParsePrefix("100.100.0.0/16"), - shouldAddRoute: false, - }, - } - - for n, testCase := range testCases { - var buf bytes.Buffer - log.SetOutput(&buf) - defer func() { - log.SetOutput(os.Stderr) - }() - t.Run(testCase.name, func(t *testing.T) { - peerPrivateKey, _ := wgtypes.GeneratePrivateKey() - newNet, err := stdnet.NewNet() - if err != nil { - t.Fatal(err) - } - wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil) - require.NoError(t, err, "should create testing WGIface interface") - defer wgInterface.Close() - - err = wgInterface.Create() - require.NoError(t, err, "should create testing wireguard interface") - - require.NoError(t, setupRouting()) - t.Cleanup(func() { - assert.NoError(t, cleanupRouting()) - }) - - MockAddr := wgInterface.Address().IP.String() - - // Prepare the environment - if testCase.preExistingPrefix.IsValid() { - err := addToRouteTableIfNoExists(testCase.preExistingPrefix, MockAddr, wgInterface.Name()) - require.NoError(t, err, "should not return err when adding pre-existing route") - } - - // Add the route - err = addToRouteTableIfNoExists(testCase.prefix, MockAddr, wgInterface.Name()) - require.NoError(t, err, "should not return err when adding route") - - if testCase.shouldAddRoute { - // test if route exists after adding - ok, err := existsInRouteTable(testCase.prefix) - require.NoError(t, err, "should not return err") - require.True(t, ok, "route should exist") - - // remove route again if added - err = removeFromRouteTableIfNonSystem(testCase.prefix, MockAddr, wgInterface.Name()) - require.NoError(t, err, "should not return err") - } - - // route should either not have been added or should have been removed - // In case of already existing route, it should not have been added (but still exist) - ok, err := existsInRouteTable(testCase.prefix) - t.Log("Buffer string: ", buf.String()) - require.NoError(t, err, "should not return err") - - // Linux uses a separate routing table, so the route can exist in both tables. - // The main routing table takes precedence over the wireguard routing table. - if !strings.Contains(buf.String(), "because it already exists") && runtime.GOOS != "linux" { - require.False(t, ok, "route should not exist") - } - }) - } -} diff --git a/client/internal/routemanager/systemops_nonlinux.go b/client/internal/routemanager/systemops_nonlinux.go index d793f0fbd..4bc186f21 100644 --- a/client/internal/routemanager/systemops_nonlinux.go +++ b/client/internal/routemanager/systemops_nonlinux.go @@ -1,22 +1,416 @@ -//go:build !linux || android +//go:build !linux && !ios package routemanager import ( + "context" + "errors" + "fmt" + "net" + "net/netip" "runtime" + "github.com/hashicorp/go-multierror" + "github.com/libp2p/go-netroute" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/iface" + nbnet "github.com/netbirdio/netbird/util/net" ) -func setupRouting() error { - return nil -} +var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1) +var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1) +var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1) +var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1) -func cleanupRouting() error { - return nil -} +var errRouteNotFound = fmt.Errorf("route not found") func enableIPForwarding() error { log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS) return nil } + +// TODO: fix: for default our wg address now appears as the default gw +func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error { + addr := netip.IPv4Unspecified() + if prefix.Addr().Is6() { + addr = netip.IPv6Unspecified() + } + + defaultGateway, _, err := getNextHop(addr) + if err != nil && !errors.Is(err, errRouteNotFound) { + return fmt.Errorf("get existing route gateway: %s", err) + } + + if !prefix.Contains(defaultGateway) { + log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", defaultGateway, prefix) + return nil + } + + gatewayPrefix := netip.PrefixFrom(defaultGateway, 32) + if defaultGateway.Is6() { + gatewayPrefix = netip.PrefixFrom(defaultGateway, 128) + } + + ok, err := existsInRouteTable(gatewayPrefix) + if err != nil { + return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err) + } + + if ok { + log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix) + return nil + } + + var exitIntf string + gatewayHop, intf, err := getNextHop(defaultGateway) + if err != nil && !errors.Is(err, errRouteNotFound) { + return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err) + } + if intf != nil { + exitIntf = intf.Name + } + + log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop) + return addToRouteTable(gatewayPrefix, gatewayHop, exitIntf) +} + +func getNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) { + r, err := netroute.New() + if err != nil { + return netip.Addr{}, nil, fmt.Errorf("new netroute: %w", err) + } + intf, gateway, preferredSrc, err := r.Route(ip.AsSlice()) + if err != nil { + log.Errorf("Getting routes returned an error: %v", err) + return netip.Addr{}, nil, errRouteNotFound + } + + log.Debugf("Route for %s: interface %v, nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc) + if gateway == nil { + if preferredSrc == nil { + return netip.Addr{}, nil, errRouteNotFound + } + log.Debugf("No next hop found for ip %s, using preferred source %s", ip, preferredSrc) + + addr, ok := netip.AddrFromSlice(preferredSrc) + if !ok { + return netip.Addr{}, nil, fmt.Errorf("failed to parse IP address: %s", preferredSrc) + } + return addr.Unmap(), intf, nil + } + + addr, ok := netip.AddrFromSlice(gateway) + if !ok { + return netip.Addr{}, nil, fmt.Errorf("failed to parse IP address: %s", gateway) + } + + return addr.Unmap(), intf, nil +} + +func existsInRouteTable(prefix netip.Prefix) (bool, error) { + routes, err := getRoutesFromTable() + if err != nil { + return false, fmt.Errorf("get routes from table: %w", err) + } + for _, tableRoute := range routes { + if tableRoute == prefix { + return true, nil + } + } + return false, nil +} + +func isSubRange(prefix netip.Prefix) (bool, error) { + routes, err := getRoutesFromTable() + if err != nil { + return false, fmt.Errorf("get routes from table: %w", err) + } + for _, tableRoute := range routes { + if tableRoute.Bits() > minRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() { + return true, nil + } + } + return false, nil +} + +// getRouteToNonVPNIntf returns the next hop and interface for the given prefix. +// If the next hop or interface is pointing to the VPN interface, it will return an error +func addRouteToNonVPNIntf( + prefix netip.Prefix, + vpnIntf *iface.WGIface, + initialNextHop netip.Addr, + initialIntf *net.Interface, +) (netip.Addr, string, error) { + addr := prefix.Addr() + switch { + case addr.IsLoopback(): + return netip.Addr{}, "", fmt.Errorf("adding route for loopback address %s is not allowed", prefix) + case addr.IsLinkLocalUnicast(): + return netip.Addr{}, "", fmt.Errorf("adding route for link-local unicast address %s is not allowed", prefix) + case addr.IsLinkLocalMulticast(): + return netip.Addr{}, "", fmt.Errorf("adding route for link-local multicast address %s is not allowed", prefix) + case addr.IsInterfaceLocalMulticast(): + return netip.Addr{}, "", fmt.Errorf("adding route for interface-local multicast address %s is not allowed", prefix) + case addr.IsUnspecified(): + return netip.Addr{}, "", fmt.Errorf("adding route for unspecified address %s is not allowed", prefix) + case addr.IsMulticast(): + return netip.Addr{}, "", fmt.Errorf("adding route for multicast address %s is not allowed", prefix) + } + + // Determine the exit interface and next hop for the prefix, so we can add a specific route + nexthop, intf, err := getNextHop(addr) + if err != nil { + return netip.Addr{}, "", fmt.Errorf("get next hop: %s", err) + } + + log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop, prefix, intf) + exitNextHop := nexthop + var exitIntf string + if intf != nil { + exitIntf = intf.Name + } + + vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP) + if !ok { + return netip.Addr{}, "", fmt.Errorf("failed to convert vpn address to netip.Addr") + } + + // if next hop is the VPN address or the interface is the VPN interface, we should use the initial values + if exitNextHop == vpnAddr || exitIntf == vpnIntf.Name() { + log.Debugf("Route for prefix %s is pointing to the VPN interface", prefix) + exitNextHop = initialNextHop + if initialIntf != nil { + exitIntf = initialIntf.Name + } + } + + log.Debugf("Adding a new route for prefix %s with next hop %s", prefix, exitNextHop) + if err := addToRouteTable(prefix, exitNextHop, exitIntf); err != nil { + return netip.Addr{}, "", fmt.Errorf("add route to table: %w", err) + } + + return exitNextHop, exitIntf, nil +} + +// addVPNRoute adds a new route to the vpn interface, it splits the default prefix +// in two /1 prefixes to avoid replacing the existing default route +func addVPNRoute(prefix netip.Prefix, intf string) error { + if prefix == defaultv4 { + if err := addToRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil { + return err + } + if err := addToRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil { + if err2 := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err2 != nil { + log.Warnf("Failed to rollback route addition: %s", err2) + } + return err + } + + // TODO: remove once IPv6 is supported on the interface + if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { + return fmt.Errorf("add unreachable route split 1: %w", err) + } + if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { + if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil { + log.Warnf("Failed to rollback route addition: %s", err2) + } + return fmt.Errorf("add unreachable route split 2: %w", err) + } + + return nil + } else if prefix == defaultv6 { + if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { + return fmt.Errorf("add unreachable route split 1: %w", err) + } + if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { + if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil { + log.Warnf("Failed to rollback route addition: %s", err2) + } + return fmt.Errorf("add unreachable route split 2: %w", err) + } + + return nil + } + + return addNonExistingRoute(prefix, intf) +} + +// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table +func addNonExistingRoute(prefix netip.Prefix, intf string) error { + ok, err := existsInRouteTable(prefix) + if err != nil { + return fmt.Errorf("exists in route table: %w", err) + } + if ok { + log.Warnf("Skipping adding a new route for network %s because it already exists", prefix) + return nil + } + + ok, err = isSubRange(prefix) + if err != nil { + return fmt.Errorf("sub range: %w", err) + } + + if ok { + err := addRouteForCurrentDefaultGateway(prefix) + if err != nil { + log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err) + } + } + + return addToRouteTable(prefix, netip.Addr{}, intf) +} + +// removeVPNRoute removes the route from the vpn interface. If a default prefix is given, +// it will remove the split /1 prefixes +func removeVPNRoute(prefix netip.Prefix, intf string) error { + if prefix == defaultv4 { + var result *multierror.Error + if err := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil { + result = multierror.Append(result, err) + } + if err := removeFromRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil { + result = multierror.Append(result, err) + } + + // TODO: remove once IPv6 is supported on the interface + if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { + result = multierror.Append(result, err) + } + if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { + result = multierror.Append(result, err) + } + + return result.ErrorOrNil() + } else if prefix == defaultv6 { + var result *multierror.Error + if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { + result = multierror.Append(result, err) + } + if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { + result = multierror.Append(result, err) + } + + return result.ErrorOrNil() + } + + return removeFromRouteTable(prefix, netip.Addr{}, intf) +} + +func getPrefixFromIP(ip net.IP) (*netip.Prefix, error) { + addr, ok := netip.AddrFromSlice(ip) + if !ok { + return nil, fmt.Errorf("parse IP address: %s", ip) + } + addr = addr.Unmap() + + var prefixLength int + switch { + case addr.Is4(): + prefixLength = 32 + case addr.Is6(): + prefixLength = 128 + default: + return nil, fmt.Errorf("invalid IP address: %s", addr) + } + + prefix := netip.PrefixFrom(addr, prefixLength) + return &prefix, nil +} + +func setupRoutingWithRouteManager(routeManager **RouteManager, initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { + initialNextHopV4, initialIntfV4, err := getNextHop(netip.IPv4Unspecified()) + if err != nil { + log.Errorf("Unable to get initial v4 default next hop: %v", err) + } + initialNextHopV6, initialIntfV6, err := getNextHop(netip.IPv6Unspecified()) + if err != nil { + log.Errorf("Unable to get initial v6 default next hop: %v", err) + } + + *routeManager = NewRouteManager( + func(prefix netip.Prefix) (netip.Addr, string, error) { + addr := prefix.Addr() + nexthop, intf := initialNextHopV4, initialIntfV4 + if addr.Is6() { + nexthop, intf = initialNextHopV6, initialIntfV6 + } + return addRouteToNonVPNIntf(prefix, wgIface, nexthop, intf) + }, + removeFromRouteTable, + ) + + return setupHooks(*routeManager, initAddresses) +} + +func cleanupRoutingWithRouteManager(routeManager *RouteManager) error { + if routeManager == nil { + return nil + } + + // TODO: Remove hooks selectively + nbnet.RemoveDialerHooks() + nbnet.RemoveListenerHooks() + + if err := routeManager.Flush(); err != nil { + return fmt.Errorf("flush route manager: %w", err) + } + + return nil +} + +func setupHooks(routeManager *RouteManager, initAddresses []net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { + beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error { + prefix, err := getPrefixFromIP(ip) + if err != nil { + return fmt.Errorf("convert ip to prefix: %w", err) + } + + if err := routeManager.AddRouteRef(connID, *prefix); err != nil { + return fmt.Errorf("adding route reference: %v", err) + } + + return nil + } + afterHook := func(connID nbnet.ConnectionID) error { + if err := routeManager.RemoveRouteRef(connID); err != nil { + return fmt.Errorf("remove route reference: %w", err) + } + + return nil + } + + for _, ip := range initAddresses { + if err := beforeHook("init", ip); err != nil { + log.Errorf("Failed to add route reference: %v", err) + } + } + + nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error { + if ctx.Err() != nil { + return ctx.Err() + } + + var result *multierror.Error + for _, ip := range resolvedIPs { + result = multierror.Append(result, beforeHook(connID, ip.IP)) + } + return result.ErrorOrNil() + }) + + nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error { + return afterHook(connID) + }) + + nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error { + return beforeHook(connID, ip.IP) + }) + + nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error { + return afterHook(connID) + }) + + return beforeHook, afterHook, nil +} diff --git a/client/internal/routemanager/systemops_nonlinux_test.go b/client/internal/routemanager/systemops_nonlinux_test.go index afaf5ba77..adb83bac6 100644 --- a/client/internal/routemanager/systemops_nonlinux_test.go +++ b/client/internal/routemanager/systemops_nonlinux_test.go @@ -1,16 +1,250 @@ -//go:build !linux || android +//go:build !linux && !ios package routemanager import ( + "bytes" + "fmt" "net" "net/netip" + "os" + "strings" "testing" + "github.com/pion/transport/v3/stdnet" + log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/netbirdio/netbird/iface" ) +func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIface, invert bool) { + t.Helper() + + prefixGateway, _, err := getNextHop(prefix.Addr()) + require.NoError(t, err, "getNextHop should not return err") + if invert { + assert.NotEqual(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should not point to wireguard interface IP") + } else { + assert.Equal(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP") + } +} + +func TestAddRemoveRoutes(t *testing.T) { + testCases := []struct { + name string + prefix netip.Prefix + shouldRouteToWireguard bool + shouldBeRemoved bool + }{ + { + name: "Should Add And Remove Route 100.66.120.0/24", + prefix: netip.MustParsePrefix("100.66.120.0/24"), + shouldRouteToWireguard: true, + shouldBeRemoved: true, + }, + { + name: "Should Not Add Or Remove Route 127.0.0.1/32", + prefix: netip.MustParsePrefix("127.0.0.1/32"), + shouldRouteToWireguard: false, + shouldBeRemoved: false, + }, + } + + for n, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + peerPrivateKey, _ := wgtypes.GeneratePrivateKey() + newNet, err := stdnet.NewNet() + if err != nil { + t.Fatal(err) + } + wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil) + require.NoError(t, err, "should create testing WGIface interface") + defer wgInterface.Close() + + err = wgInterface.Create() + require.NoError(t, err, "should create testing wireguard interface") + _, _, err = setupRouting(nil, nil) + require.NoError(t, err) + t.Cleanup(func() { + assert.NoError(t, cleanupRouting()) + }) + + err = addVPNRoute(testCase.prefix, wgInterface.Name()) + require.NoError(t, err, "addVPNRoute should not return err") + + if testCase.shouldRouteToWireguard { + assertWGOutInterface(t, testCase.prefix, wgInterface, false) + } else { + assertWGOutInterface(t, testCase.prefix, wgInterface, true) + } + exists, err := existsInRouteTable(testCase.prefix) + require.NoError(t, err, "existsInRouteTable should not return err") + if exists && testCase.shouldRouteToWireguard { + err = removeVPNRoute(testCase.prefix, wgInterface.Name()) + require.NoError(t, err, "removeVPNRoute should not return err") + + prefixGateway, _, err := getNextHop(testCase.prefix.Addr()) + require.NoError(t, err, "getNextHop should not return err") + + internetGateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0")) + require.NoError(t, err) + + if testCase.shouldBeRemoved { + require.Equal(t, internetGateway, prefixGateway, "route should be pointing to default internet gateway") + } else { + require.NotEqual(t, internetGateway, prefixGateway, "route should be pointing to a different gateway than the internet gateway") + } + } + }) + } +} + +func TestGetNextHop(t *testing.T) { + gateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0")) + if err != nil { + t.Fatal("shouldn't return error when fetching the gateway: ", err) + } + if !gateway.IsValid() { + t.Fatal("should return a gateway") + } + addresses, err := net.InterfaceAddrs() + if err != nil { + t.Fatal("shouldn't return error when fetching interface addresses: ", err) + } + + var testingIP string + var testingPrefix netip.Prefix + for _, address := range addresses { + if address.Network() != "ip+net" { + continue + } + prefix := netip.MustParsePrefix(address.String()) + if !prefix.Addr().IsLoopback() && prefix.Addr().Is4() { + testingIP = prefix.Addr().String() + testingPrefix = prefix.Masked() + break + } + } + + localIP, _, err := getNextHop(testingPrefix.Addr()) + if err != nil { + t.Fatal("shouldn't return error: ", err) + } + if !localIP.IsValid() { + t.Fatal("should return a gateway for local network") + } + if localIP.String() == gateway.String() { + t.Fatal("local ip should not match with gateway IP") + } + if localIP.String() != testingIP { + t.Fatalf("local ip should match with testing IP: want %s got %s", testingIP, localIP.String()) + } +} + +func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) { + defaultGateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0")) + t.Log("defaultGateway: ", defaultGateway) + if err != nil { + t.Fatal("shouldn't return error when fetching the gateway: ", err) + } + testCases := []struct { + name string + prefix netip.Prefix + preExistingPrefix netip.Prefix + shouldAddRoute bool + }{ + { + name: "Should Add And Remove random Route", + prefix: netip.MustParsePrefix("99.99.99.99/32"), + shouldAddRoute: true, + }, + { + name: "Should Not Add Route if overlaps with default gateway", + prefix: netip.MustParsePrefix(defaultGateway.String() + "/31"), + shouldAddRoute: false, + }, + { + name: "Should Add Route if bigger network exists", + prefix: netip.MustParsePrefix("100.100.100.0/24"), + preExistingPrefix: netip.MustParsePrefix("100.100.0.0/16"), + shouldAddRoute: true, + }, + { + name: "Should Add Route if smaller network exists", + prefix: netip.MustParsePrefix("100.100.0.0/16"), + preExistingPrefix: netip.MustParsePrefix("100.100.100.0/24"), + shouldAddRoute: true, + }, + { + name: "Should Not Add Route if same network exists", + prefix: netip.MustParsePrefix("100.100.0.0/16"), + preExistingPrefix: netip.MustParsePrefix("100.100.0.0/16"), + shouldAddRoute: false, + }, + } + + for n, testCase := range testCases { + var buf bytes.Buffer + log.SetOutput(&buf) + defer func() { + log.SetOutput(os.Stderr) + }() + t.Run(testCase.name, func(t *testing.T) { + peerPrivateKey, _ := wgtypes.GeneratePrivateKey() + newNet, err := stdnet.NewNet() + if err != nil { + t.Fatal(err) + } + wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil) + require.NoError(t, err, "should create testing WGIface interface") + defer wgInterface.Close() + + err = wgInterface.Create() + require.NoError(t, err, "should create testing wireguard interface") + + _, _, err = setupRouting(nil, nil) + require.NoError(t, err) + t.Cleanup(func() { + assert.NoError(t, cleanupRouting()) + }) + + // Prepare the environment + if testCase.preExistingPrefix.IsValid() { + err := addVPNRoute(testCase.preExistingPrefix, wgInterface.Name()) + require.NoError(t, err, "should not return err when adding pre-existing route") + } + + // Add the route + err = addVPNRoute(testCase.prefix, wgInterface.Name()) + require.NoError(t, err, "should not return err when adding route") + + if testCase.shouldAddRoute { + // test if route exists after adding + ok, err := existsInRouteTable(testCase.prefix) + require.NoError(t, err, "should not return err") + require.True(t, ok, "route should exist") + + // remove route again if added + err = removeVPNRoute(testCase.prefix, wgInterface.Name()) + require.NoError(t, err, "should not return err") + } + + // route should either not have been added or should have been removed + // In case of already existing route, it should not have been added (but still exist) + ok, err := existsInRouteTable(testCase.prefix) + t.Log("Buffer string: ", buf.String()) + require.NoError(t, err, "should not return err") + + if !strings.Contains(buf.String(), "because it already exists") { + require.False(t, ok, "route should not exist") + } + }) + } +} + func TestIsSubRange(t *testing.T) { addresses, err := net.InterfaceAddrs() if err != nil { @@ -50,7 +284,8 @@ func TestIsSubRange(t *testing.T) { } func TestExistsInRouteTable(t *testing.T) { - require.NoError(t, setupRouting()) + _, _, err := setupRouting(nil, nil) + require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, cleanupRouting()) }) @@ -63,7 +298,8 @@ func TestExistsInRouteTable(t *testing.T) { var addressPrefixes []netip.Prefix for _, address := range addresses { p := netip.MustParsePrefix(address.String()) - if p.Addr().Is4() { + // Windows sometimes has hidden interface link local addrs that don't turn up on any interface + if p.Addr().Is4() && !p.Addr().IsLinkLocalUnicast() { addressPrefixes = append(addressPrefixes, p.Masked()) } } diff --git a/client/internal/routemanager/systemops_unix_test.go b/client/internal/routemanager/systemops_unix_test.go new file mode 100644 index 000000000..561eaeea4 --- /dev/null +++ b/client/internal/routemanager/systemops_unix_test.go @@ -0,0 +1,234 @@ +//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || netbsd || dragonfly + +package routemanager + +import ( + "fmt" + "net" + "strings" + "testing" + "time" + + "github.com/gopacket/gopacket" + "github.com/gopacket/gopacket/layers" + "github.com/gopacket/gopacket/pcap" + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + nbnet "github.com/netbirdio/netbird/util/net" +) + +type PacketExpectation struct { + SrcIP net.IP + DstIP net.IP + SrcPort int + DstPort int + UDP bool + TCP bool +} + +type testCase struct { + name string + destination string + expectedInterface string + dialer dialer + expectedPacket PacketExpectation +} + +var testCases = []testCase{ + { + name: "To external host without custom dialer via vpn", + destination: "192.0.2.1:53", + expectedInterface: expectedVPNint, + dialer: &net.Dialer{}, + expectedPacket: createPacketExpectation("100.64.0.1", 12345, "192.0.2.1", 53), + }, + { + name: "To external host with custom dialer via physical interface", + destination: "192.0.2.1:53", + expectedInterface: expectedExternalInt, + dialer: nbnet.NewDialer(), + expectedPacket: createPacketExpectation("192.168.0.1", 12345, "192.0.2.1", 53), + }, + + { + name: "To duplicate internal route with custom dialer via physical interface", + destination: "10.0.0.2:53", + expectedInterface: expectedInternalInt, + dialer: nbnet.NewDialer(), + expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53), + }, + { + name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence + destination: "10.0.0.2:53", + expectedInterface: expectedInternalInt, + dialer: &net.Dialer{}, + expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53), + }, + + { + name: "To unique vpn route with custom dialer via physical interface", + destination: "172.16.0.2:53", + expectedInterface: expectedExternalInt, + dialer: nbnet.NewDialer(), + expectedPacket: createPacketExpectation("192.168.0.1", 12345, "172.16.0.2", 53), + }, + { + name: "To unique vpn route without custom dialer via vpn", + destination: "172.16.0.2:53", + expectedInterface: expectedVPNint, + dialer: &net.Dialer{}, + expectedPacket: createPacketExpectation("100.64.0.1", 12345, "172.16.0.2", 53), + }, +} + +func TestRouting(t *testing.T) { + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + setupTestEnv(t) + + filter := createBPFFilter(tc.destination) + handle := startPacketCapture(t, tc.expectedInterface, filter) + + sendTestPacket(t, tc.destination, tc.expectedPacket.SrcPort, tc.dialer) + + packetSource := gopacket.NewPacketSource(handle, handle.LinkType()) + packet, err := packetSource.NextPacket() + require.NoError(t, err) + + verifyPacket(t, packet, tc.expectedPacket) + }) + } +} + +func createPacketExpectation(srcIP string, srcPort int, dstIP string, dstPort int) PacketExpectation { + return PacketExpectation{ + SrcIP: net.ParseIP(srcIP), + DstIP: net.ParseIP(dstIP), + SrcPort: srcPort, + DstPort: dstPort, + UDP: true, + } +} + +func startPacketCapture(t *testing.T, intf, filter string) *pcap.Handle { + t.Helper() + + inactive, err := pcap.NewInactiveHandle(intf) + require.NoError(t, err, "Failed to create inactive pcap handle") + defer inactive.CleanUp() + + err = inactive.SetSnapLen(1600) + require.NoError(t, err, "Failed to set snap length on inactive handle") + + err = inactive.SetTimeout(time.Second * 10) + require.NoError(t, err, "Failed to set timeout on inactive handle") + + err = inactive.SetImmediateMode(true) + require.NoError(t, err, "Failed to set immediate mode on inactive handle") + + handle, err := inactive.Activate() + require.NoError(t, err, "Failed to activate pcap handle") + t.Cleanup(handle.Close) + + err = handle.SetBPFFilter(filter) + require.NoError(t, err, "Failed to set BPF filter") + + return handle +} + +func sendTestPacket(t *testing.T, destination string, sourcePort int, dialer dialer) { + t.Helper() + + if dialer == nil { + dialer = &net.Dialer{} + } + + if sourcePort != 0 { + localUDPAddr := &net.UDPAddr{ + IP: net.IPv4zero, + Port: sourcePort, + } + switch dialer := dialer.(type) { + case *nbnet.Dialer: + dialer.LocalAddr = localUDPAddr + case *net.Dialer: + dialer.LocalAddr = localUDPAddr + default: + t.Fatal("Unsupported dialer type") + } + } + + msg := new(dns.Msg) + msg.Id = dns.Id() + msg.RecursionDesired = true + msg.Question = []dns.Question{ + {Name: "example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, + } + + conn, err := dialer.Dial("udp", destination) + require.NoError(t, err, "Failed to dial UDP") + defer conn.Close() + + data, err := msg.Pack() + require.NoError(t, err, "Failed to pack DNS message") + + _, err = conn.Write(data) + if err != nil { + if strings.Contains(err.Error(), "required key not available") { + t.Logf("Ignoring WireGuard key error: %v", err) + return + } + t.Fatalf("Failed to send DNS query: %v", err) + } +} + +func createBPFFilter(destination string) string { + host, port, err := net.SplitHostPort(destination) + if err != nil { + return fmt.Sprintf("udp and dst host %s and dst port %s", host, port) + } + return "udp" +} + +func verifyPacket(t *testing.T, packet gopacket.Packet, exp PacketExpectation) { + t.Helper() + + ipLayer := packet.Layer(layers.LayerTypeIPv4) + require.NotNil(t, ipLayer, "Expected IPv4 layer not found in packet") + + ip, ok := ipLayer.(*layers.IPv4) + require.True(t, ok, "Failed to cast to IPv4 layer") + + // Convert both source and destination IP addresses to 16-byte representation + expectedSrcIP := exp.SrcIP.To16() + actualSrcIP := ip.SrcIP.To16() + assert.Equal(t, expectedSrcIP, actualSrcIP, "Source IP mismatch") + + expectedDstIP := exp.DstIP.To16() + actualDstIP := ip.DstIP.To16() + assert.Equal(t, expectedDstIP, actualDstIP, "Destination IP mismatch") + + if exp.UDP { + udpLayer := packet.Layer(layers.LayerTypeUDP) + require.NotNil(t, udpLayer, "Expected UDP layer not found in packet") + + udp, ok := udpLayer.(*layers.UDP) + require.True(t, ok, "Failed to cast to UDP layer") + + assert.Equal(t, layers.UDPPort(exp.SrcPort), udp.SrcPort, "UDP source port mismatch") + assert.Equal(t, layers.UDPPort(exp.DstPort), udp.DstPort, "UDP destination port mismatch") + } + + if exp.TCP { + tcpLayer := packet.Layer(layers.LayerTypeTCP) + require.NotNil(t, tcpLayer, "Expected TCP layer not found in packet") + + tcp, ok := tcpLayer.(*layers.TCP) + require.True(t, ok, "Failed to cast to TCP layer") + + assert.Equal(t, layers.TCPPort(exp.SrcPort), tcp.SrcPort, "TCP source port mismatch") + assert.Equal(t, layers.TCPPort(exp.DstPort), tcp.DstPort, "TCP destination port mismatch") + } +} diff --git a/client/internal/routemanager/systemops_windows.go b/client/internal/routemanager/systemops_windows.go index c009ce66b..50fff0cd5 100644 --- a/client/internal/routemanager/systemops_windows.go +++ b/client/internal/routemanager/systemops_windows.go @@ -6,9 +6,14 @@ import ( "fmt" "net" "net/netip" + "os/exec" + "strings" log "github.com/sirupsen/logrus" "github.com/yusufpapurcu/wmi" + + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/iface" ) type Win32_IP4RouteTable struct { @@ -16,6 +21,16 @@ type Win32_IP4RouteTable struct { Mask string } +var routeManager *RouteManager + +func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { + return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface) +} + +func cleanupRouting() error { + return cleanupRoutingWithRouteManager(routeManager) +} + func getRoutesFromTable() ([]netip.Prefix, error) { var routes []Win32_IP4RouteTable query := "SELECT Destination, Mask FROM Win32_IP4RouteTable" @@ -48,10 +63,68 @@ func getRoutesFromTable() ([]netip.Prefix, error) { return prefixList, nil } -func addToRouteTableIfNoExists(prefix netip.Prefix, addr string, intf string) error { - return genericAddToRouteTableIfNoExists(prefix, addr, intf) +func addRoutePowershell(prefix netip.Prefix, nexthop netip.Addr, intf string) error { + destinationPrefix := prefix.String() + psCmd := "New-NetRoute" + + addressFamily := "IPv4" + if prefix.Addr().Is6() { + addressFamily = "IPv6" + } + + script := fmt.Sprintf( + `%s -AddressFamily "%s" -DestinationPrefix "%s" -InterfaceAlias "%s" -Confirm:$False -ErrorAction Stop`, + psCmd, addressFamily, destinationPrefix, intf, + ) + + if nexthop.IsValid() { + script = fmt.Sprintf( + `%s -NextHop "%s"`, script, nexthop, + ) + } + + out, err := exec.Command("powershell", "-Command", script).CombinedOutput() + log.Tracef("PowerShell add route: %s", string(out)) + + if err != nil { + return fmt.Errorf("PowerShell add route: %w", err) + } + + return nil } -func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string, intf string) error { - return genericRemoveFromRouteTableIfNonSystem(prefix, addr, intf) +func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, _ string) error { + args := []string{"add", prefix.String(), nexthop.Unmap().String()} + + out, err := exec.Command("route", args...).CombinedOutput() + + log.Tracef("route %s output: %s", strings.Join(args, " "), out) + if err != nil { + return fmt.Errorf("route add: %w", err) + } + + return nil +} + +func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { + // Powershell doesn't support adding routes without an interface but allows to add interface by name + if intf != "" { + return addRoutePowershell(prefix, nexthop, intf) + } + return addRouteCmd(prefix, nexthop, intf) +} + +func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, _ string) error { + args := []string{"delete", prefix.String()} + if nexthop.IsValid() { + args = append(args, nexthop.Unmap().String()) + } + + out, err := exec.Command("route", args...).CombinedOutput() + log.Tracef("route %s output: %s", strings.Join(args, " "), out) + + if err != nil { + return fmt.Errorf("remove route: %w", err) + } + return nil } diff --git a/client/internal/routemanager/systemops_windows_test.go b/client/internal/routemanager/systemops_windows_test.go new file mode 100644 index 000000000..a5e03b8d2 --- /dev/null +++ b/client/internal/routemanager/systemops_windows_test.go @@ -0,0 +1,289 @@ +package routemanager + +import ( + "context" + "encoding/json" + "fmt" + "net" + "os/exec" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + nbnet "github.com/netbirdio/netbird/util/net" +) + +var expectedExtInt = "Ethernet1" + +type RouteInfo struct { + NextHop string `json:"nexthop"` + InterfaceAlias string `json:"interfacealias"` + RouteMetric int `json:"routemetric"` +} + +type FindNetRouteOutput struct { + IPAddress string `json:"IPAddress"` + InterfaceIndex int `json:"InterfaceIndex"` + InterfaceAlias string `json:"InterfaceAlias"` + AddressFamily int `json:"AddressFamily"` + NextHop string `json:"NextHop"` + DestinationPrefix string `json:"DestinationPrefix"` +} + +type testCase struct { + name string + destination string + expectedSourceIP string + expectedDestPrefix string + expectedNextHop string + expectedInterface string + dialer dialer +} + +var expectedVPNint = "wgtest0" + +var testCases = []testCase{ + { + name: "To external host without custom dialer via vpn", + destination: "192.0.2.1:53", + expectedSourceIP: "100.64.0.1", + expectedDestPrefix: "128.0.0.0/1", + expectedNextHop: "0.0.0.0", + expectedInterface: "wgtest0", + dialer: &net.Dialer{}, + }, + { + name: "To external host with custom dialer via physical interface", + destination: "192.0.2.1:53", + expectedDestPrefix: "192.0.2.1/32", + expectedInterface: expectedExtInt, + dialer: nbnet.NewDialer(), + }, + + { + name: "To duplicate internal route with custom dialer via physical interface", + destination: "10.0.0.2:53", + expectedDestPrefix: "10.0.0.2/32", + expectedInterface: expectedExtInt, + dialer: nbnet.NewDialer(), + }, + { + name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence + destination: "10.0.0.2:53", + expectedSourceIP: "10.0.0.1", + expectedDestPrefix: "10.0.0.0/8", + expectedNextHop: "0.0.0.0", + expectedInterface: "Loopback Pseudo-Interface 1", + dialer: &net.Dialer{}, + }, + + { + name: "To unique vpn route with custom dialer via physical interface", + destination: "172.16.0.2:53", + expectedDestPrefix: "172.16.0.2/32", + expectedInterface: expectedExtInt, + dialer: nbnet.NewDialer(), + }, + { + name: "To unique vpn route without custom dialer via vpn", + destination: "172.16.0.2:53", + expectedSourceIP: "100.64.0.1", + expectedDestPrefix: "172.16.0.0/12", + expectedNextHop: "0.0.0.0", + expectedInterface: "wgtest0", + dialer: &net.Dialer{}, + }, + + { + name: "To more specific route without custom dialer via vpn interface", + destination: "10.10.0.2:53", + expectedSourceIP: "100.64.0.1", + expectedDestPrefix: "10.10.0.0/24", + expectedNextHop: "0.0.0.0", + expectedInterface: "wgtest0", + dialer: &net.Dialer{}, + }, + + { + name: "To more specific route (local) without custom dialer via physical interface", + destination: "127.0.10.2:53", + expectedSourceIP: "10.0.0.1", + expectedDestPrefix: "127.0.0.0/8", + expectedNextHop: "0.0.0.0", + expectedInterface: "Loopback Pseudo-Interface 1", + dialer: &net.Dialer{}, + }, +} + +func TestRouting(t *testing.T) { + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + setupTestEnv(t) + + route, err := fetchOriginalGateway() + require.NoError(t, err, "Failed to fetch original gateway") + ip, err := fetchInterfaceIP(route.InterfaceAlias) + require.NoError(t, err, "Failed to fetch interface IP") + + output := testRoute(t, tc.destination, tc.dialer) + if tc.expectedInterface == expectedExtInt { + verifyOutput(t, output, ip, tc.expectedDestPrefix, route.NextHop, route.InterfaceAlias) + } else { + verifyOutput(t, output, tc.expectedSourceIP, tc.expectedDestPrefix, tc.expectedNextHop, tc.expectedInterface) + } + }) + } +} + +// fetchInterfaceIP fetches the IPv4 address of the specified interface. +func fetchInterfaceIP(interfaceAlias string) (string, error) { + script := fmt.Sprintf(`Get-NetIPAddress -InterfaceAlias "%s" | Where-Object AddressFamily -eq 2 | Select-Object -ExpandProperty IPAddress`, interfaceAlias) + out, err := exec.Command("powershell", "-Command", script).Output() + if err != nil { + return "", fmt.Errorf("failed to execute Get-NetIPAddress: %w", err) + } + + ip := strings.TrimSpace(string(out)) + return ip, nil +} + +func testRoute(t *testing.T, destination string, dialer dialer) *FindNetRouteOutput { + t.Helper() + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + conn, err := dialer.DialContext(ctx, "udp", destination) + require.NoError(t, err, "Failed to dial destination") + defer func() { + err := conn.Close() + assert.NoError(t, err, "Failed to close connection") + }() + + host, _, err := net.SplitHostPort(destination) + require.NoError(t, err) + + script := fmt.Sprintf(`Find-NetRoute -RemoteIPAddress "%s" | Select-Object -Property IPAddress, InterfaceIndex, InterfaceAlias, AddressFamily, NextHop, DestinationPrefix | ConvertTo-Json`, host) + + out, err := exec.Command("powershell", "-Command", script).Output() + require.NoError(t, err, "Failed to execute Find-NetRoute") + + var outputs []FindNetRouteOutput + err = json.Unmarshal(out, &outputs) + require.NoError(t, err, "Failed to parse JSON outputs from Find-NetRoute") + + require.Greater(t, len(outputs), 0, "No route found for destination") + combinedOutput := combineOutputs(outputs) + + return combinedOutput +} + +func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) string { + t.Helper() + + ip, ipNet, err := net.ParseCIDR(ipAddressCIDR) + require.NoError(t, err) + subnetMaskSize, _ := ipNet.Mask.Size() + script := fmt.Sprintf(`New-NetIPAddress -InterfaceAlias "%s" -IPAddress "%s" -PrefixLength %d -PolicyStore ActiveStore -Confirm:$False`, interfaceName, ip.String(), subnetMaskSize) + _, err = exec.Command("powershell", "-Command", script).CombinedOutput() + require.NoError(t, err, "Failed to assign IP address to loopback adapter") + + // Wait for the IP address to be applied + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + err = waitForIPAddress(ctx, interfaceName, ip.String()) + require.NoError(t, err, "IP address not applied within timeout") + + t.Cleanup(func() { + script = fmt.Sprintf(`Remove-NetIPAddress -InterfaceAlias "%s" -IPAddress "%s" -Confirm:$False`, interfaceName, ip.String()) + _, err = exec.Command("powershell", "-Command", script).CombinedOutput() + require.NoError(t, err, "Failed to remove IP address from loopback adapter") + }) + + return interfaceName +} + +func fetchOriginalGateway() (*RouteInfo, error) { + cmd := exec.Command("powershell", "-Command", "Get-NetRoute -DestinationPrefix 0.0.0.0/0 | Select-Object NextHop, RouteMetric, InterfaceAlias | ConvertTo-Json") + output, err := cmd.CombinedOutput() + if err != nil { + return nil, fmt.Errorf("failed to execute Get-NetRoute: %w", err) + } + + var routeInfo RouteInfo + err = json.Unmarshal(output, &routeInfo) + if err != nil { + return nil, fmt.Errorf("failed to parse JSON output: %w", err) + } + + return &routeInfo, nil +} + +func verifyOutput(t *testing.T, output *FindNetRouteOutput, sourceIP, destPrefix, nextHop, intf string) { + t.Helper() + + assert.Equal(t, sourceIP, output.IPAddress, "Source IP mismatch") + assert.Equal(t, destPrefix, output.DestinationPrefix, "Destination prefix mismatch") + assert.Equal(t, nextHop, output.NextHop, "Next hop mismatch") + assert.Equal(t, intf, output.InterfaceAlias, "Interface mismatch") +} + +func waitForIPAddress(ctx context.Context, interfaceAlias, expectedIPAddress string) error { + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + out, err := exec.Command("powershell", "-Command", fmt.Sprintf(`Get-NetIPAddress -InterfaceAlias "%s" | Select-Object -ExpandProperty IPAddress`, interfaceAlias)).CombinedOutput() + if err != nil { + return err + } + + ipAddresses := strings.Split(strings.TrimSpace(string(out)), "\n") + for _, ip := range ipAddresses { + if strings.TrimSpace(ip) == expectedIPAddress { + return nil + } + } + } + } +} + +func combineOutputs(outputs []FindNetRouteOutput) *FindNetRouteOutput { + var combined FindNetRouteOutput + + for _, output := range outputs { + if output.IPAddress != "" { + combined.IPAddress = output.IPAddress + } + if output.InterfaceIndex != 0 { + combined.InterfaceIndex = output.InterfaceIndex + } + if output.InterfaceAlias != "" { + combined.InterfaceAlias = output.InterfaceAlias + } + if output.AddressFamily != 0 { + combined.AddressFamily = output.AddressFamily + } + if output.NextHop != "" { + combined.NextHop = output.NextHop + } + if output.DestinationPrefix != "" { + combined.DestinationPrefix = output.DestinationPrefix + } + } + + return &combined +} + +func setupDummyInterfacesAndRoutes(t *testing.T) { + t.Helper() + + createAndSetupDummyInterface(t, "Loopback Pseudo-Interface 1", "10.0.0.1/8") +} diff --git a/client/internal/routemanager/sytemops_test.go b/client/internal/routemanager/sytemops_test.go new file mode 100644 index 000000000..28a6502d2 --- /dev/null +++ b/client/internal/routemanager/sytemops_test.go @@ -0,0 +1,101 @@ +//go:build !android && !ios + +package routemanager + +import ( + "context" + "net" + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/netbirdio/netbird/client/internal/stdnet" + "github.com/netbirdio/netbird/iface" +) + +type dialer interface { + Dial(network, address string) (net.Conn, error) + DialContext(ctx context.Context, network, address string) (net.Conn, error) +} + +func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listenPort int) *iface.WGIface { + t.Helper() + + peerPrivateKey, err := wgtypes.GeneratePrivateKey() + require.NoError(t, err) + + newNet, err := stdnet.NewNet(nil) + require.NoError(t, err) + + wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil) + require.NoError(t, err, "should create testing WireGuard interface") + + err = wgInterface.Create() + require.NoError(t, err, "should create testing WireGuard interface") + + t.Cleanup(func() { + wgInterface.Close() + }) + + return wgInterface +} + +func setupTestEnv(t *testing.T) { + t.Helper() + + setupDummyInterfacesAndRoutes(t) + + wgIface := createWGInterface(t, expectedVPNint, "100.64.0.1/24", 51820) + t.Cleanup(func() { + assert.NoError(t, wgIface.Close()) + }) + + _, _, err := setupRouting(nil, wgIface) + require.NoError(t, err, "setupRouting should not return err") + t.Cleanup(func() { + assert.NoError(t, cleanupRouting()) + }) + + // default route exists in main table and vpn table + err = addVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name()) + require.NoError(t, err, "addVPNRoute should not return err") + t.Cleanup(func() { + err = removeVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name()) + assert.NoError(t, err, "removeVPNRoute should not return err") + }) + + // 10.0.0.0/8 route exists in main table and vpn table + err = addVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name()) + require.NoError(t, err, "addVPNRoute should not return err") + t.Cleanup(func() { + err = removeVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name()) + assert.NoError(t, err, "removeVPNRoute should not return err") + }) + + // 10.10.0.0/24 more specific route exists in vpn table + err = addVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name()) + require.NoError(t, err, "addVPNRoute should not return err") + t.Cleanup(func() { + err = removeVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name()) + assert.NoError(t, err, "removeVPNRoute should not return err") + }) + + // 127.0.10.0/24 more specific route exists in vpn table + err = addVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name()) + require.NoError(t, err, "addVPNRoute should not return err") + t.Cleanup(func() { + err = removeVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name()) + assert.NoError(t, err, "removeVPNRoute should not return err") + }) + + // unique route in vpn table + err = addVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name()) + require.NoError(t, err, "addVPNRoute should not return err") + t.Cleanup(func() { + err = removeVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name()) + assert.NoError(t, err, "removeVPNRoute should not return err") + }) +} diff --git a/client/internal/wgproxy/portlookup.go b/client/internal/wgproxy/portlookup.go index 6ede4b83f..6f3d33487 100644 --- a/client/internal/wgproxy/portlookup.go +++ b/client/internal/wgproxy/portlookup.go @@ -1,10 +1,8 @@ package wgproxy import ( - "context" "fmt" - - nbnet "github.com/netbirdio/netbird/util/net" + "net" ) const ( @@ -25,7 +23,7 @@ func (pl portLookup) searchFreePort() (int, error) { } func (pl portLookup) tryToBind(port int) error { - l, err := nbnet.NewListener().ListenPacket(context.Background(), "udp", fmt.Sprintf(":%d", port)) + l, err := net.ListenPacket("udp", fmt.Sprintf(":%d", port)) if err != nil { return err } diff --git a/client/internal/wgproxy/proxy_ebpf.go b/client/internal/wgproxy/proxy_ebpf.go index b91cd7b43..2235c5d2b 100644 --- a/client/internal/wgproxy/proxy_ebpf.go +++ b/client/internal/wgproxy/proxy_ebpf.go @@ -12,6 +12,7 @@ import ( "github.com/google/gopacket" "github.com/google/gopacket/layers" + "github.com/pion/transport/v3" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/internal/ebpf" @@ -29,7 +30,7 @@ type WGEBPFProxy struct { turnConnMutex sync.Mutex rawConn net.PacketConn - conn *net.UDPConn + conn transport.UDPConn } // NewWGEBPFProxy create new WGEBPFProxy instance @@ -67,7 +68,7 @@ func (p *WGEBPFProxy) Listen() error { IP: net.ParseIP("127.0.0.1"), } - p.conn, err = nbnet.ListenUDP("udp", &addr) + conn, err := nbnet.ListenUDP("udp", &addr) if err != nil { cErr := p.Free() if cErr != nil { @@ -75,6 +76,7 @@ func (p *WGEBPFProxy) Listen() error { } return err } + p.conn = conn go p.proxyToRemote() log.Infof("local wg proxy listening on: %d", wgPorxyPort) diff --git a/go.mod b/go.mod index 5566f8559..29a1570c8 100644 --- a/go.mod +++ b/go.mod @@ -53,7 +53,7 @@ require ( github.com/hashicorp/go-multierror v1.1.1 github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 github.com/hashicorp/go-version v1.6.0 - github.com/libp2p/go-netroute v0.2.0 + github.com/libp2p/go-netroute v0.2.1 github.com/magiconair/properties v1.8.5 github.com/mattn/go-sqlite3 v1.14.19 github.com/mdlayher/socket v0.4.1 diff --git a/go.sum b/go.sum index 6da405341..b488a42a4 100644 --- a/go.sum +++ b/go.sum @@ -345,8 +345,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.1.0/go.mod h1:+cyI34gQWZcE1eQU7NVgKkkzdXDQHr1dBMtdAPozLkw= -github.com/libp2p/go-netroute v0.2.0 h1:0FpsbsvuSnAhXFnCY0VLFbJOzaK0VnP0r1QT/o4nWRE= -github.com/libp2p/go-netroute v0.2.0/go.mod h1:Vio7LTzZ+6hoT4CMZi5/6CpY3Snzh2vgZhWgxMNwlQI= +github.com/libp2p/go-netroute v0.2.1 h1:V8kVrpD8GK0Riv15/7VN6RbUQ3URNZVosw7H2v9tksU= +github.com/libp2p/go-netroute v0.2.1/go.mod h1:hraioZr0fhBjG0ZRXJJ6Zj2IVEVNx6tDTFQfSmcq7mQ= github.com/lucor/goinfo v0.0.0-20210802170112-c078a2b0f08b/go.mod h1:PRq09yoB+Q2OJReAmwzKivcYyremnibWGbK7WfftHzc= github.com/magiconair/properties v1.8.5 h1:b6kJs+EmPFMYGkow9GiUyCyOvIwYetYJ3fSaWak/Gls= github.com/magiconair/properties v1.8.5/go.mod h1:y3VJvCyxH9uVvJTWEGAELF3aiYNyPKd5NZ3oSwXrF60= @@ -659,7 +659,6 @@ golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwY golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= -golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= @@ -746,7 +745,6 @@ golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210303074136-134d130e1a04/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210426080607-c94f62235c83/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/util/grpc/dialer_linux.go b/util/grpc/dialer.go similarity index 56% rename from util/grpc/dialer_linux.go rename to util/grpc/dialer.go index b29ee4b29..96b2bc32b 100644 --- a/util/grpc/dialer_linux.go +++ b/util/grpc/dialer.go @@ -1,11 +1,10 @@ -//go:build !android - package grpc import ( "context" "net" + log "github.com/sirupsen/logrus" "google.golang.org/grpc" nbnet "github.com/netbirdio/netbird/util/net" @@ -13,6 +12,11 @@ import ( func WithCustomDialer() grpc.DialOption { return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { - return nbnet.NewDialer().DialContext(ctx, "tcp", addr) + conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr) + if err != nil { + log.Errorf("Failed to dial: %s", err) + return nil, err + } + return conn, nil }) } diff --git a/util/grpc/dialer_generic.go b/util/grpc/dialer_generic.go deleted file mode 100644 index 1c2285b14..000000000 --- a/util/grpc/dialer_generic.go +++ /dev/null @@ -1,9 +0,0 @@ -//go:build !linux || android - -package grpc - -import "google.golang.org/grpc" - -func WithCustomDialer() grpc.DialOption { - return grpc.EmptyDialOption{} -} diff --git a/util/net/dialer.go b/util/net/dialer.go new file mode 100644 index 000000000..7b9bddbb5 --- /dev/null +++ b/util/net/dialer.go @@ -0,0 +1,64 @@ +package net + +import ( + "fmt" + "net" + + log "github.com/sirupsen/logrus" +) + +// Dialer extends the standard net.Dialer with the ability to execute hooks before +// and after connections. This can be used to bypass the VPN for connections using this dialer. +type Dialer struct { + *net.Dialer +} + +// NewDialer returns a customized net.Dialer with overridden Control method +func NewDialer() *Dialer { + dialer := &Dialer{ + Dialer: &net.Dialer{}, + } + dialer.init() + + return dialer +} + +func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) { + dialer := NewDialer() + dialer.LocalAddr = laddr + + conn, err := dialer.Dial(network, raddr.String()) + if err != nil { + return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err) + } + + udpConn, ok := conn.(*net.UDPConn) + if !ok { + if err := conn.Close(); err != nil { + log.Errorf("Failed to closeConn connection: %v", err) + } + return nil, fmt.Errorf("expected UDP connection, got different type") + } + + return udpConn, nil +} + +func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) { + dialer := NewDialer() + dialer.LocalAddr = laddr + + conn, err := dialer.Dial(network, raddr.String()) + if err != nil { + return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err) + } + + tcpConn, ok := conn.(*net.TCPConn) + if !ok { + if err := conn.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected TCP connection, got different type") + } + + return tcpConn, nil +} diff --git a/util/net/dialer_generic.go b/util/net/dialer_generic.go index a3c3ad67c..2e102da50 100644 --- a/util/net/dialer_generic.go +++ b/util/net/dialer_generic.go @@ -1,19 +1,123 @@ -//go:build !linux || android +//go:build !android && !ios package net import ( + "context" + "fmt" "net" + "sync" + + "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" ) -func NewDialer() *net.Dialer { - return &net.Dialer{} +type DialerDialHookFunc func(ctx context.Context, connID ConnectionID, resolvedAddresses []net.IPAddr) error +type DialerCloseHookFunc func(connID ConnectionID, conn *net.Conn) error + +var ( + dialerDialHooksMutex sync.RWMutex + dialerDialHooks []DialerDialHookFunc + dialerCloseHooksMutex sync.RWMutex + dialerCloseHooks []DialerCloseHookFunc +) + +// AddDialerHook allows adding a new hook to be executed before dialing. +func AddDialerHook(hook DialerDialHookFunc) { + dialerDialHooksMutex.Lock() + defer dialerDialHooksMutex.Unlock() + dialerDialHooks = append(dialerDialHooks, hook) } -func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) { - return net.DialUDP(network, laddr, raddr) +// AddDialerCloseHook allows adding a new hook to be executed on connection close. +func AddDialerCloseHook(hook DialerCloseHookFunc) { + dialerCloseHooksMutex.Lock() + defer dialerCloseHooksMutex.Unlock() + dialerCloseHooks = append(dialerCloseHooks, hook) } -func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) { - return net.DialTCP(network, laddr, raddr) +// RemoveDialerHook removes all dialer hooks. +func RemoveDialerHooks() { + dialerDialHooksMutex.Lock() + defer dialerDialHooksMutex.Unlock() + dialerDialHooks = nil + + dialerCloseHooksMutex.Lock() + defer dialerCloseHooksMutex.Unlock() + dialerCloseHooks = nil +} + +// DialContext wraps the net.Dialer's DialContext method to use the custom connection +func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + var resolver *net.Resolver + if d.Resolver != nil { + resolver = d.Resolver + } + + connID := GenerateConnID() + if dialerDialHooks != nil { + if err := calliDialerHooks(ctx, connID, address, resolver); err != nil { + log.Errorf("Failed to call dialer hooks: %v", err) + } + } + + conn, err := d.Dialer.DialContext(ctx, network, address) + if err != nil { + return nil, fmt.Errorf("dial: %w", err) + } + + // Wrap the connection in Conn to handle Close with hooks + return &Conn{Conn: conn, ID: connID}, nil +} + +// Dial wraps the net.Dialer's Dial method to use the custom connection +func (d *Dialer) Dial(network, address string) (net.Conn, error) { + return d.DialContext(context.Background(), network, address) +} + +// Conn wraps a net.Conn to override the Close method +type Conn struct { + net.Conn + ID ConnectionID +} + +// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection +func (c *Conn) Close() error { + err := c.Conn.Close() + + dialerCloseHooksMutex.RLock() + defer dialerCloseHooksMutex.RUnlock() + + for _, hook := range dialerCloseHooks { + if err := hook(c.ID, &c.Conn); err != nil { + log.Errorf("Error executing dialer close hook: %v", err) + } + } + + return err +} + +func calliDialerHooks(ctx context.Context, connID ConnectionID, address string, resolver *net.Resolver) error { + host, _, err := net.SplitHostPort(address) + if err != nil { + return fmt.Errorf("split host and port: %w", err) + } + ips, err := resolver.LookupIPAddr(ctx, host) + if err != nil { + return fmt.Errorf("failed to resolve address %s: %w", address, err) + } + + log.Debugf("Dialer resolved IPs for %s: %v", address, ips) + + var result *multierror.Error + + dialerDialHooksMutex.RLock() + defer dialerDialHooksMutex.RUnlock() + for _, hook := range dialerDialHooks { + if err := hook(ctx, connID, ips); err != nil { + result = multierror.Append(result, fmt.Errorf("executing dial hook: %w", err)) + } + } + + return result.ErrorOrNil() } diff --git a/util/net/dialer_linux.go b/util/net/dialer_linux.go index d559490c5..aed5c59a3 100644 --- a/util/net/dialer_linux.go +++ b/util/net/dialer_linux.go @@ -2,59 +2,11 @@ package net -import ( - "context" - "fmt" - "net" - "syscall" +import "syscall" - log "github.com/sirupsen/logrus" -) - -func NewDialer() *net.Dialer { - return &net.Dialer{ - Control: func(network, address string, c syscall.RawConn) error { - return SetRawSocketMark(c) - }, +// init configures the net.Dialer Control function to set the fwmark on the socket +func (d *Dialer) init() { + d.Dialer.Control = func(_, _ string, c syscall.RawConn) error { + return SetRawSocketMark(c) } } - -func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) { - dialer := NewDialer() - dialer.LocalAddr = laddr - - conn, err := dialer.DialContext(context.Background(), network, raddr.String()) - if err != nil { - return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err) - } - - udpConn, ok := conn.(*net.UDPConn) - if !ok { - if err := conn.Close(); err != nil { - log.Errorf("Failed to close connection: %v", err) - } - return nil, fmt.Errorf("expected UDP connection, got different type") - } - - return udpConn, nil -} - -func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) { - dialer := NewDialer() - dialer.LocalAddr = laddr - - conn, err := dialer.DialContext(context.Background(), network, raddr.String()) - if err != nil { - return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err) - } - - tcpConn, ok := conn.(*net.TCPConn) - if !ok { - if err := conn.Close(); err != nil { - log.Errorf("Failed to close connection: %v", err) - } - return nil, fmt.Errorf("expected TCP connection, got different type") - } - - return tcpConn, nil -} diff --git a/util/net/dialer_nonlinux.go b/util/net/dialer_nonlinux.go new file mode 100644 index 000000000..3254e6d06 --- /dev/null +++ b/util/net/dialer_nonlinux.go @@ -0,0 +1,6 @@ +//go:build !linux || android + +package net + +func (d *Dialer) init() { +} diff --git a/util/net/listener.go b/util/net/listener.go new file mode 100644 index 000000000..f4d769f58 --- /dev/null +++ b/util/net/listener.go @@ -0,0 +1,21 @@ +package net + +import ( + "net" +) + +// ListenerConfig extends the standard net.ListenConfig with the ability to execute hooks before +// responding via the socket and after closing. This can be used to bypass the VPN for listeners. +type ListenerConfig struct { + *net.ListenConfig +} + +// NewListener creates a new ListenerConfig instance. +func NewListener() *ListenerConfig { + listener := &ListenerConfig{ + ListenConfig: &net.ListenConfig{}, + } + listener.init() + + return listener +} diff --git a/util/net/listener_generic.go b/util/net/listener_generic.go index 241c744e5..ae412415f 100644 --- a/util/net/listener_generic.go +++ b/util/net/listener_generic.go @@ -1,13 +1,154 @@ -//go:build !linux || android +//go:build !android && !ios package net -import "net" +import ( + "context" + "fmt" + "net" + "sync" -func NewListener() *net.ListenConfig { - return &net.ListenConfig{} + log "github.com/sirupsen/logrus" +) + +// ListenerWriteHookFunc defines the function signature for write hooks for PacketConn. +type ListenerWriteHookFunc func(connID ConnectionID, ip *net.IPAddr, data []byte) error + +// ListenerCloseHookFunc defines the function signature for close hooks for PacketConn. +type ListenerCloseHookFunc func(connID ConnectionID, conn net.PacketConn) error + +var ( + listenerWriteHooksMutex sync.RWMutex + listenerWriteHooks []ListenerWriteHookFunc + listenerCloseHooksMutex sync.RWMutex + listenerCloseHooks []ListenerCloseHookFunc +) + +// AddListenerWriteHook allows adding a new write hook to be executed before a UDP packet is sent. +func AddListenerWriteHook(hook ListenerWriteHookFunc) { + listenerWriteHooksMutex.Lock() + defer listenerWriteHooksMutex.Unlock() + listenerWriteHooks = append(listenerWriteHooks, hook) } -func ListenUDP(network string, locAddr *net.UDPAddr) (*net.UDPConn, error) { - return net.ListenUDP(network, locAddr) +// AddListenerCloseHook allows adding a new hook to be executed upon closing a UDP connection. +func AddListenerCloseHook(hook ListenerCloseHookFunc) { + listenerCloseHooksMutex.Lock() + defer listenerCloseHooksMutex.Unlock() + listenerCloseHooks = append(listenerCloseHooks, hook) +} + +// RemoveListenerHooks removes all dialer hooks. +func RemoveListenerHooks() { + listenerWriteHooksMutex.Lock() + defer listenerWriteHooksMutex.Unlock() + listenerWriteHooks = nil + + listenerCloseHooksMutex.Lock() + defer listenerCloseHooksMutex.Unlock() + listenerCloseHooks = nil +} + +// ListenPacket listens on the network address and returns a PacketConn +// which includes support for write hooks. +func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) { + pc, err := l.ListenConfig.ListenPacket(ctx, network, address) + if err != nil { + return nil, fmt.Errorf("listen packet: %w", err) + } + connID := GenerateConnID() + return &PacketConn{PacketConn: pc, ID: connID, seenAddrs: &sync.Map{}}, nil +} + +// PacketConn wraps net.PacketConn to override its WriteTo and Close methods to include hook functionality. +type PacketConn struct { + net.PacketConn + ID ConnectionID + seenAddrs *sync.Map +} + +// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand. +func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + callWriteHooks(c.ID, c.seenAddrs, b, addr) + return c.PacketConn.WriteTo(b, addr) +} + +// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection. +func (c *PacketConn) Close() error { + c.seenAddrs = &sync.Map{} + return closeConn(c.ID, c.PacketConn) +} + +// UDPConn wraps net.UDPConn to override its WriteTo and Close methods to include hook functionality. +type UDPConn struct { + *net.UDPConn + ID ConnectionID + seenAddrs *sync.Map +} + +// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand. +func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + callWriteHooks(c.ID, c.seenAddrs, b, addr) + return c.UDPConn.WriteTo(b, addr) +} + +// Close overrides the net.UDPConn Close method to execute all registered hooks before closing the connection. +func (c *UDPConn) Close() error { + c.seenAddrs = &sync.Map{} + return closeConn(c.ID, c.UDPConn) +} + +func callWriteHooks(id ConnectionID, seenAddrs *sync.Map, b []byte, addr net.Addr) { + // Lookup the address in the seenAddrs map to avoid calling the hooks for every write + if _, loaded := seenAddrs.LoadOrStore(addr.String(), true); !loaded { + ipStr, _, splitErr := net.SplitHostPort(addr.String()) + if splitErr != nil { + log.Errorf("Error splitting IP address and port: %v", splitErr) + return + } + + ip, err := net.ResolveIPAddr("ip", ipStr) + if err != nil { + log.Errorf("Error resolving IP address: %v", err) + return + } + log.Debugf("Listener resolved IP for %s: %s", addr, ip) + + func() { + listenerWriteHooksMutex.RLock() + defer listenerWriteHooksMutex.RUnlock() + + for _, hook := range listenerWriteHooks { + if err := hook(id, ip, b); err != nil { + log.Errorf("Error executing listener write hook: %v", err) + } + } + }() + } +} + +func closeConn(id ConnectionID, conn net.PacketConn) error { + err := conn.Close() + + listenerCloseHooksMutex.RLock() + defer listenerCloseHooksMutex.RUnlock() + + for _, hook := range listenerCloseHooks { + if err := hook(id, conn); err != nil { + log.Errorf("Error executing listener close hook: %v", err) + } + } + + return err +} + +// ListenUDP listens on the network address and returns a transport.UDPConn +// which includes support for write and close hooks. +func ListenUDP(network string, laddr *net.UDPAddr) (*UDPConn, error) { + udpConn, err := net.ListenUDP(network, laddr) + if err != nil { + return nil, fmt.Errorf("listen UDP: %w", err) + } + connID := GenerateConnID() + return &UDPConn{UDPConn: udpConn, ID: connID, seenAddrs: &sync.Map{}}, nil } diff --git a/util/net/listener_linux.go b/util/net/listener_linux.go index 7b9bda97c..8d332160a 100644 --- a/util/net/listener_linux.go +++ b/util/net/listener_linux.go @@ -3,28 +3,12 @@ package net import ( - "context" - "fmt" - "net" "syscall" ) -func NewListener() *net.ListenConfig { - return &net.ListenConfig{ - Control: func(network, address string, c syscall.RawConn) error { - return SetRawSocketMark(c) - }, +// init configures the net.ListenerConfig Control function to set the fwmark on the socket +func (l *ListenerConfig) init() { + l.ListenConfig.Control = func(_, _ string, c syscall.RawConn) error { + return SetRawSocketMark(c) } } - -func ListenUDP(network string, laddr *net.UDPAddr) (*net.UDPConn, error) { - pc, err := NewListener().ListenPacket(context.Background(), network, laddr.String()) - if err != nil { - return nil, fmt.Errorf("listening on %s:%s with fwmark: %w", network, laddr, err) - } - udpConn, ok := pc.(*net.UDPConn) - if !ok { - return nil, fmt.Errorf("packetConn is not a *net.UDPConn") - } - return udpConn, nil -} diff --git a/util/net/listener_mobile.go b/util/net/listener_mobile.go new file mode 100644 index 000000000..0dbbb360b --- /dev/null +++ b/util/net/listener_mobile.go @@ -0,0 +1,11 @@ +//go:build android || ios + +package net + +import ( + "net" +) + +func ListenUDP(network string, laddr *net.UDPAddr) (*net.UDPConn, error) { + return net.ListenUDP(network, laddr) +} diff --git a/util/net/listener_nonlinux.go b/util/net/listener_nonlinux.go new file mode 100644 index 000000000..fb6eadaaa --- /dev/null +++ b/util/net/listener_nonlinux.go @@ -0,0 +1,6 @@ +//go:build !linux || android + +package net + +func (l *ListenerConfig) init() { +} diff --git a/util/net/net.go b/util/net/net.go index 5714e5229..9ea7ae803 100644 --- a/util/net/net.go +++ b/util/net/net.go @@ -1,6 +1,17 @@ package net +import "github.com/google/uuid" + const ( // NetbirdFwmark is the fwmark value used by Netbird via wireguard NetbirdFwmark = 0x1BD00 ) + +// ConnectionID provides a globally unique identifier for network connections. +// It's used to track connections throughout their lifecycle so the close hook can correlate with the dial hook. +type ConnectionID string + +// GenerateConnID generates a unique identifier for each connection. +func GenerateConnID() ConnectionID { + return ConnectionID(uuid.NewString()) +} From bb0d5c5bafebe63653018387c8cdfd712ec65ba0 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Wed, 3 Apr 2024 18:04:22 +0200 Subject: [PATCH 14/28] Linux legacy routing (#1774) * Add Linux legacy routing if ip rule functionality is not available * Ignore exclusion route errors if host has no route * Exclude iOS from route manager * Also retrieve IPv6 routes * Ignore loopback addresses not being in the main table * Ignore "not supported" errors on cleanup * Fix regression in ListenUDP not using fwmarks --- client/internal/routemanager/routemanager.go | 6 +- client/internal/routemanager/systemops.go | 410 ++++++++++++++++++ .../internal/routemanager/systemops_linux.go | 138 ++++-- .../routemanager/systemops_linux_test.go | 2 - .../routemanager/systemops_nonlinux.go | 397 +---------------- ...ops_nonlinux_test.go => systemops_test.go} | 154 +++++-- client/internal/routemanager/sytemops_test.go | 101 ----- util/net/dialer.go | 2 +- util/net/listener_generic.go | 15 +- 9 files changed, 656 insertions(+), 569 deletions(-) create mode 100644 client/internal/routemanager/systemops.go rename client/internal/routemanager/{systemops_nonlinux_test.go => systemops_test.go} (70%) delete mode 100644 client/internal/routemanager/sytemops_test.go diff --git a/client/internal/routemanager/routemanager.go b/client/internal/routemanager/routemanager.go index fe8d7b4ef..39b55f105 100644 --- a/client/internal/routemanager/routemanager.go +++ b/client/internal/routemanager/routemanager.go @@ -1,8 +1,9 @@ -//go:build !android +//go:build !android && !ios package routemanager import ( + "errors" "fmt" "net/netip" "sync" @@ -53,6 +54,9 @@ func (rm *RouteManager) AddRouteRef(connID nbnet.ConnectionID, prefix netip.Pref if ref.count == 0 { log.Debugf("Adding route for prefix %s", prefix) nexthop, intf, err := rm.addRoute(prefix) + if errors.Is(err, errRouteNotFound) { + return nil + } if err != nil { return fmt.Errorf("failed to add route for prefix %s: %w", prefix, err) } diff --git a/client/internal/routemanager/systemops.go b/client/internal/routemanager/systemops.go new file mode 100644 index 000000000..c6f3376e0 --- /dev/null +++ b/client/internal/routemanager/systemops.go @@ -0,0 +1,410 @@ +//go:build !android && !ios + +package routemanager + +import ( + "context" + "errors" + "fmt" + "net" + "net/netip" + + "github.com/hashicorp/go-multierror" + "github.com/libp2p/go-netroute" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/iface" + nbnet "github.com/netbirdio/netbird/util/net" +) + +var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1) +var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1) +var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1) +var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1) + +var errRouteNotFound = fmt.Errorf("route not found") + +// TODO: fix: for default our wg address now appears as the default gw +func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error { + addr := netip.IPv4Unspecified() + if prefix.Addr().Is6() { + addr = netip.IPv6Unspecified() + } + + defaultGateway, _, err := getNextHop(addr) + if err != nil && !errors.Is(err, errRouteNotFound) { + return fmt.Errorf("get existing route gateway: %s", err) + } + + if !prefix.Contains(defaultGateway) { + log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", defaultGateway, prefix) + return nil + } + + gatewayPrefix := netip.PrefixFrom(defaultGateway, 32) + if defaultGateway.Is6() { + gatewayPrefix = netip.PrefixFrom(defaultGateway, 128) + } + + ok, err := existsInRouteTable(gatewayPrefix) + if err != nil { + return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err) + } + + if ok { + log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix) + return nil + } + + var exitIntf string + gatewayHop, intf, err := getNextHop(defaultGateway) + if err != nil && !errors.Is(err, errRouteNotFound) { + return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err) + } + if intf != nil { + exitIntf = intf.Name + } + + log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop) + return addToRouteTable(gatewayPrefix, gatewayHop, exitIntf) +} + +func getNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) { + r, err := netroute.New() + if err != nil { + return netip.Addr{}, nil, fmt.Errorf("new netroute: %w", err) + } + intf, gateway, preferredSrc, err := r.Route(ip.AsSlice()) + if err != nil { + log.Warnf("Failed to get route for %s: %v", ip, err) + return netip.Addr{}, nil, errRouteNotFound + } + + log.Debugf("Route for %s: interface %v, nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc) + if gateway == nil { + if preferredSrc == nil { + return netip.Addr{}, nil, errRouteNotFound + } + log.Debugf("No next hop found for ip %s, using preferred source %s", ip, preferredSrc) + + addr, ok := netip.AddrFromSlice(preferredSrc) + if !ok { + return netip.Addr{}, nil, fmt.Errorf("failed to parse IP address: %s", preferredSrc) + } + return addr.Unmap(), intf, nil + } + + addr, ok := netip.AddrFromSlice(gateway) + if !ok { + return netip.Addr{}, nil, fmt.Errorf("failed to parse IP address: %s", gateway) + } + + return addr.Unmap(), intf, nil +} + +func existsInRouteTable(prefix netip.Prefix) (bool, error) { + routes, err := getRoutesFromTable() + if err != nil { + return false, fmt.Errorf("get routes from table: %w", err) + } + for _, tableRoute := range routes { + if tableRoute == prefix { + return true, nil + } + } + return false, nil +} + +func isSubRange(prefix netip.Prefix) (bool, error) { + routes, err := getRoutesFromTable() + if err != nil { + return false, fmt.Errorf("get routes from table: %w", err) + } + for _, tableRoute := range routes { + if tableRoute.Bits() > minRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() { + return true, nil + } + } + return false, nil +} + +// getRouteToNonVPNIntf returns the next hop and interface for the given prefix. +// If the next hop or interface is pointing to the VPN interface, it will return an error +func addRouteToNonVPNIntf( + prefix netip.Prefix, + vpnIntf *iface.WGIface, + initialNextHop netip.Addr, + initialIntf *net.Interface, +) (netip.Addr, string, error) { + addr := prefix.Addr() + switch { + case addr.IsLoopback(): + return netip.Addr{}, "", fmt.Errorf("adding route for loopback address %s is not allowed", prefix) + case addr.IsLinkLocalUnicast(): + return netip.Addr{}, "", fmt.Errorf("adding route for link-local unicast address %s is not allowed", prefix) + case addr.IsLinkLocalMulticast(): + return netip.Addr{}, "", fmt.Errorf("adding route for link-local multicast address %s is not allowed", prefix) + case addr.IsInterfaceLocalMulticast(): + return netip.Addr{}, "", fmt.Errorf("adding route for interface-local multicast address %s is not allowed", prefix) + case addr.IsUnspecified(): + return netip.Addr{}, "", fmt.Errorf("adding route for unspecified address %s is not allowed", prefix) + case addr.IsMulticast(): + return netip.Addr{}, "", fmt.Errorf("adding route for multicast address %s is not allowed", prefix) + } + + // Determine the exit interface and next hop for the prefix, so we can add a specific route + nexthop, intf, err := getNextHop(addr) + if err != nil { + return netip.Addr{}, "", fmt.Errorf("get next hop: %w", err) + } + + log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop, prefix, intf) + exitNextHop := nexthop + var exitIntf string + if intf != nil { + exitIntf = intf.Name + } + + vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP) + if !ok { + return netip.Addr{}, "", fmt.Errorf("failed to convert vpn address to netip.Addr") + } + + // if next hop is the VPN address or the interface is the VPN interface, we should use the initial values + if exitNextHop == vpnAddr || exitIntf == vpnIntf.Name() { + log.Debugf("Route for prefix %s is pointing to the VPN interface", prefix) + exitNextHop = initialNextHop + if initialIntf != nil { + exitIntf = initialIntf.Name + } + } + + log.Debugf("Adding a new route for prefix %s with next hop %s", prefix, exitNextHop) + if err := addToRouteTable(prefix, exitNextHop, exitIntf); err != nil { + return netip.Addr{}, "", fmt.Errorf("add route to table: %w", err) + } + + return exitNextHop, exitIntf, nil +} + +// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix +// in two /1 prefixes to avoid replacing the existing default route +func genericAddVPNRoute(prefix netip.Prefix, intf string) error { + if prefix == defaultv4 { + if err := addToRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil { + return err + } + if err := addToRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil { + if err2 := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err2 != nil { + log.Warnf("Failed to rollback route addition: %s", err2) + } + return err + } + + // TODO: remove once IPv6 is supported on the interface + if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { + return fmt.Errorf("add unreachable route split 1: %w", err) + } + if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { + if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil { + log.Warnf("Failed to rollback route addition: %s", err2) + } + return fmt.Errorf("add unreachable route split 2: %w", err) + } + + return nil + } else if prefix == defaultv6 { + if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { + return fmt.Errorf("add unreachable route split 1: %w", err) + } + if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { + if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil { + log.Warnf("Failed to rollback route addition: %s", err2) + } + return fmt.Errorf("add unreachable route split 2: %w", err) + } + + return nil + } + + return addNonExistingRoute(prefix, intf) +} + +// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table +func addNonExistingRoute(prefix netip.Prefix, intf string) error { + ok, err := existsInRouteTable(prefix) + if err != nil { + return fmt.Errorf("exists in route table: %w", err) + } + if ok { + log.Warnf("Skipping adding a new route for network %s because it already exists", prefix) + return nil + } + + ok, err = isSubRange(prefix) + if err != nil { + return fmt.Errorf("sub range: %w", err) + } + + if ok { + err := addRouteForCurrentDefaultGateway(prefix) + if err != nil { + log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err) + } + } + + return addToRouteTable(prefix, netip.Addr{}, intf) +} + +// genericRemoveVPNRoute removes the route from the vpn interface. If a default prefix is given, +// it will remove the split /1 prefixes +func genericRemoveVPNRoute(prefix netip.Prefix, intf string) error { + if prefix == defaultv4 { + var result *multierror.Error + if err := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil { + result = multierror.Append(result, err) + } + if err := removeFromRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil { + result = multierror.Append(result, err) + } + + // TODO: remove once IPv6 is supported on the interface + if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { + result = multierror.Append(result, err) + } + if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { + result = multierror.Append(result, err) + } + + return result.ErrorOrNil() + } else if prefix == defaultv6 { + var result *multierror.Error + if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { + result = multierror.Append(result, err) + } + if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { + result = multierror.Append(result, err) + } + + return result.ErrorOrNil() + } + + return removeFromRouteTable(prefix, netip.Addr{}, intf) +} + +func getPrefixFromIP(ip net.IP) (*netip.Prefix, error) { + addr, ok := netip.AddrFromSlice(ip) + if !ok { + return nil, fmt.Errorf("parse IP address: %s", ip) + } + addr = addr.Unmap() + + var prefixLength int + switch { + case addr.Is4(): + prefixLength = 32 + case addr.Is6(): + prefixLength = 128 + default: + return nil, fmt.Errorf("invalid IP address: %s", addr) + } + + prefix := netip.PrefixFrom(addr, prefixLength) + return &prefix, nil +} + +func setupRoutingWithRouteManager(routeManager **RouteManager, initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { + initialNextHopV4, initialIntfV4, err := getNextHop(netip.IPv4Unspecified()) + if err != nil && !errors.Is(err, errRouteNotFound) { + log.Errorf("Unable to get initial v4 default next hop: %v", err) + } + initialNextHopV6, initialIntfV6, err := getNextHop(netip.IPv6Unspecified()) + if err != nil && !errors.Is(err, errRouteNotFound) { + log.Errorf("Unable to get initial v6 default next hop: %v", err) + } + + *routeManager = NewRouteManager( + func(prefix netip.Prefix) (netip.Addr, string, error) { + addr := prefix.Addr() + nexthop, intf := initialNextHopV4, initialIntfV4 + if addr.Is6() { + nexthop, intf = initialNextHopV6, initialIntfV6 + } + return addRouteToNonVPNIntf(prefix, wgIface, nexthop, intf) + }, + removeFromRouteTable, + ) + + return setupHooks(*routeManager, initAddresses) +} + +func cleanupRoutingWithRouteManager(routeManager *RouteManager) error { + if routeManager == nil { + return nil + } + + // TODO: Remove hooks selectively + nbnet.RemoveDialerHooks() + nbnet.RemoveListenerHooks() + + if err := routeManager.Flush(); err != nil { + return fmt.Errorf("flush route manager: %w", err) + } + + return nil +} + +func setupHooks(routeManager *RouteManager, initAddresses []net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { + beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error { + prefix, err := getPrefixFromIP(ip) + if err != nil { + return fmt.Errorf("convert ip to prefix: %w", err) + } + + if err := routeManager.AddRouteRef(connID, *prefix); err != nil { + return fmt.Errorf("adding route reference: %v", err) + } + + return nil + } + afterHook := func(connID nbnet.ConnectionID) error { + if err := routeManager.RemoveRouteRef(connID); err != nil { + return fmt.Errorf("remove route reference: %w", err) + } + + return nil + } + + for _, ip := range initAddresses { + if err := beforeHook("init", ip); err != nil { + log.Errorf("Failed to add route reference: %v", err) + } + } + + nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error { + if ctx.Err() != nil { + return ctx.Err() + } + + var result *multierror.Error + for _, ip := range resolvedIPs { + result = multierror.Append(result, beforeHook(connID, ip.IP)) + } + return result.ErrorOrNil() + }) + + nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error { + return afterHook(connID) + }) + + nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error { + return beforeHook(connID, ip.IP) + }) + + nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error { + return afterHook(connID) + }) + + return beforeHook, afterHook, nil +} diff --git a/client/internal/routemanager/systemops_linux.go b/client/internal/routemanager/systemops_linux.go index d21a3bfbf..44691f0d6 100644 --- a/client/internal/routemanager/systemops_linux.go +++ b/client/internal/routemanager/systemops_linux.go @@ -35,6 +35,9 @@ const ( var ErrTableIDExists = errors.New("ID exists with different name") +var routeManager = &RouteManager{} +var isLegacy = os.Getenv("NB_USE_LEGACY_ROUTING") == "true" + type ruleParams struct { fwmark int tableID int @@ -66,7 +69,12 @@ func getSetupRules() []ruleParams { // enabling VPN connectivity. // // The rules are inserted in reverse order, as rules are added from the bottom up in the rule list. -func setupRouting([]net.IP, *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) { +func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) { + if isLegacy { + log.Infof("Using legacy routing setup") + return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface) + } + if err = addRoutingTableName(); err != nil { log.Errorf("Error adding routing table name: %v", err) } @@ -82,6 +90,11 @@ func setupRouting([]net.IP, *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ pee rules := getSetupRules() for _, rule := range rules { if err := addRule(rule); err != nil { + if errors.Is(err, syscall.EOPNOTSUPP) { + log.Warnf("Rule operations are not supported, falling back to the legacy routing setup") + isLegacy = true + return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface) + } return nil, nil, fmt.Errorf("%s: %w", rule.description, err) } } @@ -93,6 +106,10 @@ func setupRouting([]net.IP, *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ pee // It systematically removes the three rules and any associated routing table entries to ensure a clean state. // The function uses error aggregation to report any errors encountered during the cleanup process. func cleanupRouting() error { + if isLegacy { + return cleanupRoutingWithRouteManager(routeManager) + } + var result *multierror.Error if err := flushRoutes(NetbirdVPNTableID, netlink.FAMILY_V4); err != nil { @@ -104,7 +121,7 @@ func cleanupRouting() error { rules := getSetupRules() for _, rule := range rules { - if err := removeAllRules(rule); err != nil { + if err := removeAllRules(rule); err != nil && !errors.Is(err, syscall.EOPNOTSUPP) { result = multierror.Append(result, fmt.Errorf("%s: %w", rule.description, err)) } } @@ -112,49 +129,104 @@ func cleanupRouting() error { return result.ErrorOrNil() } +func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { + return addRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN) +} + +func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { + return removeRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN) +} + func addVPNRoute(prefix netip.Prefix, intf string) error { - // No need to check if routes exist as main table takes precedence over the VPN table via Rule 2 + if isLegacy { + return genericAddVPNRoute(prefix, intf) + } + + // No need to check if routes exist as main table takes precedence over the VPN table via Rule 1 // TODO remove this once we have ipv6 support if prefix == defaultv4 { - if err := addUnreachableRoute(&defaultv6, NetbirdVPNTableID, netlink.FAMILY_V6); err != nil { + if err := addUnreachableRoute(defaultv6, NetbirdVPNTableID); err != nil { return fmt.Errorf("add blackhole: %w", err) } } - if err := addRoute(&prefix, nil, &intf, NetbirdVPNTableID, netlink.FAMILY_V4); err != nil { + if err := addRoute(prefix, netip.Addr{}, intf, NetbirdVPNTableID); err != nil { return fmt.Errorf("add route: %w", err) } return nil } func removeVPNRoute(prefix netip.Prefix, intf string) error { + if isLegacy { + return genericRemoveVPNRoute(prefix, intf) + } + // TODO remove this once we have ipv6 support if prefix == defaultv4 { - if err := removeUnreachableRoute(&defaultv6, NetbirdVPNTableID, netlink.FAMILY_V6); err != nil { + if err := removeUnreachableRoute(defaultv6, NetbirdVPNTableID); err != nil { return fmt.Errorf("remove unreachable route: %w", err) } } - if err := removeRoute(&prefix, nil, &intf, NetbirdVPNTableID, netlink.FAMILY_V4); err != nil { + if err := removeRoute(prefix, netip.Addr{}, intf, NetbirdVPNTableID); err != nil { return fmt.Errorf("remove route: %w", err) } return nil } +func getRoutesFromTable() ([]netip.Prefix, error) { + v4Routes, err := getRoutes(syscall.RT_TABLE_MAIN, netlink.FAMILY_V4) + if err != nil { + return nil, fmt.Errorf("get v4 routes: %w", err) + } + v6Routes, err := getRoutes(syscall.RT_TABLE_MAIN, netlink.FAMILY_V6) + if err != nil { + return nil, fmt.Errorf("get v6 routes: %w", err) + + } + return append(v4Routes, v6Routes...), nil +} + +// getRoutes fetches routes from a specific routing table identified by tableID. +func getRoutes(tableID, family int) ([]netip.Prefix, error) { + var prefixList []netip.Prefix + + routes, err := netlink.RouteListFiltered(family, &netlink.Route{Table: tableID}, netlink.RT_FILTER_TABLE) + if err != nil { + return nil, fmt.Errorf("list routes from table %d: %v", tableID, err) + } + + for _, route := range routes { + if route.Dst != nil { + addr, ok := netip.AddrFromSlice(route.Dst.IP) + if !ok { + return nil, fmt.Errorf("parse route destination IP: %v", route.Dst.IP) + } + + ones, _ := route.Dst.Mask.Size() + + prefix := netip.PrefixFrom(addr, ones) + if prefix.IsValid() { + prefixList = append(prefixList, prefix) + } + } + } + + return prefixList, nil +} + // addRoute adds a route to a specific routing table identified by tableID. -func addRoute(prefix *netip.Prefix, addr, intf *string, tableID, family int) error { +func addRoute(prefix netip.Prefix, addr netip.Addr, intf string, tableID int) error { route := &netlink.Route{ Scope: netlink.SCOPE_UNIVERSE, Table: tableID, - Family: family, + Family: getAddressFamily(prefix), } - if prefix != nil { - _, ipNet, err := net.ParseCIDR(prefix.String()) - if err != nil { - return fmt.Errorf("parse prefix %s: %w", prefix, err) - } - route.Dst = ipNet + _, ipNet, err := net.ParseCIDR(prefix.String()) + if err != nil { + return fmt.Errorf("parse prefix %s: %w", prefix, err) } + route.Dst = ipNet if err := addNextHop(addr, intf, route); err != nil { return fmt.Errorf("add gateway and device: %w", err) @@ -170,7 +242,7 @@ func addRoute(prefix *netip.Prefix, addr, intf *string, tableID, family int) err // addUnreachableRoute adds an unreachable route for the specified IP family and routing table. // ipFamily should be netlink.FAMILY_V4 for IPv4 or netlink.FAMILY_V6 for IPv6. // tableID specifies the routing table to which the unreachable route will be added. -func addUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error { +func addUnreachableRoute(prefix netip.Prefix, tableID int) error { _, ipNet, err := net.ParseCIDR(prefix.String()) if err != nil { return fmt.Errorf("parse prefix %s: %w", prefix, err) @@ -179,7 +251,7 @@ func addUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error { route := &netlink.Route{ Type: syscall.RTN_UNREACHABLE, Table: tableID, - Family: ipFamily, + Family: getAddressFamily(prefix), Dst: ipNet, } @@ -190,7 +262,7 @@ func addUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error { return nil } -func removeUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error { +func removeUnreachableRoute(prefix netip.Prefix, tableID int) error { _, ipNet, err := net.ParseCIDR(prefix.String()) if err != nil { return fmt.Errorf("parse prefix %s: %w", prefix, err) @@ -199,7 +271,7 @@ func removeUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error { route := &netlink.Route{ Type: syscall.RTN_UNREACHABLE, Table: tableID, - Family: ipFamily, + Family: getAddressFamily(prefix), Dst: ipNet, } @@ -212,7 +284,7 @@ func removeUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error { } // removeRoute removes a route from a specific routing table identified by tableID. -func removeRoute(prefix *netip.Prefix, addr, intf *string, tableID, family int) error { +func removeRoute(prefix netip.Prefix, addr netip.Addr, intf string, tableID int) error { _, ipNet, err := net.ParseCIDR(prefix.String()) if err != nil { return fmt.Errorf("parse prefix %s: %w", prefix, err) @@ -221,7 +293,7 @@ func removeRoute(prefix *netip.Prefix, addr, intf *string, tableID, family int) route := &netlink.Route{ Scope: netlink.SCOPE_UNIVERSE, Table: tableID, - Family: family, + Family: getAddressFamily(prefix), Dst: ipNet, } @@ -392,23 +464,25 @@ func removeAllRules(params ruleParams) error { } // addNextHop adds the gateway and device to the route. -func addNextHop(addr *string, intf *string, route *netlink.Route) error { - if addr != nil { - ip := net.ParseIP(*addr) - if ip == nil { - return fmt.Errorf("parsing address %s failed", *addr) - } - - route.Gw = ip +func addNextHop(addr netip.Addr, intf string, route *netlink.Route) error { + if addr.IsValid() { + route.Gw = addr.AsSlice() } - if intf != nil { - link, err := netlink.LinkByName(*intf) + if intf != "" { + link, err := netlink.LinkByName(intf) if err != nil { - return fmt.Errorf("set interface %s: %w", *intf, err) + return fmt.Errorf("set interface %s: %w", intf, err) } route.LinkIndex = link.Attrs().Index } return nil } + +func getAddressFamily(prefix netip.Prefix) int { + if prefix.Addr().Is4() { + return netlink.FAMILY_V4 + } + return netlink.FAMILY_V6 +} diff --git a/client/internal/routemanager/systemops_linux_test.go b/client/internal/routemanager/systemops_linux_test.go index 50a02401a..d77c7cc7d 100644 --- a/client/internal/routemanager/systemops_linux_test.go +++ b/client/internal/routemanager/systemops_linux_test.go @@ -21,8 +21,6 @@ var expectedLoopbackInt = "lo" var expectedExternalInt = "dummyext0" var expectedInternalInt = "dummyint0" -var errRouteNotFound = fmt.Errorf("route not found") - func init() { testCases = append(testCases, []testCase{ { diff --git a/client/internal/routemanager/systemops_nonlinux.go b/client/internal/routemanager/systemops_nonlinux.go index 4bc186f21..38026107e 100644 --- a/client/internal/routemanager/systemops_nonlinux.go +++ b/client/internal/routemanager/systemops_nonlinux.go @@ -3,414 +3,21 @@ package routemanager import ( - "context" - "errors" - "fmt" - "net" "net/netip" "runtime" - "github.com/hashicorp/go-multierror" - "github.com/libp2p/go-netroute" log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/iface" - nbnet "github.com/netbirdio/netbird/util/net" ) -var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1) -var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1) -var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1) -var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1) - -var errRouteNotFound = fmt.Errorf("route not found") - func enableIPForwarding() error { log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS) return nil } -// TODO: fix: for default our wg address now appears as the default gw -func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error { - addr := netip.IPv4Unspecified() - if prefix.Addr().Is6() { - addr = netip.IPv6Unspecified() - } - - defaultGateway, _, err := getNextHop(addr) - if err != nil && !errors.Is(err, errRouteNotFound) { - return fmt.Errorf("get existing route gateway: %s", err) - } - - if !prefix.Contains(defaultGateway) { - log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", defaultGateway, prefix) - return nil - } - - gatewayPrefix := netip.PrefixFrom(defaultGateway, 32) - if defaultGateway.Is6() { - gatewayPrefix = netip.PrefixFrom(defaultGateway, 128) - } - - ok, err := existsInRouteTable(gatewayPrefix) - if err != nil { - return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err) - } - - if ok { - log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix) - return nil - } - - var exitIntf string - gatewayHop, intf, err := getNextHop(defaultGateway) - if err != nil && !errors.Is(err, errRouteNotFound) { - return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err) - } - if intf != nil { - exitIntf = intf.Name - } - - log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop) - return addToRouteTable(gatewayPrefix, gatewayHop, exitIntf) -} - -func getNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) { - r, err := netroute.New() - if err != nil { - return netip.Addr{}, nil, fmt.Errorf("new netroute: %w", err) - } - intf, gateway, preferredSrc, err := r.Route(ip.AsSlice()) - if err != nil { - log.Errorf("Getting routes returned an error: %v", err) - return netip.Addr{}, nil, errRouteNotFound - } - - log.Debugf("Route for %s: interface %v, nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc) - if gateway == nil { - if preferredSrc == nil { - return netip.Addr{}, nil, errRouteNotFound - } - log.Debugf("No next hop found for ip %s, using preferred source %s", ip, preferredSrc) - - addr, ok := netip.AddrFromSlice(preferredSrc) - if !ok { - return netip.Addr{}, nil, fmt.Errorf("failed to parse IP address: %s", preferredSrc) - } - return addr.Unmap(), intf, nil - } - - addr, ok := netip.AddrFromSlice(gateway) - if !ok { - return netip.Addr{}, nil, fmt.Errorf("failed to parse IP address: %s", gateway) - } - - return addr.Unmap(), intf, nil -} - -func existsInRouteTable(prefix netip.Prefix) (bool, error) { - routes, err := getRoutesFromTable() - if err != nil { - return false, fmt.Errorf("get routes from table: %w", err) - } - for _, tableRoute := range routes { - if tableRoute == prefix { - return true, nil - } - } - return false, nil -} - -func isSubRange(prefix netip.Prefix) (bool, error) { - routes, err := getRoutesFromTable() - if err != nil { - return false, fmt.Errorf("get routes from table: %w", err) - } - for _, tableRoute := range routes { - if tableRoute.Bits() > minRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() { - return true, nil - } - } - return false, nil -} - -// getRouteToNonVPNIntf returns the next hop and interface for the given prefix. -// If the next hop or interface is pointing to the VPN interface, it will return an error -func addRouteToNonVPNIntf( - prefix netip.Prefix, - vpnIntf *iface.WGIface, - initialNextHop netip.Addr, - initialIntf *net.Interface, -) (netip.Addr, string, error) { - addr := prefix.Addr() - switch { - case addr.IsLoopback(): - return netip.Addr{}, "", fmt.Errorf("adding route for loopback address %s is not allowed", prefix) - case addr.IsLinkLocalUnicast(): - return netip.Addr{}, "", fmt.Errorf("adding route for link-local unicast address %s is not allowed", prefix) - case addr.IsLinkLocalMulticast(): - return netip.Addr{}, "", fmt.Errorf("adding route for link-local multicast address %s is not allowed", prefix) - case addr.IsInterfaceLocalMulticast(): - return netip.Addr{}, "", fmt.Errorf("adding route for interface-local multicast address %s is not allowed", prefix) - case addr.IsUnspecified(): - return netip.Addr{}, "", fmt.Errorf("adding route for unspecified address %s is not allowed", prefix) - case addr.IsMulticast(): - return netip.Addr{}, "", fmt.Errorf("adding route for multicast address %s is not allowed", prefix) - } - - // Determine the exit interface and next hop for the prefix, so we can add a specific route - nexthop, intf, err := getNextHop(addr) - if err != nil { - return netip.Addr{}, "", fmt.Errorf("get next hop: %s", err) - } - - log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop, prefix, intf) - exitNextHop := nexthop - var exitIntf string - if intf != nil { - exitIntf = intf.Name - } - - vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP) - if !ok { - return netip.Addr{}, "", fmt.Errorf("failed to convert vpn address to netip.Addr") - } - - // if next hop is the VPN address or the interface is the VPN interface, we should use the initial values - if exitNextHop == vpnAddr || exitIntf == vpnIntf.Name() { - log.Debugf("Route for prefix %s is pointing to the VPN interface", prefix) - exitNextHop = initialNextHop - if initialIntf != nil { - exitIntf = initialIntf.Name - } - } - - log.Debugf("Adding a new route for prefix %s with next hop %s", prefix, exitNextHop) - if err := addToRouteTable(prefix, exitNextHop, exitIntf); err != nil { - return netip.Addr{}, "", fmt.Errorf("add route to table: %w", err) - } - - return exitNextHop, exitIntf, nil -} - -// addVPNRoute adds a new route to the vpn interface, it splits the default prefix -// in two /1 prefixes to avoid replacing the existing default route func addVPNRoute(prefix netip.Prefix, intf string) error { - if prefix == defaultv4 { - if err := addToRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil { - return err - } - if err := addToRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil { - if err2 := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err2 != nil { - log.Warnf("Failed to rollback route addition: %s", err2) - } - return err - } - - // TODO: remove once IPv6 is supported on the interface - if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { - return fmt.Errorf("add unreachable route split 1: %w", err) - } - if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { - if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil { - log.Warnf("Failed to rollback route addition: %s", err2) - } - return fmt.Errorf("add unreachable route split 2: %w", err) - } - - return nil - } else if prefix == defaultv6 { - if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { - return fmt.Errorf("add unreachable route split 1: %w", err) - } - if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { - if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil { - log.Warnf("Failed to rollback route addition: %s", err2) - } - return fmt.Errorf("add unreachable route split 2: %w", err) - } - - return nil - } - - return addNonExistingRoute(prefix, intf) + return genericAddVPNRoute(prefix, intf) } -// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table -func addNonExistingRoute(prefix netip.Prefix, intf string) error { - ok, err := existsInRouteTable(prefix) - if err != nil { - return fmt.Errorf("exists in route table: %w", err) - } - if ok { - log.Warnf("Skipping adding a new route for network %s because it already exists", prefix) - return nil - } - - ok, err = isSubRange(prefix) - if err != nil { - return fmt.Errorf("sub range: %w", err) - } - - if ok { - err := addRouteForCurrentDefaultGateway(prefix) - if err != nil { - log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err) - } - } - - return addToRouteTable(prefix, netip.Addr{}, intf) -} - -// removeVPNRoute removes the route from the vpn interface. If a default prefix is given, -// it will remove the split /1 prefixes func removeVPNRoute(prefix netip.Prefix, intf string) error { - if prefix == defaultv4 { - var result *multierror.Error - if err := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil { - result = multierror.Append(result, err) - } - if err := removeFromRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil { - result = multierror.Append(result, err) - } - - // TODO: remove once IPv6 is supported on the interface - if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { - result = multierror.Append(result, err) - } - if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { - result = multierror.Append(result, err) - } - - return result.ErrorOrNil() - } else if prefix == defaultv6 { - var result *multierror.Error - if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { - result = multierror.Append(result, err) - } - if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { - result = multierror.Append(result, err) - } - - return result.ErrorOrNil() - } - - return removeFromRouteTable(prefix, netip.Addr{}, intf) -} - -func getPrefixFromIP(ip net.IP) (*netip.Prefix, error) { - addr, ok := netip.AddrFromSlice(ip) - if !ok { - return nil, fmt.Errorf("parse IP address: %s", ip) - } - addr = addr.Unmap() - - var prefixLength int - switch { - case addr.Is4(): - prefixLength = 32 - case addr.Is6(): - prefixLength = 128 - default: - return nil, fmt.Errorf("invalid IP address: %s", addr) - } - - prefix := netip.PrefixFrom(addr, prefixLength) - return &prefix, nil -} - -func setupRoutingWithRouteManager(routeManager **RouteManager, initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { - initialNextHopV4, initialIntfV4, err := getNextHop(netip.IPv4Unspecified()) - if err != nil { - log.Errorf("Unable to get initial v4 default next hop: %v", err) - } - initialNextHopV6, initialIntfV6, err := getNextHop(netip.IPv6Unspecified()) - if err != nil { - log.Errorf("Unable to get initial v6 default next hop: %v", err) - } - - *routeManager = NewRouteManager( - func(prefix netip.Prefix) (netip.Addr, string, error) { - addr := prefix.Addr() - nexthop, intf := initialNextHopV4, initialIntfV4 - if addr.Is6() { - nexthop, intf = initialNextHopV6, initialIntfV6 - } - return addRouteToNonVPNIntf(prefix, wgIface, nexthop, intf) - }, - removeFromRouteTable, - ) - - return setupHooks(*routeManager, initAddresses) -} - -func cleanupRoutingWithRouteManager(routeManager *RouteManager) error { - if routeManager == nil { - return nil - } - - // TODO: Remove hooks selectively - nbnet.RemoveDialerHooks() - nbnet.RemoveListenerHooks() - - if err := routeManager.Flush(); err != nil { - return fmt.Errorf("flush route manager: %w", err) - } - - return nil -} - -func setupHooks(routeManager *RouteManager, initAddresses []net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { - beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error { - prefix, err := getPrefixFromIP(ip) - if err != nil { - return fmt.Errorf("convert ip to prefix: %w", err) - } - - if err := routeManager.AddRouteRef(connID, *prefix); err != nil { - return fmt.Errorf("adding route reference: %v", err) - } - - return nil - } - afterHook := func(connID nbnet.ConnectionID) error { - if err := routeManager.RemoveRouteRef(connID); err != nil { - return fmt.Errorf("remove route reference: %w", err) - } - - return nil - } - - for _, ip := range initAddresses { - if err := beforeHook("init", ip); err != nil { - log.Errorf("Failed to add route reference: %v", err) - } - } - - nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error { - if ctx.Err() != nil { - return ctx.Err() - } - - var result *multierror.Error - for _, ip := range resolvedIPs { - result = multierror.Append(result, beforeHook(connID, ip.IP)) - } - return result.ErrorOrNil() - }) - - nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error { - return afterHook(connID) - }) - - nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error { - return beforeHook(connID, ip.IP) - }) - - nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error { - return afterHook(connID) - }) - - return beforeHook, afterHook, nil + return genericRemoveVPNRoute(prefix, intf) } diff --git a/client/internal/routemanager/systemops_nonlinux_test.go b/client/internal/routemanager/systemops_test.go similarity index 70% rename from client/internal/routemanager/systemops_nonlinux_test.go rename to client/internal/routemanager/systemops_test.go index adb83bac6..97386f19a 100644 --- a/client/internal/routemanager/systemops_nonlinux_test.go +++ b/client/internal/routemanager/systemops_test.go @@ -1,13 +1,15 @@ -//go:build !linux && !ios +//go:build !android && !ios package routemanager import ( "bytes" + "context" "fmt" "net" "net/netip" "os" + "runtime" "strings" "testing" @@ -20,16 +22,9 @@ import ( "github.com/netbirdio/netbird/iface" ) -func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIface, invert bool) { - t.Helper() - - prefixGateway, _, err := getNextHop(prefix.Addr()) - require.NoError(t, err, "getNextHop should not return err") - if invert { - assert.NotEqual(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should not point to wireguard interface IP") - } else { - assert.Equal(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP") - } +type dialer interface { + Dial(network, address string) (net.Conn, error) + DialContext(ctx context.Context, network, address string) (net.Conn, error) } func TestAddRemoveRoutes(t *testing.T) { @@ -72,8 +67,8 @@ func TestAddRemoveRoutes(t *testing.T) { assert.NoError(t, cleanupRouting()) }) - err = addVPNRoute(testCase.prefix, wgInterface.Name()) - require.NoError(t, err, "addVPNRoute should not return err") + err = genericAddVPNRoute(testCase.prefix, wgInterface.Name()) + require.NoError(t, err, "genericAddVPNRoute should not return err") if testCase.shouldRouteToWireguard { assertWGOutInterface(t, testCase.prefix, wgInterface, false) @@ -83,8 +78,8 @@ func TestAddRemoveRoutes(t *testing.T) { exists, err := existsInRouteTable(testCase.prefix) require.NoError(t, err, "existsInRouteTable should not return err") if exists && testCase.shouldRouteToWireguard { - err = removeVPNRoute(testCase.prefix, wgInterface.Name()) - require.NoError(t, err, "removeVPNRoute should not return err") + err = genericRemoveVPNRoute(testCase.prefix, wgInterface.Name()) + require.NoError(t, err, "genericRemoveVPNRoute should not return err") prefixGateway, _, err := getNextHop(testCase.prefix.Addr()) require.NoError(t, err, "getNextHop should not return err") @@ -144,7 +139,7 @@ func TestGetNextHop(t *testing.T) { } } -func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) { +func TestAddExistAndRemoveRoute(t *testing.T) { defaultGateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0")) t.Log("defaultGateway: ", defaultGateway) if err != nil { @@ -205,20 +200,14 @@ func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) { err = wgInterface.Create() require.NoError(t, err, "should create testing wireguard interface") - _, _, err = setupRouting(nil, nil) - require.NoError(t, err) - t.Cleanup(func() { - assert.NoError(t, cleanupRouting()) - }) - // Prepare the environment if testCase.preExistingPrefix.IsValid() { - err := addVPNRoute(testCase.preExistingPrefix, wgInterface.Name()) + err := genericAddVPNRoute(testCase.preExistingPrefix, wgInterface.Name()) require.NoError(t, err, "should not return err when adding pre-existing route") } // Add the route - err = addVPNRoute(testCase.prefix, wgInterface.Name()) + err = genericAddVPNRoute(testCase.prefix, wgInterface.Name()) require.NoError(t, err, "should not return err when adding route") if testCase.shouldAddRoute { @@ -228,7 +217,7 @@ func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) { require.True(t, ok, "route should exist") // remove route again if added - err = removeVPNRoute(testCase.prefix, wgInterface.Name()) + err = genericRemoveVPNRoute(testCase.prefix, wgInterface.Name()) require.NoError(t, err, "should not return err") } @@ -284,12 +273,6 @@ func TestIsSubRange(t *testing.T) { } func TestExistsInRouteTable(t *testing.T) { - _, _, err := setupRouting(nil, nil) - require.NoError(t, err) - t.Cleanup(func() { - assert.NoError(t, cleanupRouting()) - }) - addresses, err := net.InterfaceAddrs() if err != nil { t.Fatal("shouldn't return error when fetching interface addresses: ", err) @@ -298,10 +281,19 @@ func TestExistsInRouteTable(t *testing.T) { var addressPrefixes []netip.Prefix for _, address := range addresses { p := netip.MustParsePrefix(address.String()) - // Windows sometimes has hidden interface link local addrs that don't turn up on any interface - if p.Addr().Is4() && !p.Addr().IsLinkLocalUnicast() { - addressPrefixes = append(addressPrefixes, p.Masked()) + if p.Addr().Is6() { + continue } + // Windows sometimes has hidden interface link local addrs that don't turn up on any interface + if runtime.GOOS == "windows" && p.Addr().IsLinkLocalUnicast() { + continue + } + // Linux loopback 127/8 is in the local table, not in the main table and always takes precedence + if runtime.GOOS == "linux" && p.Addr().IsLoopback() { + continue + } + + addressPrefixes = append(addressPrefixes, p.Masked()) } for _, prefix := range addressPrefixes { @@ -314,3 +306,97 @@ func TestExistsInRouteTable(t *testing.T) { } } } + +func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listenPort int) *iface.WGIface { + t.Helper() + + peerPrivateKey, err := wgtypes.GeneratePrivateKey() + require.NoError(t, err) + + newNet, err := stdnet.NewNet() + require.NoError(t, err) + + wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil) + require.NoError(t, err, "should create testing WireGuard interface") + + err = wgInterface.Create() + require.NoError(t, err, "should create testing WireGuard interface") + + t.Cleanup(func() { + wgInterface.Close() + }) + + return wgInterface +} + +func setupTestEnv(t *testing.T) { + t.Helper() + + setupDummyInterfacesAndRoutes(t) + + wgIface := createWGInterface(t, expectedVPNint, "100.64.0.1/24", 51820) + t.Cleanup(func() { + assert.NoError(t, wgIface.Close()) + }) + + _, _, err := setupRouting(nil, wgIface) + require.NoError(t, err, "setupRouting should not return err") + t.Cleanup(func() { + assert.NoError(t, cleanupRouting()) + }) + + // default route exists in main table and vpn table + err = addVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name()) + require.NoError(t, err, "addVPNRoute should not return err") + t.Cleanup(func() { + err = removeVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name()) + assert.NoError(t, err, "removeVPNRoute should not return err") + }) + + // 10.0.0.0/8 route exists in main table and vpn table + err = addVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name()) + require.NoError(t, err, "addVPNRoute should not return err") + t.Cleanup(func() { + err = removeVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name()) + assert.NoError(t, err, "removeVPNRoute should not return err") + }) + + // 10.10.0.0/24 more specific route exists in vpn table + err = addVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name()) + require.NoError(t, err, "addVPNRoute should not return err") + t.Cleanup(func() { + err = removeVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name()) + assert.NoError(t, err, "removeVPNRoute should not return err") + }) + + // 127.0.10.0/24 more specific route exists in vpn table + err = addVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name()) + require.NoError(t, err, "addVPNRoute should not return err") + t.Cleanup(func() { + err = removeVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name()) + assert.NoError(t, err, "removeVPNRoute should not return err") + }) + + // unique route in vpn table + err = addVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name()) + require.NoError(t, err, "addVPNRoute should not return err") + t.Cleanup(func() { + err = removeVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name()) + assert.NoError(t, err, "removeVPNRoute should not return err") + }) +} + +func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIface, invert bool) { + t.Helper() + if runtime.GOOS == "linux" && prefix.Addr().IsLoopback() { + return + } + + prefixGateway, _, err := getNextHop(prefix.Addr()) + require.NoError(t, err, "getNextHop should not return err") + if invert { + assert.NotEqual(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should not point to wireguard interface IP") + } else { + assert.Equal(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP") + } +} diff --git a/client/internal/routemanager/sytemops_test.go b/client/internal/routemanager/sytemops_test.go deleted file mode 100644 index 28a6502d2..000000000 --- a/client/internal/routemanager/sytemops_test.go +++ /dev/null @@ -1,101 +0,0 @@ -//go:build !android && !ios - -package routemanager - -import ( - "context" - "net" - "net/netip" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - - "github.com/netbirdio/netbird/client/internal/stdnet" - "github.com/netbirdio/netbird/iface" -) - -type dialer interface { - Dial(network, address string) (net.Conn, error) - DialContext(ctx context.Context, network, address string) (net.Conn, error) -} - -func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listenPort int) *iface.WGIface { - t.Helper() - - peerPrivateKey, err := wgtypes.GeneratePrivateKey() - require.NoError(t, err) - - newNet, err := stdnet.NewNet(nil) - require.NoError(t, err) - - wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil) - require.NoError(t, err, "should create testing WireGuard interface") - - err = wgInterface.Create() - require.NoError(t, err, "should create testing WireGuard interface") - - t.Cleanup(func() { - wgInterface.Close() - }) - - return wgInterface -} - -func setupTestEnv(t *testing.T) { - t.Helper() - - setupDummyInterfacesAndRoutes(t) - - wgIface := createWGInterface(t, expectedVPNint, "100.64.0.1/24", 51820) - t.Cleanup(func() { - assert.NoError(t, wgIface.Close()) - }) - - _, _, err := setupRouting(nil, wgIface) - require.NoError(t, err, "setupRouting should not return err") - t.Cleanup(func() { - assert.NoError(t, cleanupRouting()) - }) - - // default route exists in main table and vpn table - err = addVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name()) - require.NoError(t, err, "addVPNRoute should not return err") - t.Cleanup(func() { - err = removeVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name()) - assert.NoError(t, err, "removeVPNRoute should not return err") - }) - - // 10.0.0.0/8 route exists in main table and vpn table - err = addVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name()) - require.NoError(t, err, "addVPNRoute should not return err") - t.Cleanup(func() { - err = removeVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name()) - assert.NoError(t, err, "removeVPNRoute should not return err") - }) - - // 10.10.0.0/24 more specific route exists in vpn table - err = addVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name()) - require.NoError(t, err, "addVPNRoute should not return err") - t.Cleanup(func() { - err = removeVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name()) - assert.NoError(t, err, "removeVPNRoute should not return err") - }) - - // 127.0.10.0/24 more specific route exists in vpn table - err = addVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name()) - require.NoError(t, err, "addVPNRoute should not return err") - t.Cleanup(func() { - err = removeVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name()) - assert.NoError(t, err, "removeVPNRoute should not return err") - }) - - // unique route in vpn table - err = addVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name()) - require.NoError(t, err, "addVPNRoute should not return err") - t.Cleanup(func() { - err = removeVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name()) - assert.NoError(t, err, "removeVPNRoute should not return err") - }) -} diff --git a/util/net/dialer.go b/util/net/dialer.go index 7b9bddbb5..d3adef363 100644 --- a/util/net/dialer.go +++ b/util/net/dialer.go @@ -35,7 +35,7 @@ func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) { udpConn, ok := conn.(*net.UDPConn) if !ok { if err := conn.Close(); err != nil { - log.Errorf("Failed to closeConn connection: %v", err) + log.Errorf("Failed to close connection: %v", err) } return nil, fmt.Errorf("expected UDP connection, got different type") } diff --git a/util/net/listener_generic.go b/util/net/listener_generic.go index ae412415f..a195bdeb9 100644 --- a/util/net/listener_generic.go +++ b/util/net/listener_generic.go @@ -145,10 +145,19 @@ func closeConn(id ConnectionID, conn net.PacketConn) error { // ListenUDP listens on the network address and returns a transport.UDPConn // which includes support for write and close hooks. func ListenUDP(network string, laddr *net.UDPAddr) (*UDPConn, error) { - udpConn, err := net.ListenUDP(network, laddr) + conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String()) if err != nil { return nil, fmt.Errorf("listen UDP: %w", err) } - connID := GenerateConnID() - return &UDPConn{UDPConn: udpConn, ID: connID, seenAddrs: &sync.Map{}}, nil + + packetConn := conn.(*PacketConn) + udpConn, ok := packetConn.PacketConn.(*net.UDPConn) + if !ok { + if err := packetConn.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected UDPConn, got different type") + } + + return &UDPConn{UDPConn: udpConn, ID: packetConn.ID, seenAddrs: &sync.Map{}}, nil } From 25f5f26527d777c1064d2825209da5183ea3dcae Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Wed, 3 Apr 2024 18:57:50 +0200 Subject: [PATCH 15/28] Timeout rule removing loop and catch IPv6 unsupported error in loop (#1791) --- .../internal/routemanager/systemops_linux.go | 34 +++++++++++++++---- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/client/internal/routemanager/systemops_linux.go b/client/internal/routemanager/systemops_linux.go index 44691f0d6..ef4643727 100644 --- a/client/internal/routemanager/systemops_linux.go +++ b/client/internal/routemanager/systemops_linux.go @@ -4,12 +4,14 @@ package routemanager import ( "bufio" + "context" "errors" "fmt" "net" "net/netip" "os" "syscall" + "time" "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" @@ -444,7 +446,7 @@ func removeRule(params ruleParams) error { rule.Priority = params.priority rule.SuppressPrefixlen = params.suppressPrefix - if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { + if err := netlink.RuleDel(rule); err != nil { return fmt.Errorf("remove routing rule: %w", err) } @@ -452,15 +454,33 @@ func removeRule(params ruleParams) error { } func removeAllRules(params ruleParams) error { - for { - if err := removeRule(params); err != nil { - if errors.Is(err, syscall.ENOENT) { - break + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + done := make(chan error, 1) + go func() { + for { + if ctx.Err() != nil { + done <- ctx.Err() + return + } + if err := removeRule(params); err != nil { + if errors.Is(err, syscall.ENOENT) || errors.Is(err, syscall.EAFNOSUPPORT) { + done <- nil + return + } + done <- err + return } - return err } + }() + + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-done: + return err } - return nil } // addNextHop adds the gateway and device to the route. From 3d2a2377c685d397c876587f9a55e04057f341db Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Wed, 3 Apr 2024 19:06:04 +0200 Subject: [PATCH 16/28] Don't return errors on disallowed routes (#1792) --- client/internal/routemanager/routemanager.go | 5 ++- client/internal/routemanager/systemops.go | 39 +++++++++---------- .../routemanager/systemops_linux_test.go | 4 +- 3 files changed, 24 insertions(+), 24 deletions(-) diff --git a/client/internal/routemanager/routemanager.go b/client/internal/routemanager/routemanager.go index 39b55f105..8f9ff9f4b 100644 --- a/client/internal/routemanager/routemanager.go +++ b/client/internal/routemanager/routemanager.go @@ -54,9 +54,12 @@ func (rm *RouteManager) AddRouteRef(connID nbnet.ConnectionID, prefix netip.Pref if ref.count == 0 { log.Debugf("Adding route for prefix %s", prefix) nexthop, intf, err := rm.addRoute(prefix) - if errors.Is(err, errRouteNotFound) { + if errors.Is(err, ErrRouteNotFound) { return nil } + if errors.Is(err, ErrRouteNotAllowed) { + log.Debugf("Adding route for prefix %s: %s", prefix, err) + } if err != nil { return fmt.Errorf("failed to add route for prefix %s: %w", prefix, err) } diff --git a/client/internal/routemanager/systemops.go b/client/internal/routemanager/systemops.go index c6f3376e0..a91f53636 100644 --- a/client/internal/routemanager/systemops.go +++ b/client/internal/routemanager/systemops.go @@ -23,7 +23,8 @@ var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1) var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1) var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1) -var errRouteNotFound = fmt.Errorf("route not found") +var ErrRouteNotFound = errors.New("route not found") +var ErrRouteNotAllowed = errors.New("route not allowed") // TODO: fix: for default our wg address now appears as the default gw func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error { @@ -33,7 +34,7 @@ func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error { } defaultGateway, _, err := getNextHop(addr) - if err != nil && !errors.Is(err, errRouteNotFound) { + if err != nil && !errors.Is(err, ErrRouteNotFound) { return fmt.Errorf("get existing route gateway: %s", err) } @@ -59,7 +60,7 @@ func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error { var exitIntf string gatewayHop, intf, err := getNextHop(defaultGateway) - if err != nil && !errors.Is(err, errRouteNotFound) { + if err != nil && !errors.Is(err, ErrRouteNotFound) { return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err) } if intf != nil { @@ -78,13 +79,13 @@ func getNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) { intf, gateway, preferredSrc, err := r.Route(ip.AsSlice()) if err != nil { log.Warnf("Failed to get route for %s: %v", ip, err) - return netip.Addr{}, nil, errRouteNotFound + return netip.Addr{}, nil, ErrRouteNotFound } log.Debugf("Route for %s: interface %v, nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc) if gateway == nil { if preferredSrc == nil { - return netip.Addr{}, nil, errRouteNotFound + return netip.Addr{}, nil, ErrRouteNotFound } log.Debugf("No next hop found for ip %s, using preferred source %s", ip, preferredSrc) @@ -129,8 +130,8 @@ func isSubRange(prefix netip.Prefix) (bool, error) { return false, nil } -// getRouteToNonVPNIntf returns the next hop and interface for the given prefix. -// If the next hop or interface is pointing to the VPN interface, it will return an error +// addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface. +// If the next hop or interface is pointing to the VPN interface, it will return the initial values. func addRouteToNonVPNIntf( prefix netip.Prefix, vpnIntf *iface.WGIface, @@ -139,18 +140,14 @@ func addRouteToNonVPNIntf( ) (netip.Addr, string, error) { addr := prefix.Addr() switch { - case addr.IsLoopback(): - return netip.Addr{}, "", fmt.Errorf("adding route for loopback address %s is not allowed", prefix) - case addr.IsLinkLocalUnicast(): - return netip.Addr{}, "", fmt.Errorf("adding route for link-local unicast address %s is not allowed", prefix) - case addr.IsLinkLocalMulticast(): - return netip.Addr{}, "", fmt.Errorf("adding route for link-local multicast address %s is not allowed", prefix) - case addr.IsInterfaceLocalMulticast(): - return netip.Addr{}, "", fmt.Errorf("adding route for interface-local multicast address %s is not allowed", prefix) - case addr.IsUnspecified(): - return netip.Addr{}, "", fmt.Errorf("adding route for unspecified address %s is not allowed", prefix) - case addr.IsMulticast(): - return netip.Addr{}, "", fmt.Errorf("adding route for multicast address %s is not allowed", prefix) + case addr.IsLoopback(), + addr.IsLinkLocalUnicast(), + addr.IsLinkLocalMulticast(), + addr.IsInterfaceLocalMulticast(), + addr.IsUnspecified(), + addr.IsMulticast(): + + return netip.Addr{}, "", ErrRouteNotAllowed } // Determine the exit interface and next hop for the prefix, so we can add a specific route @@ -316,11 +313,11 @@ func getPrefixFromIP(ip net.IP) (*netip.Prefix, error) { func setupRoutingWithRouteManager(routeManager **RouteManager, initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { initialNextHopV4, initialIntfV4, err := getNextHop(netip.IPv4Unspecified()) - if err != nil && !errors.Is(err, errRouteNotFound) { + if err != nil && !errors.Is(err, ErrRouteNotFound) { log.Errorf("Unable to get initial v4 default next hop: %v", err) } initialNextHopV6, initialIntfV6, err := getNextHop(netip.IPv6Unspecified()) - if err != nil && !errors.Is(err, errRouteNotFound) { + if err != nil && !errors.Is(err, ErrRouteNotFound) { log.Errorf("Unable to get initial v6 default next hop: %v", err) } diff --git a/client/internal/routemanager/systemops_linux_test.go b/client/internal/routemanager/systemops_linux_test.go index d77c7cc7d..0043c3f4e 100644 --- a/client/internal/routemanager/systemops_linux_test.go +++ b/client/internal/routemanager/systemops_linux_test.go @@ -138,7 +138,7 @@ func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, intf string) { if dstIPNet.String() == "0.0.0.0/0" { var err error originalNexthop, originalLinkIndex, err = fetchOriginalGateway(netlink.FAMILY_V4) - if err != nil && !errors.Is(err, errRouteNotFound) { + if err != nil && !errors.Is(err, ErrRouteNotFound) { t.Logf("Failed to fetch original gateway: %v", err) } @@ -193,7 +193,7 @@ func fetchOriginalGateway(family int) (net.IP, int, error) { } } - return nil, 0, errRouteNotFound + return nil, 0, ErrRouteNotFound } func setupDummyInterfacesAndRoutes(t *testing.T) { From 3461b1bb90e71152084b57e4feec1583a3cf60b9 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 5 Apr 2024 00:10:32 +0200 Subject: [PATCH 17/28] Expect correct conn type (#1801) --- util/net/dialer.go | 43 ------------------------------------ util/net/dialer_generic.go | 40 +++++++++++++++++++++++++++++++++ util/net/dialer_mobile.go | 15 +++++++++++++ util/net/listener_generic.go | 2 +- 4 files changed, 56 insertions(+), 44 deletions(-) create mode 100644 util/net/dialer_mobile.go diff --git a/util/net/dialer.go b/util/net/dialer.go index d3adef363..0786c667e 100644 --- a/util/net/dialer.go +++ b/util/net/dialer.go @@ -1,10 +1,7 @@ package net import ( - "fmt" "net" - - log "github.com/sirupsen/logrus" ) // Dialer extends the standard net.Dialer with the ability to execute hooks before @@ -22,43 +19,3 @@ func NewDialer() *Dialer { return dialer } - -func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) { - dialer := NewDialer() - dialer.LocalAddr = laddr - - conn, err := dialer.Dial(network, raddr.String()) - if err != nil { - return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err) - } - - udpConn, ok := conn.(*net.UDPConn) - if !ok { - if err := conn.Close(); err != nil { - log.Errorf("Failed to close connection: %v", err) - } - return nil, fmt.Errorf("expected UDP connection, got different type") - } - - return udpConn, nil -} - -func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) { - dialer := NewDialer() - dialer.LocalAddr = laddr - - conn, err := dialer.Dial(network, raddr.String()) - if err != nil { - return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err) - } - - tcpConn, ok := conn.(*net.TCPConn) - if !ok { - if err := conn.Close(); err != nil { - log.Errorf("Failed to close connection: %v", err) - } - return nil, fmt.Errorf("expected TCP connection, got different type") - } - - return tcpConn, nil -} diff --git a/util/net/dialer_generic.go b/util/net/dialer_generic.go index 2e102da50..06fac3bbf 100644 --- a/util/net/dialer_generic.go +++ b/util/net/dialer_generic.go @@ -121,3 +121,43 @@ func calliDialerHooks(ctx context.Context, connID ConnectionID, address string, return result.ErrorOrNil() } + +func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) { + dialer := NewDialer() + dialer.LocalAddr = laddr + + conn, err := dialer.Dial(network, raddr.String()) + if err != nil { + return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err) + } + + udpConn, ok := conn.(*Conn).Conn.(*net.UDPConn) + if !ok { + if err := conn.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected UDP connection, got different type: %T", conn) + } + + return udpConn, nil +} + +func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) { + dialer := NewDialer() + dialer.LocalAddr = laddr + + conn, err := dialer.Dial(network, raddr.String()) + if err != nil { + return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err) + } + + tcpConn, ok := conn.(*Conn).Conn.(*net.TCPConn) + if !ok { + if err := conn.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected TCP connection, got different type: %T", conn) + } + + return tcpConn, nil +} diff --git a/util/net/dialer_mobile.go b/util/net/dialer_mobile.go new file mode 100644 index 000000000..b95aaa973 --- /dev/null +++ b/util/net/dialer_mobile.go @@ -0,0 +1,15 @@ +//go:build android || ios + +package net + +import ( + "net" +) + +func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) { + return net.DialUDP(network, laddr, raddr) +} + +func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) { + return net.DialTCP(network, laddr, raddr) +} diff --git a/util/net/listener_generic.go b/util/net/listener_generic.go index a195bdeb9..451279e9d 100644 --- a/util/net/listener_generic.go +++ b/util/net/listener_generic.go @@ -156,7 +156,7 @@ func ListenUDP(network string, laddr *net.UDPAddr) (*UDPConn, error) { if err := packetConn.Close(); err != nil { log.Errorf("Failed to close connection: %v", err) } - return nil, fmt.Errorf("expected UDPConn, got different type") + return nil, fmt.Errorf("expected UDPConn, got different type: %T", udpConn) } return &UDPConn{UDPConn: udpConn, ID: packetConn.ID, seenAddrs: &sync.Map{}}, nil From 1d1d057e7d01f8e3e34bb509738a7dbb5f21a62e Mon Sep 17 00:00:00 2001 From: trax <38944599+4nx@users.noreply.github.com> Date: Fri, 5 Apr 2024 13:51:28 +0200 Subject: [PATCH 18/28] Change the dashboard image pull from wiretrustee to netbirdio (#1804) --- infrastructure_files/docker-compose.yml.tmpl.traefik | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/infrastructure_files/docker-compose.yml.tmpl.traefik b/infrastructure_files/docker-compose.yml.tmpl.traefik index fd194a042..d3ae6529a 100644 --- a/infrastructure_files/docker-compose.yml.tmpl.traefik +++ b/infrastructure_files/docker-compose.yml.tmpl.traefik @@ -2,7 +2,7 @@ version: "3" services: #UI dashboard dashboard: - image: wiretrustee/dashboard:$NETBIRD_DASHBOARD_TAG + image: netbirdio/dashboard:$NETBIRD_DASHBOARD_TAG restart: unless-stopped #ports: # - 80:80 From 9f32ccd4533d5301bcb901677af9816cb3408f92 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 5 Apr 2024 20:38:49 +0200 Subject: [PATCH 19/28] Rollback new routing functionality (#1805) --- .github/workflows/golang-test-darwin.yml | 3 - .github/workflows/golang-test-linux.yml | 10 +- .github/workflows/golang-test-windows.yml | 2 +- .github/workflows/golangci-lint.yml | 2 +- client/internal/engine.go | 16 - client/internal/peer/conn.go | 32 -- client/internal/relay/relay.go | 7 +- client/internal/routemanager/client.go | 63 +-- client/internal/routemanager/manager.go | 79 +-- client/internal/routemanager/manager_test.go | 30 +- client/internal/routemanager/mock.go | 5 - client/internal/routemanager/routemanager.go | 126 ----- .../routemanager/server_nonandroid.go | 57 +- client/internal/routemanager/systemops.go | 407 -------------- .../routemanager/systemops_android.go | 24 +- client/internal/routemanager/systemops_bsd.go | 1 + .../internal/routemanager/systemops_darwin.go | 61 -- .../routemanager/systemops_darwin_test.go | 100 ---- client/internal/routemanager/systemops_ios.go | 26 +- .../internal/routemanager/systemops_linux.go | 531 +++--------------- .../routemanager/systemops_linux_test.go | 207 ------- .../routemanager/systemops_nonandroid.go | 120 ++++ ...s_test.go => systemops_nonandroid_test.go} | 212 ++----- .../routemanager/systemops_nonlinux.go | 32 +- .../routemanager/systemops_unix_test.go | 234 -------- .../routemanager/systemops_windows.go | 88 +-- .../routemanager/systemops_windows_test.go | 289 ---------- client/internal/stdnet/dialer.go | 24 - client/internal/stdnet/listener.go | 20 - client/internal/wgproxy/proxy_ebpf.go | 9 +- client/internal/wgproxy/proxy_userspace.go | 4 +- go.mod | 2 +- iface/wg_configurer_kernel.go | 4 +- iface/wg_configurer_usp.go | 11 +- management/client/grpc.go | 2 - sharedsock/sock_linux.go | 10 - signal/client/grpc.go | 2 - util/grpc/dialer.go | 22 - util/net/dialer.go | 21 - util/net/dialer_generic.go | 163 ------ util/net/dialer_linux.go | 12 - util/net/dialer_nonlinux.go | 6 - util/net/listener.go | 21 - util/net/listener_generic.go | 163 ------ util/net/listener_linux.go | 14 - util/net/listener_mobile.go | 11 - util/net/listener_nonlinux.go | 6 - util/net/net.go | 17 - util/net/net_linux.go | 35 -- 49 files changed, 364 insertions(+), 2979 deletions(-) delete mode 100644 client/internal/routemanager/routemanager.go delete mode 100644 client/internal/routemanager/systemops.go delete mode 100644 client/internal/routemanager/systemops_darwin.go delete mode 100644 client/internal/routemanager/systemops_darwin_test.go delete mode 100644 client/internal/routemanager/systemops_linux_test.go create mode 100644 client/internal/routemanager/systemops_nonandroid.go rename client/internal/routemanager/{systemops_test.go => systemops_nonandroid_test.go} (59%) delete mode 100644 client/internal/routemanager/systemops_unix_test.go delete mode 100644 client/internal/routemanager/systemops_windows_test.go delete mode 100644 client/internal/stdnet/dialer.go delete mode 100644 client/internal/stdnet/listener.go delete mode 100644 util/grpc/dialer.go delete mode 100644 util/net/dialer.go delete mode 100644 util/net/dialer_generic.go delete mode 100644 util/net/dialer_linux.go delete mode 100644 util/net/dialer_nonlinux.go delete mode 100644 util/net/listener.go delete mode 100644 util/net/listener_generic.go delete mode 100644 util/net/listener_linux.go delete mode 100644 util/net/listener_mobile.go delete mode 100644 util/net/listener_nonlinux.go delete mode 100644 util/net/net.go delete mode 100644 util/net/net_linux.go diff --git a/.github/workflows/golang-test-darwin.yml b/.github/workflows/golang-test-darwin.yml index d7007c860..f8afd3d6e 100644 --- a/.github/workflows/golang-test-darwin.yml +++ b/.github/workflows/golang-test-darwin.yml @@ -32,9 +32,6 @@ jobs: restore-keys: | macos-go- - - name: Install libpcap - run: brew install libpcap - - name: Install modules run: go mod tidy diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index 42f740e9b..74e6d1203 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -36,11 +36,7 @@ jobs: uses: actions/checkout@v3 - name: Install dependencies - run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev - - - name: Install 32-bit libpcap - if: matrix.arch == '386' - run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386 + run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib - name: Install modules run: go mod tidy @@ -71,7 +67,7 @@ jobs: uses: actions/checkout@v3 - name: Install dependencies - run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev + run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib - name: Install modules run: go mod tidy @@ -86,7 +82,7 @@ jobs: run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock - name: Generate RouteManager Test bin - run: CGO_ENABLED=1 go test -c -o routemanager-testing.bin -tags netgo -ldflags '-w -extldflags "-static -ldbus-1 -lpcap"' ./client/internal/routemanager/... + run: CGO_ENABLED=0 go test -c -o routemanager-testing.bin ./client/internal/routemanager/... - name: Generate nftables Manager Test bin run: CGO_ENABLED=0 go test -c -o nftablesmanager-testing.bin ./client/firewall/nftables/... diff --git a/.github/workflows/golang-test-windows.yml b/.github/workflows/golang-test-windows.yml index 2d63acbcd..6027d3626 100644 --- a/.github/workflows/golang-test-windows.yml +++ b/.github/workflows/golang-test-windows.yml @@ -46,7 +46,7 @@ jobs: - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=C:\Users\runneradmin\AppData\Local\go-build - name: test - run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 10m -p 1 ./... > test-out.txt 2>&1" + run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 5m -p 1 ./... > test-out.txt 2>&1" - name: test output if: ${{ always() }} run: Get-Content test-out.txt diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index 13228250d..9f543c74c 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -40,7 +40,7 @@ jobs: cache: false - name: Install dependencies if: matrix.os == 'ubuntu-latest' - run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev + run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev - name: golangci-lint uses: golangci/golangci-lint-action@v3 with: diff --git a/client/internal/engine.go b/client/internal/engine.go index d6238c4b3..13ef8ce15 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -94,9 +94,6 @@ type Engine struct { // peerConns is a map that holds all the peers that are known to this peer peerConns map[string]*peer.Conn - beforePeerHook peer.BeforeAddPeerHookFunc - afterPeerHook peer.AfterRemovePeerHookFunc - // rpManager is a Rosenpass manager rpManager *rosenpass.Manager @@ -264,14 +261,6 @@ func (e *Engine) Start() error { e.dnsServer = dnsServer e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, initialRoutes) - beforePeerHook, afterPeerHook, err := e.routeManager.Init() - if err != nil { - log.Errorf("Failed to initialize route manager: %s", err) - } else { - e.beforePeerHook = beforePeerHook - e.afterPeerHook = afterPeerHook - } - e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener) err = e.wgInterfaceCreate() @@ -821,11 +810,6 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error { } e.peerConns[peerKey] = conn - if e.beforePeerHook != nil && e.afterPeerHook != nil { - conn.AddBeforeAddPeerHook(e.beforePeerHook) - conn.AddAfterRemovePeerHook(e.afterPeerHook) - } - err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn) if err != nil { log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err) diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index f3d07dcad..17ef7e87f 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -20,7 +20,6 @@ import ( "github.com/netbirdio/netbird/iface/bind" signal "github.com/netbirdio/netbird/signal/client" sProto "github.com/netbirdio/netbird/signal/proto" - nbnet "github.com/netbirdio/netbird/util/net" "github.com/netbirdio/netbird/version" ) @@ -101,9 +100,6 @@ type IceCredentials struct { Pwd string } -type BeforeAddPeerHookFunc func(connID nbnet.ConnectionID, IP net.IP) error -type AfterRemovePeerHookFunc func(connID nbnet.ConnectionID) error - type Conn struct { config ConnConfig mu sync.Mutex @@ -142,10 +138,6 @@ type Conn struct { remoteEndpoint *net.UDPAddr remoteConn *ice.Conn - - connID nbnet.ConnectionID - beforeAddPeerHooks []BeforeAddPeerHookFunc - afterRemovePeerHooks []AfterRemovePeerHookFunc } // meta holds meta information about a connection @@ -401,14 +393,6 @@ func isRelayCandidate(candidate ice.Candidate) bool { return candidate.Type() == ice.CandidateTypeRelay } -func (conn *Conn) AddBeforeAddPeerHook(hook BeforeAddPeerHookFunc) { - conn.beforeAddPeerHooks = append(conn.beforeAddPeerHooks, hook) -} - -func (conn *Conn) AddAfterRemovePeerHook(hook AfterRemovePeerHookFunc) { - conn.afterRemovePeerHooks = append(conn.afterRemovePeerHooks, hook) -} - // configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, remoteRosenpassPubKey []byte, remoteRosenpassAddr string) (net.Addr, error) { conn.mu.Lock() @@ -437,13 +421,6 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem conn.remoteEndpoint = endpointUdpAddr log.Debugf("Conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP) - conn.connID = nbnet.GenerateConnID() - for _, hook := range conn.beforeAddPeerHooks { - if err := hook(conn.connID, endpointUdpAddr.IP); err != nil { - log.Errorf("Before add peer hook failed: %v", err) - } - } - err = conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, endpointUdpAddr, conn.config.WgConfig.PreSharedKey) if err != nil { if conn.wgProxy != nil { @@ -534,15 +511,6 @@ func (conn *Conn) cleanup() error { // todo: is it problem if we try to remove a peer what is never existed? err3 = conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey) - if conn.connID != "" { - for _, hook := range conn.afterRemovePeerHooks { - if err := hook(conn.connID); err != nil { - log.Errorf("After remove peer hook failed: %v", err) - } - } - } - conn.connID = "" - if conn.notifyDisconnected != nil { conn.notifyDisconnected() conn.notifyDisconnected = nil diff --git a/client/internal/relay/relay.go b/client/internal/relay/relay.go index 84fd72e49..ad3b94f2a 100644 --- a/client/internal/relay/relay.go +++ b/client/internal/relay/relay.go @@ -12,7 +12,6 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/internal/stdnet" - nbnet "github.com/netbirdio/netbird/util/net" ) // ProbeResult holds the info about the result of a relay probe request @@ -96,13 +95,15 @@ func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) switch uri.Proto { case stun.ProtoTypeUDP: var err error - conn, err = nbnet.NewListener().ListenPacket(ctx, "udp", "") + listener := &net.ListenConfig{} + conn, err = listener.ListenPacket(ctx, "udp", "") if err != nil { probeErr = fmt.Errorf("listen: %w", err) return } case stun.ProtoTypeTCP: - tcpConn, err := nbnet.NewDialer().DialContext(ctx, "tcp", turnServerAddr) + dialer := &net.Dialer{} + tcpConn, err := dialer.DialContext(ctx, "tcp", turnServerAddr) if err != nil { probeErr = fmt.Errorf("dial: %w", err) return diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go index 38cf4bf65..f7ead5827 100644 --- a/client/internal/routemanager/client.go +++ b/client/internal/routemanager/client.go @@ -41,7 +41,6 @@ type clientNetwork struct { func newClientNetworkWatcher(ctx context.Context, wgInterface *iface.WGIface, statusRecorder *peer.Status, network netip.Prefix) *clientNetwork { ctx, cancel := context.WithCancel(ctx) - client := &clientNetwork{ ctx: ctx, stop: cancel, @@ -73,18 +72,6 @@ func (c *clientNetwork) getRouterPeerStatuses() map[string]routerPeerStatus { return routePeerStatuses } -// getBestRouteFromStatuses determines the most optimal route from the available routes -// within a clientNetwork, taking into account peer connection status, route metrics, and -// preference for non-relayed and direct connections. -// -// It follows these prioritization rules: -// * Connected peers: Only routes with connected peers are considered. -// * Metric: Routes with lower metrics (better) are prioritized. -// * Non-relayed: Routes without relays are preferred. -// * Direct connections: Routes with direct peer connections are favored. -// * Stability: In case of equal scores, the currently active route (if any) is maintained. -// -// It returns the ID of the selected optimal route. func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]routerPeerStatus) string { chosen := "" chosenScore := 0 @@ -171,7 +158,7 @@ func (c *clientNetwork) startPeersStatusChangeWatcher() { func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error { state, err := c.statusRecorder.GetPeer(peerKey) if err != nil { - return fmt.Errorf("get peer state: %v", err) + return err } delete(state.Routes, c.network.String()) @@ -185,7 +172,7 @@ func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error { err = c.wgInterface.RemoveAllowedIP(peerKey, c.network.String()) if err != nil { - return fmt.Errorf("remove allowed IP %s removed for peer %s, err: %v", + return fmt.Errorf("couldn't remove allowed IP %s removed for peer %s, err: %v", c.network, c.chosenRoute.Peer, err) } return nil @@ -193,26 +180,30 @@ func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error { func (c *clientNetwork) removeRouteFromPeerAndSystem() error { if c.chosenRoute != nil { - if err := removeVPNRoute(c.network, c.wgInterface.Name()); err != nil { - return fmt.Errorf("remove route %s from system, err: %v", c.network, err) + err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer) + if err != nil { + return err } - - if err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer); err != nil { - return fmt.Errorf("remove route: %v", err) + err = removeFromRouteTableIfNonSystem(c.network, c.wgInterface.Address().IP.String()) + if err != nil { + return fmt.Errorf("couldn't remove route %s from system, err: %v", + c.network, err) } } return nil } func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { + + var err error + routerPeerStatuses := c.getRouterPeerStatuses() chosen := c.getBestRouteFromStatuses(routerPeerStatuses) - - // If no route is chosen, remove the route from the peer and system if chosen == "" { - if err := c.removeRouteFromPeerAndSystem(); err != nil { - return fmt.Errorf("remove route from peer and system: %v", err) + err = c.removeRouteFromPeerAndSystem() + if err != nil { + return err } c.chosenRoute = nil @@ -220,7 +211,6 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { return nil } - // If the chosen route is the same as the current route, do nothing if c.chosenRoute != nil && c.chosenRoute.ID == chosen { if c.chosenRoute.IsEqual(c.routes[chosen]) { return nil @@ -228,13 +218,13 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { } if c.chosenRoute != nil { - // If a previous route exists, remove it from the peer - if err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer); err != nil { - return fmt.Errorf("remove route from peer: %v", err) + err = c.removeRouteFromWireguardPeer(c.chosenRoute.Peer) + if err != nil { + return err } } else { - // otherwise add the route to the system - if err := addVPNRoute(c.network, c.wgInterface.Name()); err != nil { + err = addToRouteTableIfNoExists(c.network, c.wgInterface.Address().IP.String()) + if err != nil { return fmt.Errorf("route %s couldn't be added for peer %s, err: %v", c.network.String(), c.wgInterface.Address().IP.String(), err) } @@ -255,7 +245,8 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { } } - if err := c.wgInterface.AddAllowedIP(c.chosenRoute.Peer, c.network.String()); err != nil { + err = c.wgInterface.AddAllowedIP(c.chosenRoute.Peer, c.network.String()) + if err != nil { log.Errorf("couldn't add allowed IP %s added for peer %s, err: %v", c.network, c.chosenRoute.Peer, err) } @@ -296,21 +287,21 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() { log.Debugf("stopping watcher for network %s", c.network) err := c.removeRouteFromPeerAndSystem() if err != nil { - log.Errorf("Couldn't remove route from peer and system for network %s: %v", c.network, err) + log.Error(err) } return case <-c.peerStateUpdate: err := c.recalculateRouteAndUpdatePeerAndSystem() if err != nil { - log.Errorf("Couldn't recalculate route and update peer and system: %v", err) + log.Error(err) } case update := <-c.routeUpdate: if update.updateSerial < c.updateSerial { - log.Warnf("Received a routes update with smaller serial number, ignoring it") + log.Warnf("received a routes update with smaller serial number, ignoring it") continue } - log.Debugf("Received a new client network route update for %s", c.network) + log.Debugf("received a new client network route update for %s", c.network) c.handleUpdate(update) @@ -318,7 +309,7 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() { err := c.recalculateRouteAndUpdatePeerAndSystem() if err != nil { - log.Errorf("Couldn't recalculate route and update peer and system for network %s: %v", c.network, err) + log.Error(err) } c.startPeersStatusChangeWatcher() diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 36a37f02c..b624d8c34 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -2,10 +2,6 @@ package routemanager import ( "context" - "fmt" - "net" - "net/netip" - "net/url" "runtime" "sync" @@ -19,14 +15,8 @@ import ( "github.com/netbirdio/netbird/version" ) -var defaultv4 = netip.PrefixFrom(netip.IPv4Unspecified(), 0) - -// nolint:unused -var defaultv6 = netip.PrefixFrom(netip.IPv6Unspecified(), 0) - // Manager is a route manager interface type Manager interface { - Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error SetRouteChangeListener(listener listener.NetworkChangeListener) InitialRouteRange() []string @@ -66,24 +56,6 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, return dm } -// Init sets up the routing -func (m *DefaultManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { - if err := cleanupRouting(); err != nil { - log.Warnf("Failed cleaning up routing: %v", err) - } - - mgmtAddress := m.statusRecorder.GetManagementState().URL - signalAddress := m.statusRecorder.GetSignalState().URL - ips := resolveURLsToIPs([]string{mgmtAddress, signalAddress}) - - beforePeerHook, afterPeerHook, err := setupRouting(ips, m.wgInterface) - if err != nil { - return nil, nil, fmt.Errorf("setup routing: %w", err) - } - log.Info("Routing setup complete") - return beforePeerHook, afterPeerHook, nil -} - func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error { var err error m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder) @@ -99,15 +71,9 @@ func (m *DefaultManager) Stop() { if m.serverRouter != nil { m.serverRouter.cleanUp() } - if err := cleanupRouting(); err != nil { - log.Errorf("Error cleaning up routing: %v", err) - } else { - log.Info("Routing cleanup complete") - } - m.ctx = nil } -// UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps +// UpdateRoutes compares received routes with existing routes and remove, update or add them to the client and server maps func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error { select { case <-m.ctx.Done(): @@ -125,7 +91,7 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro if m.serverRouter != nil { err := m.serverRouter.updateRoutes(newServerRoutesMap) if err != nil { - return fmt.Errorf("update routes: %w", err) + return err } } @@ -190,7 +156,11 @@ func (m *DefaultManager) classifiesRoutes(newRoutes []*route.Route) (map[string] for _, newRoute := range newRoutes { networkID := route.GetHAUniqueID(newRoute) if !ownNetworkIDs[networkID] { - if !isPrefixSupported(newRoute.Network) { + // if prefix is too small, lets assume is a possible default route which is not yet supported + // we skip this route management + if newRoute.Network.Bits() < minRangeBits { + log.Errorf("this agent version: %s, doesn't support default routes, received %s, skipping this route", + version.NetbirdVersion(), newRoute.Network) continue } newClientRoutesIDMap[networkID] = append(newClientRoutesIDMap[networkID], newRoute) @@ -208,38 +178,3 @@ func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Rou } return rs } - -func isPrefixSupported(prefix netip.Prefix) bool { - switch runtime.GOOS { - case "linux", "windows", "darwin": - return true - } - - // If prefix is too small, lets assume it is a possible default prefix which is not yet supported - // we skip this prefix management - if prefix.Bits() <= minRangeBits { - log.Warnf("This agent version: %s, doesn't support default routes, received %s, skipping this prefix", - version.NetbirdVersion(), prefix) - return false - } - return true -} - -// resolveURLsToIPs takes a slice of URLs, resolves them to IP addresses and returns a slice of IPs. -func resolveURLsToIPs(urls []string) []net.IP { - var ips []net.IP - for _, rawurl := range urls { - u, err := url.Parse(rawurl) - if err != nil { - log.Errorf("Failed to parse url %s: %v", rawurl, err) - continue - } - ipAddrs, err := net.LookupIP(u.Hostname()) - if err != nil { - log.Errorf("Failed to resolve host %s: %v", u.Hostname(), err) - continue - } - ips = append(ips, ipAddrs...) - } - return ips -} diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index 03e77e09b..2e5cf6649 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -28,14 +28,13 @@ const remotePeerKey2 = "remote1" func TestManagerUpdateRoutes(t *testing.T) { testCases := []struct { - name string - inputInitRoutes []*route.Route - inputRoutes []*route.Route - inputSerial uint64 - removeSrvRouter bool - serverRoutesExpected int - clientNetworkWatchersExpected int - clientNetworkWatchersExpectedAllowed int + name string + inputInitRoutes []*route.Route + inputRoutes []*route.Route + inputSerial uint64 + removeSrvRouter bool + serverRoutesExpected int + clientNetworkWatchersExpected int }{ { name: "Should create 2 client networks", @@ -201,9 +200,8 @@ func TestManagerUpdateRoutes(t *testing.T) { Enabled: true, }, }, - inputSerial: 1, - clientNetworkWatchersExpected: 0, - clientNetworkWatchersExpectedAllowed: 1, + inputSerial: 1, + clientNetworkWatchersExpected: 0, }, { name: "Remove 1 Client Route", @@ -417,10 +415,6 @@ func TestManagerUpdateRoutes(t *testing.T) { statusRecorder := peer.NewRecorder("https://mgm") ctx := context.TODO() routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder, nil) - - _, _, err = routeManager.Init() - - require.NoError(t, err, "should init route manager") defer routeManager.Stop() if testCase.removeSrvRouter { @@ -435,11 +429,7 @@ func TestManagerUpdateRoutes(t *testing.T) { err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes) require.NoError(t, err, "should update routes") - expectedWatchers := testCase.clientNetworkWatchersExpected - if (runtime.GOOS == "linux" || runtime.GOOS == "windows" || runtime.GOOS == "darwin") && testCase.clientNetworkWatchersExpectedAllowed != 0 { - expectedWatchers = testCase.clientNetworkWatchersExpectedAllowed - } - require.Len(t, routeManager.clientNetworks, expectedWatchers, "client networks size should match") + require.Len(t, routeManager.clientNetworks, testCase.clientNetworkWatchersExpected, "client networks size should match") if runtime.GOOS == "linux" && routeManager.serverRouter != nil { sr := routeManager.serverRouter.(*defaultServerRouter) diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index dd2c28e59..a1214cbb9 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -6,7 +6,6 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/listener" - "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/route" ) @@ -17,10 +16,6 @@ type MockManager struct { StopFunc func() } -func (m *MockManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { - return nil, nil, nil -} - // InitialRouteRange mock implementation of InitialRouteRange from Manager interface func (m *MockManager) InitialRouteRange() []string { return nil diff --git a/client/internal/routemanager/routemanager.go b/client/internal/routemanager/routemanager.go deleted file mode 100644 index 8f9ff9f4b..000000000 --- a/client/internal/routemanager/routemanager.go +++ /dev/null @@ -1,126 +0,0 @@ -//go:build !android && !ios - -package routemanager - -import ( - "errors" - "fmt" - "net/netip" - "sync" - - "github.com/hashicorp/go-multierror" - log "github.com/sirupsen/logrus" - - nbnet "github.com/netbirdio/netbird/util/net" -) - -type ref struct { - count int - nexthop netip.Addr - intf string -} - -type RouteManager struct { - // refCountMap keeps track of the reference ref for prefixes - refCountMap map[netip.Prefix]ref - // prefixMap keeps track of the prefixes associated with a connection ID for removal - prefixMap map[nbnet.ConnectionID][]netip.Prefix - addRoute AddRouteFunc - removeRoute RemoveRouteFunc - mutex sync.Mutex -} - -type AddRouteFunc func(prefix netip.Prefix) (nexthop netip.Addr, intf string, err error) -type RemoveRouteFunc func(prefix netip.Prefix, nexthop netip.Addr, intf string) error - -func NewRouteManager(addRoute AddRouteFunc, removeRoute RemoveRouteFunc) *RouteManager { - // TODO: read initial routing table into refCountMap - return &RouteManager{ - refCountMap: map[netip.Prefix]ref{}, - prefixMap: map[nbnet.ConnectionID][]netip.Prefix{}, - addRoute: addRoute, - removeRoute: removeRoute, - } -} - -func (rm *RouteManager) AddRouteRef(connID nbnet.ConnectionID, prefix netip.Prefix) error { - rm.mutex.Lock() - defer rm.mutex.Unlock() - - ref := rm.refCountMap[prefix] - log.Debugf("Increasing route ref count %d for prefix %s", ref.count, prefix) - - // Add route to the system, only if it's a new prefix - if ref.count == 0 { - log.Debugf("Adding route for prefix %s", prefix) - nexthop, intf, err := rm.addRoute(prefix) - if errors.Is(err, ErrRouteNotFound) { - return nil - } - if errors.Is(err, ErrRouteNotAllowed) { - log.Debugf("Adding route for prefix %s: %s", prefix, err) - } - if err != nil { - return fmt.Errorf("failed to add route for prefix %s: %w", prefix, err) - } - ref.nexthop = nexthop - ref.intf = intf - } - - ref.count++ - rm.refCountMap[prefix] = ref - rm.prefixMap[connID] = append(rm.prefixMap[connID], prefix) - - return nil -} - -func (rm *RouteManager) RemoveRouteRef(connID nbnet.ConnectionID) error { - rm.mutex.Lock() - defer rm.mutex.Unlock() - - prefixes, ok := rm.prefixMap[connID] - if !ok { - log.Debugf("No prefixes found for connection ID %s", connID) - return nil - } - - var result *multierror.Error - for _, prefix := range prefixes { - ref := rm.refCountMap[prefix] - log.Debugf("Decreasing route ref count %d for prefix %s", ref.count, prefix) - if ref.count == 1 { - log.Debugf("Removing route for prefix %s", prefix) - // TODO: don't fail if the route is not found - if err := rm.removeRoute(prefix, ref.nexthop, ref.intf); err != nil { - result = multierror.Append(result, fmt.Errorf("remove route for prefix %s: %w", prefix, err)) - continue - } - delete(rm.refCountMap, prefix) - } else { - ref.count-- - rm.refCountMap[prefix] = ref - } - } - delete(rm.prefixMap, connID) - - return result.ErrorOrNil() -} - -// Flush removes all references and routes from the system -func (rm *RouteManager) Flush() error { - rm.mutex.Lock() - defer rm.mutex.Unlock() - - var result *multierror.Error - for prefix := range rm.refCountMap { - log.Debugf("Removing route for prefix %s", prefix) - ref := rm.refCountMap[prefix] - if err := rm.removeRoute(prefix, ref.nexthop, ref.intf); err != nil { - result = multierror.Append(result, fmt.Errorf("remove route for prefix %s: %w", prefix, err)) - } - } - rm.refCountMap = map[netip.Prefix]ref{} - rm.prefixMap = map[nbnet.ConnectionID][]netip.Prefix{} - - return result.ErrorOrNil() -} diff --git a/client/internal/routemanager/server_nonandroid.go b/client/internal/routemanager/server_nonandroid.go index af82dc913..192367877 100644 --- a/client/internal/routemanager/server_nonandroid.go +++ b/client/internal/routemanager/server_nonandroid.go @@ -4,7 +4,6 @@ package routemanager import ( "context" - "fmt" "net/netip" "sync" @@ -49,7 +48,7 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) er oldRoute := m.routes[routeID] err := m.removeFromServerNetwork(oldRoute) if err != nil { - log.Errorf("Unable to remove route id: %s, network %s, from server, got: %v", + log.Errorf("unable to remove route id: %s, network %s, from server, got: %v", oldRoute.ID, oldRoute.Network, err) } delete(m.routes, routeID) @@ -63,7 +62,7 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) er err := m.addToServerNetwork(newRoute) if err != nil { - log.Errorf("Unable to add route %s from server, got: %v", newRoute.ID, err) + log.Errorf("unable to add route %s from server, got: %v", newRoute.ID, err) continue } m.routes[id] = newRoute @@ -82,22 +81,15 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) er func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error { select { case <-m.ctx.Done(): - log.Infof("Not removing from server network because context is done") + log.Infof("not removing from server network because context is done") return m.ctx.Err() default: m.mux.Lock() defer m.mux.Unlock() - - routerPair, err := routeToRouterPair(m.wgInterface.Address().Masked().String(), route) + err := m.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route)) if err != nil { - return fmt.Errorf("parse prefix: %w", err) + return err } - - err = m.firewall.RemoveRoutingRules(routerPair) - if err != nil { - return fmt.Errorf("remove routing rules: %w", err) - } - delete(m.routes, route.ID) state := m.statusRecorder.GetLocalPeerState() @@ -111,22 +103,15 @@ func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error { select { case <-m.ctx.Done(): - log.Infof("Not adding to server network because context is done") + log.Infof("not adding to server network because context is done") return m.ctx.Err() default: m.mux.Lock() defer m.mux.Unlock() - - routerPair, err := routeToRouterPair(m.wgInterface.Address().Masked().String(), route) + err := m.firewall.InsertRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route)) if err != nil { - return fmt.Errorf("parse prefix: %w", err) + return err } - - err = m.firewall.InsertRoutingRules(routerPair) - if err != nil { - return fmt.Errorf("insert routing rules: %w", err) - } - m.routes[route.ID] = route state := m.statusRecorder.GetLocalPeerState() @@ -144,33 +129,23 @@ func (m *defaultServerRouter) cleanUp() { m.mux.Lock() defer m.mux.Unlock() for _, r := range m.routes { - routerPair, err := routeToRouterPair(m.wgInterface.Address().Masked().String(), r) + err := m.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), r)) if err != nil { - log.Errorf("Failed to convert route to router pair: %v", err) - continue - } - - err = m.firewall.RemoveRoutingRules(routerPair) - if err != nil { - log.Errorf("Failed to remove cleanup route: %v", err) + log.Warnf("failed to remove clean up route: %s", r.ID) } + state := m.statusRecorder.GetLocalPeerState() + state.Routes = nil + m.statusRecorder.UpdateLocalPeerState(state) } - - state := m.statusRecorder.GetLocalPeerState() - state.Routes = nil - m.statusRecorder.UpdateLocalPeerState(state) } -func routeToRouterPair(source string, route *route.Route) (firewall.RouterPair, error) { - parsed, err := netip.ParsePrefix(source) - if err != nil { - return firewall.RouterPair{}, err - } +func routeToRouterPair(source string, route *route.Route) firewall.RouterPair { + parsed := netip.MustParsePrefix(source).Masked() return firewall.RouterPair{ ID: route.ID, Source: parsed.String(), Destination: route.Network.Masked().String(), Masquerade: route.Masquerade, - }, nil + } } diff --git a/client/internal/routemanager/systemops.go b/client/internal/routemanager/systemops.go deleted file mode 100644 index a91f53636..000000000 --- a/client/internal/routemanager/systemops.go +++ /dev/null @@ -1,407 +0,0 @@ -//go:build !android && !ios - -package routemanager - -import ( - "context" - "errors" - "fmt" - "net" - "net/netip" - - "github.com/hashicorp/go-multierror" - "github.com/libp2p/go-netroute" - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/iface" - nbnet "github.com/netbirdio/netbird/util/net" -) - -var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1) -var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1) -var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1) -var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1) - -var ErrRouteNotFound = errors.New("route not found") -var ErrRouteNotAllowed = errors.New("route not allowed") - -// TODO: fix: for default our wg address now appears as the default gw -func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error { - addr := netip.IPv4Unspecified() - if prefix.Addr().Is6() { - addr = netip.IPv6Unspecified() - } - - defaultGateway, _, err := getNextHop(addr) - if err != nil && !errors.Is(err, ErrRouteNotFound) { - return fmt.Errorf("get existing route gateway: %s", err) - } - - if !prefix.Contains(defaultGateway) { - log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", defaultGateway, prefix) - return nil - } - - gatewayPrefix := netip.PrefixFrom(defaultGateway, 32) - if defaultGateway.Is6() { - gatewayPrefix = netip.PrefixFrom(defaultGateway, 128) - } - - ok, err := existsInRouteTable(gatewayPrefix) - if err != nil { - return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err) - } - - if ok { - log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix) - return nil - } - - var exitIntf string - gatewayHop, intf, err := getNextHop(defaultGateway) - if err != nil && !errors.Is(err, ErrRouteNotFound) { - return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err) - } - if intf != nil { - exitIntf = intf.Name - } - - log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop) - return addToRouteTable(gatewayPrefix, gatewayHop, exitIntf) -} - -func getNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) { - r, err := netroute.New() - if err != nil { - return netip.Addr{}, nil, fmt.Errorf("new netroute: %w", err) - } - intf, gateway, preferredSrc, err := r.Route(ip.AsSlice()) - if err != nil { - log.Warnf("Failed to get route for %s: %v", ip, err) - return netip.Addr{}, nil, ErrRouteNotFound - } - - log.Debugf("Route for %s: interface %v, nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc) - if gateway == nil { - if preferredSrc == nil { - return netip.Addr{}, nil, ErrRouteNotFound - } - log.Debugf("No next hop found for ip %s, using preferred source %s", ip, preferredSrc) - - addr, ok := netip.AddrFromSlice(preferredSrc) - if !ok { - return netip.Addr{}, nil, fmt.Errorf("failed to parse IP address: %s", preferredSrc) - } - return addr.Unmap(), intf, nil - } - - addr, ok := netip.AddrFromSlice(gateway) - if !ok { - return netip.Addr{}, nil, fmt.Errorf("failed to parse IP address: %s", gateway) - } - - return addr.Unmap(), intf, nil -} - -func existsInRouteTable(prefix netip.Prefix) (bool, error) { - routes, err := getRoutesFromTable() - if err != nil { - return false, fmt.Errorf("get routes from table: %w", err) - } - for _, tableRoute := range routes { - if tableRoute == prefix { - return true, nil - } - } - return false, nil -} - -func isSubRange(prefix netip.Prefix) (bool, error) { - routes, err := getRoutesFromTable() - if err != nil { - return false, fmt.Errorf("get routes from table: %w", err) - } - for _, tableRoute := range routes { - if tableRoute.Bits() > minRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() { - return true, nil - } - } - return false, nil -} - -// addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface. -// If the next hop or interface is pointing to the VPN interface, it will return the initial values. -func addRouteToNonVPNIntf( - prefix netip.Prefix, - vpnIntf *iface.WGIface, - initialNextHop netip.Addr, - initialIntf *net.Interface, -) (netip.Addr, string, error) { - addr := prefix.Addr() - switch { - case addr.IsLoopback(), - addr.IsLinkLocalUnicast(), - addr.IsLinkLocalMulticast(), - addr.IsInterfaceLocalMulticast(), - addr.IsUnspecified(), - addr.IsMulticast(): - - return netip.Addr{}, "", ErrRouteNotAllowed - } - - // Determine the exit interface and next hop for the prefix, so we can add a specific route - nexthop, intf, err := getNextHop(addr) - if err != nil { - return netip.Addr{}, "", fmt.Errorf("get next hop: %w", err) - } - - log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop, prefix, intf) - exitNextHop := nexthop - var exitIntf string - if intf != nil { - exitIntf = intf.Name - } - - vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP) - if !ok { - return netip.Addr{}, "", fmt.Errorf("failed to convert vpn address to netip.Addr") - } - - // if next hop is the VPN address or the interface is the VPN interface, we should use the initial values - if exitNextHop == vpnAddr || exitIntf == vpnIntf.Name() { - log.Debugf("Route for prefix %s is pointing to the VPN interface", prefix) - exitNextHop = initialNextHop - if initialIntf != nil { - exitIntf = initialIntf.Name - } - } - - log.Debugf("Adding a new route for prefix %s with next hop %s", prefix, exitNextHop) - if err := addToRouteTable(prefix, exitNextHop, exitIntf); err != nil { - return netip.Addr{}, "", fmt.Errorf("add route to table: %w", err) - } - - return exitNextHop, exitIntf, nil -} - -// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix -// in two /1 prefixes to avoid replacing the existing default route -func genericAddVPNRoute(prefix netip.Prefix, intf string) error { - if prefix == defaultv4 { - if err := addToRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil { - return err - } - if err := addToRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil { - if err2 := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err2 != nil { - log.Warnf("Failed to rollback route addition: %s", err2) - } - return err - } - - // TODO: remove once IPv6 is supported on the interface - if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { - return fmt.Errorf("add unreachable route split 1: %w", err) - } - if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { - if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil { - log.Warnf("Failed to rollback route addition: %s", err2) - } - return fmt.Errorf("add unreachable route split 2: %w", err) - } - - return nil - } else if prefix == defaultv6 { - if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { - return fmt.Errorf("add unreachable route split 1: %w", err) - } - if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { - if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil { - log.Warnf("Failed to rollback route addition: %s", err2) - } - return fmt.Errorf("add unreachable route split 2: %w", err) - } - - return nil - } - - return addNonExistingRoute(prefix, intf) -} - -// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table -func addNonExistingRoute(prefix netip.Prefix, intf string) error { - ok, err := existsInRouteTable(prefix) - if err != nil { - return fmt.Errorf("exists in route table: %w", err) - } - if ok { - log.Warnf("Skipping adding a new route for network %s because it already exists", prefix) - return nil - } - - ok, err = isSubRange(prefix) - if err != nil { - return fmt.Errorf("sub range: %w", err) - } - - if ok { - err := addRouteForCurrentDefaultGateway(prefix) - if err != nil { - log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err) - } - } - - return addToRouteTable(prefix, netip.Addr{}, intf) -} - -// genericRemoveVPNRoute removes the route from the vpn interface. If a default prefix is given, -// it will remove the split /1 prefixes -func genericRemoveVPNRoute(prefix netip.Prefix, intf string) error { - if prefix == defaultv4 { - var result *multierror.Error - if err := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil { - result = multierror.Append(result, err) - } - if err := removeFromRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil { - result = multierror.Append(result, err) - } - - // TODO: remove once IPv6 is supported on the interface - if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { - result = multierror.Append(result, err) - } - if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { - result = multierror.Append(result, err) - } - - return result.ErrorOrNil() - } else if prefix == defaultv6 { - var result *multierror.Error - if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { - result = multierror.Append(result, err) - } - if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { - result = multierror.Append(result, err) - } - - return result.ErrorOrNil() - } - - return removeFromRouteTable(prefix, netip.Addr{}, intf) -} - -func getPrefixFromIP(ip net.IP) (*netip.Prefix, error) { - addr, ok := netip.AddrFromSlice(ip) - if !ok { - return nil, fmt.Errorf("parse IP address: %s", ip) - } - addr = addr.Unmap() - - var prefixLength int - switch { - case addr.Is4(): - prefixLength = 32 - case addr.Is6(): - prefixLength = 128 - default: - return nil, fmt.Errorf("invalid IP address: %s", addr) - } - - prefix := netip.PrefixFrom(addr, prefixLength) - return &prefix, nil -} - -func setupRoutingWithRouteManager(routeManager **RouteManager, initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { - initialNextHopV4, initialIntfV4, err := getNextHop(netip.IPv4Unspecified()) - if err != nil && !errors.Is(err, ErrRouteNotFound) { - log.Errorf("Unable to get initial v4 default next hop: %v", err) - } - initialNextHopV6, initialIntfV6, err := getNextHop(netip.IPv6Unspecified()) - if err != nil && !errors.Is(err, ErrRouteNotFound) { - log.Errorf("Unable to get initial v6 default next hop: %v", err) - } - - *routeManager = NewRouteManager( - func(prefix netip.Prefix) (netip.Addr, string, error) { - addr := prefix.Addr() - nexthop, intf := initialNextHopV4, initialIntfV4 - if addr.Is6() { - nexthop, intf = initialNextHopV6, initialIntfV6 - } - return addRouteToNonVPNIntf(prefix, wgIface, nexthop, intf) - }, - removeFromRouteTable, - ) - - return setupHooks(*routeManager, initAddresses) -} - -func cleanupRoutingWithRouteManager(routeManager *RouteManager) error { - if routeManager == nil { - return nil - } - - // TODO: Remove hooks selectively - nbnet.RemoveDialerHooks() - nbnet.RemoveListenerHooks() - - if err := routeManager.Flush(); err != nil { - return fmt.Errorf("flush route manager: %w", err) - } - - return nil -} - -func setupHooks(routeManager *RouteManager, initAddresses []net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { - beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error { - prefix, err := getPrefixFromIP(ip) - if err != nil { - return fmt.Errorf("convert ip to prefix: %w", err) - } - - if err := routeManager.AddRouteRef(connID, *prefix); err != nil { - return fmt.Errorf("adding route reference: %v", err) - } - - return nil - } - afterHook := func(connID nbnet.ConnectionID) error { - if err := routeManager.RemoveRouteRef(connID); err != nil { - return fmt.Errorf("remove route reference: %w", err) - } - - return nil - } - - for _, ip := range initAddresses { - if err := beforeHook("init", ip); err != nil { - log.Errorf("Failed to add route reference: %v", err) - } - } - - nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error { - if ctx.Err() != nil { - return ctx.Err() - } - - var result *multierror.Error - for _, ip := range resolvedIPs { - result = multierror.Append(result, beforeHook(connID, ip.IP)) - } - return result.ErrorOrNil() - }) - - nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error { - return afterHook(connID) - }) - - nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error { - return beforeHook(connID, ip.IP) - }) - - nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error { - return afterHook(connID) - }) - - return beforeHook, afterHook, nil -} diff --git a/client/internal/routemanager/systemops_android.go b/client/internal/routemanager/systemops_android.go index 34d2d270f..950a26843 100644 --- a/client/internal/routemanager/systemops_android.go +++ b/client/internal/routemanager/systemops_android.go @@ -1,33 +1,13 @@ package routemanager import ( - "net" "net/netip" - "runtime" - - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/iface" ) -func setupRouting([]net.IP, *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { - return nil, nil, nil -} - -func cleanupRouting() error { +func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error { return nil } -func enableIPForwarding() error { - log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS) - return nil -} - -func addVPNRoute(netip.Prefix, string) error { - return nil -} - -func removeVPNRoute(netip.Prefix, string) error { +func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string) error { return nil } diff --git a/client/internal/routemanager/systemops_bsd.go b/client/internal/routemanager/systemops_bsd.go index 173e7c0e8..b2da8075c 100644 --- a/client/internal/routemanager/systemops_bsd.go +++ b/client/internal/routemanager/systemops_bsd.go @@ -1,4 +1,5 @@ //go:build darwin || dragonfly || freebsd || netbsd || openbsd +// +build darwin dragonfly freebsd netbsd openbsd package routemanager diff --git a/client/internal/routemanager/systemops_darwin.go b/client/internal/routemanager/systemops_darwin.go deleted file mode 100644 index f34964a83..000000000 --- a/client/internal/routemanager/systemops_darwin.go +++ /dev/null @@ -1,61 +0,0 @@ -//go:build darwin && !ios - -package routemanager - -import ( - "fmt" - "net" - "net/netip" - "os/exec" - "strings" - - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/iface" -) - -var routeManager *RouteManager - -func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { - return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface) -} - -func cleanupRouting() error { - return cleanupRoutingWithRouteManager(routeManager) -} - -func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { - return routeCmd("add", prefix, nexthop, intf) -} - -func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { - return routeCmd("delete", prefix, nexthop, intf) -} - -func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf string) error { - inet := "-inet" - if prefix.Addr().Is6() { - inet = "-inet6" - // Special case for IPv6 split default route, pointing to the wg interface fails - // TODO: Remove once we have IPv6 support on the interface - if prefix.Bits() == 1 { - intf = "lo0" - } - } - - args := []string{"-n", action, inet, prefix.String()} - if nexthop.IsValid() { - args = append(args, nexthop.Unmap().String()) - } else if intf != "" { - args = append(args, "-interface", intf) - } - - out, err := exec.Command("route", args...).CombinedOutput() - log.Tracef("route %s: %s", strings.Join(args, " "), out) - - if err != nil { - return fmt.Errorf("failed to %s route for %s: %w", action, prefix, err) - } - return nil -} diff --git a/client/internal/routemanager/systemops_darwin_test.go b/client/internal/routemanager/systemops_darwin_test.go deleted file mode 100644 index 5c5aaa24f..000000000 --- a/client/internal/routemanager/systemops_darwin_test.go +++ /dev/null @@ -1,100 +0,0 @@ -//go:build !ios - -package routemanager - -import ( - "fmt" - "net" - "os/exec" - "regexp" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -var expectedVPNint = "utun100" -var expectedExternalInt = "lo0" -var expectedInternalInt = "lo0" - -func init() { - testCases = append(testCases, []testCase{ - { - name: "To more specific route without custom dialer via vpn", - destination: "10.10.0.2:53", - expectedInterface: expectedVPNint, - dialer: &net.Dialer{}, - expectedPacket: createPacketExpectation("100.64.0.1", 12345, "10.10.0.2", 53), - }, - }...) -} - -func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string { - t.Helper() - - err := exec.Command("ifconfig", intf, "alias", ipAddressCIDR).Run() - require.NoError(t, err, "Failed to create loopback alias") - - t.Cleanup(func() { - err := exec.Command("ifconfig", intf, ipAddressCIDR, "-alias").Run() - assert.NoError(t, err, "Failed to remove loopback alias") - }) - - return "lo0" -} - -func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, _ string) { - t.Helper() - - var originalNexthop net.IP - if dstCIDR == "0.0.0.0/0" { - var err error - originalNexthop, err = fetchOriginalGateway() - if err != nil { - t.Logf("Failed to fetch original gateway: %v", err) - } - - if output, err := exec.Command("route", "delete", "-net", dstCIDR).CombinedOutput(); err != nil { - t.Logf("Failed to delete route: %v, output: %s", err, output) - } - } - - t.Cleanup(func() { - if originalNexthop != nil { - err := exec.Command("route", "add", "-net", dstCIDR, originalNexthop.String()).Run() - assert.NoError(t, err, "Failed to restore original route") - } - }) - - err := exec.Command("route", "add", "-net", dstCIDR, gw.String()).Run() - require.NoError(t, err, "Failed to add route") - - t.Cleanup(func() { - err := exec.Command("route", "delete", "-net", dstCIDR).Run() - assert.NoError(t, err, "Failed to remove route") - }) -} - -func fetchOriginalGateway() (net.IP, error) { - output, err := exec.Command("route", "-n", "get", "default").CombinedOutput() - if err != nil { - return nil, err - } - - matches := regexp.MustCompile(`gateway: (\S+)`).FindStringSubmatch(string(output)) - if len(matches) == 0 { - return nil, fmt.Errorf("gateway not found") - } - - return net.ParseIP(matches[1]), nil -} - -func setupDummyInterfacesAndRoutes(t *testing.T) { - t.Helper() - - defaultDummy := createAndSetupDummyInterface(t, expectedExternalInt, "192.168.0.1/24") - addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), defaultDummy) - - otherDummy := createAndSetupDummyInterface(t, expectedInternalInt, "192.168.1.1/24") - addDummyRoute(t, "10.0.0.0/8", net.IPv4(192, 168, 1, 1), otherDummy) -} diff --git a/client/internal/routemanager/systemops_ios.go b/client/internal/routemanager/systemops_ios.go index 34d2d270f..aae0f8dc8 100644 --- a/client/internal/routemanager/systemops_ios.go +++ b/client/internal/routemanager/systemops_ios.go @@ -1,33 +1,15 @@ +//go:build ios + package routemanager import ( - "net" "net/netip" - "runtime" - - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/iface" ) -func setupRouting([]net.IP, *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { - return nil, nil, nil -} - -func cleanupRouting() error { +func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error { return nil } -func enableIPForwarding() error { - log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS) - return nil -} - -func addVPNRoute(netip.Prefix, string) error { - return nil -} - -func removeVPNRoute(netip.Prefix, string) error { +func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string) error { return nil } diff --git a/client/internal/routemanager/systemops_linux.go b/client/internal/routemanager/systemops_linux.go index ef4643727..0562826a5 100644 --- a/client/internal/routemanager/systemops_linux.go +++ b/client/internal/routemanager/systemops_linux.go @@ -3,342 +3,142 @@ package routemanager import ( - "bufio" - "context" - "errors" - "fmt" "net" "net/netip" "os" "syscall" - "time" + "unsafe" - "github.com/hashicorp/go-multierror" - log "github.com/sirupsen/logrus" "github.com/vishvananda/netlink" - - "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/iface" - nbnet "github.com/netbirdio/netbird/util/net" ) -const ( - // NetbirdVPNTableID is the ID of the custom routing table used by Netbird. - NetbirdVPNTableID = 0x1BD0 - // NetbirdVPNTableName is the name of the custom routing table used by Netbird. - NetbirdVPNTableName = "netbird" +// Pulled from http://man7.org/linux/man-pages/man7/rtnetlink.7.html +// See the section on RTM_NEWROUTE, specifically 'struct rtmsg'. +type routeInfoInMemory struct { + Family byte + DstLen byte + SrcLen byte + TOS byte - // rtTablesPath is the path to the file containing the routing table names. - rtTablesPath = "/etc/iproute2/rt_tables" + Table byte + Protocol byte + Scope byte + Type byte - // ipv4ForwardingPath is the path to the file containing the IP forwarding setting. - ipv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward" -) - -var ErrTableIDExists = errors.New("ID exists with different name") - -var routeManager = &RouteManager{} -var isLegacy = os.Getenv("NB_USE_LEGACY_ROUTING") == "true" - -type ruleParams struct { - fwmark int - tableID int - family int - priority int - invert bool - suppressPrefix int - description string + Flags uint32 } -func getSetupRules() []ruleParams { - return []ruleParams{ - {nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V4, -1, true, -1, "rule v4 netbird"}, - {nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V6, -1, true, -1, "rule v6 netbird"}, - {-1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, -1, false, 0, "rule with suppress prefixlen v4"}, - {-1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, -1, false, 0, "rule with suppress prefixlen v6"}, - } -} +const ipv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward" -// setupRouting establishes the routing configuration for the VPN, including essential rules -// to ensure proper traffic flow for management, locally configured routes, and VPN traffic. -// -// Rule 1 (Main Route Precedence): Safeguards locally installed routes by giving them precedence over -// potential routes received and configured for the VPN. This rule is skipped for the default route and routes -// that are not in the main table. -// -// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table. -// This table is where a default route or other specific routes received from the management server are configured, -// enabling VPN connectivity. -// -// The rules are inserted in reverse order, as rules are added from the bottom up in the rule list. -func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) { - if isLegacy { - log.Infof("Using legacy routing setup") - return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface) +func addToRouteTable(prefix netip.Prefix, addr string) error { + _, ipNet, err := net.ParseCIDR(prefix.String()) + if err != nil { + return err } - if err = addRoutingTableName(); err != nil { - log.Errorf("Error adding routing table name: %v", err) + addrMask := "/32" + if prefix.Addr().Unmap().Is6() { + addrMask = "/128" } - defer func() { - if err != nil { - if cleanErr := cleanupRouting(); cleanErr != nil { - log.Errorf("Error cleaning up routing: %v", cleanErr) - } - } - }() - - rules := getSetupRules() - for _, rule := range rules { - if err := addRule(rule); err != nil { - if errors.Is(err, syscall.EOPNOTSUPP) { - log.Warnf("Rule operations are not supported, falling back to the legacy routing setup") - isLegacy = true - return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface) - } - return nil, nil, fmt.Errorf("%s: %w", rule.description, err) - } + ip, _, err := net.ParseCIDR(addr + addrMask) + if err != nil { + return err } - return nil, nil, nil -} - -// cleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'. -// It systematically removes the three rules and any associated routing table entries to ensure a clean state. -// The function uses error aggregation to report any errors encountered during the cleanup process. -func cleanupRouting() error { - if isLegacy { - return cleanupRoutingWithRouteManager(routeManager) + route := &netlink.Route{ + Scope: netlink.SCOPE_UNIVERSE, + Dst: ipNet, + Gw: ip, } - var result *multierror.Error - - if err := flushRoutes(NetbirdVPNTableID, netlink.FAMILY_V4); err != nil { - result = multierror.Append(result, fmt.Errorf("flush routes v4: %w", err)) - } - if err := flushRoutes(NetbirdVPNTableID, netlink.FAMILY_V6); err != nil { - result = multierror.Append(result, fmt.Errorf("flush routes v6: %w", err)) + err = netlink.RouteAdd(route) + if err != nil { + return err } - rules := getSetupRules() - for _, rule := range rules { - if err := removeAllRules(rule); err != nil && !errors.Is(err, syscall.EOPNOTSUPP) { - result = multierror.Append(result, fmt.Errorf("%s: %w", rule.description, err)) - } - } - - return result.ErrorOrNil() -} - -func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { - return addRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN) -} - -func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { - return removeRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN) -} - -func addVPNRoute(prefix netip.Prefix, intf string) error { - if isLegacy { - return genericAddVPNRoute(prefix, intf) - } - - // No need to check if routes exist as main table takes precedence over the VPN table via Rule 1 - - // TODO remove this once we have ipv6 support - if prefix == defaultv4 { - if err := addUnreachableRoute(defaultv6, NetbirdVPNTableID); err != nil { - return fmt.Errorf("add blackhole: %w", err) - } - } - if err := addRoute(prefix, netip.Addr{}, intf, NetbirdVPNTableID); err != nil { - return fmt.Errorf("add route: %w", err) - } return nil } -func removeVPNRoute(prefix netip.Prefix, intf string) error { - if isLegacy { - return genericRemoveVPNRoute(prefix, intf) +func removeFromRouteTable(prefix netip.Prefix, addr string) error { + _, ipNet, err := net.ParseCIDR(prefix.String()) + if err != nil { + return err } - // TODO remove this once we have ipv6 support - if prefix == defaultv4 { - if err := removeUnreachableRoute(defaultv6, NetbirdVPNTableID); err != nil { - return fmt.Errorf("remove unreachable route: %w", err) - } + addrMask := "/32" + if prefix.Addr().Unmap().Is6() { + addrMask = "/128" } - if err := removeRoute(prefix, netip.Addr{}, intf, NetbirdVPNTableID); err != nil { - return fmt.Errorf("remove route: %w", err) + + ip, _, err := net.ParseCIDR(addr + addrMask) + if err != nil { + return err } + + route := &netlink.Route{ + Scope: netlink.SCOPE_UNIVERSE, + Dst: ipNet, + Gw: ip, + } + + err = netlink.RouteDel(route) + if err != nil { + return err + } + return nil } func getRoutesFromTable() ([]netip.Prefix, error) { - v4Routes, err := getRoutes(syscall.RT_TABLE_MAIN, netlink.FAMILY_V4) + tab, err := syscall.NetlinkRIB(syscall.RTM_GETROUTE, syscall.AF_UNSPEC) if err != nil { - return nil, fmt.Errorf("get v4 routes: %w", err) + return nil, err } - v6Routes, err := getRoutes(syscall.RT_TABLE_MAIN, netlink.FAMILY_V6) + msgs, err := syscall.ParseNetlinkMessage(tab) if err != nil { - return nil, fmt.Errorf("get v6 routes: %w", err) - + return nil, err } - return append(v4Routes, v6Routes...), nil -} - -// getRoutes fetches routes from a specific routing table identified by tableID. -func getRoutes(tableID, family int) ([]netip.Prefix, error) { var prefixList []netip.Prefix - - routes, err := netlink.RouteListFiltered(family, &netlink.Route{Table: tableID}, netlink.RT_FILTER_TABLE) - if err != nil { - return nil, fmt.Errorf("list routes from table %d: %v", tableID, err) - } - - for _, route := range routes { - if route.Dst != nil { - addr, ok := netip.AddrFromSlice(route.Dst.IP) - if !ok { - return nil, fmt.Errorf("parse route destination IP: %v", route.Dst.IP) +loop: + for _, m := range msgs { + switch m.Header.Type { + case syscall.NLMSG_DONE: + break loop + case syscall.RTM_NEWROUTE: + rt := (*routeInfoInMemory)(unsafe.Pointer(&m.Data[0])) + msg := m + attrs, err := syscall.ParseNetlinkRouteAttr(&msg) + if err != nil { + return nil, err + } + if rt.Family != syscall.AF_INET { + continue loop } - ones, _ := route.Dst.Mask.Size() - - prefix := netip.PrefixFrom(addr, ones) - if prefix.IsValid() { - prefixList = append(prefixList, prefix) + for _, attr := range attrs { + if attr.Attr.Type == syscall.RTA_DST { + addr, ok := netip.AddrFromSlice(attr.Value) + if !ok { + continue + } + mask := net.CIDRMask(int(rt.DstLen), len(attr.Value)*8) + cidr, _ := mask.Size() + routePrefix := netip.PrefixFrom(addr, cidr) + if routePrefix.IsValid() && routePrefix.Addr().Is4() { + prefixList = append(prefixList, routePrefix) + } + } } } } - return prefixList, nil } -// addRoute adds a route to a specific routing table identified by tableID. -func addRoute(prefix netip.Prefix, addr netip.Addr, intf string, tableID int) error { - route := &netlink.Route{ - Scope: netlink.SCOPE_UNIVERSE, - Table: tableID, - Family: getAddressFamily(prefix), - } - - _, ipNet, err := net.ParseCIDR(prefix.String()) - if err != nil { - return fmt.Errorf("parse prefix %s: %w", prefix, err) - } - route.Dst = ipNet - - if err := addNextHop(addr, intf, route); err != nil { - return fmt.Errorf("add gateway and device: %w", err) - } - - if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) { - return fmt.Errorf("netlink add route: %w", err) - } - - return nil -} - -// addUnreachableRoute adds an unreachable route for the specified IP family and routing table. -// ipFamily should be netlink.FAMILY_V4 for IPv4 or netlink.FAMILY_V6 for IPv6. -// tableID specifies the routing table to which the unreachable route will be added. -func addUnreachableRoute(prefix netip.Prefix, tableID int) error { - _, ipNet, err := net.ParseCIDR(prefix.String()) - if err != nil { - return fmt.Errorf("parse prefix %s: %w", prefix, err) - } - - route := &netlink.Route{ - Type: syscall.RTN_UNREACHABLE, - Table: tableID, - Family: getAddressFamily(prefix), - Dst: ipNet, - } - - if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) { - return fmt.Errorf("netlink add unreachable route: %w", err) - } - - return nil -} - -func removeUnreachableRoute(prefix netip.Prefix, tableID int) error { - _, ipNet, err := net.ParseCIDR(prefix.String()) - if err != nil { - return fmt.Errorf("parse prefix %s: %w", prefix, err) - } - - route := &netlink.Route{ - Type: syscall.RTN_UNREACHABLE, - Table: tableID, - Family: getAddressFamily(prefix), - Dst: ipNet, - } - - if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) && !errors.Is(err, syscall.EAFNOSUPPORT) { - return fmt.Errorf("netlink remove unreachable route: %w", err) - } - - return nil - -} - -// removeRoute removes a route from a specific routing table identified by tableID. -func removeRoute(prefix netip.Prefix, addr netip.Addr, intf string, tableID int) error { - _, ipNet, err := net.ParseCIDR(prefix.String()) - if err != nil { - return fmt.Errorf("parse prefix %s: %w", prefix, err) - } - - route := &netlink.Route{ - Scope: netlink.SCOPE_UNIVERSE, - Table: tableID, - Family: getAddressFamily(prefix), - Dst: ipNet, - } - - if err := addNextHop(addr, intf, route); err != nil { - return fmt.Errorf("add gateway and device: %w", err) - } - - if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) && !errors.Is(err, syscall.EAFNOSUPPORT) { - return fmt.Errorf("netlink remove route: %w", err) - } - - return nil -} - -func flushRoutes(tableID, family int) error { - routes, err := netlink.RouteListFiltered(family, &netlink.Route{Table: tableID}, netlink.RT_FILTER_TABLE) - if err != nil { - return fmt.Errorf("list routes from table %d: %w", tableID, err) - } - - var result *multierror.Error - for i := range routes { - route := routes[i] - // unreachable default routes don't come back with Dst set - if route.Gw == nil && route.Src == nil && route.Dst == nil { - if family == netlink.FAMILY_V4 { - routes[i].Dst = &net.IPNet{IP: net.IPv4zero, Mask: net.CIDRMask(0, 32)} - } else { - routes[i].Dst = &net.IPNet{IP: net.IPv6zero, Mask: net.CIDRMask(0, 128)} - } - } - if err := netlink.RouteDel(&routes[i]); err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { - result = multierror.Append(result, fmt.Errorf("failed to delete route %v from table %d: %w", routes[i], tableID, err)) - } - } - - return result.ErrorOrNil() -} - func enableIPForwarding() error { bytes, err := os.ReadFile(ipv4ForwardingPath) if err != nil { - return fmt.Errorf("read file %s: %w", ipv4ForwardingPath, err) + return err } // check if it is already enabled @@ -347,162 +147,5 @@ func enableIPForwarding() error { return nil } - //nolint:gosec - if err := os.WriteFile(ipv4ForwardingPath, []byte("1"), 0644); err != nil { - return fmt.Errorf("write file %s: %w", ipv4ForwardingPath, err) - } - return nil -} - -// entryExists checks if the specified ID or name already exists in the rt_tables file -// and verifies if existing names start with "netbird_". -func entryExists(file *os.File, id int) (bool, error) { - if _, err := file.Seek(0, 0); err != nil { - return false, fmt.Errorf("seek rt_tables: %w", err) - } - scanner := bufio.NewScanner(file) - for scanner.Scan() { - line := scanner.Text() - var existingID int - var existingName string - if _, err := fmt.Sscanf(line, "%d %s\n", &existingID, &existingName); err == nil { - if existingID == id { - if existingName != NetbirdVPNTableName { - return true, ErrTableIDExists - } - return true, nil - } - } - } - if err := scanner.Err(); err != nil { - return false, fmt.Errorf("scan rt_tables: %w", err) - } - return false, nil -} - -// addRoutingTableName adds human-readable names for custom routing tables. -func addRoutingTableName() error { - file, err := os.Open(rtTablesPath) - if err != nil { - if errors.Is(err, os.ErrNotExist) { - return nil - } - return fmt.Errorf("open rt_tables: %w", err) - } - defer func() { - if err := file.Close(); err != nil { - log.Errorf("Error closing rt_tables: %v", err) - } - }() - - exists, err := entryExists(file, NetbirdVPNTableID) - if err != nil { - return fmt.Errorf("verify entry %d, %s: %w", NetbirdVPNTableID, NetbirdVPNTableName, err) - } - if exists { - return nil - } - - // Reopen the file in append mode to add new entries - if err := file.Close(); err != nil { - log.Errorf("Error closing rt_tables before appending: %v", err) - } - file, err = os.OpenFile(rtTablesPath, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0644) - if err != nil { - return fmt.Errorf("open rt_tables for appending: %w", err) - } - - if _, err := file.WriteString(fmt.Sprintf("\n%d\t%s\n", NetbirdVPNTableID, NetbirdVPNTableName)); err != nil { - return fmt.Errorf("append entry to rt_tables: %w", err) - } - - return nil -} - -// addRule adds a routing rule to a specific routing table identified by tableID. -func addRule(params ruleParams) error { - rule := netlink.NewRule() - rule.Table = params.tableID - rule.Mark = params.fwmark - rule.Family = params.family - rule.Priority = params.priority - rule.Invert = params.invert - rule.SuppressPrefixlen = params.suppressPrefix - - if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { - return fmt.Errorf("add routing rule: %w", err) - } - - return nil -} - -// removeRule removes a routing rule from a specific routing table identified by tableID. -func removeRule(params ruleParams) error { - rule := netlink.NewRule() - rule.Table = params.tableID - rule.Mark = params.fwmark - rule.Family = params.family - rule.Invert = params.invert - rule.Priority = params.priority - rule.SuppressPrefixlen = params.suppressPrefix - - if err := netlink.RuleDel(rule); err != nil { - return fmt.Errorf("remove routing rule: %w", err) - } - - return nil -} - -func removeAllRules(params ruleParams) error { - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - - done := make(chan error, 1) - go func() { - for { - if ctx.Err() != nil { - done <- ctx.Err() - return - } - if err := removeRule(params); err != nil { - if errors.Is(err, syscall.ENOENT) || errors.Is(err, syscall.EAFNOSUPPORT) { - done <- nil - return - } - done <- err - return - } - } - }() - - select { - case <-ctx.Done(): - return ctx.Err() - case err := <-done: - return err - } -} - -// addNextHop adds the gateway and device to the route. -func addNextHop(addr netip.Addr, intf string, route *netlink.Route) error { - if addr.IsValid() { - route.Gw = addr.AsSlice() - } - - if intf != "" { - link, err := netlink.LinkByName(intf) - if err != nil { - return fmt.Errorf("set interface %s: %w", intf, err) - } - route.LinkIndex = link.Attrs().Index - } - - return nil -} - -func getAddressFamily(prefix netip.Prefix) int { - if prefix.Addr().Is4() { - return netlink.FAMILY_V4 - } - return netlink.FAMILY_V6 + return os.WriteFile(ipv4ForwardingPath, []byte("1"), 0644) //nolint:gosec } diff --git a/client/internal/routemanager/systemops_linux_test.go b/client/internal/routemanager/systemops_linux_test.go deleted file mode 100644 index 0043c3f4e..000000000 --- a/client/internal/routemanager/systemops_linux_test.go +++ /dev/null @@ -1,207 +0,0 @@ -//go:build !android - -package routemanager - -import ( - "errors" - "fmt" - "net" - "os" - "strings" - "syscall" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/vishvananda/netlink" -) - -var expectedVPNint = "wgtest0" -var expectedLoopbackInt = "lo" -var expectedExternalInt = "dummyext0" -var expectedInternalInt = "dummyint0" - -func init() { - testCases = append(testCases, []testCase{ - { - name: "To more specific route without custom dialer via physical interface", - destination: "10.10.0.2:53", - expectedInterface: expectedInternalInt, - dialer: &net.Dialer{}, - expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.10.0.2", 53), - }, - { - name: "To more specific route (local) without custom dialer via physical interface", - destination: "127.0.10.1:53", - expectedInterface: expectedLoopbackInt, - dialer: &net.Dialer{}, - expectedPacket: createPacketExpectation("127.0.0.1", 12345, "127.0.10.1", 53), - }, - }...) -} - -func TestEntryExists(t *testing.T) { - tempDir := t.TempDir() - tempFilePath := fmt.Sprintf("%s/rt_tables", tempDir) - - content := []string{ - "1000 reserved", - fmt.Sprintf("%d %s", NetbirdVPNTableID, NetbirdVPNTableName), - "9999 other_table", - } - require.NoError(t, os.WriteFile(tempFilePath, []byte(strings.Join(content, "\n")), 0644)) - - file, err := os.Open(tempFilePath) - require.NoError(t, err) - defer func() { - assert.NoError(t, file.Close()) - }() - - tests := []struct { - name string - id int - shouldExist bool - err error - }{ - { - name: "ExistsWithNetbirdPrefix", - id: 7120, - shouldExist: true, - err: nil, - }, - { - name: "ExistsWithDifferentName", - id: 1000, - shouldExist: true, - err: ErrTableIDExists, - }, - { - name: "DoesNotExist", - id: 1234, - shouldExist: false, - err: nil, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - exists, err := entryExists(file, tc.id) - if tc.err != nil { - assert.ErrorIs(t, err, tc.err) - } else { - assert.NoError(t, err) - } - assert.Equal(t, tc.shouldExist, exists) - }) - } -} - -func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) string { - t.Helper() - - dummy := &netlink.Dummy{LinkAttrs: netlink.LinkAttrs{Name: interfaceName}} - err := netlink.LinkDel(dummy) - if err != nil && !errors.Is(err, syscall.EINVAL) { - t.Logf("Failed to delete dummy interface: %v", err) - } - - err = netlink.LinkAdd(dummy) - require.NoError(t, err) - - err = netlink.LinkSetUp(dummy) - require.NoError(t, err) - - if ipAddressCIDR != "" { - addr, err := netlink.ParseAddr(ipAddressCIDR) - require.NoError(t, err) - err = netlink.AddrAdd(dummy, addr) - require.NoError(t, err) - } - - t.Cleanup(func() { - err := netlink.LinkDel(dummy) - assert.NoError(t, err) - }) - - return dummy.Name -} - -func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, intf string) { - t.Helper() - - _, dstIPNet, err := net.ParseCIDR(dstCIDR) - require.NoError(t, err) - - // Handle existing routes with metric 0 - var originalNexthop net.IP - var originalLinkIndex int - if dstIPNet.String() == "0.0.0.0/0" { - var err error - originalNexthop, originalLinkIndex, err = fetchOriginalGateway(netlink.FAMILY_V4) - if err != nil && !errors.Is(err, ErrRouteNotFound) { - t.Logf("Failed to fetch original gateway: %v", err) - } - - if originalNexthop != nil { - err = netlink.RouteDel(&netlink.Route{Dst: dstIPNet, Priority: 0}) - switch { - case err != nil && !errors.Is(err, syscall.ESRCH): - t.Logf("Failed to delete route: %v", err) - case err == nil: - t.Cleanup(func() { - err := netlink.RouteAdd(&netlink.Route{Dst: dstIPNet, Gw: originalNexthop, LinkIndex: originalLinkIndex, Priority: 0}) - if err != nil && !errors.Is(err, syscall.EEXIST) { - t.Fatalf("Failed to add route: %v", err) - } - }) - default: - t.Logf("Failed to delete route: %v", err) - } - } - } - - link, err := netlink.LinkByName(intf) - require.NoError(t, err) - linkIndex := link.Attrs().Index - - route := &netlink.Route{ - Dst: dstIPNet, - Gw: gw, - LinkIndex: linkIndex, - } - err = netlink.RouteDel(route) - if err != nil && !errors.Is(err, syscall.ESRCH) { - t.Logf("Failed to delete route: %v", err) - } - - err = netlink.RouteAdd(route) - if err != nil && !errors.Is(err, syscall.EEXIST) { - t.Fatalf("Failed to add route: %v", err) - } - require.NoError(t, err) -} - -func fetchOriginalGateway(family int) (net.IP, int, error) { - routes, err := netlink.RouteList(nil, family) - if err != nil { - return nil, 0, err - } - - for _, route := range routes { - if route.Dst == nil && route.Priority == 0 { - return route.Gw, route.LinkIndex, nil - } - } - - return nil, 0, ErrRouteNotFound -} - -func setupDummyInterfacesAndRoutes(t *testing.T) { - t.Helper() - - defaultDummy := createAndSetupDummyInterface(t, "dummyext0", "192.168.0.1/24") - addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), defaultDummy) - - otherDummy := createAndSetupDummyInterface(t, "dummyint0", "192.168.1.1/24") - addDummyRoute(t, "10.0.0.0/8", net.IPv4(192, 168, 1, 1), otherDummy) -} diff --git a/client/internal/routemanager/systemops_nonandroid.go b/client/internal/routemanager/systemops_nonandroid.go new file mode 100644 index 000000000..11247c7dc --- /dev/null +++ b/client/internal/routemanager/systemops_nonandroid.go @@ -0,0 +1,120 @@ +//go:build !android && !ios + +package routemanager + +import ( + "fmt" + "net" + "net/netip" + + "github.com/libp2p/go-netroute" + log "github.com/sirupsen/logrus" +) + +var errRouteNotFound = fmt.Errorf("route not found") + +func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error { + ok, err := existsInRouteTable(prefix) + if err != nil { + return err + } + if ok { + log.Warnf("skipping adding a new route for network %s because it already exists", prefix) + return nil + } + + ok, err = isSubRange(prefix) + if err != nil { + return err + } + + if ok { + err := addRouteForCurrentDefaultGateway(prefix) + if err != nil { + log.Warnf("unable to add route for current default gateway route. Will proceed without it. error: %s", err) + } + } + + return addToRouteTable(prefix, addr) +} + +func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error { + defaultGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0")) + if err != nil && err != errRouteNotFound { + return err + } + + addr := netip.MustParseAddr(defaultGateway.String()) + + if !prefix.Contains(addr) { + log.Debugf("skipping adding a new route for gateway %s because it is not in the network %s", addr, prefix) + return nil + } + + gatewayPrefix := netip.PrefixFrom(addr, 32) + + ok, err := existsInRouteTable(gatewayPrefix) + if err != nil { + return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err) + } + + if ok { + log.Debugf("skipping adding a new route for gateway %s because it already exists", gatewayPrefix) + return nil + } + + gatewayHop, err := getExistingRIBRouteGateway(gatewayPrefix) + if err != nil && err != errRouteNotFound { + return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err) + } + log.Debugf("adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop) + return addToRouteTable(gatewayPrefix, gatewayHop.String()) +} + +func existsInRouteTable(prefix netip.Prefix) (bool, error) { + routes, err := getRoutesFromTable() + if err != nil { + return false, err + } + for _, tableRoute := range routes { + if tableRoute == prefix { + return true, nil + } + } + return false, nil +} + +func isSubRange(prefix netip.Prefix) (bool, error) { + routes, err := getRoutesFromTable() + if err != nil { + return false, err + } + for _, tableRoute := range routes { + if tableRoute.Bits() > minRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() { + return true, nil + } + } + return false, nil +} + +func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string) error { + return removeFromRouteTable(prefix, addr) +} + +func getExistingRIBRouteGateway(prefix netip.Prefix) (net.IP, error) { + r, err := netroute.New() + if err != nil { + return nil, err + } + _, gateway, preferredSrc, err := r.Route(prefix.Addr().AsSlice()) + if err != nil { + log.Errorf("getting routes returned an error: %v", err) + return nil, errRouteNotFound + } + + if gateway == nil { + return preferredSrc, nil + } + + return gateway, nil +} diff --git a/client/internal/routemanager/systemops_test.go b/client/internal/routemanager/systemops_nonandroid_test.go similarity index 59% rename from client/internal/routemanager/systemops_test.go rename to client/internal/routemanager/systemops_nonandroid_test.go index 97386f19a..6f32d9634 100644 --- a/client/internal/routemanager/systemops_test.go +++ b/client/internal/routemanager/systemops_nonandroid_test.go @@ -1,32 +1,24 @@ -//go:build !android && !ios +//go:build !android package routemanager import ( "bytes" - "context" "fmt" "net" "net/netip" "os" - "runtime" "strings" "testing" "github.com/pion/transport/v3/stdnet" log "github.com/sirupsen/logrus" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/iface" ) -type dialer interface { - Dial(network, address string) (net.Conn, error) - DialContext(ctx context.Context, network, address string) (net.Conn, error) -} - func TestAddRemoveRoutes(t *testing.T) { testCases := []struct { name string @@ -61,30 +53,27 @@ func TestAddRemoveRoutes(t *testing.T) { err = wgInterface.Create() require.NoError(t, err, "should create testing wireguard interface") - _, _, err = setupRouting(nil, nil) - require.NoError(t, err) - t.Cleanup(func() { - assert.NoError(t, cleanupRouting()) - }) - err = genericAddVPNRoute(testCase.prefix, wgInterface.Name()) - require.NoError(t, err, "genericAddVPNRoute should not return err") + err = addToRouteTableIfNoExists(testCase.prefix, wgInterface.Address().IP.String()) + require.NoError(t, err, "addToRouteTableIfNoExists should not return err") + prefixGateway, err := getExistingRIBRouteGateway(testCase.prefix) + require.NoError(t, err, "getExistingRIBRouteGateway should not return err") if testCase.shouldRouteToWireguard { - assertWGOutInterface(t, testCase.prefix, wgInterface, false) + require.Equal(t, wgInterface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP") } else { - assertWGOutInterface(t, testCase.prefix, wgInterface, true) + require.NotEqual(t, wgInterface.Address().IP.String(), prefixGateway.String(), "route should point to a different interface") } exists, err := existsInRouteTable(testCase.prefix) require.NoError(t, err, "existsInRouteTable should not return err") if exists && testCase.shouldRouteToWireguard { - err = genericRemoveVPNRoute(testCase.prefix, wgInterface.Name()) - require.NoError(t, err, "genericRemoveVPNRoute should not return err") + err = removeFromRouteTableIfNonSystem(testCase.prefix, wgInterface.Address().IP.String()) + require.NoError(t, err, "removeFromRouteTableIfNonSystem should not return err") - prefixGateway, _, err := getNextHop(testCase.prefix.Addr()) - require.NoError(t, err, "getNextHop should not return err") + prefixGateway, err = getExistingRIBRouteGateway(testCase.prefix) + require.NoError(t, err, "getExistingRIBRouteGateway should not return err") - internetGateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0")) + internetGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0")) require.NoError(t, err) if testCase.shouldBeRemoved { @@ -97,12 +86,12 @@ func TestAddRemoveRoutes(t *testing.T) { } } -func TestGetNextHop(t *testing.T) { - gateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0")) +func TestGetExistingRIBRouteGateway(t *testing.T) { + gateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0")) if err != nil { t.Fatal("shouldn't return error when fetching the gateway: ", err) } - if !gateway.IsValid() { + if gateway == nil { t.Fatal("should return a gateway") } addresses, err := net.InterfaceAddrs() @@ -124,11 +113,11 @@ func TestGetNextHop(t *testing.T) { } } - localIP, _, err := getNextHop(testingPrefix.Addr()) + localIP, err := getExistingRIBRouteGateway(testingPrefix) if err != nil { t.Fatal("shouldn't return error: ", err) } - if !localIP.IsValid() { + if localIP == nil { t.Fatal("should return a gateway for local network") } if localIP.String() == gateway.String() { @@ -139,8 +128,8 @@ func TestGetNextHop(t *testing.T) { } } -func TestAddExistAndRemoveRoute(t *testing.T) { - defaultGateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0")) +func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) { + defaultGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0")) t.Log("defaultGateway: ", defaultGateway) if err != nil { t.Fatal("shouldn't return error when fetching the gateway: ", err) @@ -200,14 +189,16 @@ func TestAddExistAndRemoveRoute(t *testing.T) { err = wgInterface.Create() require.NoError(t, err, "should create testing wireguard interface") + MockAddr := wgInterface.Address().IP.String() + // Prepare the environment if testCase.preExistingPrefix.IsValid() { - err := genericAddVPNRoute(testCase.preExistingPrefix, wgInterface.Name()) + err := addToRouteTableIfNoExists(testCase.preExistingPrefix, MockAddr) require.NoError(t, err, "should not return err when adding pre-existing route") } // Add the route - err = genericAddVPNRoute(testCase.prefix, wgInterface.Name()) + err = addToRouteTableIfNoExists(testCase.prefix, MockAddr) require.NoError(t, err, "should not return err when adding route") if testCase.shouldAddRoute { @@ -217,7 +208,7 @@ func TestAddExistAndRemoveRoute(t *testing.T) { require.True(t, ok, "route should exist") // remove route again if added - err = genericRemoveVPNRoute(testCase.prefix, wgInterface.Name()) + err = removeFromRouteTableIfNonSystem(testCase.prefix, MockAddr) require.NoError(t, err, "should not return err") } @@ -226,7 +217,6 @@ func TestAddExistAndRemoveRoute(t *testing.T) { ok, err := existsInRouteTable(testCase.prefix) t.Log("Buffer string: ", buf.String()) require.NoError(t, err, "should not return err") - if !strings.Contains(buf.String(), "because it already exists") { require.False(t, ok, "route should not exist") } @@ -234,6 +224,31 @@ func TestAddExistAndRemoveRoute(t *testing.T) { } } +func TestExistsInRouteTable(t *testing.T) { + addresses, err := net.InterfaceAddrs() + if err != nil { + t.Fatal("shouldn't return error when fetching interface addresses: ", err) + } + + var addressPrefixes []netip.Prefix + for _, address := range addresses { + p := netip.MustParsePrefix(address.String()) + if p.Addr().Is4() { + addressPrefixes = append(addressPrefixes, p.Masked()) + } + } + + for _, prefix := range addressPrefixes { + exists, err := existsInRouteTable(prefix) + if err != nil { + t.Fatal("shouldn't return error when checking if address exists in route table: ", err) + } + if !exists { + t.Fatalf("address %s should exist in route table", prefix) + } + } +} + func TestIsSubRange(t *testing.T) { addresses, err := net.InterfaceAddrs() if err != nil { @@ -271,132 +286,3 @@ func TestIsSubRange(t *testing.T) { } } } - -func TestExistsInRouteTable(t *testing.T) { - addresses, err := net.InterfaceAddrs() - if err != nil { - t.Fatal("shouldn't return error when fetching interface addresses: ", err) - } - - var addressPrefixes []netip.Prefix - for _, address := range addresses { - p := netip.MustParsePrefix(address.String()) - if p.Addr().Is6() { - continue - } - // Windows sometimes has hidden interface link local addrs that don't turn up on any interface - if runtime.GOOS == "windows" && p.Addr().IsLinkLocalUnicast() { - continue - } - // Linux loopback 127/8 is in the local table, not in the main table and always takes precedence - if runtime.GOOS == "linux" && p.Addr().IsLoopback() { - continue - } - - addressPrefixes = append(addressPrefixes, p.Masked()) - } - - for _, prefix := range addressPrefixes { - exists, err := existsInRouteTable(prefix) - if err != nil { - t.Fatal("shouldn't return error when checking if address exists in route table: ", err) - } - if !exists { - t.Fatalf("address %s should exist in route table", prefix) - } - } -} - -func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listenPort int) *iface.WGIface { - t.Helper() - - peerPrivateKey, err := wgtypes.GeneratePrivateKey() - require.NoError(t, err) - - newNet, err := stdnet.NewNet() - require.NoError(t, err) - - wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil) - require.NoError(t, err, "should create testing WireGuard interface") - - err = wgInterface.Create() - require.NoError(t, err, "should create testing WireGuard interface") - - t.Cleanup(func() { - wgInterface.Close() - }) - - return wgInterface -} - -func setupTestEnv(t *testing.T) { - t.Helper() - - setupDummyInterfacesAndRoutes(t) - - wgIface := createWGInterface(t, expectedVPNint, "100.64.0.1/24", 51820) - t.Cleanup(func() { - assert.NoError(t, wgIface.Close()) - }) - - _, _, err := setupRouting(nil, wgIface) - require.NoError(t, err, "setupRouting should not return err") - t.Cleanup(func() { - assert.NoError(t, cleanupRouting()) - }) - - // default route exists in main table and vpn table - err = addVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name()) - require.NoError(t, err, "addVPNRoute should not return err") - t.Cleanup(func() { - err = removeVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name()) - assert.NoError(t, err, "removeVPNRoute should not return err") - }) - - // 10.0.0.0/8 route exists in main table and vpn table - err = addVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name()) - require.NoError(t, err, "addVPNRoute should not return err") - t.Cleanup(func() { - err = removeVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name()) - assert.NoError(t, err, "removeVPNRoute should not return err") - }) - - // 10.10.0.0/24 more specific route exists in vpn table - err = addVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name()) - require.NoError(t, err, "addVPNRoute should not return err") - t.Cleanup(func() { - err = removeVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name()) - assert.NoError(t, err, "removeVPNRoute should not return err") - }) - - // 127.0.10.0/24 more specific route exists in vpn table - err = addVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name()) - require.NoError(t, err, "addVPNRoute should not return err") - t.Cleanup(func() { - err = removeVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name()) - assert.NoError(t, err, "removeVPNRoute should not return err") - }) - - // unique route in vpn table - err = addVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name()) - require.NoError(t, err, "addVPNRoute should not return err") - t.Cleanup(func() { - err = removeVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name()) - assert.NoError(t, err, "removeVPNRoute should not return err") - }) -} - -func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIface, invert bool) { - t.Helper() - if runtime.GOOS == "linux" && prefix.Addr().IsLoopback() { - return - } - - prefixGateway, _, err := getNextHop(prefix.Addr()) - require.NoError(t, err, "getNextHop should not return err") - if invert { - assert.NotEqual(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should not point to wireguard interface IP") - } else { - assert.Equal(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP") - } -} diff --git a/client/internal/routemanager/systemops_nonlinux.go b/client/internal/routemanager/systemops_nonlinux.go index 38026107e..47bd60eb0 100644 --- a/client/internal/routemanager/systemops_nonlinux.go +++ b/client/internal/routemanager/systemops_nonlinux.go @@ -1,23 +1,41 @@ -//go:build !linux && !ios +//go:build !linux +// +build !linux package routemanager import ( "net/netip" + "os/exec" "runtime" log "github.com/sirupsen/logrus" ) -func enableIPForwarding() error { - log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS) +func addToRouteTable(prefix netip.Prefix, addr string) error { + cmd := exec.Command("route", "add", prefix.String(), addr) + out, err := cmd.Output() + if err != nil { + return err + } + log.Debugf(string(out)) return nil } -func addVPNRoute(prefix netip.Prefix, intf string) error { - return genericAddVPNRoute(prefix, intf) +func removeFromRouteTable(prefix netip.Prefix, addr string) error { + args := []string{"delete", prefix.String()} + if runtime.GOOS == "darwin" { + args = append(args, addr) + } + cmd := exec.Command("route", args...) + out, err := cmd.Output() + if err != nil { + return err + } + log.Debugf(string(out)) + return nil } -func removeVPNRoute(prefix netip.Prefix, intf string) error { - return genericRemoveVPNRoute(prefix, intf) +func enableIPForwarding() error { + log.Infof("enable IP forwarding is not implemented on %s", runtime.GOOS) + return nil } diff --git a/client/internal/routemanager/systemops_unix_test.go b/client/internal/routemanager/systemops_unix_test.go deleted file mode 100644 index 561eaeea4..000000000 --- a/client/internal/routemanager/systemops_unix_test.go +++ /dev/null @@ -1,234 +0,0 @@ -//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || netbsd || dragonfly - -package routemanager - -import ( - "fmt" - "net" - "strings" - "testing" - "time" - - "github.com/gopacket/gopacket" - "github.com/gopacket/gopacket/layers" - "github.com/gopacket/gopacket/pcap" - "github.com/miekg/dns" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - nbnet "github.com/netbirdio/netbird/util/net" -) - -type PacketExpectation struct { - SrcIP net.IP - DstIP net.IP - SrcPort int - DstPort int - UDP bool - TCP bool -} - -type testCase struct { - name string - destination string - expectedInterface string - dialer dialer - expectedPacket PacketExpectation -} - -var testCases = []testCase{ - { - name: "To external host without custom dialer via vpn", - destination: "192.0.2.1:53", - expectedInterface: expectedVPNint, - dialer: &net.Dialer{}, - expectedPacket: createPacketExpectation("100.64.0.1", 12345, "192.0.2.1", 53), - }, - { - name: "To external host with custom dialer via physical interface", - destination: "192.0.2.1:53", - expectedInterface: expectedExternalInt, - dialer: nbnet.NewDialer(), - expectedPacket: createPacketExpectation("192.168.0.1", 12345, "192.0.2.1", 53), - }, - - { - name: "To duplicate internal route with custom dialer via physical interface", - destination: "10.0.0.2:53", - expectedInterface: expectedInternalInt, - dialer: nbnet.NewDialer(), - expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53), - }, - { - name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence - destination: "10.0.0.2:53", - expectedInterface: expectedInternalInt, - dialer: &net.Dialer{}, - expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53), - }, - - { - name: "To unique vpn route with custom dialer via physical interface", - destination: "172.16.0.2:53", - expectedInterface: expectedExternalInt, - dialer: nbnet.NewDialer(), - expectedPacket: createPacketExpectation("192.168.0.1", 12345, "172.16.0.2", 53), - }, - { - name: "To unique vpn route without custom dialer via vpn", - destination: "172.16.0.2:53", - expectedInterface: expectedVPNint, - dialer: &net.Dialer{}, - expectedPacket: createPacketExpectation("100.64.0.1", 12345, "172.16.0.2", 53), - }, -} - -func TestRouting(t *testing.T) { - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - setupTestEnv(t) - - filter := createBPFFilter(tc.destination) - handle := startPacketCapture(t, tc.expectedInterface, filter) - - sendTestPacket(t, tc.destination, tc.expectedPacket.SrcPort, tc.dialer) - - packetSource := gopacket.NewPacketSource(handle, handle.LinkType()) - packet, err := packetSource.NextPacket() - require.NoError(t, err) - - verifyPacket(t, packet, tc.expectedPacket) - }) - } -} - -func createPacketExpectation(srcIP string, srcPort int, dstIP string, dstPort int) PacketExpectation { - return PacketExpectation{ - SrcIP: net.ParseIP(srcIP), - DstIP: net.ParseIP(dstIP), - SrcPort: srcPort, - DstPort: dstPort, - UDP: true, - } -} - -func startPacketCapture(t *testing.T, intf, filter string) *pcap.Handle { - t.Helper() - - inactive, err := pcap.NewInactiveHandle(intf) - require.NoError(t, err, "Failed to create inactive pcap handle") - defer inactive.CleanUp() - - err = inactive.SetSnapLen(1600) - require.NoError(t, err, "Failed to set snap length on inactive handle") - - err = inactive.SetTimeout(time.Second * 10) - require.NoError(t, err, "Failed to set timeout on inactive handle") - - err = inactive.SetImmediateMode(true) - require.NoError(t, err, "Failed to set immediate mode on inactive handle") - - handle, err := inactive.Activate() - require.NoError(t, err, "Failed to activate pcap handle") - t.Cleanup(handle.Close) - - err = handle.SetBPFFilter(filter) - require.NoError(t, err, "Failed to set BPF filter") - - return handle -} - -func sendTestPacket(t *testing.T, destination string, sourcePort int, dialer dialer) { - t.Helper() - - if dialer == nil { - dialer = &net.Dialer{} - } - - if sourcePort != 0 { - localUDPAddr := &net.UDPAddr{ - IP: net.IPv4zero, - Port: sourcePort, - } - switch dialer := dialer.(type) { - case *nbnet.Dialer: - dialer.LocalAddr = localUDPAddr - case *net.Dialer: - dialer.LocalAddr = localUDPAddr - default: - t.Fatal("Unsupported dialer type") - } - } - - msg := new(dns.Msg) - msg.Id = dns.Id() - msg.RecursionDesired = true - msg.Question = []dns.Question{ - {Name: "example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, - } - - conn, err := dialer.Dial("udp", destination) - require.NoError(t, err, "Failed to dial UDP") - defer conn.Close() - - data, err := msg.Pack() - require.NoError(t, err, "Failed to pack DNS message") - - _, err = conn.Write(data) - if err != nil { - if strings.Contains(err.Error(), "required key not available") { - t.Logf("Ignoring WireGuard key error: %v", err) - return - } - t.Fatalf("Failed to send DNS query: %v", err) - } -} - -func createBPFFilter(destination string) string { - host, port, err := net.SplitHostPort(destination) - if err != nil { - return fmt.Sprintf("udp and dst host %s and dst port %s", host, port) - } - return "udp" -} - -func verifyPacket(t *testing.T, packet gopacket.Packet, exp PacketExpectation) { - t.Helper() - - ipLayer := packet.Layer(layers.LayerTypeIPv4) - require.NotNil(t, ipLayer, "Expected IPv4 layer not found in packet") - - ip, ok := ipLayer.(*layers.IPv4) - require.True(t, ok, "Failed to cast to IPv4 layer") - - // Convert both source and destination IP addresses to 16-byte representation - expectedSrcIP := exp.SrcIP.To16() - actualSrcIP := ip.SrcIP.To16() - assert.Equal(t, expectedSrcIP, actualSrcIP, "Source IP mismatch") - - expectedDstIP := exp.DstIP.To16() - actualDstIP := ip.DstIP.To16() - assert.Equal(t, expectedDstIP, actualDstIP, "Destination IP mismatch") - - if exp.UDP { - udpLayer := packet.Layer(layers.LayerTypeUDP) - require.NotNil(t, udpLayer, "Expected UDP layer not found in packet") - - udp, ok := udpLayer.(*layers.UDP) - require.True(t, ok, "Failed to cast to UDP layer") - - assert.Equal(t, layers.UDPPort(exp.SrcPort), udp.SrcPort, "UDP source port mismatch") - assert.Equal(t, layers.UDPPort(exp.DstPort), udp.DstPort, "UDP destination port mismatch") - } - - if exp.TCP { - tcpLayer := packet.Layer(layers.LayerTypeTCP) - require.NotNil(t, tcpLayer, "Expected TCP layer not found in packet") - - tcp, ok := tcpLayer.(*layers.TCP) - require.True(t, ok, "Failed to cast to TCP layer") - - assert.Equal(t, layers.TCPPort(exp.SrcPort), tcp.SrcPort, "TCP source port mismatch") - assert.Equal(t, layers.TCPPort(exp.DstPort), tcp.DstPort, "TCP destination port mismatch") - } -} diff --git a/client/internal/routemanager/systemops_windows.go b/client/internal/routemanager/systemops_windows.go index 50fff0cd5..309c184b9 100644 --- a/client/internal/routemanager/systemops_windows.go +++ b/client/internal/routemanager/systemops_windows.go @@ -1,19 +1,13 @@ //go:build windows +// +build windows package routemanager import ( - "fmt" "net" "net/netip" - "os/exec" - "strings" - log "github.com/sirupsen/logrus" "github.com/yusufpapurcu/wmi" - - "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/iface" ) type Win32_IP4RouteTable struct { @@ -21,35 +15,23 @@ type Win32_IP4RouteTable struct { Mask string } -var routeManager *RouteManager - -func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { - return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface) -} - -func cleanupRouting() error { - return cleanupRoutingWithRouteManager(routeManager) -} - func getRoutesFromTable() ([]netip.Prefix, error) { var routes []Win32_IP4RouteTable query := "SELECT Destination, Mask FROM Win32_IP4RouteTable" err := wmi.Query(query, &routes) if err != nil { - return nil, fmt.Errorf("get routes: %w", err) + return nil, err } var prefixList []netip.Prefix for _, route := range routes { addr, err := netip.ParseAddr(route.Destination) if err != nil { - log.Warnf("Unable to parse route destination %s: %v", route.Destination, err) continue } maskSlice := net.ParseIP(route.Mask).To4() if maskSlice == nil { - log.Warnf("Unable to parse route mask %s", route.Mask) continue } mask := net.IPv4Mask(maskSlice[0], maskSlice[1], maskSlice[2], maskSlice[3]) @@ -62,69 +44,3 @@ func getRoutesFromTable() ([]netip.Prefix, error) { } return prefixList, nil } - -func addRoutePowershell(prefix netip.Prefix, nexthop netip.Addr, intf string) error { - destinationPrefix := prefix.String() - psCmd := "New-NetRoute" - - addressFamily := "IPv4" - if prefix.Addr().Is6() { - addressFamily = "IPv6" - } - - script := fmt.Sprintf( - `%s -AddressFamily "%s" -DestinationPrefix "%s" -InterfaceAlias "%s" -Confirm:$False -ErrorAction Stop`, - psCmd, addressFamily, destinationPrefix, intf, - ) - - if nexthop.IsValid() { - script = fmt.Sprintf( - `%s -NextHop "%s"`, script, nexthop, - ) - } - - out, err := exec.Command("powershell", "-Command", script).CombinedOutput() - log.Tracef("PowerShell add route: %s", string(out)) - - if err != nil { - return fmt.Errorf("PowerShell add route: %w", err) - } - - return nil -} - -func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, _ string) error { - args := []string{"add", prefix.String(), nexthop.Unmap().String()} - - out, err := exec.Command("route", args...).CombinedOutput() - - log.Tracef("route %s output: %s", strings.Join(args, " "), out) - if err != nil { - return fmt.Errorf("route add: %w", err) - } - - return nil -} - -func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { - // Powershell doesn't support adding routes without an interface but allows to add interface by name - if intf != "" { - return addRoutePowershell(prefix, nexthop, intf) - } - return addRouteCmd(prefix, nexthop, intf) -} - -func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, _ string) error { - args := []string{"delete", prefix.String()} - if nexthop.IsValid() { - args = append(args, nexthop.Unmap().String()) - } - - out, err := exec.Command("route", args...).CombinedOutput() - log.Tracef("route %s output: %s", strings.Join(args, " "), out) - - if err != nil { - return fmt.Errorf("remove route: %w", err) - } - return nil -} diff --git a/client/internal/routemanager/systemops_windows_test.go b/client/internal/routemanager/systemops_windows_test.go deleted file mode 100644 index a5e03b8d2..000000000 --- a/client/internal/routemanager/systemops_windows_test.go +++ /dev/null @@ -1,289 +0,0 @@ -package routemanager - -import ( - "context" - "encoding/json" - "fmt" - "net" - "os/exec" - "strings" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - nbnet "github.com/netbirdio/netbird/util/net" -) - -var expectedExtInt = "Ethernet1" - -type RouteInfo struct { - NextHop string `json:"nexthop"` - InterfaceAlias string `json:"interfacealias"` - RouteMetric int `json:"routemetric"` -} - -type FindNetRouteOutput struct { - IPAddress string `json:"IPAddress"` - InterfaceIndex int `json:"InterfaceIndex"` - InterfaceAlias string `json:"InterfaceAlias"` - AddressFamily int `json:"AddressFamily"` - NextHop string `json:"NextHop"` - DestinationPrefix string `json:"DestinationPrefix"` -} - -type testCase struct { - name string - destination string - expectedSourceIP string - expectedDestPrefix string - expectedNextHop string - expectedInterface string - dialer dialer -} - -var expectedVPNint = "wgtest0" - -var testCases = []testCase{ - { - name: "To external host without custom dialer via vpn", - destination: "192.0.2.1:53", - expectedSourceIP: "100.64.0.1", - expectedDestPrefix: "128.0.0.0/1", - expectedNextHop: "0.0.0.0", - expectedInterface: "wgtest0", - dialer: &net.Dialer{}, - }, - { - name: "To external host with custom dialer via physical interface", - destination: "192.0.2.1:53", - expectedDestPrefix: "192.0.2.1/32", - expectedInterface: expectedExtInt, - dialer: nbnet.NewDialer(), - }, - - { - name: "To duplicate internal route with custom dialer via physical interface", - destination: "10.0.0.2:53", - expectedDestPrefix: "10.0.0.2/32", - expectedInterface: expectedExtInt, - dialer: nbnet.NewDialer(), - }, - { - name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence - destination: "10.0.0.2:53", - expectedSourceIP: "10.0.0.1", - expectedDestPrefix: "10.0.0.0/8", - expectedNextHop: "0.0.0.0", - expectedInterface: "Loopback Pseudo-Interface 1", - dialer: &net.Dialer{}, - }, - - { - name: "To unique vpn route with custom dialer via physical interface", - destination: "172.16.0.2:53", - expectedDestPrefix: "172.16.0.2/32", - expectedInterface: expectedExtInt, - dialer: nbnet.NewDialer(), - }, - { - name: "To unique vpn route without custom dialer via vpn", - destination: "172.16.0.2:53", - expectedSourceIP: "100.64.0.1", - expectedDestPrefix: "172.16.0.0/12", - expectedNextHop: "0.0.0.0", - expectedInterface: "wgtest0", - dialer: &net.Dialer{}, - }, - - { - name: "To more specific route without custom dialer via vpn interface", - destination: "10.10.0.2:53", - expectedSourceIP: "100.64.0.1", - expectedDestPrefix: "10.10.0.0/24", - expectedNextHop: "0.0.0.0", - expectedInterface: "wgtest0", - dialer: &net.Dialer{}, - }, - - { - name: "To more specific route (local) without custom dialer via physical interface", - destination: "127.0.10.2:53", - expectedSourceIP: "10.0.0.1", - expectedDestPrefix: "127.0.0.0/8", - expectedNextHop: "0.0.0.0", - expectedInterface: "Loopback Pseudo-Interface 1", - dialer: &net.Dialer{}, - }, -} - -func TestRouting(t *testing.T) { - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - setupTestEnv(t) - - route, err := fetchOriginalGateway() - require.NoError(t, err, "Failed to fetch original gateway") - ip, err := fetchInterfaceIP(route.InterfaceAlias) - require.NoError(t, err, "Failed to fetch interface IP") - - output := testRoute(t, tc.destination, tc.dialer) - if tc.expectedInterface == expectedExtInt { - verifyOutput(t, output, ip, tc.expectedDestPrefix, route.NextHop, route.InterfaceAlias) - } else { - verifyOutput(t, output, tc.expectedSourceIP, tc.expectedDestPrefix, tc.expectedNextHop, tc.expectedInterface) - } - }) - } -} - -// fetchInterfaceIP fetches the IPv4 address of the specified interface. -func fetchInterfaceIP(interfaceAlias string) (string, error) { - script := fmt.Sprintf(`Get-NetIPAddress -InterfaceAlias "%s" | Where-Object AddressFamily -eq 2 | Select-Object -ExpandProperty IPAddress`, interfaceAlias) - out, err := exec.Command("powershell", "-Command", script).Output() - if err != nil { - return "", fmt.Errorf("failed to execute Get-NetIPAddress: %w", err) - } - - ip := strings.TrimSpace(string(out)) - return ip, nil -} - -func testRoute(t *testing.T, destination string, dialer dialer) *FindNetRouteOutput { - t.Helper() - - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) - defer cancel() - - conn, err := dialer.DialContext(ctx, "udp", destination) - require.NoError(t, err, "Failed to dial destination") - defer func() { - err := conn.Close() - assert.NoError(t, err, "Failed to close connection") - }() - - host, _, err := net.SplitHostPort(destination) - require.NoError(t, err) - - script := fmt.Sprintf(`Find-NetRoute -RemoteIPAddress "%s" | Select-Object -Property IPAddress, InterfaceIndex, InterfaceAlias, AddressFamily, NextHop, DestinationPrefix | ConvertTo-Json`, host) - - out, err := exec.Command("powershell", "-Command", script).Output() - require.NoError(t, err, "Failed to execute Find-NetRoute") - - var outputs []FindNetRouteOutput - err = json.Unmarshal(out, &outputs) - require.NoError(t, err, "Failed to parse JSON outputs from Find-NetRoute") - - require.Greater(t, len(outputs), 0, "No route found for destination") - combinedOutput := combineOutputs(outputs) - - return combinedOutput -} - -func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) string { - t.Helper() - - ip, ipNet, err := net.ParseCIDR(ipAddressCIDR) - require.NoError(t, err) - subnetMaskSize, _ := ipNet.Mask.Size() - script := fmt.Sprintf(`New-NetIPAddress -InterfaceAlias "%s" -IPAddress "%s" -PrefixLength %d -PolicyStore ActiveStore -Confirm:$False`, interfaceName, ip.String(), subnetMaskSize) - _, err = exec.Command("powershell", "-Command", script).CombinedOutput() - require.NoError(t, err, "Failed to assign IP address to loopback adapter") - - // Wait for the IP address to be applied - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - err = waitForIPAddress(ctx, interfaceName, ip.String()) - require.NoError(t, err, "IP address not applied within timeout") - - t.Cleanup(func() { - script = fmt.Sprintf(`Remove-NetIPAddress -InterfaceAlias "%s" -IPAddress "%s" -Confirm:$False`, interfaceName, ip.String()) - _, err = exec.Command("powershell", "-Command", script).CombinedOutput() - require.NoError(t, err, "Failed to remove IP address from loopback adapter") - }) - - return interfaceName -} - -func fetchOriginalGateway() (*RouteInfo, error) { - cmd := exec.Command("powershell", "-Command", "Get-NetRoute -DestinationPrefix 0.0.0.0/0 | Select-Object NextHop, RouteMetric, InterfaceAlias | ConvertTo-Json") - output, err := cmd.CombinedOutput() - if err != nil { - return nil, fmt.Errorf("failed to execute Get-NetRoute: %w", err) - } - - var routeInfo RouteInfo - err = json.Unmarshal(output, &routeInfo) - if err != nil { - return nil, fmt.Errorf("failed to parse JSON output: %w", err) - } - - return &routeInfo, nil -} - -func verifyOutput(t *testing.T, output *FindNetRouteOutput, sourceIP, destPrefix, nextHop, intf string) { - t.Helper() - - assert.Equal(t, sourceIP, output.IPAddress, "Source IP mismatch") - assert.Equal(t, destPrefix, output.DestinationPrefix, "Destination prefix mismatch") - assert.Equal(t, nextHop, output.NextHop, "Next hop mismatch") - assert.Equal(t, intf, output.InterfaceAlias, "Interface mismatch") -} - -func waitForIPAddress(ctx context.Context, interfaceAlias, expectedIPAddress string) error { - ticker := time.NewTicker(1 * time.Second) - defer ticker.Stop() - - for { - select { - case <-ctx.Done(): - return ctx.Err() - case <-ticker.C: - out, err := exec.Command("powershell", "-Command", fmt.Sprintf(`Get-NetIPAddress -InterfaceAlias "%s" | Select-Object -ExpandProperty IPAddress`, interfaceAlias)).CombinedOutput() - if err != nil { - return err - } - - ipAddresses := strings.Split(strings.TrimSpace(string(out)), "\n") - for _, ip := range ipAddresses { - if strings.TrimSpace(ip) == expectedIPAddress { - return nil - } - } - } - } -} - -func combineOutputs(outputs []FindNetRouteOutput) *FindNetRouteOutput { - var combined FindNetRouteOutput - - for _, output := range outputs { - if output.IPAddress != "" { - combined.IPAddress = output.IPAddress - } - if output.InterfaceIndex != 0 { - combined.InterfaceIndex = output.InterfaceIndex - } - if output.InterfaceAlias != "" { - combined.InterfaceAlias = output.InterfaceAlias - } - if output.AddressFamily != 0 { - combined.AddressFamily = output.AddressFamily - } - if output.NextHop != "" { - combined.NextHop = output.NextHop - } - if output.DestinationPrefix != "" { - combined.DestinationPrefix = output.DestinationPrefix - } - } - - return &combined -} - -func setupDummyInterfacesAndRoutes(t *testing.T) { - t.Helper() - - createAndSetupDummyInterface(t, "Loopback Pseudo-Interface 1", "10.0.0.1/8") -} diff --git a/client/internal/stdnet/dialer.go b/client/internal/stdnet/dialer.go deleted file mode 100644 index e80adb42b..000000000 --- a/client/internal/stdnet/dialer.go +++ /dev/null @@ -1,24 +0,0 @@ -package stdnet - -import ( - "net" - - "github.com/pion/transport/v3" - - nbnet "github.com/netbirdio/netbird/util/net" -) - -// Dial connects to the address on the named network. -func (n *Net) Dial(network, address string) (net.Conn, error) { - return nbnet.NewDialer().Dial(network, address) -} - -// DialUDP connects to the address on the named UDP network. -func (n *Net) DialUDP(network string, laddr, raddr *net.UDPAddr) (transport.UDPConn, error) { - return nbnet.DialUDP(network, laddr, raddr) -} - -// DialTCP connects to the address on the named TCP network. -func (n *Net) DialTCP(network string, laddr, raddr *net.TCPAddr) (transport.TCPConn, error) { - return nbnet.DialTCP(network, laddr, raddr) -} diff --git a/client/internal/stdnet/listener.go b/client/internal/stdnet/listener.go deleted file mode 100644 index 9ce0a5556..000000000 --- a/client/internal/stdnet/listener.go +++ /dev/null @@ -1,20 +0,0 @@ -package stdnet - -import ( - "context" - "net" - - "github.com/pion/transport/v3" - - nbnet "github.com/netbirdio/netbird/util/net" -) - -// ListenPacket listens for incoming packets on the given network and address. -func (n *Net) ListenPacket(network, address string) (net.PacketConn, error) { - return nbnet.NewListener().ListenPacket(context.Background(), network, address) -} - -// ListenUDP acts like ListenPacket for UDP networks. -func (n *Net) ListenUDP(network string, locAddr *net.UDPAddr) (transport.UDPConn, error) { - return nbnet.ListenUDP(network, locAddr) -} diff --git a/client/internal/wgproxy/proxy_ebpf.go b/client/internal/wgproxy/proxy_ebpf.go index 2235c5d2b..f02b4943b 100644 --- a/client/internal/wgproxy/proxy_ebpf.go +++ b/client/internal/wgproxy/proxy_ebpf.go @@ -17,7 +17,6 @@ import ( "github.com/netbirdio/netbird/client/internal/ebpf" ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager" - nbnet "github.com/netbirdio/netbird/util/net" ) // WGEBPFProxy definition for proxy with EBPF support @@ -68,7 +67,7 @@ func (p *WGEBPFProxy) Listen() error { IP: net.ParseIP("127.0.0.1"), } - conn, err := nbnet.ListenUDP("udp", &addr) + conn, err := net.ListenUDP("udp", &addr) if err != nil { cErr := p.Free() if cErr != nil { @@ -229,12 +228,6 @@ func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) { return nil, fmt.Errorf("binding to lo interface failed: %w", err) } - // Set the fwmark on the socket. - err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, nbnet.NetbirdFwmark) - if err != nil { - return nil, fmt.Errorf("setting fwmark failed: %w", err) - } - // Convert the file descriptor to a PacketConn. file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd)) if file == nil { diff --git a/client/internal/wgproxy/proxy_userspace.go b/client/internal/wgproxy/proxy_userspace.go index 17ebfbc49..b692ea708 100644 --- a/client/internal/wgproxy/proxy_userspace.go +++ b/client/internal/wgproxy/proxy_userspace.go @@ -6,8 +6,6 @@ import ( "net" log "github.com/sirupsen/logrus" - - nbnet "github.com/netbirdio/netbird/util/net" ) // WGUserSpaceProxy proxies @@ -35,7 +33,7 @@ func (p *WGUserSpaceProxy) AddTurnConn(remoteConn net.Conn) (net.Addr, error) { p.remoteConn = remoteConn var err error - p.localConn, err = nbnet.NewDialer().Dial("udp", fmt.Sprintf(":%d", p.localWGListenPort)) + p.localConn, err = net.Dial("udp", fmt.Sprintf(":%d", p.localWGListenPort)) if err != nil { log.Errorf("failed dialing to local Wireguard port %s", err) return nil, err diff --git a/go.mod b/go.mod index 29a1570c8..e4e36b966 100644 --- a/go.mod +++ b/go.mod @@ -48,7 +48,6 @@ require ( github.com/google/gopacket v1.1.19 github.com/google/martian/v3 v3.0.0 github.com/google/nftables v0.0.0-20220808154552-2eca00135732 - github.com/gopacket/gopacket v1.1.1 github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357 github.com/hashicorp/go-multierror v1.1.1 github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 @@ -124,6 +123,7 @@ require ( github.com/google/s2a-go v0.1.4 // indirect github.com/googleapis/enterprise-certificate-proxy v0.2.3 // indirect github.com/googleapis/gax-go/v2 v2.10.0 // indirect + github.com/gopacket/gopacket v1.1.1 // indirect github.com/hashicorp/errwrap v1.0.0 // indirect github.com/hashicorp/go-uuid v1.0.2 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect diff --git a/iface/wg_configurer_kernel.go b/iface/wg_configurer_kernel.go index 9fe987cee..36fd13cc2 100644 --- a/iface/wg_configurer_kernel.go +++ b/iface/wg_configurer_kernel.go @@ -10,8 +10,6 @@ import ( log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - - nbnet "github.com/netbirdio/netbird/util/net" ) type wgKernelConfigurer struct { @@ -31,7 +29,7 @@ func (c *wgKernelConfigurer) configureInterface(privateKey string, port int) err if err != nil { return err } - fwmark := nbnet.NetbirdFwmark + fwmark := 0 config := wgtypes.Config{ PrivateKey: &key, ReplacePeers: true, diff --git a/iface/wg_configurer_usp.go b/iface/wg_configurer_usp.go index 24dfadf14..200bfbc96 100644 --- a/iface/wg_configurer_usp.go +++ b/iface/wg_configurer_usp.go @@ -13,8 +13,6 @@ import ( log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - - nbnet "github.com/netbirdio/netbird/util/net" ) type wgUSPConfigurer struct { @@ -39,7 +37,7 @@ func (c *wgUSPConfigurer) configureInterface(privateKey string, port int) error if err != nil { return err } - fwmark := getFwmark() + fwmark := 0 config := wgtypes.Config{ PrivateKey: &key, ReplacePeers: true, @@ -347,10 +345,3 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string { } return sb.String() } - -func getFwmark() int { - if runtime.GOOS == "linux" { - return nbnet.NetbirdFwmark - } - return 0 -} diff --git a/management/client/grpc.go b/management/client/grpc.go index 0b1804906..0234f866c 100644 --- a/management/client/grpc.go +++ b/management/client/grpc.go @@ -24,7 +24,6 @@ import ( "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/management/proto" - nbgrpc "github.com/netbirdio/netbird/util/grpc" ) const ConnectTimeout = 10 * time.Second @@ -58,7 +57,6 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE mgmCtx, addr, transportOption, - nbgrpc.WithCustomDialer(), grpc.WithBlock(), grpc.WithKeepaliveParams(keepalive.ClientParameters{ Time: 30 * time.Second, diff --git a/sharedsock/sock_linux.go b/sharedsock/sock_linux.go index 74ac6c163..02b4e174d 100644 --- a/sharedsock/sock_linux.go +++ b/sharedsock/sock_linux.go @@ -21,8 +21,6 @@ import ( log "github.com/sirupsen/logrus" "golang.org/x/sync/errgroup" "golang.org/x/sys/unix" - - nbnet "github.com/netbirdio/netbird/util/net" ) // ErrSharedSockStopped indicates that shared socket has been stopped @@ -84,18 +82,10 @@ func Listen(port int, filter BPFFilter) (_ net.PacketConn, err error) { return nil, fmt.Errorf("failed to create ipv4 raw socket: %w", err) } - if err = nbnet.SetSocketMark(rawSock.conn4); err != nil { - return nil, fmt.Errorf("failed to set SO_MARK on ipv4 socket: %w", err) - } - var sockErr error rawSock.conn6, sockErr = socket.Socket(unix.AF_INET6, unix.SOCK_RAW, unix.IPPROTO_UDP, "raw_udp6", nil) if sockErr != nil { log.Errorf("Failed to create ipv6 raw socket: %v", err) - } else { - if err = nbnet.SetSocketMark(rawSock.conn6); err != nil { - return nil, fmt.Errorf("failed to set SO_MARK on ipv6 socket: %w", err) - } } ipv4Instructions, ipv6Instructions, err := filter.GetInstructions(uint32(rawSock.port)) diff --git a/signal/client/grpc.go b/signal/client/grpc.go index 7c4535e28..7531608c3 100644 --- a/signal/client/grpc.go +++ b/signal/client/grpc.go @@ -23,7 +23,6 @@ import ( "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/management/client" "github.com/netbirdio/netbird/signal/proto" - nbgrpc "github.com/netbirdio/netbird/util/grpc" ) // ConnStateNotifier is a wrapper interface of the status recorder @@ -77,7 +76,6 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo sigCtx, addr, transportOption, - nbgrpc.WithCustomDialer(), grpc.WithBlock(), grpc.WithKeepaliveParams(keepalive.ClientParameters{ Time: 30 * time.Second, diff --git a/util/grpc/dialer.go b/util/grpc/dialer.go deleted file mode 100644 index 96b2bc32b..000000000 --- a/util/grpc/dialer.go +++ /dev/null @@ -1,22 +0,0 @@ -package grpc - -import ( - "context" - "net" - - log "github.com/sirupsen/logrus" - "google.golang.org/grpc" - - nbnet "github.com/netbirdio/netbird/util/net" -) - -func WithCustomDialer() grpc.DialOption { - return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { - conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr) - if err != nil { - log.Errorf("Failed to dial: %s", err) - return nil, err - } - return conn, nil - }) -} diff --git a/util/net/dialer.go b/util/net/dialer.go deleted file mode 100644 index 0786c667e..000000000 --- a/util/net/dialer.go +++ /dev/null @@ -1,21 +0,0 @@ -package net - -import ( - "net" -) - -// Dialer extends the standard net.Dialer with the ability to execute hooks before -// and after connections. This can be used to bypass the VPN for connections using this dialer. -type Dialer struct { - *net.Dialer -} - -// NewDialer returns a customized net.Dialer with overridden Control method -func NewDialer() *Dialer { - dialer := &Dialer{ - Dialer: &net.Dialer{}, - } - dialer.init() - - return dialer -} diff --git a/util/net/dialer_generic.go b/util/net/dialer_generic.go deleted file mode 100644 index 06fac3bbf..000000000 --- a/util/net/dialer_generic.go +++ /dev/null @@ -1,163 +0,0 @@ -//go:build !android && !ios - -package net - -import ( - "context" - "fmt" - "net" - "sync" - - "github.com/hashicorp/go-multierror" - log "github.com/sirupsen/logrus" -) - -type DialerDialHookFunc func(ctx context.Context, connID ConnectionID, resolvedAddresses []net.IPAddr) error -type DialerCloseHookFunc func(connID ConnectionID, conn *net.Conn) error - -var ( - dialerDialHooksMutex sync.RWMutex - dialerDialHooks []DialerDialHookFunc - dialerCloseHooksMutex sync.RWMutex - dialerCloseHooks []DialerCloseHookFunc -) - -// AddDialerHook allows adding a new hook to be executed before dialing. -func AddDialerHook(hook DialerDialHookFunc) { - dialerDialHooksMutex.Lock() - defer dialerDialHooksMutex.Unlock() - dialerDialHooks = append(dialerDialHooks, hook) -} - -// AddDialerCloseHook allows adding a new hook to be executed on connection close. -func AddDialerCloseHook(hook DialerCloseHookFunc) { - dialerCloseHooksMutex.Lock() - defer dialerCloseHooksMutex.Unlock() - dialerCloseHooks = append(dialerCloseHooks, hook) -} - -// RemoveDialerHook removes all dialer hooks. -func RemoveDialerHooks() { - dialerDialHooksMutex.Lock() - defer dialerDialHooksMutex.Unlock() - dialerDialHooks = nil - - dialerCloseHooksMutex.Lock() - defer dialerCloseHooksMutex.Unlock() - dialerCloseHooks = nil -} - -// DialContext wraps the net.Dialer's DialContext method to use the custom connection -func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { - var resolver *net.Resolver - if d.Resolver != nil { - resolver = d.Resolver - } - - connID := GenerateConnID() - if dialerDialHooks != nil { - if err := calliDialerHooks(ctx, connID, address, resolver); err != nil { - log.Errorf("Failed to call dialer hooks: %v", err) - } - } - - conn, err := d.Dialer.DialContext(ctx, network, address) - if err != nil { - return nil, fmt.Errorf("dial: %w", err) - } - - // Wrap the connection in Conn to handle Close with hooks - return &Conn{Conn: conn, ID: connID}, nil -} - -// Dial wraps the net.Dialer's Dial method to use the custom connection -func (d *Dialer) Dial(network, address string) (net.Conn, error) { - return d.DialContext(context.Background(), network, address) -} - -// Conn wraps a net.Conn to override the Close method -type Conn struct { - net.Conn - ID ConnectionID -} - -// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection -func (c *Conn) Close() error { - err := c.Conn.Close() - - dialerCloseHooksMutex.RLock() - defer dialerCloseHooksMutex.RUnlock() - - for _, hook := range dialerCloseHooks { - if err := hook(c.ID, &c.Conn); err != nil { - log.Errorf("Error executing dialer close hook: %v", err) - } - } - - return err -} - -func calliDialerHooks(ctx context.Context, connID ConnectionID, address string, resolver *net.Resolver) error { - host, _, err := net.SplitHostPort(address) - if err != nil { - return fmt.Errorf("split host and port: %w", err) - } - ips, err := resolver.LookupIPAddr(ctx, host) - if err != nil { - return fmt.Errorf("failed to resolve address %s: %w", address, err) - } - - log.Debugf("Dialer resolved IPs for %s: %v", address, ips) - - var result *multierror.Error - - dialerDialHooksMutex.RLock() - defer dialerDialHooksMutex.RUnlock() - for _, hook := range dialerDialHooks { - if err := hook(ctx, connID, ips); err != nil { - result = multierror.Append(result, fmt.Errorf("executing dial hook: %w", err)) - } - } - - return result.ErrorOrNil() -} - -func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) { - dialer := NewDialer() - dialer.LocalAddr = laddr - - conn, err := dialer.Dial(network, raddr.String()) - if err != nil { - return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err) - } - - udpConn, ok := conn.(*Conn).Conn.(*net.UDPConn) - if !ok { - if err := conn.Close(); err != nil { - log.Errorf("Failed to close connection: %v", err) - } - return nil, fmt.Errorf("expected UDP connection, got different type: %T", conn) - } - - return udpConn, nil -} - -func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) { - dialer := NewDialer() - dialer.LocalAddr = laddr - - conn, err := dialer.Dial(network, raddr.String()) - if err != nil { - return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err) - } - - tcpConn, ok := conn.(*Conn).Conn.(*net.TCPConn) - if !ok { - if err := conn.Close(); err != nil { - log.Errorf("Failed to close connection: %v", err) - } - return nil, fmt.Errorf("expected TCP connection, got different type: %T", conn) - } - - return tcpConn, nil -} diff --git a/util/net/dialer_linux.go b/util/net/dialer_linux.go deleted file mode 100644 index aed5c59a3..000000000 --- a/util/net/dialer_linux.go +++ /dev/null @@ -1,12 +0,0 @@ -//go:build !android - -package net - -import "syscall" - -// init configures the net.Dialer Control function to set the fwmark on the socket -func (d *Dialer) init() { - d.Dialer.Control = func(_, _ string, c syscall.RawConn) error { - return SetRawSocketMark(c) - } -} diff --git a/util/net/dialer_nonlinux.go b/util/net/dialer_nonlinux.go deleted file mode 100644 index 3254e6d06..000000000 --- a/util/net/dialer_nonlinux.go +++ /dev/null @@ -1,6 +0,0 @@ -//go:build !linux || android - -package net - -func (d *Dialer) init() { -} diff --git a/util/net/listener.go b/util/net/listener.go deleted file mode 100644 index f4d769f58..000000000 --- a/util/net/listener.go +++ /dev/null @@ -1,21 +0,0 @@ -package net - -import ( - "net" -) - -// ListenerConfig extends the standard net.ListenConfig with the ability to execute hooks before -// responding via the socket and after closing. This can be used to bypass the VPN for listeners. -type ListenerConfig struct { - *net.ListenConfig -} - -// NewListener creates a new ListenerConfig instance. -func NewListener() *ListenerConfig { - listener := &ListenerConfig{ - ListenConfig: &net.ListenConfig{}, - } - listener.init() - - return listener -} diff --git a/util/net/listener_generic.go b/util/net/listener_generic.go deleted file mode 100644 index 451279e9d..000000000 --- a/util/net/listener_generic.go +++ /dev/null @@ -1,163 +0,0 @@ -//go:build !android && !ios - -package net - -import ( - "context" - "fmt" - "net" - "sync" - - log "github.com/sirupsen/logrus" -) - -// ListenerWriteHookFunc defines the function signature for write hooks for PacketConn. -type ListenerWriteHookFunc func(connID ConnectionID, ip *net.IPAddr, data []byte) error - -// ListenerCloseHookFunc defines the function signature for close hooks for PacketConn. -type ListenerCloseHookFunc func(connID ConnectionID, conn net.PacketConn) error - -var ( - listenerWriteHooksMutex sync.RWMutex - listenerWriteHooks []ListenerWriteHookFunc - listenerCloseHooksMutex sync.RWMutex - listenerCloseHooks []ListenerCloseHookFunc -) - -// AddListenerWriteHook allows adding a new write hook to be executed before a UDP packet is sent. -func AddListenerWriteHook(hook ListenerWriteHookFunc) { - listenerWriteHooksMutex.Lock() - defer listenerWriteHooksMutex.Unlock() - listenerWriteHooks = append(listenerWriteHooks, hook) -} - -// AddListenerCloseHook allows adding a new hook to be executed upon closing a UDP connection. -func AddListenerCloseHook(hook ListenerCloseHookFunc) { - listenerCloseHooksMutex.Lock() - defer listenerCloseHooksMutex.Unlock() - listenerCloseHooks = append(listenerCloseHooks, hook) -} - -// RemoveListenerHooks removes all dialer hooks. -func RemoveListenerHooks() { - listenerWriteHooksMutex.Lock() - defer listenerWriteHooksMutex.Unlock() - listenerWriteHooks = nil - - listenerCloseHooksMutex.Lock() - defer listenerCloseHooksMutex.Unlock() - listenerCloseHooks = nil -} - -// ListenPacket listens on the network address and returns a PacketConn -// which includes support for write hooks. -func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) { - pc, err := l.ListenConfig.ListenPacket(ctx, network, address) - if err != nil { - return nil, fmt.Errorf("listen packet: %w", err) - } - connID := GenerateConnID() - return &PacketConn{PacketConn: pc, ID: connID, seenAddrs: &sync.Map{}}, nil -} - -// PacketConn wraps net.PacketConn to override its WriteTo and Close methods to include hook functionality. -type PacketConn struct { - net.PacketConn - ID ConnectionID - seenAddrs *sync.Map -} - -// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand. -func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { - callWriteHooks(c.ID, c.seenAddrs, b, addr) - return c.PacketConn.WriteTo(b, addr) -} - -// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection. -func (c *PacketConn) Close() error { - c.seenAddrs = &sync.Map{} - return closeConn(c.ID, c.PacketConn) -} - -// UDPConn wraps net.UDPConn to override its WriteTo and Close methods to include hook functionality. -type UDPConn struct { - *net.UDPConn - ID ConnectionID - seenAddrs *sync.Map -} - -// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand. -func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { - callWriteHooks(c.ID, c.seenAddrs, b, addr) - return c.UDPConn.WriteTo(b, addr) -} - -// Close overrides the net.UDPConn Close method to execute all registered hooks before closing the connection. -func (c *UDPConn) Close() error { - c.seenAddrs = &sync.Map{} - return closeConn(c.ID, c.UDPConn) -} - -func callWriteHooks(id ConnectionID, seenAddrs *sync.Map, b []byte, addr net.Addr) { - // Lookup the address in the seenAddrs map to avoid calling the hooks for every write - if _, loaded := seenAddrs.LoadOrStore(addr.String(), true); !loaded { - ipStr, _, splitErr := net.SplitHostPort(addr.String()) - if splitErr != nil { - log.Errorf("Error splitting IP address and port: %v", splitErr) - return - } - - ip, err := net.ResolveIPAddr("ip", ipStr) - if err != nil { - log.Errorf("Error resolving IP address: %v", err) - return - } - log.Debugf("Listener resolved IP for %s: %s", addr, ip) - - func() { - listenerWriteHooksMutex.RLock() - defer listenerWriteHooksMutex.RUnlock() - - for _, hook := range listenerWriteHooks { - if err := hook(id, ip, b); err != nil { - log.Errorf("Error executing listener write hook: %v", err) - } - } - }() - } -} - -func closeConn(id ConnectionID, conn net.PacketConn) error { - err := conn.Close() - - listenerCloseHooksMutex.RLock() - defer listenerCloseHooksMutex.RUnlock() - - for _, hook := range listenerCloseHooks { - if err := hook(id, conn); err != nil { - log.Errorf("Error executing listener close hook: %v", err) - } - } - - return err -} - -// ListenUDP listens on the network address and returns a transport.UDPConn -// which includes support for write and close hooks. -func ListenUDP(network string, laddr *net.UDPAddr) (*UDPConn, error) { - conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String()) - if err != nil { - return nil, fmt.Errorf("listen UDP: %w", err) - } - - packetConn := conn.(*PacketConn) - udpConn, ok := packetConn.PacketConn.(*net.UDPConn) - if !ok { - if err := packetConn.Close(); err != nil { - log.Errorf("Failed to close connection: %v", err) - } - return nil, fmt.Errorf("expected UDPConn, got different type: %T", udpConn) - } - - return &UDPConn{UDPConn: udpConn, ID: packetConn.ID, seenAddrs: &sync.Map{}}, nil -} diff --git a/util/net/listener_linux.go b/util/net/listener_linux.go deleted file mode 100644 index 8d332160a..000000000 --- a/util/net/listener_linux.go +++ /dev/null @@ -1,14 +0,0 @@ -//go:build !android - -package net - -import ( - "syscall" -) - -// init configures the net.ListenerConfig Control function to set the fwmark on the socket -func (l *ListenerConfig) init() { - l.ListenConfig.Control = func(_, _ string, c syscall.RawConn) error { - return SetRawSocketMark(c) - } -} diff --git a/util/net/listener_mobile.go b/util/net/listener_mobile.go deleted file mode 100644 index 0dbbb360b..000000000 --- a/util/net/listener_mobile.go +++ /dev/null @@ -1,11 +0,0 @@ -//go:build android || ios - -package net - -import ( - "net" -) - -func ListenUDP(network string, laddr *net.UDPAddr) (*net.UDPConn, error) { - return net.ListenUDP(network, laddr) -} diff --git a/util/net/listener_nonlinux.go b/util/net/listener_nonlinux.go deleted file mode 100644 index fb6eadaaa..000000000 --- a/util/net/listener_nonlinux.go +++ /dev/null @@ -1,6 +0,0 @@ -//go:build !linux || android - -package net - -func (l *ListenerConfig) init() { -} diff --git a/util/net/net.go b/util/net/net.go deleted file mode 100644 index 9ea7ae803..000000000 --- a/util/net/net.go +++ /dev/null @@ -1,17 +0,0 @@ -package net - -import "github.com/google/uuid" - -const ( - // NetbirdFwmark is the fwmark value used by Netbird via wireguard - NetbirdFwmark = 0x1BD00 -) - -// ConnectionID provides a globally unique identifier for network connections. -// It's used to track connections throughout their lifecycle so the close hook can correlate with the dial hook. -type ConnectionID string - -// GenerateConnID generates a unique identifier for each connection. -func GenerateConnID() ConnectionID { - return ConnectionID(uuid.NewString()) -} diff --git a/util/net/net_linux.go b/util/net/net_linux.go deleted file mode 100644 index 821417500..000000000 --- a/util/net/net_linux.go +++ /dev/null @@ -1,35 +0,0 @@ -//go:build !android - -package net - -import ( - "fmt" - "syscall" -) - -// SetSocketMark sets the SO_MARK option on the given socket connection -func SetSocketMark(conn syscall.Conn) error { - sysconn, err := conn.SyscallConn() - if err != nil { - return fmt.Errorf("get raw conn: %w", err) - } - - return SetRawSocketMark(sysconn) -} - -func SetRawSocketMark(conn syscall.RawConn) error { - var setErr error - - err := conn.Control(func(fd uintptr) { - setErr = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, NetbirdFwmark) - }) - if err != nil { - return fmt.Errorf("control: %w", err) - } - - if setErr != nil { - return fmt.Errorf("set SO_MARK: %w", setErr) - } - - return nil -} From 3875c29f6b1d92ef2fc00029e74036001c55829b Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Tue, 9 Apr 2024 01:56:52 +0900 Subject: [PATCH 20/28] Revert "Rollback new routing functionality (#1805)" (#1813) This reverts commit 9f32ccd4533d5301bcb901677af9816cb3408f92. --- .github/workflows/golang-test-darwin.yml | 3 + .github/workflows/golang-test-linux.yml | 10 +- .github/workflows/golang-test-windows.yml | 2 +- .github/workflows/golangci-lint.yml | 2 +- client/internal/engine.go | 16 + client/internal/peer/conn.go | 32 ++ client/internal/relay/relay.go | 7 +- client/internal/routemanager/client.go | 63 ++- client/internal/routemanager/manager.go | 79 ++- client/internal/routemanager/manager_test.go | 30 +- client/internal/routemanager/mock.go | 5 + client/internal/routemanager/routemanager.go | 126 +++++ .../routemanager/server_nonandroid.go | 57 +- client/internal/routemanager/systemops.go | 407 ++++++++++++++ .../routemanager/systemops_android.go | 24 +- client/internal/routemanager/systemops_bsd.go | 1 - .../internal/routemanager/systemops_darwin.go | 61 ++ .../routemanager/systemops_darwin_test.go | 100 ++++ client/internal/routemanager/systemops_ios.go | 26 +- .../internal/routemanager/systemops_linux.go | 531 +++++++++++++++--- .../routemanager/systemops_linux_test.go | 207 +++++++ .../routemanager/systemops_nonandroid.go | 120 ---- .../routemanager/systemops_nonlinux.go | 38 +- ...s_nonandroid_test.go => systemops_test.go} | 212 +++++-- .../routemanager/systemops_unix_test.go | 234 ++++++++ .../routemanager/systemops_windows.go | 88 ++- .../routemanager/systemops_windows_test.go | 289 ++++++++++ client/internal/stdnet/dialer.go | 24 + client/internal/stdnet/listener.go | 20 + client/internal/wgproxy/proxy_ebpf.go | 9 +- client/internal/wgproxy/proxy_userspace.go | 4 +- go.mod | 2 +- iface/wg_configurer_kernel.go | 4 +- iface/wg_configurer_usp.go | 11 +- management/client/grpc.go | 2 + sharedsock/sock_linux.go | 10 + signal/client/grpc.go | 2 + util/grpc/dialer.go | 22 + util/net/dialer.go | 21 + util/net/dialer_generic.go | 163 ++++++ util/net/dialer_linux.go | 12 + util/net/dialer_nonlinux.go | 6 + util/net/listener.go | 21 + util/net/listener_generic.go | 163 ++++++ util/net/listener_linux.go | 14 + util/net/listener_mobile.go | 11 + util/net/listener_nonlinux.go | 6 + util/net/net.go | 17 + util/net/net_linux.go | 35 ++ 49 files changed, 2982 insertions(+), 367 deletions(-) create mode 100644 client/internal/routemanager/routemanager.go create mode 100644 client/internal/routemanager/systemops.go create mode 100644 client/internal/routemanager/systemops_darwin.go create mode 100644 client/internal/routemanager/systemops_darwin_test.go create mode 100644 client/internal/routemanager/systemops_linux_test.go delete mode 100644 client/internal/routemanager/systemops_nonandroid.go rename client/internal/routemanager/{systemops_nonandroid_test.go => systemops_test.go} (59%) create mode 100644 client/internal/routemanager/systemops_unix_test.go create mode 100644 client/internal/routemanager/systemops_windows_test.go create mode 100644 client/internal/stdnet/dialer.go create mode 100644 client/internal/stdnet/listener.go create mode 100644 util/grpc/dialer.go create mode 100644 util/net/dialer.go create mode 100644 util/net/dialer_generic.go create mode 100644 util/net/dialer_linux.go create mode 100644 util/net/dialer_nonlinux.go create mode 100644 util/net/listener.go create mode 100644 util/net/listener_generic.go create mode 100644 util/net/listener_linux.go create mode 100644 util/net/listener_mobile.go create mode 100644 util/net/listener_nonlinux.go create mode 100644 util/net/net.go create mode 100644 util/net/net_linux.go diff --git a/.github/workflows/golang-test-darwin.yml b/.github/workflows/golang-test-darwin.yml index f8afd3d6e..d7007c860 100644 --- a/.github/workflows/golang-test-darwin.yml +++ b/.github/workflows/golang-test-darwin.yml @@ -32,6 +32,9 @@ jobs: restore-keys: | macos-go- + - name: Install libpcap + run: brew install libpcap + - name: Install modules run: go mod tidy diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index 74e6d1203..42f740e9b 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -36,7 +36,11 @@ jobs: uses: actions/checkout@v3 - name: Install dependencies - run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib + run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev + + - name: Install 32-bit libpcap + if: matrix.arch == '386' + run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386 - name: Install modules run: go mod tidy @@ -67,7 +71,7 @@ jobs: uses: actions/checkout@v3 - name: Install dependencies - run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib + run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev - name: Install modules run: go mod tidy @@ -82,7 +86,7 @@ jobs: run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock - name: Generate RouteManager Test bin - run: CGO_ENABLED=0 go test -c -o routemanager-testing.bin ./client/internal/routemanager/... + run: CGO_ENABLED=1 go test -c -o routemanager-testing.bin -tags netgo -ldflags '-w -extldflags "-static -ldbus-1 -lpcap"' ./client/internal/routemanager/... - name: Generate nftables Manager Test bin run: CGO_ENABLED=0 go test -c -o nftablesmanager-testing.bin ./client/firewall/nftables/... diff --git a/.github/workflows/golang-test-windows.yml b/.github/workflows/golang-test-windows.yml index 6027d3626..2d63acbcd 100644 --- a/.github/workflows/golang-test-windows.yml +++ b/.github/workflows/golang-test-windows.yml @@ -46,7 +46,7 @@ jobs: - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=C:\Users\runneradmin\AppData\Local\go-build - name: test - run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 5m -p 1 ./... > test-out.txt 2>&1" + run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 10m -p 1 ./... > test-out.txt 2>&1" - name: test output if: ${{ always() }} run: Get-Content test-out.txt diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index 9f543c74c..13228250d 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -40,7 +40,7 @@ jobs: cache: false - name: Install dependencies if: matrix.os == 'ubuntu-latest' - run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev + run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev - name: golangci-lint uses: golangci/golangci-lint-action@v3 with: diff --git a/client/internal/engine.go b/client/internal/engine.go index 13ef8ce15..d6238c4b3 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -94,6 +94,9 @@ type Engine struct { // peerConns is a map that holds all the peers that are known to this peer peerConns map[string]*peer.Conn + beforePeerHook peer.BeforeAddPeerHookFunc + afterPeerHook peer.AfterRemovePeerHookFunc + // rpManager is a Rosenpass manager rpManager *rosenpass.Manager @@ -261,6 +264,14 @@ func (e *Engine) Start() error { e.dnsServer = dnsServer e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, initialRoutes) + beforePeerHook, afterPeerHook, err := e.routeManager.Init() + if err != nil { + log.Errorf("Failed to initialize route manager: %s", err) + } else { + e.beforePeerHook = beforePeerHook + e.afterPeerHook = afterPeerHook + } + e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener) err = e.wgInterfaceCreate() @@ -810,6 +821,11 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error { } e.peerConns[peerKey] = conn + if e.beforePeerHook != nil && e.afterPeerHook != nil { + conn.AddBeforeAddPeerHook(e.beforePeerHook) + conn.AddAfterRemovePeerHook(e.afterPeerHook) + } + err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn) if err != nil { log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err) diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 17ef7e87f..f3d07dcad 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -20,6 +20,7 @@ import ( "github.com/netbirdio/netbird/iface/bind" signal "github.com/netbirdio/netbird/signal/client" sProto "github.com/netbirdio/netbird/signal/proto" + nbnet "github.com/netbirdio/netbird/util/net" "github.com/netbirdio/netbird/version" ) @@ -100,6 +101,9 @@ type IceCredentials struct { Pwd string } +type BeforeAddPeerHookFunc func(connID nbnet.ConnectionID, IP net.IP) error +type AfterRemovePeerHookFunc func(connID nbnet.ConnectionID) error + type Conn struct { config ConnConfig mu sync.Mutex @@ -138,6 +142,10 @@ type Conn struct { remoteEndpoint *net.UDPAddr remoteConn *ice.Conn + + connID nbnet.ConnectionID + beforeAddPeerHooks []BeforeAddPeerHookFunc + afterRemovePeerHooks []AfterRemovePeerHookFunc } // meta holds meta information about a connection @@ -393,6 +401,14 @@ func isRelayCandidate(candidate ice.Candidate) bool { return candidate.Type() == ice.CandidateTypeRelay } +func (conn *Conn) AddBeforeAddPeerHook(hook BeforeAddPeerHookFunc) { + conn.beforeAddPeerHooks = append(conn.beforeAddPeerHooks, hook) +} + +func (conn *Conn) AddAfterRemovePeerHook(hook AfterRemovePeerHookFunc) { + conn.afterRemovePeerHooks = append(conn.afterRemovePeerHooks, hook) +} + // configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, remoteRosenpassPubKey []byte, remoteRosenpassAddr string) (net.Addr, error) { conn.mu.Lock() @@ -421,6 +437,13 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem conn.remoteEndpoint = endpointUdpAddr log.Debugf("Conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP) + conn.connID = nbnet.GenerateConnID() + for _, hook := range conn.beforeAddPeerHooks { + if err := hook(conn.connID, endpointUdpAddr.IP); err != nil { + log.Errorf("Before add peer hook failed: %v", err) + } + } + err = conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, endpointUdpAddr, conn.config.WgConfig.PreSharedKey) if err != nil { if conn.wgProxy != nil { @@ -511,6 +534,15 @@ func (conn *Conn) cleanup() error { // todo: is it problem if we try to remove a peer what is never existed? err3 = conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey) + if conn.connID != "" { + for _, hook := range conn.afterRemovePeerHooks { + if err := hook(conn.connID); err != nil { + log.Errorf("After remove peer hook failed: %v", err) + } + } + } + conn.connID = "" + if conn.notifyDisconnected != nil { conn.notifyDisconnected() conn.notifyDisconnected = nil diff --git a/client/internal/relay/relay.go b/client/internal/relay/relay.go index ad3b94f2a..84fd72e49 100644 --- a/client/internal/relay/relay.go +++ b/client/internal/relay/relay.go @@ -12,6 +12,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/internal/stdnet" + nbnet "github.com/netbirdio/netbird/util/net" ) // ProbeResult holds the info about the result of a relay probe request @@ -95,15 +96,13 @@ func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) switch uri.Proto { case stun.ProtoTypeUDP: var err error - listener := &net.ListenConfig{} - conn, err = listener.ListenPacket(ctx, "udp", "") + conn, err = nbnet.NewListener().ListenPacket(ctx, "udp", "") if err != nil { probeErr = fmt.Errorf("listen: %w", err) return } case stun.ProtoTypeTCP: - dialer := &net.Dialer{} - tcpConn, err := dialer.DialContext(ctx, "tcp", turnServerAddr) + tcpConn, err := nbnet.NewDialer().DialContext(ctx, "tcp", turnServerAddr) if err != nil { probeErr = fmt.Errorf("dial: %w", err) return diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go index f7ead5827..38cf4bf65 100644 --- a/client/internal/routemanager/client.go +++ b/client/internal/routemanager/client.go @@ -41,6 +41,7 @@ type clientNetwork struct { func newClientNetworkWatcher(ctx context.Context, wgInterface *iface.WGIface, statusRecorder *peer.Status, network netip.Prefix) *clientNetwork { ctx, cancel := context.WithCancel(ctx) + client := &clientNetwork{ ctx: ctx, stop: cancel, @@ -72,6 +73,18 @@ func (c *clientNetwork) getRouterPeerStatuses() map[string]routerPeerStatus { return routePeerStatuses } +// getBestRouteFromStatuses determines the most optimal route from the available routes +// within a clientNetwork, taking into account peer connection status, route metrics, and +// preference for non-relayed and direct connections. +// +// It follows these prioritization rules: +// * Connected peers: Only routes with connected peers are considered. +// * Metric: Routes with lower metrics (better) are prioritized. +// * Non-relayed: Routes without relays are preferred. +// * Direct connections: Routes with direct peer connections are favored. +// * Stability: In case of equal scores, the currently active route (if any) is maintained. +// +// It returns the ID of the selected optimal route. func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]routerPeerStatus) string { chosen := "" chosenScore := 0 @@ -158,7 +171,7 @@ func (c *clientNetwork) startPeersStatusChangeWatcher() { func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error { state, err := c.statusRecorder.GetPeer(peerKey) if err != nil { - return err + return fmt.Errorf("get peer state: %v", err) } delete(state.Routes, c.network.String()) @@ -172,7 +185,7 @@ func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error { err = c.wgInterface.RemoveAllowedIP(peerKey, c.network.String()) if err != nil { - return fmt.Errorf("couldn't remove allowed IP %s removed for peer %s, err: %v", + return fmt.Errorf("remove allowed IP %s removed for peer %s, err: %v", c.network, c.chosenRoute.Peer, err) } return nil @@ -180,30 +193,26 @@ func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error { func (c *clientNetwork) removeRouteFromPeerAndSystem() error { if c.chosenRoute != nil { - err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer) - if err != nil { - return err + if err := removeVPNRoute(c.network, c.wgInterface.Name()); err != nil { + return fmt.Errorf("remove route %s from system, err: %v", c.network, err) } - err = removeFromRouteTableIfNonSystem(c.network, c.wgInterface.Address().IP.String()) - if err != nil { - return fmt.Errorf("couldn't remove route %s from system, err: %v", - c.network, err) + + if err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer); err != nil { + return fmt.Errorf("remove route: %v", err) } } return nil } func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { - - var err error - routerPeerStatuses := c.getRouterPeerStatuses() chosen := c.getBestRouteFromStatuses(routerPeerStatuses) + + // If no route is chosen, remove the route from the peer and system if chosen == "" { - err = c.removeRouteFromPeerAndSystem() - if err != nil { - return err + if err := c.removeRouteFromPeerAndSystem(); err != nil { + return fmt.Errorf("remove route from peer and system: %v", err) } c.chosenRoute = nil @@ -211,6 +220,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { return nil } + // If the chosen route is the same as the current route, do nothing if c.chosenRoute != nil && c.chosenRoute.ID == chosen { if c.chosenRoute.IsEqual(c.routes[chosen]) { return nil @@ -218,13 +228,13 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { } if c.chosenRoute != nil { - err = c.removeRouteFromWireguardPeer(c.chosenRoute.Peer) - if err != nil { - return err + // If a previous route exists, remove it from the peer + if err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer); err != nil { + return fmt.Errorf("remove route from peer: %v", err) } } else { - err = addToRouteTableIfNoExists(c.network, c.wgInterface.Address().IP.String()) - if err != nil { + // otherwise add the route to the system + if err := addVPNRoute(c.network, c.wgInterface.Name()); err != nil { return fmt.Errorf("route %s couldn't be added for peer %s, err: %v", c.network.String(), c.wgInterface.Address().IP.String(), err) } @@ -245,8 +255,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { } } - err = c.wgInterface.AddAllowedIP(c.chosenRoute.Peer, c.network.String()) - if err != nil { + if err := c.wgInterface.AddAllowedIP(c.chosenRoute.Peer, c.network.String()); err != nil { log.Errorf("couldn't add allowed IP %s added for peer %s, err: %v", c.network, c.chosenRoute.Peer, err) } @@ -287,21 +296,21 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() { log.Debugf("stopping watcher for network %s", c.network) err := c.removeRouteFromPeerAndSystem() if err != nil { - log.Error(err) + log.Errorf("Couldn't remove route from peer and system for network %s: %v", c.network, err) } return case <-c.peerStateUpdate: err := c.recalculateRouteAndUpdatePeerAndSystem() if err != nil { - log.Error(err) + log.Errorf("Couldn't recalculate route and update peer and system: %v", err) } case update := <-c.routeUpdate: if update.updateSerial < c.updateSerial { - log.Warnf("received a routes update with smaller serial number, ignoring it") + log.Warnf("Received a routes update with smaller serial number, ignoring it") continue } - log.Debugf("received a new client network route update for %s", c.network) + log.Debugf("Received a new client network route update for %s", c.network) c.handleUpdate(update) @@ -309,7 +318,7 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() { err := c.recalculateRouteAndUpdatePeerAndSystem() if err != nil { - log.Error(err) + log.Errorf("Couldn't recalculate route and update peer and system for network %s: %v", c.network, err) } c.startPeersStatusChangeWatcher() diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index b624d8c34..36a37f02c 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -2,6 +2,10 @@ package routemanager import ( "context" + "fmt" + "net" + "net/netip" + "net/url" "runtime" "sync" @@ -15,8 +19,14 @@ import ( "github.com/netbirdio/netbird/version" ) +var defaultv4 = netip.PrefixFrom(netip.IPv4Unspecified(), 0) + +// nolint:unused +var defaultv6 = netip.PrefixFrom(netip.IPv6Unspecified(), 0) + // Manager is a route manager interface type Manager interface { + Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error SetRouteChangeListener(listener listener.NetworkChangeListener) InitialRouteRange() []string @@ -56,6 +66,24 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, return dm } +// Init sets up the routing +func (m *DefaultManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { + if err := cleanupRouting(); err != nil { + log.Warnf("Failed cleaning up routing: %v", err) + } + + mgmtAddress := m.statusRecorder.GetManagementState().URL + signalAddress := m.statusRecorder.GetSignalState().URL + ips := resolveURLsToIPs([]string{mgmtAddress, signalAddress}) + + beforePeerHook, afterPeerHook, err := setupRouting(ips, m.wgInterface) + if err != nil { + return nil, nil, fmt.Errorf("setup routing: %w", err) + } + log.Info("Routing setup complete") + return beforePeerHook, afterPeerHook, nil +} + func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error { var err error m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder) @@ -71,9 +99,15 @@ func (m *DefaultManager) Stop() { if m.serverRouter != nil { m.serverRouter.cleanUp() } + if err := cleanupRouting(); err != nil { + log.Errorf("Error cleaning up routing: %v", err) + } else { + log.Info("Routing cleanup complete") + } + m.ctx = nil } -// UpdateRoutes compares received routes with existing routes and remove, update or add them to the client and server maps +// UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error { select { case <-m.ctx.Done(): @@ -91,7 +125,7 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro if m.serverRouter != nil { err := m.serverRouter.updateRoutes(newServerRoutesMap) if err != nil { - return err + return fmt.Errorf("update routes: %w", err) } } @@ -156,11 +190,7 @@ func (m *DefaultManager) classifiesRoutes(newRoutes []*route.Route) (map[string] for _, newRoute := range newRoutes { networkID := route.GetHAUniqueID(newRoute) if !ownNetworkIDs[networkID] { - // if prefix is too small, lets assume is a possible default route which is not yet supported - // we skip this route management - if newRoute.Network.Bits() < minRangeBits { - log.Errorf("this agent version: %s, doesn't support default routes, received %s, skipping this route", - version.NetbirdVersion(), newRoute.Network) + if !isPrefixSupported(newRoute.Network) { continue } newClientRoutesIDMap[networkID] = append(newClientRoutesIDMap[networkID], newRoute) @@ -178,3 +208,38 @@ func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Rou } return rs } + +func isPrefixSupported(prefix netip.Prefix) bool { + switch runtime.GOOS { + case "linux", "windows", "darwin": + return true + } + + // If prefix is too small, lets assume it is a possible default prefix which is not yet supported + // we skip this prefix management + if prefix.Bits() <= minRangeBits { + log.Warnf("This agent version: %s, doesn't support default routes, received %s, skipping this prefix", + version.NetbirdVersion(), prefix) + return false + } + return true +} + +// resolveURLsToIPs takes a slice of URLs, resolves them to IP addresses and returns a slice of IPs. +func resolveURLsToIPs(urls []string) []net.IP { + var ips []net.IP + for _, rawurl := range urls { + u, err := url.Parse(rawurl) + if err != nil { + log.Errorf("Failed to parse url %s: %v", rawurl, err) + continue + } + ipAddrs, err := net.LookupIP(u.Hostname()) + if err != nil { + log.Errorf("Failed to resolve host %s: %v", u.Hostname(), err) + continue + } + ips = append(ips, ipAddrs...) + } + return ips +} diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index 2e5cf6649..03e77e09b 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -28,13 +28,14 @@ const remotePeerKey2 = "remote1" func TestManagerUpdateRoutes(t *testing.T) { testCases := []struct { - name string - inputInitRoutes []*route.Route - inputRoutes []*route.Route - inputSerial uint64 - removeSrvRouter bool - serverRoutesExpected int - clientNetworkWatchersExpected int + name string + inputInitRoutes []*route.Route + inputRoutes []*route.Route + inputSerial uint64 + removeSrvRouter bool + serverRoutesExpected int + clientNetworkWatchersExpected int + clientNetworkWatchersExpectedAllowed int }{ { name: "Should create 2 client networks", @@ -200,8 +201,9 @@ func TestManagerUpdateRoutes(t *testing.T) { Enabled: true, }, }, - inputSerial: 1, - clientNetworkWatchersExpected: 0, + inputSerial: 1, + clientNetworkWatchersExpected: 0, + clientNetworkWatchersExpectedAllowed: 1, }, { name: "Remove 1 Client Route", @@ -415,6 +417,10 @@ func TestManagerUpdateRoutes(t *testing.T) { statusRecorder := peer.NewRecorder("https://mgm") ctx := context.TODO() routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder, nil) + + _, _, err = routeManager.Init() + + require.NoError(t, err, "should init route manager") defer routeManager.Stop() if testCase.removeSrvRouter { @@ -429,7 +435,11 @@ func TestManagerUpdateRoutes(t *testing.T) { err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes) require.NoError(t, err, "should update routes") - require.Len(t, routeManager.clientNetworks, testCase.clientNetworkWatchersExpected, "client networks size should match") + expectedWatchers := testCase.clientNetworkWatchersExpected + if (runtime.GOOS == "linux" || runtime.GOOS == "windows" || runtime.GOOS == "darwin") && testCase.clientNetworkWatchersExpectedAllowed != 0 { + expectedWatchers = testCase.clientNetworkWatchersExpectedAllowed + } + require.Len(t, routeManager.clientNetworks, expectedWatchers, "client networks size should match") if runtime.GOOS == "linux" && routeManager.serverRouter != nil { sr := routeManager.serverRouter.(*defaultServerRouter) diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index a1214cbb9..dd2c28e59 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -6,6 +6,7 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/listener" + "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/route" ) @@ -16,6 +17,10 @@ type MockManager struct { StopFunc func() } +func (m *MockManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { + return nil, nil, nil +} + // InitialRouteRange mock implementation of InitialRouteRange from Manager interface func (m *MockManager) InitialRouteRange() []string { return nil diff --git a/client/internal/routemanager/routemanager.go b/client/internal/routemanager/routemanager.go new file mode 100644 index 000000000..8f9ff9f4b --- /dev/null +++ b/client/internal/routemanager/routemanager.go @@ -0,0 +1,126 @@ +//go:build !android && !ios + +package routemanager + +import ( + "errors" + "fmt" + "net/netip" + "sync" + + "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" + + nbnet "github.com/netbirdio/netbird/util/net" +) + +type ref struct { + count int + nexthop netip.Addr + intf string +} + +type RouteManager struct { + // refCountMap keeps track of the reference ref for prefixes + refCountMap map[netip.Prefix]ref + // prefixMap keeps track of the prefixes associated with a connection ID for removal + prefixMap map[nbnet.ConnectionID][]netip.Prefix + addRoute AddRouteFunc + removeRoute RemoveRouteFunc + mutex sync.Mutex +} + +type AddRouteFunc func(prefix netip.Prefix) (nexthop netip.Addr, intf string, err error) +type RemoveRouteFunc func(prefix netip.Prefix, nexthop netip.Addr, intf string) error + +func NewRouteManager(addRoute AddRouteFunc, removeRoute RemoveRouteFunc) *RouteManager { + // TODO: read initial routing table into refCountMap + return &RouteManager{ + refCountMap: map[netip.Prefix]ref{}, + prefixMap: map[nbnet.ConnectionID][]netip.Prefix{}, + addRoute: addRoute, + removeRoute: removeRoute, + } +} + +func (rm *RouteManager) AddRouteRef(connID nbnet.ConnectionID, prefix netip.Prefix) error { + rm.mutex.Lock() + defer rm.mutex.Unlock() + + ref := rm.refCountMap[prefix] + log.Debugf("Increasing route ref count %d for prefix %s", ref.count, prefix) + + // Add route to the system, only if it's a new prefix + if ref.count == 0 { + log.Debugf("Adding route for prefix %s", prefix) + nexthop, intf, err := rm.addRoute(prefix) + if errors.Is(err, ErrRouteNotFound) { + return nil + } + if errors.Is(err, ErrRouteNotAllowed) { + log.Debugf("Adding route for prefix %s: %s", prefix, err) + } + if err != nil { + return fmt.Errorf("failed to add route for prefix %s: %w", prefix, err) + } + ref.nexthop = nexthop + ref.intf = intf + } + + ref.count++ + rm.refCountMap[prefix] = ref + rm.prefixMap[connID] = append(rm.prefixMap[connID], prefix) + + return nil +} + +func (rm *RouteManager) RemoveRouteRef(connID nbnet.ConnectionID) error { + rm.mutex.Lock() + defer rm.mutex.Unlock() + + prefixes, ok := rm.prefixMap[connID] + if !ok { + log.Debugf("No prefixes found for connection ID %s", connID) + return nil + } + + var result *multierror.Error + for _, prefix := range prefixes { + ref := rm.refCountMap[prefix] + log.Debugf("Decreasing route ref count %d for prefix %s", ref.count, prefix) + if ref.count == 1 { + log.Debugf("Removing route for prefix %s", prefix) + // TODO: don't fail if the route is not found + if err := rm.removeRoute(prefix, ref.nexthop, ref.intf); err != nil { + result = multierror.Append(result, fmt.Errorf("remove route for prefix %s: %w", prefix, err)) + continue + } + delete(rm.refCountMap, prefix) + } else { + ref.count-- + rm.refCountMap[prefix] = ref + } + } + delete(rm.prefixMap, connID) + + return result.ErrorOrNil() +} + +// Flush removes all references and routes from the system +func (rm *RouteManager) Flush() error { + rm.mutex.Lock() + defer rm.mutex.Unlock() + + var result *multierror.Error + for prefix := range rm.refCountMap { + log.Debugf("Removing route for prefix %s", prefix) + ref := rm.refCountMap[prefix] + if err := rm.removeRoute(prefix, ref.nexthop, ref.intf); err != nil { + result = multierror.Append(result, fmt.Errorf("remove route for prefix %s: %w", prefix, err)) + } + } + rm.refCountMap = map[netip.Prefix]ref{} + rm.prefixMap = map[nbnet.ConnectionID][]netip.Prefix{} + + return result.ErrorOrNil() +} diff --git a/client/internal/routemanager/server_nonandroid.go b/client/internal/routemanager/server_nonandroid.go index 192367877..af82dc913 100644 --- a/client/internal/routemanager/server_nonandroid.go +++ b/client/internal/routemanager/server_nonandroid.go @@ -4,6 +4,7 @@ package routemanager import ( "context" + "fmt" "net/netip" "sync" @@ -48,7 +49,7 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) er oldRoute := m.routes[routeID] err := m.removeFromServerNetwork(oldRoute) if err != nil { - log.Errorf("unable to remove route id: %s, network %s, from server, got: %v", + log.Errorf("Unable to remove route id: %s, network %s, from server, got: %v", oldRoute.ID, oldRoute.Network, err) } delete(m.routes, routeID) @@ -62,7 +63,7 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) er err := m.addToServerNetwork(newRoute) if err != nil { - log.Errorf("unable to add route %s from server, got: %v", newRoute.ID, err) + log.Errorf("Unable to add route %s from server, got: %v", newRoute.ID, err) continue } m.routes[id] = newRoute @@ -81,15 +82,22 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) er func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error { select { case <-m.ctx.Done(): - log.Infof("not removing from server network because context is done") + log.Infof("Not removing from server network because context is done") return m.ctx.Err() default: m.mux.Lock() defer m.mux.Unlock() - err := m.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route)) + + routerPair, err := routeToRouterPair(m.wgInterface.Address().Masked().String(), route) if err != nil { - return err + return fmt.Errorf("parse prefix: %w", err) } + + err = m.firewall.RemoveRoutingRules(routerPair) + if err != nil { + return fmt.Errorf("remove routing rules: %w", err) + } + delete(m.routes, route.ID) state := m.statusRecorder.GetLocalPeerState() @@ -103,15 +111,22 @@ func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error { select { case <-m.ctx.Done(): - log.Infof("not adding to server network because context is done") + log.Infof("Not adding to server network because context is done") return m.ctx.Err() default: m.mux.Lock() defer m.mux.Unlock() - err := m.firewall.InsertRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route)) + + routerPair, err := routeToRouterPair(m.wgInterface.Address().Masked().String(), route) if err != nil { - return err + return fmt.Errorf("parse prefix: %w", err) } + + err = m.firewall.InsertRoutingRules(routerPair) + if err != nil { + return fmt.Errorf("insert routing rules: %w", err) + } + m.routes[route.ID] = route state := m.statusRecorder.GetLocalPeerState() @@ -129,23 +144,33 @@ func (m *defaultServerRouter) cleanUp() { m.mux.Lock() defer m.mux.Unlock() for _, r := range m.routes { - err := m.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), r)) + routerPair, err := routeToRouterPair(m.wgInterface.Address().Masked().String(), r) if err != nil { - log.Warnf("failed to remove clean up route: %s", r.ID) + log.Errorf("Failed to convert route to router pair: %v", err) + continue + } + + err = m.firewall.RemoveRoutingRules(routerPair) + if err != nil { + log.Errorf("Failed to remove cleanup route: %v", err) } - state := m.statusRecorder.GetLocalPeerState() - state.Routes = nil - m.statusRecorder.UpdateLocalPeerState(state) } + + state := m.statusRecorder.GetLocalPeerState() + state.Routes = nil + m.statusRecorder.UpdateLocalPeerState(state) } -func routeToRouterPair(source string, route *route.Route) firewall.RouterPair { - parsed := netip.MustParsePrefix(source).Masked() +func routeToRouterPair(source string, route *route.Route) (firewall.RouterPair, error) { + parsed, err := netip.ParsePrefix(source) + if err != nil { + return firewall.RouterPair{}, err + } return firewall.RouterPair{ ID: route.ID, Source: parsed.String(), Destination: route.Network.Masked().String(), Masquerade: route.Masquerade, - } + }, nil } diff --git a/client/internal/routemanager/systemops.go b/client/internal/routemanager/systemops.go new file mode 100644 index 000000000..a91f53636 --- /dev/null +++ b/client/internal/routemanager/systemops.go @@ -0,0 +1,407 @@ +//go:build !android && !ios + +package routemanager + +import ( + "context" + "errors" + "fmt" + "net" + "net/netip" + + "github.com/hashicorp/go-multierror" + "github.com/libp2p/go-netroute" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/iface" + nbnet "github.com/netbirdio/netbird/util/net" +) + +var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1) +var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1) +var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1) +var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1) + +var ErrRouteNotFound = errors.New("route not found") +var ErrRouteNotAllowed = errors.New("route not allowed") + +// TODO: fix: for default our wg address now appears as the default gw +func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error { + addr := netip.IPv4Unspecified() + if prefix.Addr().Is6() { + addr = netip.IPv6Unspecified() + } + + defaultGateway, _, err := getNextHop(addr) + if err != nil && !errors.Is(err, ErrRouteNotFound) { + return fmt.Errorf("get existing route gateway: %s", err) + } + + if !prefix.Contains(defaultGateway) { + log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", defaultGateway, prefix) + return nil + } + + gatewayPrefix := netip.PrefixFrom(defaultGateway, 32) + if defaultGateway.Is6() { + gatewayPrefix = netip.PrefixFrom(defaultGateway, 128) + } + + ok, err := existsInRouteTable(gatewayPrefix) + if err != nil { + return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err) + } + + if ok { + log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix) + return nil + } + + var exitIntf string + gatewayHop, intf, err := getNextHop(defaultGateway) + if err != nil && !errors.Is(err, ErrRouteNotFound) { + return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err) + } + if intf != nil { + exitIntf = intf.Name + } + + log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop) + return addToRouteTable(gatewayPrefix, gatewayHop, exitIntf) +} + +func getNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) { + r, err := netroute.New() + if err != nil { + return netip.Addr{}, nil, fmt.Errorf("new netroute: %w", err) + } + intf, gateway, preferredSrc, err := r.Route(ip.AsSlice()) + if err != nil { + log.Warnf("Failed to get route for %s: %v", ip, err) + return netip.Addr{}, nil, ErrRouteNotFound + } + + log.Debugf("Route for %s: interface %v, nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc) + if gateway == nil { + if preferredSrc == nil { + return netip.Addr{}, nil, ErrRouteNotFound + } + log.Debugf("No next hop found for ip %s, using preferred source %s", ip, preferredSrc) + + addr, ok := netip.AddrFromSlice(preferredSrc) + if !ok { + return netip.Addr{}, nil, fmt.Errorf("failed to parse IP address: %s", preferredSrc) + } + return addr.Unmap(), intf, nil + } + + addr, ok := netip.AddrFromSlice(gateway) + if !ok { + return netip.Addr{}, nil, fmt.Errorf("failed to parse IP address: %s", gateway) + } + + return addr.Unmap(), intf, nil +} + +func existsInRouteTable(prefix netip.Prefix) (bool, error) { + routes, err := getRoutesFromTable() + if err != nil { + return false, fmt.Errorf("get routes from table: %w", err) + } + for _, tableRoute := range routes { + if tableRoute == prefix { + return true, nil + } + } + return false, nil +} + +func isSubRange(prefix netip.Prefix) (bool, error) { + routes, err := getRoutesFromTable() + if err != nil { + return false, fmt.Errorf("get routes from table: %w", err) + } + for _, tableRoute := range routes { + if tableRoute.Bits() > minRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() { + return true, nil + } + } + return false, nil +} + +// addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface. +// If the next hop or interface is pointing to the VPN interface, it will return the initial values. +func addRouteToNonVPNIntf( + prefix netip.Prefix, + vpnIntf *iface.WGIface, + initialNextHop netip.Addr, + initialIntf *net.Interface, +) (netip.Addr, string, error) { + addr := prefix.Addr() + switch { + case addr.IsLoopback(), + addr.IsLinkLocalUnicast(), + addr.IsLinkLocalMulticast(), + addr.IsInterfaceLocalMulticast(), + addr.IsUnspecified(), + addr.IsMulticast(): + + return netip.Addr{}, "", ErrRouteNotAllowed + } + + // Determine the exit interface and next hop for the prefix, so we can add a specific route + nexthop, intf, err := getNextHop(addr) + if err != nil { + return netip.Addr{}, "", fmt.Errorf("get next hop: %w", err) + } + + log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop, prefix, intf) + exitNextHop := nexthop + var exitIntf string + if intf != nil { + exitIntf = intf.Name + } + + vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP) + if !ok { + return netip.Addr{}, "", fmt.Errorf("failed to convert vpn address to netip.Addr") + } + + // if next hop is the VPN address or the interface is the VPN interface, we should use the initial values + if exitNextHop == vpnAddr || exitIntf == vpnIntf.Name() { + log.Debugf("Route for prefix %s is pointing to the VPN interface", prefix) + exitNextHop = initialNextHop + if initialIntf != nil { + exitIntf = initialIntf.Name + } + } + + log.Debugf("Adding a new route for prefix %s with next hop %s", prefix, exitNextHop) + if err := addToRouteTable(prefix, exitNextHop, exitIntf); err != nil { + return netip.Addr{}, "", fmt.Errorf("add route to table: %w", err) + } + + return exitNextHop, exitIntf, nil +} + +// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix +// in two /1 prefixes to avoid replacing the existing default route +func genericAddVPNRoute(prefix netip.Prefix, intf string) error { + if prefix == defaultv4 { + if err := addToRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil { + return err + } + if err := addToRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil { + if err2 := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err2 != nil { + log.Warnf("Failed to rollback route addition: %s", err2) + } + return err + } + + // TODO: remove once IPv6 is supported on the interface + if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { + return fmt.Errorf("add unreachable route split 1: %w", err) + } + if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { + if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil { + log.Warnf("Failed to rollback route addition: %s", err2) + } + return fmt.Errorf("add unreachable route split 2: %w", err) + } + + return nil + } else if prefix == defaultv6 { + if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { + return fmt.Errorf("add unreachable route split 1: %w", err) + } + if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { + if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil { + log.Warnf("Failed to rollback route addition: %s", err2) + } + return fmt.Errorf("add unreachable route split 2: %w", err) + } + + return nil + } + + return addNonExistingRoute(prefix, intf) +} + +// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table +func addNonExistingRoute(prefix netip.Prefix, intf string) error { + ok, err := existsInRouteTable(prefix) + if err != nil { + return fmt.Errorf("exists in route table: %w", err) + } + if ok { + log.Warnf("Skipping adding a new route for network %s because it already exists", prefix) + return nil + } + + ok, err = isSubRange(prefix) + if err != nil { + return fmt.Errorf("sub range: %w", err) + } + + if ok { + err := addRouteForCurrentDefaultGateway(prefix) + if err != nil { + log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err) + } + } + + return addToRouteTable(prefix, netip.Addr{}, intf) +} + +// genericRemoveVPNRoute removes the route from the vpn interface. If a default prefix is given, +// it will remove the split /1 prefixes +func genericRemoveVPNRoute(prefix netip.Prefix, intf string) error { + if prefix == defaultv4 { + var result *multierror.Error + if err := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil { + result = multierror.Append(result, err) + } + if err := removeFromRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil { + result = multierror.Append(result, err) + } + + // TODO: remove once IPv6 is supported on the interface + if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { + result = multierror.Append(result, err) + } + if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { + result = multierror.Append(result, err) + } + + return result.ErrorOrNil() + } else if prefix == defaultv6 { + var result *multierror.Error + if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { + result = multierror.Append(result, err) + } + if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { + result = multierror.Append(result, err) + } + + return result.ErrorOrNil() + } + + return removeFromRouteTable(prefix, netip.Addr{}, intf) +} + +func getPrefixFromIP(ip net.IP) (*netip.Prefix, error) { + addr, ok := netip.AddrFromSlice(ip) + if !ok { + return nil, fmt.Errorf("parse IP address: %s", ip) + } + addr = addr.Unmap() + + var prefixLength int + switch { + case addr.Is4(): + prefixLength = 32 + case addr.Is6(): + prefixLength = 128 + default: + return nil, fmt.Errorf("invalid IP address: %s", addr) + } + + prefix := netip.PrefixFrom(addr, prefixLength) + return &prefix, nil +} + +func setupRoutingWithRouteManager(routeManager **RouteManager, initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { + initialNextHopV4, initialIntfV4, err := getNextHop(netip.IPv4Unspecified()) + if err != nil && !errors.Is(err, ErrRouteNotFound) { + log.Errorf("Unable to get initial v4 default next hop: %v", err) + } + initialNextHopV6, initialIntfV6, err := getNextHop(netip.IPv6Unspecified()) + if err != nil && !errors.Is(err, ErrRouteNotFound) { + log.Errorf("Unable to get initial v6 default next hop: %v", err) + } + + *routeManager = NewRouteManager( + func(prefix netip.Prefix) (netip.Addr, string, error) { + addr := prefix.Addr() + nexthop, intf := initialNextHopV4, initialIntfV4 + if addr.Is6() { + nexthop, intf = initialNextHopV6, initialIntfV6 + } + return addRouteToNonVPNIntf(prefix, wgIface, nexthop, intf) + }, + removeFromRouteTable, + ) + + return setupHooks(*routeManager, initAddresses) +} + +func cleanupRoutingWithRouteManager(routeManager *RouteManager) error { + if routeManager == nil { + return nil + } + + // TODO: Remove hooks selectively + nbnet.RemoveDialerHooks() + nbnet.RemoveListenerHooks() + + if err := routeManager.Flush(); err != nil { + return fmt.Errorf("flush route manager: %w", err) + } + + return nil +} + +func setupHooks(routeManager *RouteManager, initAddresses []net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { + beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error { + prefix, err := getPrefixFromIP(ip) + if err != nil { + return fmt.Errorf("convert ip to prefix: %w", err) + } + + if err := routeManager.AddRouteRef(connID, *prefix); err != nil { + return fmt.Errorf("adding route reference: %v", err) + } + + return nil + } + afterHook := func(connID nbnet.ConnectionID) error { + if err := routeManager.RemoveRouteRef(connID); err != nil { + return fmt.Errorf("remove route reference: %w", err) + } + + return nil + } + + for _, ip := range initAddresses { + if err := beforeHook("init", ip); err != nil { + log.Errorf("Failed to add route reference: %v", err) + } + } + + nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error { + if ctx.Err() != nil { + return ctx.Err() + } + + var result *multierror.Error + for _, ip := range resolvedIPs { + result = multierror.Append(result, beforeHook(connID, ip.IP)) + } + return result.ErrorOrNil() + }) + + nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error { + return afterHook(connID) + }) + + nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error { + return beforeHook(connID, ip.IP) + }) + + nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error { + return afterHook(connID) + }) + + return beforeHook, afterHook, nil +} diff --git a/client/internal/routemanager/systemops_android.go b/client/internal/routemanager/systemops_android.go index 950a26843..34d2d270f 100644 --- a/client/internal/routemanager/systemops_android.go +++ b/client/internal/routemanager/systemops_android.go @@ -1,13 +1,33 @@ package routemanager import ( + "net" "net/netip" + "runtime" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/iface" ) -func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error { +func setupRouting([]net.IP, *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { + return nil, nil, nil +} + +func cleanupRouting() error { return nil } -func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string) error { +func enableIPForwarding() error { + log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS) + return nil +} + +func addVPNRoute(netip.Prefix, string) error { + return nil +} + +func removeVPNRoute(netip.Prefix, string) error { return nil } diff --git a/client/internal/routemanager/systemops_bsd.go b/client/internal/routemanager/systemops_bsd.go index b2da8075c..173e7c0e8 100644 --- a/client/internal/routemanager/systemops_bsd.go +++ b/client/internal/routemanager/systemops_bsd.go @@ -1,5 +1,4 @@ //go:build darwin || dragonfly || freebsd || netbsd || openbsd -// +build darwin dragonfly freebsd netbsd openbsd package routemanager diff --git a/client/internal/routemanager/systemops_darwin.go b/client/internal/routemanager/systemops_darwin.go new file mode 100644 index 000000000..f34964a83 --- /dev/null +++ b/client/internal/routemanager/systemops_darwin.go @@ -0,0 +1,61 @@ +//go:build darwin && !ios + +package routemanager + +import ( + "fmt" + "net" + "net/netip" + "os/exec" + "strings" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/iface" +) + +var routeManager *RouteManager + +func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { + return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface) +} + +func cleanupRouting() error { + return cleanupRoutingWithRouteManager(routeManager) +} + +func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { + return routeCmd("add", prefix, nexthop, intf) +} + +func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { + return routeCmd("delete", prefix, nexthop, intf) +} + +func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf string) error { + inet := "-inet" + if prefix.Addr().Is6() { + inet = "-inet6" + // Special case for IPv6 split default route, pointing to the wg interface fails + // TODO: Remove once we have IPv6 support on the interface + if prefix.Bits() == 1 { + intf = "lo0" + } + } + + args := []string{"-n", action, inet, prefix.String()} + if nexthop.IsValid() { + args = append(args, nexthop.Unmap().String()) + } else if intf != "" { + args = append(args, "-interface", intf) + } + + out, err := exec.Command("route", args...).CombinedOutput() + log.Tracef("route %s: %s", strings.Join(args, " "), out) + + if err != nil { + return fmt.Errorf("failed to %s route for %s: %w", action, prefix, err) + } + return nil +} diff --git a/client/internal/routemanager/systemops_darwin_test.go b/client/internal/routemanager/systemops_darwin_test.go new file mode 100644 index 000000000..5c5aaa24f --- /dev/null +++ b/client/internal/routemanager/systemops_darwin_test.go @@ -0,0 +1,100 @@ +//go:build !ios + +package routemanager + +import ( + "fmt" + "net" + "os/exec" + "regexp" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var expectedVPNint = "utun100" +var expectedExternalInt = "lo0" +var expectedInternalInt = "lo0" + +func init() { + testCases = append(testCases, []testCase{ + { + name: "To more specific route without custom dialer via vpn", + destination: "10.10.0.2:53", + expectedInterface: expectedVPNint, + dialer: &net.Dialer{}, + expectedPacket: createPacketExpectation("100.64.0.1", 12345, "10.10.0.2", 53), + }, + }...) +} + +func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string { + t.Helper() + + err := exec.Command("ifconfig", intf, "alias", ipAddressCIDR).Run() + require.NoError(t, err, "Failed to create loopback alias") + + t.Cleanup(func() { + err := exec.Command("ifconfig", intf, ipAddressCIDR, "-alias").Run() + assert.NoError(t, err, "Failed to remove loopback alias") + }) + + return "lo0" +} + +func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, _ string) { + t.Helper() + + var originalNexthop net.IP + if dstCIDR == "0.0.0.0/0" { + var err error + originalNexthop, err = fetchOriginalGateway() + if err != nil { + t.Logf("Failed to fetch original gateway: %v", err) + } + + if output, err := exec.Command("route", "delete", "-net", dstCIDR).CombinedOutput(); err != nil { + t.Logf("Failed to delete route: %v, output: %s", err, output) + } + } + + t.Cleanup(func() { + if originalNexthop != nil { + err := exec.Command("route", "add", "-net", dstCIDR, originalNexthop.String()).Run() + assert.NoError(t, err, "Failed to restore original route") + } + }) + + err := exec.Command("route", "add", "-net", dstCIDR, gw.String()).Run() + require.NoError(t, err, "Failed to add route") + + t.Cleanup(func() { + err := exec.Command("route", "delete", "-net", dstCIDR).Run() + assert.NoError(t, err, "Failed to remove route") + }) +} + +func fetchOriginalGateway() (net.IP, error) { + output, err := exec.Command("route", "-n", "get", "default").CombinedOutput() + if err != nil { + return nil, err + } + + matches := regexp.MustCompile(`gateway: (\S+)`).FindStringSubmatch(string(output)) + if len(matches) == 0 { + return nil, fmt.Errorf("gateway not found") + } + + return net.ParseIP(matches[1]), nil +} + +func setupDummyInterfacesAndRoutes(t *testing.T) { + t.Helper() + + defaultDummy := createAndSetupDummyInterface(t, expectedExternalInt, "192.168.0.1/24") + addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), defaultDummy) + + otherDummy := createAndSetupDummyInterface(t, expectedInternalInt, "192.168.1.1/24") + addDummyRoute(t, "10.0.0.0/8", net.IPv4(192, 168, 1, 1), otherDummy) +} diff --git a/client/internal/routemanager/systemops_ios.go b/client/internal/routemanager/systemops_ios.go index aae0f8dc8..34d2d270f 100644 --- a/client/internal/routemanager/systemops_ios.go +++ b/client/internal/routemanager/systemops_ios.go @@ -1,15 +1,33 @@ -//go:build ios - package routemanager import ( + "net" "net/netip" + "runtime" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/iface" ) -func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error { +func setupRouting([]net.IP, *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { + return nil, nil, nil +} + +func cleanupRouting() error { return nil } -func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string) error { +func enableIPForwarding() error { + log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS) + return nil +} + +func addVPNRoute(netip.Prefix, string) error { + return nil +} + +func removeVPNRoute(netip.Prefix, string) error { return nil } diff --git a/client/internal/routemanager/systemops_linux.go b/client/internal/routemanager/systemops_linux.go index 0562826a5..ef4643727 100644 --- a/client/internal/routemanager/systemops_linux.go +++ b/client/internal/routemanager/systemops_linux.go @@ -3,142 +3,342 @@ package routemanager import ( + "bufio" + "context" + "errors" + "fmt" "net" "net/netip" "os" "syscall" - "unsafe" + "time" + "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" "github.com/vishvananda/netlink" + + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/iface" + nbnet "github.com/netbirdio/netbird/util/net" ) -// Pulled from http://man7.org/linux/man-pages/man7/rtnetlink.7.html -// See the section on RTM_NEWROUTE, specifically 'struct rtmsg'. -type routeInfoInMemory struct { - Family byte - DstLen byte - SrcLen byte - TOS byte +const ( + // NetbirdVPNTableID is the ID of the custom routing table used by Netbird. + NetbirdVPNTableID = 0x1BD0 + // NetbirdVPNTableName is the name of the custom routing table used by Netbird. + NetbirdVPNTableName = "netbird" - Table byte - Protocol byte - Scope byte - Type byte + // rtTablesPath is the path to the file containing the routing table names. + rtTablesPath = "/etc/iproute2/rt_tables" - Flags uint32 + // ipv4ForwardingPath is the path to the file containing the IP forwarding setting. + ipv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward" +) + +var ErrTableIDExists = errors.New("ID exists with different name") + +var routeManager = &RouteManager{} +var isLegacy = os.Getenv("NB_USE_LEGACY_ROUTING") == "true" + +type ruleParams struct { + fwmark int + tableID int + family int + priority int + invert bool + suppressPrefix int + description string } -const ipv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward" +func getSetupRules() []ruleParams { + return []ruleParams{ + {nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V4, -1, true, -1, "rule v4 netbird"}, + {nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V6, -1, true, -1, "rule v6 netbird"}, + {-1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, -1, false, 0, "rule with suppress prefixlen v4"}, + {-1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, -1, false, 0, "rule with suppress prefixlen v6"}, + } +} -func addToRouteTable(prefix netip.Prefix, addr string) error { - _, ipNet, err := net.ParseCIDR(prefix.String()) - if err != nil { - return err +// setupRouting establishes the routing configuration for the VPN, including essential rules +// to ensure proper traffic flow for management, locally configured routes, and VPN traffic. +// +// Rule 1 (Main Route Precedence): Safeguards locally installed routes by giving them precedence over +// potential routes received and configured for the VPN. This rule is skipped for the default route and routes +// that are not in the main table. +// +// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table. +// This table is where a default route or other specific routes received from the management server are configured, +// enabling VPN connectivity. +// +// The rules are inserted in reverse order, as rules are added from the bottom up in the rule list. +func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) { + if isLegacy { + log.Infof("Using legacy routing setup") + return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface) } - addrMask := "/32" - if prefix.Addr().Unmap().Is6() { - addrMask = "/128" + if err = addRoutingTableName(); err != nil { + log.Errorf("Error adding routing table name: %v", err) } - ip, _, err := net.ParseCIDR(addr + addrMask) - if err != nil { - return err + defer func() { + if err != nil { + if cleanErr := cleanupRouting(); cleanErr != nil { + log.Errorf("Error cleaning up routing: %v", cleanErr) + } + } + }() + + rules := getSetupRules() + for _, rule := range rules { + if err := addRule(rule); err != nil { + if errors.Is(err, syscall.EOPNOTSUPP) { + log.Warnf("Rule operations are not supported, falling back to the legacy routing setup") + isLegacy = true + return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface) + } + return nil, nil, fmt.Errorf("%s: %w", rule.description, err) + } } - route := &netlink.Route{ - Scope: netlink.SCOPE_UNIVERSE, - Dst: ipNet, - Gw: ip, + return nil, nil, nil +} + +// cleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'. +// It systematically removes the three rules and any associated routing table entries to ensure a clean state. +// The function uses error aggregation to report any errors encountered during the cleanup process. +func cleanupRouting() error { + if isLegacy { + return cleanupRoutingWithRouteManager(routeManager) } - err = netlink.RouteAdd(route) - if err != nil { - return err + var result *multierror.Error + + if err := flushRoutes(NetbirdVPNTableID, netlink.FAMILY_V4); err != nil { + result = multierror.Append(result, fmt.Errorf("flush routes v4: %w", err)) + } + if err := flushRoutes(NetbirdVPNTableID, netlink.FAMILY_V6); err != nil { + result = multierror.Append(result, fmt.Errorf("flush routes v6: %w", err)) } + rules := getSetupRules() + for _, rule := range rules { + if err := removeAllRules(rule); err != nil && !errors.Is(err, syscall.EOPNOTSUPP) { + result = multierror.Append(result, fmt.Errorf("%s: %w", rule.description, err)) + } + } + + return result.ErrorOrNil() +} + +func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { + return addRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN) +} + +func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { + return removeRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN) +} + +func addVPNRoute(prefix netip.Prefix, intf string) error { + if isLegacy { + return genericAddVPNRoute(prefix, intf) + } + + // No need to check if routes exist as main table takes precedence over the VPN table via Rule 1 + + // TODO remove this once we have ipv6 support + if prefix == defaultv4 { + if err := addUnreachableRoute(defaultv6, NetbirdVPNTableID); err != nil { + return fmt.Errorf("add blackhole: %w", err) + } + } + if err := addRoute(prefix, netip.Addr{}, intf, NetbirdVPNTableID); err != nil { + return fmt.Errorf("add route: %w", err) + } return nil } -func removeFromRouteTable(prefix netip.Prefix, addr string) error { - _, ipNet, err := net.ParseCIDR(prefix.String()) - if err != nil { - return err +func removeVPNRoute(prefix netip.Prefix, intf string) error { + if isLegacy { + return genericRemoveVPNRoute(prefix, intf) } - addrMask := "/32" - if prefix.Addr().Unmap().Is6() { - addrMask = "/128" + // TODO remove this once we have ipv6 support + if prefix == defaultv4 { + if err := removeUnreachableRoute(defaultv6, NetbirdVPNTableID); err != nil { + return fmt.Errorf("remove unreachable route: %w", err) + } } - - ip, _, err := net.ParseCIDR(addr + addrMask) - if err != nil { - return err + if err := removeRoute(prefix, netip.Addr{}, intf, NetbirdVPNTableID); err != nil { + return fmt.Errorf("remove route: %w", err) } - - route := &netlink.Route{ - Scope: netlink.SCOPE_UNIVERSE, - Dst: ipNet, - Gw: ip, - } - - err = netlink.RouteDel(route) - if err != nil { - return err - } - return nil } func getRoutesFromTable() ([]netip.Prefix, error) { - tab, err := syscall.NetlinkRIB(syscall.RTM_GETROUTE, syscall.AF_UNSPEC) + v4Routes, err := getRoutes(syscall.RT_TABLE_MAIN, netlink.FAMILY_V4) if err != nil { - return nil, err + return nil, fmt.Errorf("get v4 routes: %w", err) } - msgs, err := syscall.ParseNetlinkMessage(tab) + v6Routes, err := getRoutes(syscall.RT_TABLE_MAIN, netlink.FAMILY_V6) if err != nil { - return nil, err + return nil, fmt.Errorf("get v6 routes: %w", err) + } + return append(v4Routes, v6Routes...), nil +} + +// getRoutes fetches routes from a specific routing table identified by tableID. +func getRoutes(tableID, family int) ([]netip.Prefix, error) { var prefixList []netip.Prefix -loop: - for _, m := range msgs { - switch m.Header.Type { - case syscall.NLMSG_DONE: - break loop - case syscall.RTM_NEWROUTE: - rt := (*routeInfoInMemory)(unsafe.Pointer(&m.Data[0])) - msg := m - attrs, err := syscall.ParseNetlinkRouteAttr(&msg) - if err != nil { - return nil, err - } - if rt.Family != syscall.AF_INET { - continue loop + + routes, err := netlink.RouteListFiltered(family, &netlink.Route{Table: tableID}, netlink.RT_FILTER_TABLE) + if err != nil { + return nil, fmt.Errorf("list routes from table %d: %v", tableID, err) + } + + for _, route := range routes { + if route.Dst != nil { + addr, ok := netip.AddrFromSlice(route.Dst.IP) + if !ok { + return nil, fmt.Errorf("parse route destination IP: %v", route.Dst.IP) } - for _, attr := range attrs { - if attr.Attr.Type == syscall.RTA_DST { - addr, ok := netip.AddrFromSlice(attr.Value) - if !ok { - continue - } - mask := net.CIDRMask(int(rt.DstLen), len(attr.Value)*8) - cidr, _ := mask.Size() - routePrefix := netip.PrefixFrom(addr, cidr) - if routePrefix.IsValid() && routePrefix.Addr().Is4() { - prefixList = append(prefixList, routePrefix) - } - } + ones, _ := route.Dst.Mask.Size() + + prefix := netip.PrefixFrom(addr, ones) + if prefix.IsValid() { + prefixList = append(prefixList, prefix) } } } + return prefixList, nil } +// addRoute adds a route to a specific routing table identified by tableID. +func addRoute(prefix netip.Prefix, addr netip.Addr, intf string, tableID int) error { + route := &netlink.Route{ + Scope: netlink.SCOPE_UNIVERSE, + Table: tableID, + Family: getAddressFamily(prefix), + } + + _, ipNet, err := net.ParseCIDR(prefix.String()) + if err != nil { + return fmt.Errorf("parse prefix %s: %w", prefix, err) + } + route.Dst = ipNet + + if err := addNextHop(addr, intf, route); err != nil { + return fmt.Errorf("add gateway and device: %w", err) + } + + if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) { + return fmt.Errorf("netlink add route: %w", err) + } + + return nil +} + +// addUnreachableRoute adds an unreachable route for the specified IP family and routing table. +// ipFamily should be netlink.FAMILY_V4 for IPv4 or netlink.FAMILY_V6 for IPv6. +// tableID specifies the routing table to which the unreachable route will be added. +func addUnreachableRoute(prefix netip.Prefix, tableID int) error { + _, ipNet, err := net.ParseCIDR(prefix.String()) + if err != nil { + return fmt.Errorf("parse prefix %s: %w", prefix, err) + } + + route := &netlink.Route{ + Type: syscall.RTN_UNREACHABLE, + Table: tableID, + Family: getAddressFamily(prefix), + Dst: ipNet, + } + + if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) { + return fmt.Errorf("netlink add unreachable route: %w", err) + } + + return nil +} + +func removeUnreachableRoute(prefix netip.Prefix, tableID int) error { + _, ipNet, err := net.ParseCIDR(prefix.String()) + if err != nil { + return fmt.Errorf("parse prefix %s: %w", prefix, err) + } + + route := &netlink.Route{ + Type: syscall.RTN_UNREACHABLE, + Table: tableID, + Family: getAddressFamily(prefix), + Dst: ipNet, + } + + if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) && !errors.Is(err, syscall.EAFNOSUPPORT) { + return fmt.Errorf("netlink remove unreachable route: %w", err) + } + + return nil + +} + +// removeRoute removes a route from a specific routing table identified by tableID. +func removeRoute(prefix netip.Prefix, addr netip.Addr, intf string, tableID int) error { + _, ipNet, err := net.ParseCIDR(prefix.String()) + if err != nil { + return fmt.Errorf("parse prefix %s: %w", prefix, err) + } + + route := &netlink.Route{ + Scope: netlink.SCOPE_UNIVERSE, + Table: tableID, + Family: getAddressFamily(prefix), + Dst: ipNet, + } + + if err := addNextHop(addr, intf, route); err != nil { + return fmt.Errorf("add gateway and device: %w", err) + } + + if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) && !errors.Is(err, syscall.EAFNOSUPPORT) { + return fmt.Errorf("netlink remove route: %w", err) + } + + return nil +} + +func flushRoutes(tableID, family int) error { + routes, err := netlink.RouteListFiltered(family, &netlink.Route{Table: tableID}, netlink.RT_FILTER_TABLE) + if err != nil { + return fmt.Errorf("list routes from table %d: %w", tableID, err) + } + + var result *multierror.Error + for i := range routes { + route := routes[i] + // unreachable default routes don't come back with Dst set + if route.Gw == nil && route.Src == nil && route.Dst == nil { + if family == netlink.FAMILY_V4 { + routes[i].Dst = &net.IPNet{IP: net.IPv4zero, Mask: net.CIDRMask(0, 32)} + } else { + routes[i].Dst = &net.IPNet{IP: net.IPv6zero, Mask: net.CIDRMask(0, 128)} + } + } + if err := netlink.RouteDel(&routes[i]); err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { + result = multierror.Append(result, fmt.Errorf("failed to delete route %v from table %d: %w", routes[i], tableID, err)) + } + } + + return result.ErrorOrNil() +} + func enableIPForwarding() error { bytes, err := os.ReadFile(ipv4ForwardingPath) if err != nil { - return err + return fmt.Errorf("read file %s: %w", ipv4ForwardingPath, err) } // check if it is already enabled @@ -147,5 +347,162 @@ func enableIPForwarding() error { return nil } - return os.WriteFile(ipv4ForwardingPath, []byte("1"), 0644) //nolint:gosec + //nolint:gosec + if err := os.WriteFile(ipv4ForwardingPath, []byte("1"), 0644); err != nil { + return fmt.Errorf("write file %s: %w", ipv4ForwardingPath, err) + } + return nil +} + +// entryExists checks if the specified ID or name already exists in the rt_tables file +// and verifies if existing names start with "netbird_". +func entryExists(file *os.File, id int) (bool, error) { + if _, err := file.Seek(0, 0); err != nil { + return false, fmt.Errorf("seek rt_tables: %w", err) + } + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := scanner.Text() + var existingID int + var existingName string + if _, err := fmt.Sscanf(line, "%d %s\n", &existingID, &existingName); err == nil { + if existingID == id { + if existingName != NetbirdVPNTableName { + return true, ErrTableIDExists + } + return true, nil + } + } + } + if err := scanner.Err(); err != nil { + return false, fmt.Errorf("scan rt_tables: %w", err) + } + return false, nil +} + +// addRoutingTableName adds human-readable names for custom routing tables. +func addRoutingTableName() error { + file, err := os.Open(rtTablesPath) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return nil + } + return fmt.Errorf("open rt_tables: %w", err) + } + defer func() { + if err := file.Close(); err != nil { + log.Errorf("Error closing rt_tables: %v", err) + } + }() + + exists, err := entryExists(file, NetbirdVPNTableID) + if err != nil { + return fmt.Errorf("verify entry %d, %s: %w", NetbirdVPNTableID, NetbirdVPNTableName, err) + } + if exists { + return nil + } + + // Reopen the file in append mode to add new entries + if err := file.Close(); err != nil { + log.Errorf("Error closing rt_tables before appending: %v", err) + } + file, err = os.OpenFile(rtTablesPath, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0644) + if err != nil { + return fmt.Errorf("open rt_tables for appending: %w", err) + } + + if _, err := file.WriteString(fmt.Sprintf("\n%d\t%s\n", NetbirdVPNTableID, NetbirdVPNTableName)); err != nil { + return fmt.Errorf("append entry to rt_tables: %w", err) + } + + return nil +} + +// addRule adds a routing rule to a specific routing table identified by tableID. +func addRule(params ruleParams) error { + rule := netlink.NewRule() + rule.Table = params.tableID + rule.Mark = params.fwmark + rule.Family = params.family + rule.Priority = params.priority + rule.Invert = params.invert + rule.SuppressPrefixlen = params.suppressPrefix + + if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { + return fmt.Errorf("add routing rule: %w", err) + } + + return nil +} + +// removeRule removes a routing rule from a specific routing table identified by tableID. +func removeRule(params ruleParams) error { + rule := netlink.NewRule() + rule.Table = params.tableID + rule.Mark = params.fwmark + rule.Family = params.family + rule.Invert = params.invert + rule.Priority = params.priority + rule.SuppressPrefixlen = params.suppressPrefix + + if err := netlink.RuleDel(rule); err != nil { + return fmt.Errorf("remove routing rule: %w", err) + } + + return nil +} + +func removeAllRules(params ruleParams) error { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + done := make(chan error, 1) + go func() { + for { + if ctx.Err() != nil { + done <- ctx.Err() + return + } + if err := removeRule(params); err != nil { + if errors.Is(err, syscall.ENOENT) || errors.Is(err, syscall.EAFNOSUPPORT) { + done <- nil + return + } + done <- err + return + } + } + }() + + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-done: + return err + } +} + +// addNextHop adds the gateway and device to the route. +func addNextHop(addr netip.Addr, intf string, route *netlink.Route) error { + if addr.IsValid() { + route.Gw = addr.AsSlice() + } + + if intf != "" { + link, err := netlink.LinkByName(intf) + if err != nil { + return fmt.Errorf("set interface %s: %w", intf, err) + } + route.LinkIndex = link.Attrs().Index + } + + return nil +} + +func getAddressFamily(prefix netip.Prefix) int { + if prefix.Addr().Is4() { + return netlink.FAMILY_V4 + } + return netlink.FAMILY_V6 } diff --git a/client/internal/routemanager/systemops_linux_test.go b/client/internal/routemanager/systemops_linux_test.go new file mode 100644 index 000000000..0043c3f4e --- /dev/null +++ b/client/internal/routemanager/systemops_linux_test.go @@ -0,0 +1,207 @@ +//go:build !android + +package routemanager + +import ( + "errors" + "fmt" + "net" + "os" + "strings" + "syscall" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/vishvananda/netlink" +) + +var expectedVPNint = "wgtest0" +var expectedLoopbackInt = "lo" +var expectedExternalInt = "dummyext0" +var expectedInternalInt = "dummyint0" + +func init() { + testCases = append(testCases, []testCase{ + { + name: "To more specific route without custom dialer via physical interface", + destination: "10.10.0.2:53", + expectedInterface: expectedInternalInt, + dialer: &net.Dialer{}, + expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.10.0.2", 53), + }, + { + name: "To more specific route (local) without custom dialer via physical interface", + destination: "127.0.10.1:53", + expectedInterface: expectedLoopbackInt, + dialer: &net.Dialer{}, + expectedPacket: createPacketExpectation("127.0.0.1", 12345, "127.0.10.1", 53), + }, + }...) +} + +func TestEntryExists(t *testing.T) { + tempDir := t.TempDir() + tempFilePath := fmt.Sprintf("%s/rt_tables", tempDir) + + content := []string{ + "1000 reserved", + fmt.Sprintf("%d %s", NetbirdVPNTableID, NetbirdVPNTableName), + "9999 other_table", + } + require.NoError(t, os.WriteFile(tempFilePath, []byte(strings.Join(content, "\n")), 0644)) + + file, err := os.Open(tempFilePath) + require.NoError(t, err) + defer func() { + assert.NoError(t, file.Close()) + }() + + tests := []struct { + name string + id int + shouldExist bool + err error + }{ + { + name: "ExistsWithNetbirdPrefix", + id: 7120, + shouldExist: true, + err: nil, + }, + { + name: "ExistsWithDifferentName", + id: 1000, + shouldExist: true, + err: ErrTableIDExists, + }, + { + name: "DoesNotExist", + id: 1234, + shouldExist: false, + err: nil, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + exists, err := entryExists(file, tc.id) + if tc.err != nil { + assert.ErrorIs(t, err, tc.err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.shouldExist, exists) + }) + } +} + +func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) string { + t.Helper() + + dummy := &netlink.Dummy{LinkAttrs: netlink.LinkAttrs{Name: interfaceName}} + err := netlink.LinkDel(dummy) + if err != nil && !errors.Is(err, syscall.EINVAL) { + t.Logf("Failed to delete dummy interface: %v", err) + } + + err = netlink.LinkAdd(dummy) + require.NoError(t, err) + + err = netlink.LinkSetUp(dummy) + require.NoError(t, err) + + if ipAddressCIDR != "" { + addr, err := netlink.ParseAddr(ipAddressCIDR) + require.NoError(t, err) + err = netlink.AddrAdd(dummy, addr) + require.NoError(t, err) + } + + t.Cleanup(func() { + err := netlink.LinkDel(dummy) + assert.NoError(t, err) + }) + + return dummy.Name +} + +func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, intf string) { + t.Helper() + + _, dstIPNet, err := net.ParseCIDR(dstCIDR) + require.NoError(t, err) + + // Handle existing routes with metric 0 + var originalNexthop net.IP + var originalLinkIndex int + if dstIPNet.String() == "0.0.0.0/0" { + var err error + originalNexthop, originalLinkIndex, err = fetchOriginalGateway(netlink.FAMILY_V4) + if err != nil && !errors.Is(err, ErrRouteNotFound) { + t.Logf("Failed to fetch original gateway: %v", err) + } + + if originalNexthop != nil { + err = netlink.RouteDel(&netlink.Route{Dst: dstIPNet, Priority: 0}) + switch { + case err != nil && !errors.Is(err, syscall.ESRCH): + t.Logf("Failed to delete route: %v", err) + case err == nil: + t.Cleanup(func() { + err := netlink.RouteAdd(&netlink.Route{Dst: dstIPNet, Gw: originalNexthop, LinkIndex: originalLinkIndex, Priority: 0}) + if err != nil && !errors.Is(err, syscall.EEXIST) { + t.Fatalf("Failed to add route: %v", err) + } + }) + default: + t.Logf("Failed to delete route: %v", err) + } + } + } + + link, err := netlink.LinkByName(intf) + require.NoError(t, err) + linkIndex := link.Attrs().Index + + route := &netlink.Route{ + Dst: dstIPNet, + Gw: gw, + LinkIndex: linkIndex, + } + err = netlink.RouteDel(route) + if err != nil && !errors.Is(err, syscall.ESRCH) { + t.Logf("Failed to delete route: %v", err) + } + + err = netlink.RouteAdd(route) + if err != nil && !errors.Is(err, syscall.EEXIST) { + t.Fatalf("Failed to add route: %v", err) + } + require.NoError(t, err) +} + +func fetchOriginalGateway(family int) (net.IP, int, error) { + routes, err := netlink.RouteList(nil, family) + if err != nil { + return nil, 0, err + } + + for _, route := range routes { + if route.Dst == nil && route.Priority == 0 { + return route.Gw, route.LinkIndex, nil + } + } + + return nil, 0, ErrRouteNotFound +} + +func setupDummyInterfacesAndRoutes(t *testing.T) { + t.Helper() + + defaultDummy := createAndSetupDummyInterface(t, "dummyext0", "192.168.0.1/24") + addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), defaultDummy) + + otherDummy := createAndSetupDummyInterface(t, "dummyint0", "192.168.1.1/24") + addDummyRoute(t, "10.0.0.0/8", net.IPv4(192, 168, 1, 1), otherDummy) +} diff --git a/client/internal/routemanager/systemops_nonandroid.go b/client/internal/routemanager/systemops_nonandroid.go deleted file mode 100644 index 11247c7dc..000000000 --- a/client/internal/routemanager/systemops_nonandroid.go +++ /dev/null @@ -1,120 +0,0 @@ -//go:build !android && !ios - -package routemanager - -import ( - "fmt" - "net" - "net/netip" - - "github.com/libp2p/go-netroute" - log "github.com/sirupsen/logrus" -) - -var errRouteNotFound = fmt.Errorf("route not found") - -func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error { - ok, err := existsInRouteTable(prefix) - if err != nil { - return err - } - if ok { - log.Warnf("skipping adding a new route for network %s because it already exists", prefix) - return nil - } - - ok, err = isSubRange(prefix) - if err != nil { - return err - } - - if ok { - err := addRouteForCurrentDefaultGateway(prefix) - if err != nil { - log.Warnf("unable to add route for current default gateway route. Will proceed without it. error: %s", err) - } - } - - return addToRouteTable(prefix, addr) -} - -func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error { - defaultGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0")) - if err != nil && err != errRouteNotFound { - return err - } - - addr := netip.MustParseAddr(defaultGateway.String()) - - if !prefix.Contains(addr) { - log.Debugf("skipping adding a new route for gateway %s because it is not in the network %s", addr, prefix) - return nil - } - - gatewayPrefix := netip.PrefixFrom(addr, 32) - - ok, err := existsInRouteTable(gatewayPrefix) - if err != nil { - return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err) - } - - if ok { - log.Debugf("skipping adding a new route for gateway %s because it already exists", gatewayPrefix) - return nil - } - - gatewayHop, err := getExistingRIBRouteGateway(gatewayPrefix) - if err != nil && err != errRouteNotFound { - return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err) - } - log.Debugf("adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop) - return addToRouteTable(gatewayPrefix, gatewayHop.String()) -} - -func existsInRouteTable(prefix netip.Prefix) (bool, error) { - routes, err := getRoutesFromTable() - if err != nil { - return false, err - } - for _, tableRoute := range routes { - if tableRoute == prefix { - return true, nil - } - } - return false, nil -} - -func isSubRange(prefix netip.Prefix) (bool, error) { - routes, err := getRoutesFromTable() - if err != nil { - return false, err - } - for _, tableRoute := range routes { - if tableRoute.Bits() > minRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() { - return true, nil - } - } - return false, nil -} - -func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string) error { - return removeFromRouteTable(prefix, addr) -} - -func getExistingRIBRouteGateway(prefix netip.Prefix) (net.IP, error) { - r, err := netroute.New() - if err != nil { - return nil, err - } - _, gateway, preferredSrc, err := r.Route(prefix.Addr().AsSlice()) - if err != nil { - log.Errorf("getting routes returned an error: %v", err) - return nil, errRouteNotFound - } - - if gateway == nil { - return preferredSrc, nil - } - - return gateway, nil -} diff --git a/client/internal/routemanager/systemops_nonlinux.go b/client/internal/routemanager/systemops_nonlinux.go index 47bd60eb0..38026107e 100644 --- a/client/internal/routemanager/systemops_nonlinux.go +++ b/client/internal/routemanager/systemops_nonlinux.go @@ -1,41 +1,23 @@ -//go:build !linux -// +build !linux +//go:build !linux && !ios package routemanager import ( "net/netip" - "os/exec" "runtime" log "github.com/sirupsen/logrus" ) -func addToRouteTable(prefix netip.Prefix, addr string) error { - cmd := exec.Command("route", "add", prefix.String(), addr) - out, err := cmd.Output() - if err != nil { - return err - } - log.Debugf(string(out)) - return nil -} - -func removeFromRouteTable(prefix netip.Prefix, addr string) error { - args := []string{"delete", prefix.String()} - if runtime.GOOS == "darwin" { - args = append(args, addr) - } - cmd := exec.Command("route", args...) - out, err := cmd.Output() - if err != nil { - return err - } - log.Debugf(string(out)) - return nil -} - func enableIPForwarding() error { - log.Infof("enable IP forwarding is not implemented on %s", runtime.GOOS) + log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS) return nil } + +func addVPNRoute(prefix netip.Prefix, intf string) error { + return genericAddVPNRoute(prefix, intf) +} + +func removeVPNRoute(prefix netip.Prefix, intf string) error { + return genericRemoveVPNRoute(prefix, intf) +} diff --git a/client/internal/routemanager/systemops_nonandroid_test.go b/client/internal/routemanager/systemops_test.go similarity index 59% rename from client/internal/routemanager/systemops_nonandroid_test.go rename to client/internal/routemanager/systemops_test.go index 6f32d9634..97386f19a 100644 --- a/client/internal/routemanager/systemops_nonandroid_test.go +++ b/client/internal/routemanager/systemops_test.go @@ -1,24 +1,32 @@ -//go:build !android +//go:build !android && !ios package routemanager import ( "bytes" + "context" "fmt" "net" "net/netip" "os" + "runtime" "strings" "testing" "github.com/pion/transport/v3/stdnet" log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/iface" ) +type dialer interface { + Dial(network, address string) (net.Conn, error) + DialContext(ctx context.Context, network, address string) (net.Conn, error) +} + func TestAddRemoveRoutes(t *testing.T) { testCases := []struct { name string @@ -53,27 +61,30 @@ func TestAddRemoveRoutes(t *testing.T) { err = wgInterface.Create() require.NoError(t, err, "should create testing wireguard interface") + _, _, err = setupRouting(nil, nil) + require.NoError(t, err) + t.Cleanup(func() { + assert.NoError(t, cleanupRouting()) + }) - err = addToRouteTableIfNoExists(testCase.prefix, wgInterface.Address().IP.String()) - require.NoError(t, err, "addToRouteTableIfNoExists should not return err") + err = genericAddVPNRoute(testCase.prefix, wgInterface.Name()) + require.NoError(t, err, "genericAddVPNRoute should not return err") - prefixGateway, err := getExistingRIBRouteGateway(testCase.prefix) - require.NoError(t, err, "getExistingRIBRouteGateway should not return err") if testCase.shouldRouteToWireguard { - require.Equal(t, wgInterface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP") + assertWGOutInterface(t, testCase.prefix, wgInterface, false) } else { - require.NotEqual(t, wgInterface.Address().IP.String(), prefixGateway.String(), "route should point to a different interface") + assertWGOutInterface(t, testCase.prefix, wgInterface, true) } exists, err := existsInRouteTable(testCase.prefix) require.NoError(t, err, "existsInRouteTable should not return err") if exists && testCase.shouldRouteToWireguard { - err = removeFromRouteTableIfNonSystem(testCase.prefix, wgInterface.Address().IP.String()) - require.NoError(t, err, "removeFromRouteTableIfNonSystem should not return err") + err = genericRemoveVPNRoute(testCase.prefix, wgInterface.Name()) + require.NoError(t, err, "genericRemoveVPNRoute should not return err") - prefixGateway, err = getExistingRIBRouteGateway(testCase.prefix) - require.NoError(t, err, "getExistingRIBRouteGateway should not return err") + prefixGateway, _, err := getNextHop(testCase.prefix.Addr()) + require.NoError(t, err, "getNextHop should not return err") - internetGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0")) + internetGateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0")) require.NoError(t, err) if testCase.shouldBeRemoved { @@ -86,12 +97,12 @@ func TestAddRemoveRoutes(t *testing.T) { } } -func TestGetExistingRIBRouteGateway(t *testing.T) { - gateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0")) +func TestGetNextHop(t *testing.T) { + gateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0")) if err != nil { t.Fatal("shouldn't return error when fetching the gateway: ", err) } - if gateway == nil { + if !gateway.IsValid() { t.Fatal("should return a gateway") } addresses, err := net.InterfaceAddrs() @@ -113,11 +124,11 @@ func TestGetExistingRIBRouteGateway(t *testing.T) { } } - localIP, err := getExistingRIBRouteGateway(testingPrefix) + localIP, _, err := getNextHop(testingPrefix.Addr()) if err != nil { t.Fatal("shouldn't return error: ", err) } - if localIP == nil { + if !localIP.IsValid() { t.Fatal("should return a gateway for local network") } if localIP.String() == gateway.String() { @@ -128,8 +139,8 @@ func TestGetExistingRIBRouteGateway(t *testing.T) { } } -func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) { - defaultGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0")) +func TestAddExistAndRemoveRoute(t *testing.T) { + defaultGateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0")) t.Log("defaultGateway: ", defaultGateway) if err != nil { t.Fatal("shouldn't return error when fetching the gateway: ", err) @@ -189,16 +200,14 @@ func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) { err = wgInterface.Create() require.NoError(t, err, "should create testing wireguard interface") - MockAddr := wgInterface.Address().IP.String() - // Prepare the environment if testCase.preExistingPrefix.IsValid() { - err := addToRouteTableIfNoExists(testCase.preExistingPrefix, MockAddr) + err := genericAddVPNRoute(testCase.preExistingPrefix, wgInterface.Name()) require.NoError(t, err, "should not return err when adding pre-existing route") } // Add the route - err = addToRouteTableIfNoExists(testCase.prefix, MockAddr) + err = genericAddVPNRoute(testCase.prefix, wgInterface.Name()) require.NoError(t, err, "should not return err when adding route") if testCase.shouldAddRoute { @@ -208,7 +217,7 @@ func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) { require.True(t, ok, "route should exist") // remove route again if added - err = removeFromRouteTableIfNonSystem(testCase.prefix, MockAddr) + err = genericRemoveVPNRoute(testCase.prefix, wgInterface.Name()) require.NoError(t, err, "should not return err") } @@ -217,6 +226,7 @@ func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) { ok, err := existsInRouteTable(testCase.prefix) t.Log("Buffer string: ", buf.String()) require.NoError(t, err, "should not return err") + if !strings.Contains(buf.String(), "because it already exists") { require.False(t, ok, "route should not exist") } @@ -224,31 +234,6 @@ func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) { } } -func TestExistsInRouteTable(t *testing.T) { - addresses, err := net.InterfaceAddrs() - if err != nil { - t.Fatal("shouldn't return error when fetching interface addresses: ", err) - } - - var addressPrefixes []netip.Prefix - for _, address := range addresses { - p := netip.MustParsePrefix(address.String()) - if p.Addr().Is4() { - addressPrefixes = append(addressPrefixes, p.Masked()) - } - } - - for _, prefix := range addressPrefixes { - exists, err := existsInRouteTable(prefix) - if err != nil { - t.Fatal("shouldn't return error when checking if address exists in route table: ", err) - } - if !exists { - t.Fatalf("address %s should exist in route table", prefix) - } - } -} - func TestIsSubRange(t *testing.T) { addresses, err := net.InterfaceAddrs() if err != nil { @@ -286,3 +271,132 @@ func TestIsSubRange(t *testing.T) { } } } + +func TestExistsInRouteTable(t *testing.T) { + addresses, err := net.InterfaceAddrs() + if err != nil { + t.Fatal("shouldn't return error when fetching interface addresses: ", err) + } + + var addressPrefixes []netip.Prefix + for _, address := range addresses { + p := netip.MustParsePrefix(address.String()) + if p.Addr().Is6() { + continue + } + // Windows sometimes has hidden interface link local addrs that don't turn up on any interface + if runtime.GOOS == "windows" && p.Addr().IsLinkLocalUnicast() { + continue + } + // Linux loopback 127/8 is in the local table, not in the main table and always takes precedence + if runtime.GOOS == "linux" && p.Addr().IsLoopback() { + continue + } + + addressPrefixes = append(addressPrefixes, p.Masked()) + } + + for _, prefix := range addressPrefixes { + exists, err := existsInRouteTable(prefix) + if err != nil { + t.Fatal("shouldn't return error when checking if address exists in route table: ", err) + } + if !exists { + t.Fatalf("address %s should exist in route table", prefix) + } + } +} + +func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listenPort int) *iface.WGIface { + t.Helper() + + peerPrivateKey, err := wgtypes.GeneratePrivateKey() + require.NoError(t, err) + + newNet, err := stdnet.NewNet() + require.NoError(t, err) + + wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil) + require.NoError(t, err, "should create testing WireGuard interface") + + err = wgInterface.Create() + require.NoError(t, err, "should create testing WireGuard interface") + + t.Cleanup(func() { + wgInterface.Close() + }) + + return wgInterface +} + +func setupTestEnv(t *testing.T) { + t.Helper() + + setupDummyInterfacesAndRoutes(t) + + wgIface := createWGInterface(t, expectedVPNint, "100.64.0.1/24", 51820) + t.Cleanup(func() { + assert.NoError(t, wgIface.Close()) + }) + + _, _, err := setupRouting(nil, wgIface) + require.NoError(t, err, "setupRouting should not return err") + t.Cleanup(func() { + assert.NoError(t, cleanupRouting()) + }) + + // default route exists in main table and vpn table + err = addVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name()) + require.NoError(t, err, "addVPNRoute should not return err") + t.Cleanup(func() { + err = removeVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name()) + assert.NoError(t, err, "removeVPNRoute should not return err") + }) + + // 10.0.0.0/8 route exists in main table and vpn table + err = addVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name()) + require.NoError(t, err, "addVPNRoute should not return err") + t.Cleanup(func() { + err = removeVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name()) + assert.NoError(t, err, "removeVPNRoute should not return err") + }) + + // 10.10.0.0/24 more specific route exists in vpn table + err = addVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name()) + require.NoError(t, err, "addVPNRoute should not return err") + t.Cleanup(func() { + err = removeVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name()) + assert.NoError(t, err, "removeVPNRoute should not return err") + }) + + // 127.0.10.0/24 more specific route exists in vpn table + err = addVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name()) + require.NoError(t, err, "addVPNRoute should not return err") + t.Cleanup(func() { + err = removeVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name()) + assert.NoError(t, err, "removeVPNRoute should not return err") + }) + + // unique route in vpn table + err = addVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name()) + require.NoError(t, err, "addVPNRoute should not return err") + t.Cleanup(func() { + err = removeVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name()) + assert.NoError(t, err, "removeVPNRoute should not return err") + }) +} + +func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIface, invert bool) { + t.Helper() + if runtime.GOOS == "linux" && prefix.Addr().IsLoopback() { + return + } + + prefixGateway, _, err := getNextHop(prefix.Addr()) + require.NoError(t, err, "getNextHop should not return err") + if invert { + assert.NotEqual(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should not point to wireguard interface IP") + } else { + assert.Equal(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP") + } +} diff --git a/client/internal/routemanager/systemops_unix_test.go b/client/internal/routemanager/systemops_unix_test.go new file mode 100644 index 000000000..561eaeea4 --- /dev/null +++ b/client/internal/routemanager/systemops_unix_test.go @@ -0,0 +1,234 @@ +//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || netbsd || dragonfly + +package routemanager + +import ( + "fmt" + "net" + "strings" + "testing" + "time" + + "github.com/gopacket/gopacket" + "github.com/gopacket/gopacket/layers" + "github.com/gopacket/gopacket/pcap" + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + nbnet "github.com/netbirdio/netbird/util/net" +) + +type PacketExpectation struct { + SrcIP net.IP + DstIP net.IP + SrcPort int + DstPort int + UDP bool + TCP bool +} + +type testCase struct { + name string + destination string + expectedInterface string + dialer dialer + expectedPacket PacketExpectation +} + +var testCases = []testCase{ + { + name: "To external host without custom dialer via vpn", + destination: "192.0.2.1:53", + expectedInterface: expectedVPNint, + dialer: &net.Dialer{}, + expectedPacket: createPacketExpectation("100.64.0.1", 12345, "192.0.2.1", 53), + }, + { + name: "To external host with custom dialer via physical interface", + destination: "192.0.2.1:53", + expectedInterface: expectedExternalInt, + dialer: nbnet.NewDialer(), + expectedPacket: createPacketExpectation("192.168.0.1", 12345, "192.0.2.1", 53), + }, + + { + name: "To duplicate internal route with custom dialer via physical interface", + destination: "10.0.0.2:53", + expectedInterface: expectedInternalInt, + dialer: nbnet.NewDialer(), + expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53), + }, + { + name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence + destination: "10.0.0.2:53", + expectedInterface: expectedInternalInt, + dialer: &net.Dialer{}, + expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53), + }, + + { + name: "To unique vpn route with custom dialer via physical interface", + destination: "172.16.0.2:53", + expectedInterface: expectedExternalInt, + dialer: nbnet.NewDialer(), + expectedPacket: createPacketExpectation("192.168.0.1", 12345, "172.16.0.2", 53), + }, + { + name: "To unique vpn route without custom dialer via vpn", + destination: "172.16.0.2:53", + expectedInterface: expectedVPNint, + dialer: &net.Dialer{}, + expectedPacket: createPacketExpectation("100.64.0.1", 12345, "172.16.0.2", 53), + }, +} + +func TestRouting(t *testing.T) { + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + setupTestEnv(t) + + filter := createBPFFilter(tc.destination) + handle := startPacketCapture(t, tc.expectedInterface, filter) + + sendTestPacket(t, tc.destination, tc.expectedPacket.SrcPort, tc.dialer) + + packetSource := gopacket.NewPacketSource(handle, handle.LinkType()) + packet, err := packetSource.NextPacket() + require.NoError(t, err) + + verifyPacket(t, packet, tc.expectedPacket) + }) + } +} + +func createPacketExpectation(srcIP string, srcPort int, dstIP string, dstPort int) PacketExpectation { + return PacketExpectation{ + SrcIP: net.ParseIP(srcIP), + DstIP: net.ParseIP(dstIP), + SrcPort: srcPort, + DstPort: dstPort, + UDP: true, + } +} + +func startPacketCapture(t *testing.T, intf, filter string) *pcap.Handle { + t.Helper() + + inactive, err := pcap.NewInactiveHandle(intf) + require.NoError(t, err, "Failed to create inactive pcap handle") + defer inactive.CleanUp() + + err = inactive.SetSnapLen(1600) + require.NoError(t, err, "Failed to set snap length on inactive handle") + + err = inactive.SetTimeout(time.Second * 10) + require.NoError(t, err, "Failed to set timeout on inactive handle") + + err = inactive.SetImmediateMode(true) + require.NoError(t, err, "Failed to set immediate mode on inactive handle") + + handle, err := inactive.Activate() + require.NoError(t, err, "Failed to activate pcap handle") + t.Cleanup(handle.Close) + + err = handle.SetBPFFilter(filter) + require.NoError(t, err, "Failed to set BPF filter") + + return handle +} + +func sendTestPacket(t *testing.T, destination string, sourcePort int, dialer dialer) { + t.Helper() + + if dialer == nil { + dialer = &net.Dialer{} + } + + if sourcePort != 0 { + localUDPAddr := &net.UDPAddr{ + IP: net.IPv4zero, + Port: sourcePort, + } + switch dialer := dialer.(type) { + case *nbnet.Dialer: + dialer.LocalAddr = localUDPAddr + case *net.Dialer: + dialer.LocalAddr = localUDPAddr + default: + t.Fatal("Unsupported dialer type") + } + } + + msg := new(dns.Msg) + msg.Id = dns.Id() + msg.RecursionDesired = true + msg.Question = []dns.Question{ + {Name: "example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, + } + + conn, err := dialer.Dial("udp", destination) + require.NoError(t, err, "Failed to dial UDP") + defer conn.Close() + + data, err := msg.Pack() + require.NoError(t, err, "Failed to pack DNS message") + + _, err = conn.Write(data) + if err != nil { + if strings.Contains(err.Error(), "required key not available") { + t.Logf("Ignoring WireGuard key error: %v", err) + return + } + t.Fatalf("Failed to send DNS query: %v", err) + } +} + +func createBPFFilter(destination string) string { + host, port, err := net.SplitHostPort(destination) + if err != nil { + return fmt.Sprintf("udp and dst host %s and dst port %s", host, port) + } + return "udp" +} + +func verifyPacket(t *testing.T, packet gopacket.Packet, exp PacketExpectation) { + t.Helper() + + ipLayer := packet.Layer(layers.LayerTypeIPv4) + require.NotNil(t, ipLayer, "Expected IPv4 layer not found in packet") + + ip, ok := ipLayer.(*layers.IPv4) + require.True(t, ok, "Failed to cast to IPv4 layer") + + // Convert both source and destination IP addresses to 16-byte representation + expectedSrcIP := exp.SrcIP.To16() + actualSrcIP := ip.SrcIP.To16() + assert.Equal(t, expectedSrcIP, actualSrcIP, "Source IP mismatch") + + expectedDstIP := exp.DstIP.To16() + actualDstIP := ip.DstIP.To16() + assert.Equal(t, expectedDstIP, actualDstIP, "Destination IP mismatch") + + if exp.UDP { + udpLayer := packet.Layer(layers.LayerTypeUDP) + require.NotNil(t, udpLayer, "Expected UDP layer not found in packet") + + udp, ok := udpLayer.(*layers.UDP) + require.True(t, ok, "Failed to cast to UDP layer") + + assert.Equal(t, layers.UDPPort(exp.SrcPort), udp.SrcPort, "UDP source port mismatch") + assert.Equal(t, layers.UDPPort(exp.DstPort), udp.DstPort, "UDP destination port mismatch") + } + + if exp.TCP { + tcpLayer := packet.Layer(layers.LayerTypeTCP) + require.NotNil(t, tcpLayer, "Expected TCP layer not found in packet") + + tcp, ok := tcpLayer.(*layers.TCP) + require.True(t, ok, "Failed to cast to TCP layer") + + assert.Equal(t, layers.TCPPort(exp.SrcPort), tcp.SrcPort, "TCP source port mismatch") + assert.Equal(t, layers.TCPPort(exp.DstPort), tcp.DstPort, "TCP destination port mismatch") + } +} diff --git a/client/internal/routemanager/systemops_windows.go b/client/internal/routemanager/systemops_windows.go index 309c184b9..50fff0cd5 100644 --- a/client/internal/routemanager/systemops_windows.go +++ b/client/internal/routemanager/systemops_windows.go @@ -1,13 +1,19 @@ //go:build windows -// +build windows package routemanager import ( + "fmt" "net" "net/netip" + "os/exec" + "strings" + log "github.com/sirupsen/logrus" "github.com/yusufpapurcu/wmi" + + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/iface" ) type Win32_IP4RouteTable struct { @@ -15,23 +21,35 @@ type Win32_IP4RouteTable struct { Mask string } +var routeManager *RouteManager + +func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { + return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface) +} + +func cleanupRouting() error { + return cleanupRoutingWithRouteManager(routeManager) +} + func getRoutesFromTable() ([]netip.Prefix, error) { var routes []Win32_IP4RouteTable query := "SELECT Destination, Mask FROM Win32_IP4RouteTable" err := wmi.Query(query, &routes) if err != nil { - return nil, err + return nil, fmt.Errorf("get routes: %w", err) } var prefixList []netip.Prefix for _, route := range routes { addr, err := netip.ParseAddr(route.Destination) if err != nil { + log.Warnf("Unable to parse route destination %s: %v", route.Destination, err) continue } maskSlice := net.ParseIP(route.Mask).To4() if maskSlice == nil { + log.Warnf("Unable to parse route mask %s", route.Mask) continue } mask := net.IPv4Mask(maskSlice[0], maskSlice[1], maskSlice[2], maskSlice[3]) @@ -44,3 +62,69 @@ func getRoutesFromTable() ([]netip.Prefix, error) { } return prefixList, nil } + +func addRoutePowershell(prefix netip.Prefix, nexthop netip.Addr, intf string) error { + destinationPrefix := prefix.String() + psCmd := "New-NetRoute" + + addressFamily := "IPv4" + if prefix.Addr().Is6() { + addressFamily = "IPv6" + } + + script := fmt.Sprintf( + `%s -AddressFamily "%s" -DestinationPrefix "%s" -InterfaceAlias "%s" -Confirm:$False -ErrorAction Stop`, + psCmd, addressFamily, destinationPrefix, intf, + ) + + if nexthop.IsValid() { + script = fmt.Sprintf( + `%s -NextHop "%s"`, script, nexthop, + ) + } + + out, err := exec.Command("powershell", "-Command", script).CombinedOutput() + log.Tracef("PowerShell add route: %s", string(out)) + + if err != nil { + return fmt.Errorf("PowerShell add route: %w", err) + } + + return nil +} + +func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, _ string) error { + args := []string{"add", prefix.String(), nexthop.Unmap().String()} + + out, err := exec.Command("route", args...).CombinedOutput() + + log.Tracef("route %s output: %s", strings.Join(args, " "), out) + if err != nil { + return fmt.Errorf("route add: %w", err) + } + + return nil +} + +func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { + // Powershell doesn't support adding routes without an interface but allows to add interface by name + if intf != "" { + return addRoutePowershell(prefix, nexthop, intf) + } + return addRouteCmd(prefix, nexthop, intf) +} + +func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, _ string) error { + args := []string{"delete", prefix.String()} + if nexthop.IsValid() { + args = append(args, nexthop.Unmap().String()) + } + + out, err := exec.Command("route", args...).CombinedOutput() + log.Tracef("route %s output: %s", strings.Join(args, " "), out) + + if err != nil { + return fmt.Errorf("remove route: %w", err) + } + return nil +} diff --git a/client/internal/routemanager/systemops_windows_test.go b/client/internal/routemanager/systemops_windows_test.go new file mode 100644 index 000000000..a5e03b8d2 --- /dev/null +++ b/client/internal/routemanager/systemops_windows_test.go @@ -0,0 +1,289 @@ +package routemanager + +import ( + "context" + "encoding/json" + "fmt" + "net" + "os/exec" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + nbnet "github.com/netbirdio/netbird/util/net" +) + +var expectedExtInt = "Ethernet1" + +type RouteInfo struct { + NextHop string `json:"nexthop"` + InterfaceAlias string `json:"interfacealias"` + RouteMetric int `json:"routemetric"` +} + +type FindNetRouteOutput struct { + IPAddress string `json:"IPAddress"` + InterfaceIndex int `json:"InterfaceIndex"` + InterfaceAlias string `json:"InterfaceAlias"` + AddressFamily int `json:"AddressFamily"` + NextHop string `json:"NextHop"` + DestinationPrefix string `json:"DestinationPrefix"` +} + +type testCase struct { + name string + destination string + expectedSourceIP string + expectedDestPrefix string + expectedNextHop string + expectedInterface string + dialer dialer +} + +var expectedVPNint = "wgtest0" + +var testCases = []testCase{ + { + name: "To external host without custom dialer via vpn", + destination: "192.0.2.1:53", + expectedSourceIP: "100.64.0.1", + expectedDestPrefix: "128.0.0.0/1", + expectedNextHop: "0.0.0.0", + expectedInterface: "wgtest0", + dialer: &net.Dialer{}, + }, + { + name: "To external host with custom dialer via physical interface", + destination: "192.0.2.1:53", + expectedDestPrefix: "192.0.2.1/32", + expectedInterface: expectedExtInt, + dialer: nbnet.NewDialer(), + }, + + { + name: "To duplicate internal route with custom dialer via physical interface", + destination: "10.0.0.2:53", + expectedDestPrefix: "10.0.0.2/32", + expectedInterface: expectedExtInt, + dialer: nbnet.NewDialer(), + }, + { + name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence + destination: "10.0.0.2:53", + expectedSourceIP: "10.0.0.1", + expectedDestPrefix: "10.0.0.0/8", + expectedNextHop: "0.0.0.0", + expectedInterface: "Loopback Pseudo-Interface 1", + dialer: &net.Dialer{}, + }, + + { + name: "To unique vpn route with custom dialer via physical interface", + destination: "172.16.0.2:53", + expectedDestPrefix: "172.16.0.2/32", + expectedInterface: expectedExtInt, + dialer: nbnet.NewDialer(), + }, + { + name: "To unique vpn route without custom dialer via vpn", + destination: "172.16.0.2:53", + expectedSourceIP: "100.64.0.1", + expectedDestPrefix: "172.16.0.0/12", + expectedNextHop: "0.0.0.0", + expectedInterface: "wgtest0", + dialer: &net.Dialer{}, + }, + + { + name: "To more specific route without custom dialer via vpn interface", + destination: "10.10.0.2:53", + expectedSourceIP: "100.64.0.1", + expectedDestPrefix: "10.10.0.0/24", + expectedNextHop: "0.0.0.0", + expectedInterface: "wgtest0", + dialer: &net.Dialer{}, + }, + + { + name: "To more specific route (local) without custom dialer via physical interface", + destination: "127.0.10.2:53", + expectedSourceIP: "10.0.0.1", + expectedDestPrefix: "127.0.0.0/8", + expectedNextHop: "0.0.0.0", + expectedInterface: "Loopback Pseudo-Interface 1", + dialer: &net.Dialer{}, + }, +} + +func TestRouting(t *testing.T) { + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + setupTestEnv(t) + + route, err := fetchOriginalGateway() + require.NoError(t, err, "Failed to fetch original gateway") + ip, err := fetchInterfaceIP(route.InterfaceAlias) + require.NoError(t, err, "Failed to fetch interface IP") + + output := testRoute(t, tc.destination, tc.dialer) + if tc.expectedInterface == expectedExtInt { + verifyOutput(t, output, ip, tc.expectedDestPrefix, route.NextHop, route.InterfaceAlias) + } else { + verifyOutput(t, output, tc.expectedSourceIP, tc.expectedDestPrefix, tc.expectedNextHop, tc.expectedInterface) + } + }) + } +} + +// fetchInterfaceIP fetches the IPv4 address of the specified interface. +func fetchInterfaceIP(interfaceAlias string) (string, error) { + script := fmt.Sprintf(`Get-NetIPAddress -InterfaceAlias "%s" | Where-Object AddressFamily -eq 2 | Select-Object -ExpandProperty IPAddress`, interfaceAlias) + out, err := exec.Command("powershell", "-Command", script).Output() + if err != nil { + return "", fmt.Errorf("failed to execute Get-NetIPAddress: %w", err) + } + + ip := strings.TrimSpace(string(out)) + return ip, nil +} + +func testRoute(t *testing.T, destination string, dialer dialer) *FindNetRouteOutput { + t.Helper() + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + conn, err := dialer.DialContext(ctx, "udp", destination) + require.NoError(t, err, "Failed to dial destination") + defer func() { + err := conn.Close() + assert.NoError(t, err, "Failed to close connection") + }() + + host, _, err := net.SplitHostPort(destination) + require.NoError(t, err) + + script := fmt.Sprintf(`Find-NetRoute -RemoteIPAddress "%s" | Select-Object -Property IPAddress, InterfaceIndex, InterfaceAlias, AddressFamily, NextHop, DestinationPrefix | ConvertTo-Json`, host) + + out, err := exec.Command("powershell", "-Command", script).Output() + require.NoError(t, err, "Failed to execute Find-NetRoute") + + var outputs []FindNetRouteOutput + err = json.Unmarshal(out, &outputs) + require.NoError(t, err, "Failed to parse JSON outputs from Find-NetRoute") + + require.Greater(t, len(outputs), 0, "No route found for destination") + combinedOutput := combineOutputs(outputs) + + return combinedOutput +} + +func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) string { + t.Helper() + + ip, ipNet, err := net.ParseCIDR(ipAddressCIDR) + require.NoError(t, err) + subnetMaskSize, _ := ipNet.Mask.Size() + script := fmt.Sprintf(`New-NetIPAddress -InterfaceAlias "%s" -IPAddress "%s" -PrefixLength %d -PolicyStore ActiveStore -Confirm:$False`, interfaceName, ip.String(), subnetMaskSize) + _, err = exec.Command("powershell", "-Command", script).CombinedOutput() + require.NoError(t, err, "Failed to assign IP address to loopback adapter") + + // Wait for the IP address to be applied + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + err = waitForIPAddress(ctx, interfaceName, ip.String()) + require.NoError(t, err, "IP address not applied within timeout") + + t.Cleanup(func() { + script = fmt.Sprintf(`Remove-NetIPAddress -InterfaceAlias "%s" -IPAddress "%s" -Confirm:$False`, interfaceName, ip.String()) + _, err = exec.Command("powershell", "-Command", script).CombinedOutput() + require.NoError(t, err, "Failed to remove IP address from loopback adapter") + }) + + return interfaceName +} + +func fetchOriginalGateway() (*RouteInfo, error) { + cmd := exec.Command("powershell", "-Command", "Get-NetRoute -DestinationPrefix 0.0.0.0/0 | Select-Object NextHop, RouteMetric, InterfaceAlias | ConvertTo-Json") + output, err := cmd.CombinedOutput() + if err != nil { + return nil, fmt.Errorf("failed to execute Get-NetRoute: %w", err) + } + + var routeInfo RouteInfo + err = json.Unmarshal(output, &routeInfo) + if err != nil { + return nil, fmt.Errorf("failed to parse JSON output: %w", err) + } + + return &routeInfo, nil +} + +func verifyOutput(t *testing.T, output *FindNetRouteOutput, sourceIP, destPrefix, nextHop, intf string) { + t.Helper() + + assert.Equal(t, sourceIP, output.IPAddress, "Source IP mismatch") + assert.Equal(t, destPrefix, output.DestinationPrefix, "Destination prefix mismatch") + assert.Equal(t, nextHop, output.NextHop, "Next hop mismatch") + assert.Equal(t, intf, output.InterfaceAlias, "Interface mismatch") +} + +func waitForIPAddress(ctx context.Context, interfaceAlias, expectedIPAddress string) error { + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + out, err := exec.Command("powershell", "-Command", fmt.Sprintf(`Get-NetIPAddress -InterfaceAlias "%s" | Select-Object -ExpandProperty IPAddress`, interfaceAlias)).CombinedOutput() + if err != nil { + return err + } + + ipAddresses := strings.Split(strings.TrimSpace(string(out)), "\n") + for _, ip := range ipAddresses { + if strings.TrimSpace(ip) == expectedIPAddress { + return nil + } + } + } + } +} + +func combineOutputs(outputs []FindNetRouteOutput) *FindNetRouteOutput { + var combined FindNetRouteOutput + + for _, output := range outputs { + if output.IPAddress != "" { + combined.IPAddress = output.IPAddress + } + if output.InterfaceIndex != 0 { + combined.InterfaceIndex = output.InterfaceIndex + } + if output.InterfaceAlias != "" { + combined.InterfaceAlias = output.InterfaceAlias + } + if output.AddressFamily != 0 { + combined.AddressFamily = output.AddressFamily + } + if output.NextHop != "" { + combined.NextHop = output.NextHop + } + if output.DestinationPrefix != "" { + combined.DestinationPrefix = output.DestinationPrefix + } + } + + return &combined +} + +func setupDummyInterfacesAndRoutes(t *testing.T) { + t.Helper() + + createAndSetupDummyInterface(t, "Loopback Pseudo-Interface 1", "10.0.0.1/8") +} diff --git a/client/internal/stdnet/dialer.go b/client/internal/stdnet/dialer.go new file mode 100644 index 000000000..e80adb42b --- /dev/null +++ b/client/internal/stdnet/dialer.go @@ -0,0 +1,24 @@ +package stdnet + +import ( + "net" + + "github.com/pion/transport/v3" + + nbnet "github.com/netbirdio/netbird/util/net" +) + +// Dial connects to the address on the named network. +func (n *Net) Dial(network, address string) (net.Conn, error) { + return nbnet.NewDialer().Dial(network, address) +} + +// DialUDP connects to the address on the named UDP network. +func (n *Net) DialUDP(network string, laddr, raddr *net.UDPAddr) (transport.UDPConn, error) { + return nbnet.DialUDP(network, laddr, raddr) +} + +// DialTCP connects to the address on the named TCP network. +func (n *Net) DialTCP(network string, laddr, raddr *net.TCPAddr) (transport.TCPConn, error) { + return nbnet.DialTCP(network, laddr, raddr) +} diff --git a/client/internal/stdnet/listener.go b/client/internal/stdnet/listener.go new file mode 100644 index 000000000..9ce0a5556 --- /dev/null +++ b/client/internal/stdnet/listener.go @@ -0,0 +1,20 @@ +package stdnet + +import ( + "context" + "net" + + "github.com/pion/transport/v3" + + nbnet "github.com/netbirdio/netbird/util/net" +) + +// ListenPacket listens for incoming packets on the given network and address. +func (n *Net) ListenPacket(network, address string) (net.PacketConn, error) { + return nbnet.NewListener().ListenPacket(context.Background(), network, address) +} + +// ListenUDP acts like ListenPacket for UDP networks. +func (n *Net) ListenUDP(network string, locAddr *net.UDPAddr) (transport.UDPConn, error) { + return nbnet.ListenUDP(network, locAddr) +} diff --git a/client/internal/wgproxy/proxy_ebpf.go b/client/internal/wgproxy/proxy_ebpf.go index f02b4943b..2235c5d2b 100644 --- a/client/internal/wgproxy/proxy_ebpf.go +++ b/client/internal/wgproxy/proxy_ebpf.go @@ -17,6 +17,7 @@ import ( "github.com/netbirdio/netbird/client/internal/ebpf" ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager" + nbnet "github.com/netbirdio/netbird/util/net" ) // WGEBPFProxy definition for proxy with EBPF support @@ -67,7 +68,7 @@ func (p *WGEBPFProxy) Listen() error { IP: net.ParseIP("127.0.0.1"), } - conn, err := net.ListenUDP("udp", &addr) + conn, err := nbnet.ListenUDP("udp", &addr) if err != nil { cErr := p.Free() if cErr != nil { @@ -228,6 +229,12 @@ func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) { return nil, fmt.Errorf("binding to lo interface failed: %w", err) } + // Set the fwmark on the socket. + err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, nbnet.NetbirdFwmark) + if err != nil { + return nil, fmt.Errorf("setting fwmark failed: %w", err) + } + // Convert the file descriptor to a PacketConn. file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd)) if file == nil { diff --git a/client/internal/wgproxy/proxy_userspace.go b/client/internal/wgproxy/proxy_userspace.go index b692ea708..17ebfbc49 100644 --- a/client/internal/wgproxy/proxy_userspace.go +++ b/client/internal/wgproxy/proxy_userspace.go @@ -6,6 +6,8 @@ import ( "net" log "github.com/sirupsen/logrus" + + nbnet "github.com/netbirdio/netbird/util/net" ) // WGUserSpaceProxy proxies @@ -33,7 +35,7 @@ func (p *WGUserSpaceProxy) AddTurnConn(remoteConn net.Conn) (net.Addr, error) { p.remoteConn = remoteConn var err error - p.localConn, err = net.Dial("udp", fmt.Sprintf(":%d", p.localWGListenPort)) + p.localConn, err = nbnet.NewDialer().Dial("udp", fmt.Sprintf(":%d", p.localWGListenPort)) if err != nil { log.Errorf("failed dialing to local Wireguard port %s", err) return nil, err diff --git a/go.mod b/go.mod index e4e36b966..29a1570c8 100644 --- a/go.mod +++ b/go.mod @@ -48,6 +48,7 @@ require ( github.com/google/gopacket v1.1.19 github.com/google/martian/v3 v3.0.0 github.com/google/nftables v0.0.0-20220808154552-2eca00135732 + github.com/gopacket/gopacket v1.1.1 github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357 github.com/hashicorp/go-multierror v1.1.1 github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 @@ -123,7 +124,6 @@ require ( github.com/google/s2a-go v0.1.4 // indirect github.com/googleapis/enterprise-certificate-proxy v0.2.3 // indirect github.com/googleapis/gax-go/v2 v2.10.0 // indirect - github.com/gopacket/gopacket v1.1.1 // indirect github.com/hashicorp/errwrap v1.0.0 // indirect github.com/hashicorp/go-uuid v1.0.2 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect diff --git a/iface/wg_configurer_kernel.go b/iface/wg_configurer_kernel.go index 36fd13cc2..9fe987cee 100644 --- a/iface/wg_configurer_kernel.go +++ b/iface/wg_configurer_kernel.go @@ -10,6 +10,8 @@ import ( log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + nbnet "github.com/netbirdio/netbird/util/net" ) type wgKernelConfigurer struct { @@ -29,7 +31,7 @@ func (c *wgKernelConfigurer) configureInterface(privateKey string, port int) err if err != nil { return err } - fwmark := 0 + fwmark := nbnet.NetbirdFwmark config := wgtypes.Config{ PrivateKey: &key, ReplacePeers: true, diff --git a/iface/wg_configurer_usp.go b/iface/wg_configurer_usp.go index 200bfbc96..24dfadf14 100644 --- a/iface/wg_configurer_usp.go +++ b/iface/wg_configurer_usp.go @@ -13,6 +13,8 @@ import ( log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + nbnet "github.com/netbirdio/netbird/util/net" ) type wgUSPConfigurer struct { @@ -37,7 +39,7 @@ func (c *wgUSPConfigurer) configureInterface(privateKey string, port int) error if err != nil { return err } - fwmark := 0 + fwmark := getFwmark() config := wgtypes.Config{ PrivateKey: &key, ReplacePeers: true, @@ -345,3 +347,10 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string { } return sb.String() } + +func getFwmark() int { + if runtime.GOOS == "linux" { + return nbnet.NetbirdFwmark + } + return 0 +} diff --git a/management/client/grpc.go b/management/client/grpc.go index 0234f866c..0b1804906 100644 --- a/management/client/grpc.go +++ b/management/client/grpc.go @@ -24,6 +24,7 @@ import ( "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/management/proto" + nbgrpc "github.com/netbirdio/netbird/util/grpc" ) const ConnectTimeout = 10 * time.Second @@ -57,6 +58,7 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE mgmCtx, addr, transportOption, + nbgrpc.WithCustomDialer(), grpc.WithBlock(), grpc.WithKeepaliveParams(keepalive.ClientParameters{ Time: 30 * time.Second, diff --git a/sharedsock/sock_linux.go b/sharedsock/sock_linux.go index 02b4e174d..74ac6c163 100644 --- a/sharedsock/sock_linux.go +++ b/sharedsock/sock_linux.go @@ -21,6 +21,8 @@ import ( log "github.com/sirupsen/logrus" "golang.org/x/sync/errgroup" "golang.org/x/sys/unix" + + nbnet "github.com/netbirdio/netbird/util/net" ) // ErrSharedSockStopped indicates that shared socket has been stopped @@ -82,10 +84,18 @@ func Listen(port int, filter BPFFilter) (_ net.PacketConn, err error) { return nil, fmt.Errorf("failed to create ipv4 raw socket: %w", err) } + if err = nbnet.SetSocketMark(rawSock.conn4); err != nil { + return nil, fmt.Errorf("failed to set SO_MARK on ipv4 socket: %w", err) + } + var sockErr error rawSock.conn6, sockErr = socket.Socket(unix.AF_INET6, unix.SOCK_RAW, unix.IPPROTO_UDP, "raw_udp6", nil) if sockErr != nil { log.Errorf("Failed to create ipv6 raw socket: %v", err) + } else { + if err = nbnet.SetSocketMark(rawSock.conn6); err != nil { + return nil, fmt.Errorf("failed to set SO_MARK on ipv6 socket: %w", err) + } } ipv4Instructions, ipv6Instructions, err := filter.GetInstructions(uint32(rawSock.port)) diff --git a/signal/client/grpc.go b/signal/client/grpc.go index 7531608c3..7c4535e28 100644 --- a/signal/client/grpc.go +++ b/signal/client/grpc.go @@ -23,6 +23,7 @@ import ( "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/management/client" "github.com/netbirdio/netbird/signal/proto" + nbgrpc "github.com/netbirdio/netbird/util/grpc" ) // ConnStateNotifier is a wrapper interface of the status recorder @@ -76,6 +77,7 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo sigCtx, addr, transportOption, + nbgrpc.WithCustomDialer(), grpc.WithBlock(), grpc.WithKeepaliveParams(keepalive.ClientParameters{ Time: 30 * time.Second, diff --git a/util/grpc/dialer.go b/util/grpc/dialer.go new file mode 100644 index 000000000..96b2bc32b --- /dev/null +++ b/util/grpc/dialer.go @@ -0,0 +1,22 @@ +package grpc + +import ( + "context" + "net" + + log "github.com/sirupsen/logrus" + "google.golang.org/grpc" + + nbnet "github.com/netbirdio/netbird/util/net" +) + +func WithCustomDialer() grpc.DialOption { + return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { + conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr) + if err != nil { + log.Errorf("Failed to dial: %s", err) + return nil, err + } + return conn, nil + }) +} diff --git a/util/net/dialer.go b/util/net/dialer.go new file mode 100644 index 000000000..0786c667e --- /dev/null +++ b/util/net/dialer.go @@ -0,0 +1,21 @@ +package net + +import ( + "net" +) + +// Dialer extends the standard net.Dialer with the ability to execute hooks before +// and after connections. This can be used to bypass the VPN for connections using this dialer. +type Dialer struct { + *net.Dialer +} + +// NewDialer returns a customized net.Dialer with overridden Control method +func NewDialer() *Dialer { + dialer := &Dialer{ + Dialer: &net.Dialer{}, + } + dialer.init() + + return dialer +} diff --git a/util/net/dialer_generic.go b/util/net/dialer_generic.go new file mode 100644 index 000000000..06fac3bbf --- /dev/null +++ b/util/net/dialer_generic.go @@ -0,0 +1,163 @@ +//go:build !android && !ios + +package net + +import ( + "context" + "fmt" + "net" + "sync" + + "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" +) + +type DialerDialHookFunc func(ctx context.Context, connID ConnectionID, resolvedAddresses []net.IPAddr) error +type DialerCloseHookFunc func(connID ConnectionID, conn *net.Conn) error + +var ( + dialerDialHooksMutex sync.RWMutex + dialerDialHooks []DialerDialHookFunc + dialerCloseHooksMutex sync.RWMutex + dialerCloseHooks []DialerCloseHookFunc +) + +// AddDialerHook allows adding a new hook to be executed before dialing. +func AddDialerHook(hook DialerDialHookFunc) { + dialerDialHooksMutex.Lock() + defer dialerDialHooksMutex.Unlock() + dialerDialHooks = append(dialerDialHooks, hook) +} + +// AddDialerCloseHook allows adding a new hook to be executed on connection close. +func AddDialerCloseHook(hook DialerCloseHookFunc) { + dialerCloseHooksMutex.Lock() + defer dialerCloseHooksMutex.Unlock() + dialerCloseHooks = append(dialerCloseHooks, hook) +} + +// RemoveDialerHook removes all dialer hooks. +func RemoveDialerHooks() { + dialerDialHooksMutex.Lock() + defer dialerDialHooksMutex.Unlock() + dialerDialHooks = nil + + dialerCloseHooksMutex.Lock() + defer dialerCloseHooksMutex.Unlock() + dialerCloseHooks = nil +} + +// DialContext wraps the net.Dialer's DialContext method to use the custom connection +func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + var resolver *net.Resolver + if d.Resolver != nil { + resolver = d.Resolver + } + + connID := GenerateConnID() + if dialerDialHooks != nil { + if err := calliDialerHooks(ctx, connID, address, resolver); err != nil { + log.Errorf("Failed to call dialer hooks: %v", err) + } + } + + conn, err := d.Dialer.DialContext(ctx, network, address) + if err != nil { + return nil, fmt.Errorf("dial: %w", err) + } + + // Wrap the connection in Conn to handle Close with hooks + return &Conn{Conn: conn, ID: connID}, nil +} + +// Dial wraps the net.Dialer's Dial method to use the custom connection +func (d *Dialer) Dial(network, address string) (net.Conn, error) { + return d.DialContext(context.Background(), network, address) +} + +// Conn wraps a net.Conn to override the Close method +type Conn struct { + net.Conn + ID ConnectionID +} + +// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection +func (c *Conn) Close() error { + err := c.Conn.Close() + + dialerCloseHooksMutex.RLock() + defer dialerCloseHooksMutex.RUnlock() + + for _, hook := range dialerCloseHooks { + if err := hook(c.ID, &c.Conn); err != nil { + log.Errorf("Error executing dialer close hook: %v", err) + } + } + + return err +} + +func calliDialerHooks(ctx context.Context, connID ConnectionID, address string, resolver *net.Resolver) error { + host, _, err := net.SplitHostPort(address) + if err != nil { + return fmt.Errorf("split host and port: %w", err) + } + ips, err := resolver.LookupIPAddr(ctx, host) + if err != nil { + return fmt.Errorf("failed to resolve address %s: %w", address, err) + } + + log.Debugf("Dialer resolved IPs for %s: %v", address, ips) + + var result *multierror.Error + + dialerDialHooksMutex.RLock() + defer dialerDialHooksMutex.RUnlock() + for _, hook := range dialerDialHooks { + if err := hook(ctx, connID, ips); err != nil { + result = multierror.Append(result, fmt.Errorf("executing dial hook: %w", err)) + } + } + + return result.ErrorOrNil() +} + +func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) { + dialer := NewDialer() + dialer.LocalAddr = laddr + + conn, err := dialer.Dial(network, raddr.String()) + if err != nil { + return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err) + } + + udpConn, ok := conn.(*Conn).Conn.(*net.UDPConn) + if !ok { + if err := conn.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected UDP connection, got different type: %T", conn) + } + + return udpConn, nil +} + +func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) { + dialer := NewDialer() + dialer.LocalAddr = laddr + + conn, err := dialer.Dial(network, raddr.String()) + if err != nil { + return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err) + } + + tcpConn, ok := conn.(*Conn).Conn.(*net.TCPConn) + if !ok { + if err := conn.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected TCP connection, got different type: %T", conn) + } + + return tcpConn, nil +} diff --git a/util/net/dialer_linux.go b/util/net/dialer_linux.go new file mode 100644 index 000000000..aed5c59a3 --- /dev/null +++ b/util/net/dialer_linux.go @@ -0,0 +1,12 @@ +//go:build !android + +package net + +import "syscall" + +// init configures the net.Dialer Control function to set the fwmark on the socket +func (d *Dialer) init() { + d.Dialer.Control = func(_, _ string, c syscall.RawConn) error { + return SetRawSocketMark(c) + } +} diff --git a/util/net/dialer_nonlinux.go b/util/net/dialer_nonlinux.go new file mode 100644 index 000000000..3254e6d06 --- /dev/null +++ b/util/net/dialer_nonlinux.go @@ -0,0 +1,6 @@ +//go:build !linux || android + +package net + +func (d *Dialer) init() { +} diff --git a/util/net/listener.go b/util/net/listener.go new file mode 100644 index 000000000..f4d769f58 --- /dev/null +++ b/util/net/listener.go @@ -0,0 +1,21 @@ +package net + +import ( + "net" +) + +// ListenerConfig extends the standard net.ListenConfig with the ability to execute hooks before +// responding via the socket and after closing. This can be used to bypass the VPN for listeners. +type ListenerConfig struct { + *net.ListenConfig +} + +// NewListener creates a new ListenerConfig instance. +func NewListener() *ListenerConfig { + listener := &ListenerConfig{ + ListenConfig: &net.ListenConfig{}, + } + listener.init() + + return listener +} diff --git a/util/net/listener_generic.go b/util/net/listener_generic.go new file mode 100644 index 000000000..451279e9d --- /dev/null +++ b/util/net/listener_generic.go @@ -0,0 +1,163 @@ +//go:build !android && !ios + +package net + +import ( + "context" + "fmt" + "net" + "sync" + + log "github.com/sirupsen/logrus" +) + +// ListenerWriteHookFunc defines the function signature for write hooks for PacketConn. +type ListenerWriteHookFunc func(connID ConnectionID, ip *net.IPAddr, data []byte) error + +// ListenerCloseHookFunc defines the function signature for close hooks for PacketConn. +type ListenerCloseHookFunc func(connID ConnectionID, conn net.PacketConn) error + +var ( + listenerWriteHooksMutex sync.RWMutex + listenerWriteHooks []ListenerWriteHookFunc + listenerCloseHooksMutex sync.RWMutex + listenerCloseHooks []ListenerCloseHookFunc +) + +// AddListenerWriteHook allows adding a new write hook to be executed before a UDP packet is sent. +func AddListenerWriteHook(hook ListenerWriteHookFunc) { + listenerWriteHooksMutex.Lock() + defer listenerWriteHooksMutex.Unlock() + listenerWriteHooks = append(listenerWriteHooks, hook) +} + +// AddListenerCloseHook allows adding a new hook to be executed upon closing a UDP connection. +func AddListenerCloseHook(hook ListenerCloseHookFunc) { + listenerCloseHooksMutex.Lock() + defer listenerCloseHooksMutex.Unlock() + listenerCloseHooks = append(listenerCloseHooks, hook) +} + +// RemoveListenerHooks removes all dialer hooks. +func RemoveListenerHooks() { + listenerWriteHooksMutex.Lock() + defer listenerWriteHooksMutex.Unlock() + listenerWriteHooks = nil + + listenerCloseHooksMutex.Lock() + defer listenerCloseHooksMutex.Unlock() + listenerCloseHooks = nil +} + +// ListenPacket listens on the network address and returns a PacketConn +// which includes support for write hooks. +func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) { + pc, err := l.ListenConfig.ListenPacket(ctx, network, address) + if err != nil { + return nil, fmt.Errorf("listen packet: %w", err) + } + connID := GenerateConnID() + return &PacketConn{PacketConn: pc, ID: connID, seenAddrs: &sync.Map{}}, nil +} + +// PacketConn wraps net.PacketConn to override its WriteTo and Close methods to include hook functionality. +type PacketConn struct { + net.PacketConn + ID ConnectionID + seenAddrs *sync.Map +} + +// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand. +func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + callWriteHooks(c.ID, c.seenAddrs, b, addr) + return c.PacketConn.WriteTo(b, addr) +} + +// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection. +func (c *PacketConn) Close() error { + c.seenAddrs = &sync.Map{} + return closeConn(c.ID, c.PacketConn) +} + +// UDPConn wraps net.UDPConn to override its WriteTo and Close methods to include hook functionality. +type UDPConn struct { + *net.UDPConn + ID ConnectionID + seenAddrs *sync.Map +} + +// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand. +func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + callWriteHooks(c.ID, c.seenAddrs, b, addr) + return c.UDPConn.WriteTo(b, addr) +} + +// Close overrides the net.UDPConn Close method to execute all registered hooks before closing the connection. +func (c *UDPConn) Close() error { + c.seenAddrs = &sync.Map{} + return closeConn(c.ID, c.UDPConn) +} + +func callWriteHooks(id ConnectionID, seenAddrs *sync.Map, b []byte, addr net.Addr) { + // Lookup the address in the seenAddrs map to avoid calling the hooks for every write + if _, loaded := seenAddrs.LoadOrStore(addr.String(), true); !loaded { + ipStr, _, splitErr := net.SplitHostPort(addr.String()) + if splitErr != nil { + log.Errorf("Error splitting IP address and port: %v", splitErr) + return + } + + ip, err := net.ResolveIPAddr("ip", ipStr) + if err != nil { + log.Errorf("Error resolving IP address: %v", err) + return + } + log.Debugf("Listener resolved IP for %s: %s", addr, ip) + + func() { + listenerWriteHooksMutex.RLock() + defer listenerWriteHooksMutex.RUnlock() + + for _, hook := range listenerWriteHooks { + if err := hook(id, ip, b); err != nil { + log.Errorf("Error executing listener write hook: %v", err) + } + } + }() + } +} + +func closeConn(id ConnectionID, conn net.PacketConn) error { + err := conn.Close() + + listenerCloseHooksMutex.RLock() + defer listenerCloseHooksMutex.RUnlock() + + for _, hook := range listenerCloseHooks { + if err := hook(id, conn); err != nil { + log.Errorf("Error executing listener close hook: %v", err) + } + } + + return err +} + +// ListenUDP listens on the network address and returns a transport.UDPConn +// which includes support for write and close hooks. +func ListenUDP(network string, laddr *net.UDPAddr) (*UDPConn, error) { + conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String()) + if err != nil { + return nil, fmt.Errorf("listen UDP: %w", err) + } + + packetConn := conn.(*PacketConn) + udpConn, ok := packetConn.PacketConn.(*net.UDPConn) + if !ok { + if err := packetConn.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected UDPConn, got different type: %T", udpConn) + } + + return &UDPConn{UDPConn: udpConn, ID: packetConn.ID, seenAddrs: &sync.Map{}}, nil +} diff --git a/util/net/listener_linux.go b/util/net/listener_linux.go new file mode 100644 index 000000000..8d332160a --- /dev/null +++ b/util/net/listener_linux.go @@ -0,0 +1,14 @@ +//go:build !android + +package net + +import ( + "syscall" +) + +// init configures the net.ListenerConfig Control function to set the fwmark on the socket +func (l *ListenerConfig) init() { + l.ListenConfig.Control = func(_, _ string, c syscall.RawConn) error { + return SetRawSocketMark(c) + } +} diff --git a/util/net/listener_mobile.go b/util/net/listener_mobile.go new file mode 100644 index 000000000..0dbbb360b --- /dev/null +++ b/util/net/listener_mobile.go @@ -0,0 +1,11 @@ +//go:build android || ios + +package net + +import ( + "net" +) + +func ListenUDP(network string, laddr *net.UDPAddr) (*net.UDPConn, error) { + return net.ListenUDP(network, laddr) +} diff --git a/util/net/listener_nonlinux.go b/util/net/listener_nonlinux.go new file mode 100644 index 000000000..fb6eadaaa --- /dev/null +++ b/util/net/listener_nonlinux.go @@ -0,0 +1,6 @@ +//go:build !linux || android + +package net + +func (l *ListenerConfig) init() { +} diff --git a/util/net/net.go b/util/net/net.go new file mode 100644 index 000000000..9ea7ae803 --- /dev/null +++ b/util/net/net.go @@ -0,0 +1,17 @@ +package net + +import "github.com/google/uuid" + +const ( + // NetbirdFwmark is the fwmark value used by Netbird via wireguard + NetbirdFwmark = 0x1BD00 +) + +// ConnectionID provides a globally unique identifier for network connections. +// It's used to track connections throughout their lifecycle so the close hook can correlate with the dial hook. +type ConnectionID string + +// GenerateConnID generates a unique identifier for each connection. +func GenerateConnID() ConnectionID { + return ConnectionID(uuid.NewString()) +} diff --git a/util/net/net_linux.go b/util/net/net_linux.go new file mode 100644 index 000000000..821417500 --- /dev/null +++ b/util/net/net_linux.go @@ -0,0 +1,35 @@ +//go:build !android + +package net + +import ( + "fmt" + "syscall" +) + +// SetSocketMark sets the SO_MARK option on the given socket connection +func SetSocketMark(conn syscall.Conn) error { + sysconn, err := conn.SyscallConn() + if err != nil { + return fmt.Errorf("get raw conn: %w", err) + } + + return SetRawSocketMark(sysconn) +} + +func SetRawSocketMark(conn syscall.RawConn) error { + var setErr error + + err := conn.Control(func(fd uintptr) { + setErr = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, NetbirdFwmark) + }) + if err != nil { + return fmt.Errorf("control: %w", err) + } + + if setErr != nil { + return fmt.Errorf("set SO_MARK: %w", setErr) + } + + return nil +} From c28657710a7184a43f14b0a0d341de819ebc3af7 Mon Sep 17 00:00:00 2001 From: verytrap <166317454+verytrap@users.noreply.github.com> Date: Tue, 9 Apr 2024 19:18:38 +0800 Subject: [PATCH 21/28] Fix function names in comments (#1816) Signed-off-by: verytrap --- management/server/account.go | 2 +- management/server/idp/okta.go | 2 +- management/server/mock_server/account_mock.go | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index c145c1bd7..20bd15ad6 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -278,7 +278,7 @@ func (a *Account) getRoutesToSync(peerID string, aclPeers []*nbpeer.Peer) []*rou return routes } -// filterRoutesByHAMembership filters and returns a list of routes that don't share the same HA route membership +// filterRoutesFromPeersOfSameHAGroup filters and returns a list of routes that don't share the same HA route membership func (a *Account) filterRoutesFromPeersOfSameHAGroup(routes []*route.Route, peerMemberships lookupMap) []*route.Route { var filteredRoutes []*route.Route for _, r := range routes { diff --git a/management/server/idp/okta.go b/management/server/idp/okta.go index d20ee7e48..c8d33a207 100644 --- a/management/server/idp/okta.go +++ b/management/server/idp/okta.go @@ -273,7 +273,7 @@ func (om *OktaManager) DeleteUser(userID string) error { return nil } -// parseOktaUserToUserData parse okta user to UserData. +// parseOktaUser parse okta user to UserData. func parseOktaUser(user *okta.User) (*UserData, error) { var oktaUser struct { Email string `json:"email"` diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 8e7c47a28..8687937dc 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -706,7 +706,7 @@ func (am *MockAccountManager) GetIdpManager() idp.Manager { return nil } -// UpdateIntegratedValidatedGroups mocks UpdateIntegratedApprovalGroups of the AccountManager interface +// UpdateIntegratedValidatorGroups mocks UpdateIntegratedApprovalGroups of the AccountManager interface func (am *MockAccountManager) UpdateIntegratedValidatorGroups(accountID string, userID string, groups []string) error { if am.UpdateIntegratedValidatorGroupsFunc != nil { return am.UpdateIntegratedValidatorGroupsFunc(accountID, userID, groups) From ac0fe6025b0ea565262b1fdd961f1eba439f59e5 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Tue, 9 Apr 2024 13:25:14 +0200 Subject: [PATCH 22/28] Fix routing issues with MacOS (#1815) * Handle zones properly * Use host routes for single IPs * Add GOOS and GOARCH to startup log * Log powershell command --- client/internal/connect.go | 3 +- client/internal/routemanager/systemops.go | 37 +++++++++++++++---- client/internal/routemanager/systemops_bsd.go | 28 +++++++++----- .../internal/routemanager/systemops_darwin.go | 6 ++- .../internal/routemanager/systemops_linux.go | 3 ++ .../routemanager/systemops_windows.go | 33 +++++++++++++---- util/net/dialer_generic.go | 4 +- 7 files changed, 84 insertions(+), 30 deletions(-) diff --git a/client/internal/connect.go b/client/internal/connect.go index 682a1efed..b50b3a629 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "runtime" "strings" "time" @@ -93,7 +94,7 @@ func runClient( relayProbe *Probe, wgProbe *Probe, ) error { - log.Infof("starting NetBird client version %s", version.NetbirdVersion()) + log.Infof("starting NetBird client version %s on %s/%s", version.NetbirdVersion(), runtime.GOOS, runtime.GOARCH) // Check if client was not shut down in a clean way and restore DNS config if required. // Otherwise, we might not be able to connect to the management server to retrieve new config. diff --git a/client/internal/routemanager/systemops.go b/client/internal/routemanager/systemops.go index a91f53636..1ee54b746 100644 --- a/client/internal/routemanager/systemops.go +++ b/client/internal/routemanager/systemops.go @@ -8,6 +8,8 @@ import ( "fmt" "net" "net/netip" + "runtime" + "strconv" "github.com/hashicorp/go-multierror" "github.com/libp2p/go-netroute" @@ -85,23 +87,42 @@ func getNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) { log.Debugf("Route for %s: interface %v, nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc) if gateway == nil { if preferredSrc == nil { - return netip.Addr{}, nil, ErrRouteNotFound + return netip.Addr{}, nil, ErrRouteNotFound } log.Debugf("No next hop found for ip %s, using preferred source %s", ip, preferredSrc) - addr, ok := netip.AddrFromSlice(preferredSrc) - if !ok { - return netip.Addr{}, nil, fmt.Errorf("failed to parse IP address: %s", preferredSrc) + addr, err := ipToAddr(preferredSrc, intf) + if err != nil { + return netip.Addr{}, nil, fmt.Errorf("convert preferred source to address: %w", err) } return addr.Unmap(), intf, nil } - addr, ok := netip.AddrFromSlice(gateway) - if !ok { - return netip.Addr{}, nil, fmt.Errorf("failed to parse IP address: %s", gateway) + addr, err := ipToAddr(gateway, intf) + if err != nil { + return netip.Addr{}, nil, fmt.Errorf("convert gateway to address: %w", err) } - return addr.Unmap(), intf, nil + return addr, intf, nil +} + +// converts a net.IP to a netip.Addr including the zone based on the passed interface +func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) { + addr, ok := netip.AddrFromSlice(ip) + if !ok { + return netip.Addr{}, fmt.Errorf("failed to convert IP address to netip.Addr: %s", ip) + } + + if intf != nil && (addr.IsLinkLocalMulticast() || addr.IsLinkLocalUnicast()) { + log.Tracef("Adding zone %s to address %s", intf.Name, addr) + if runtime.GOOS == "windows" { + addr = addr.WithZone(strconv.Itoa(intf.Index)) + } else { + addr = addr.WithZone(intf.Name) + } + } + + return addr.Unmap(), nil } func existsInRouteTable(prefix netip.Prefix) (bool, error) { diff --git a/client/internal/routemanager/systemops_bsd.go b/client/internal/routemanager/systemops_bsd.go index 173e7c0e8..b6a2006e7 100644 --- a/client/internal/routemanager/systemops_bsd.go +++ b/client/internal/routemanager/systemops_bsd.go @@ -8,6 +8,7 @@ import ( "net/netip" "syscall" + log "github.com/sirupsen/logrus" "golang.org/x/net/route" ) @@ -51,16 +52,24 @@ func getRoutesFromTable() ([]netip.Prefix, error) { continue } + if len(m.Addrs) < 3 { + log.Warnf("Unexpected RIB message Addrs: %v", m.Addrs) + continue + } + addr, ok := toNetIPAddr(m.Addrs[0]) if !ok { continue } - mask, ok := toNetIPMASK(m.Addrs[2]) - if !ok { - continue + cidr := 32 + if mask := m.Addrs[2]; mask != nil { + cidr, ok = toCIDR(mask) + if !ok { + log.Debugf("Unexpected RIB message Addrs[2]: %v", mask) + continue + } } - cidr, _ := mask.Size() routePrefix := netip.PrefixFrom(addr, cidr) if routePrefix.IsValid() { @@ -73,20 +82,19 @@ func getRoutesFromTable() ([]netip.Prefix, error) { func toNetIPAddr(a route.Addr) (netip.Addr, bool) { switch t := a.(type) { case *route.Inet4Addr: - ip := net.IPv4(t.IP[0], t.IP[1], t.IP[2], t.IP[3]) - addr := netip.MustParseAddr(ip.String()) - return addr, true + return netip.AddrFrom4(t.IP), true default: return netip.Addr{}, false } } -func toNetIPMASK(a route.Addr) (net.IPMask, bool) { +func toCIDR(a route.Addr) (int, bool) { switch t := a.(type) { case *route.Inet4Addr: mask := net.IPv4Mask(t.IP[0], t.IP[1], t.IP[2], t.IP[3]) - return mask, true + cidr, _ := mask.Size() + return cidr, true default: - return nil, false + return 0, false } } diff --git a/client/internal/routemanager/systemops_darwin.go b/client/internal/routemanager/systemops_darwin.go index f34964a83..33b8287b6 100644 --- a/client/internal/routemanager/systemops_darwin.go +++ b/client/internal/routemanager/systemops_darwin.go @@ -35,6 +35,10 @@ func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf string) error { inet := "-inet" + network := prefix.String() + if prefix.IsSingleIP() { + network = prefix.Addr().String() + } if prefix.Addr().Is6() { inet = "-inet6" // Special case for IPv6 split default route, pointing to the wg interface fails @@ -44,7 +48,7 @@ func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf strin } } - args := []string{"-n", action, inet, prefix.String()} + args := []string{"-n", action, inet, network} if nexthop.IsValid() { args = append(args, nexthop.Unmap().String()) } else if intf != "" { diff --git a/client/internal/routemanager/systemops_linux.go b/client/internal/routemanager/systemops_linux.go index ef4643727..dd00626e1 100644 --- a/client/internal/routemanager/systemops_linux.go +++ b/client/internal/routemanager/systemops_linux.go @@ -487,6 +487,9 @@ func removeAllRules(params ruleParams) error { func addNextHop(addr netip.Addr, intf string, route *netlink.Route) error { if addr.IsValid() { route.Gw = addr.AsSlice() + if intf == "" { + intf = addr.Zone() + } } if intf != "" { diff --git a/client/internal/routemanager/systemops_windows.go b/client/internal/routemanager/systemops_windows.go index 50fff0cd5..334ace453 100644 --- a/client/internal/routemanager/systemops_windows.go +++ b/client/internal/routemanager/systemops_windows.go @@ -63,7 +63,7 @@ func getRoutesFromTable() ([]netip.Prefix, error) { return prefixList, nil } -func addRoutePowershell(prefix netip.Prefix, nexthop netip.Addr, intf string) error { +func addRoutePowershell(prefix netip.Prefix, nexthop netip.Addr, intf, intfIdx string) error { destinationPrefix := prefix.String() psCmd := "New-NetRoute" @@ -73,10 +73,20 @@ func addRoutePowershell(prefix netip.Prefix, nexthop netip.Addr, intf string) er } script := fmt.Sprintf( - `%s -AddressFamily "%s" -DestinationPrefix "%s" -InterfaceAlias "%s" -Confirm:$False -ErrorAction Stop`, - psCmd, addressFamily, destinationPrefix, intf, + `%s -AddressFamily "%s" -DestinationPrefix "%s" -Confirm:$False -ErrorAction Stop`, + psCmd, addressFamily, destinationPrefix, ) + if intfIdx != "" { + script = fmt.Sprintf( + `%s -InterfaceIndex %s`, script, intfIdx, + ) + } else { + script = fmt.Sprintf( + `%s -InterfaceAlias "%s"`, script, intf, + ) + } + if nexthop.IsValid() { script = fmt.Sprintf( `%s -NextHop "%s"`, script, nexthop, @@ -84,7 +94,7 @@ func addRoutePowershell(prefix netip.Prefix, nexthop netip.Addr, intf string) er } out, err := exec.Command("powershell", "-Command", script).CombinedOutput() - log.Tracef("PowerShell add route: %s", string(out)) + log.Tracef("PowerShell %s: %s", script, string(out)) if err != nil { return fmt.Errorf("PowerShell add route: %w", err) @@ -98,7 +108,7 @@ func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, _ string) error { out, err := exec.Command("route", args...).CombinedOutput() - log.Tracef("route %s output: %s", strings.Join(args, " "), out) + log.Tracef("route %s: %s", strings.Join(args, " "), out) if err != nil { return fmt.Errorf("route add: %w", err) } @@ -107,9 +117,15 @@ func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, _ string) error { } func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { + var intfIdx string + if nexthop.Zone() != "" { + intfIdx = nexthop.Zone() + nexthop.WithZone("") + } + // Powershell doesn't support adding routes without an interface but allows to add interface by name - if intf != "" { - return addRoutePowershell(prefix, nexthop, intf) + if intf != "" || intfIdx != "" { + return addRoutePowershell(prefix, nexthop, intf, intfIdx) } return addRouteCmd(prefix, nexthop, intf) } @@ -117,11 +133,12 @@ func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, _ string) error { args := []string{"delete", prefix.String()} if nexthop.IsValid() { + nexthop.WithZone("") args = append(args, nexthop.Unmap().String()) } out, err := exec.Command("route", args...).CombinedOutput() - log.Tracef("route %s output: %s", strings.Join(args, " "), out) + log.Tracef("route %s: %s", strings.Join(args, " "), out) if err != nil { return fmt.Errorf("remove route: %w", err) diff --git a/util/net/dialer_generic.go b/util/net/dialer_generic.go index 06fac3bbf..4eda710ac 100644 --- a/util/net/dialer_generic.go +++ b/util/net/dialer_generic.go @@ -56,7 +56,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net. connID := GenerateConnID() if dialerDialHooks != nil { - if err := calliDialerHooks(ctx, connID, address, resolver); err != nil { + if err := callDialerHooks(ctx, connID, address, resolver); err != nil { log.Errorf("Failed to call dialer hooks: %v", err) } } @@ -97,7 +97,7 @@ func (c *Conn) Close() error { return err } -func calliDialerHooks(ctx context.Context, connID ConnectionID, address string, resolver *net.Resolver) error { +func callDialerHooks(ctx context.Context, connID ConnectionID, address string, resolver *net.Resolver) error { host, _, err := net.SplitHostPort(address) if err != nil { return fmt.Errorf("split host and port: %w", err) From c1f66d135487b6093fe899cb0f078ded3d465ed8 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Tue, 9 Apr 2024 15:27:19 +0200 Subject: [PATCH 23/28] Retry macOS route command (#1817) --- .../internal/routemanager/systemops_darwin.go | 32 ++++++++++++++-- .../routemanager/systemops_darwin_test.go | 38 +++++++++++++++++++ 2 files changed, 66 insertions(+), 4 deletions(-) diff --git a/client/internal/routemanager/systemops_darwin.go b/client/internal/routemanager/systemops_darwin.go index 33b8287b6..f7ce72a4e 100644 --- a/client/internal/routemanager/systemops_darwin.go +++ b/client/internal/routemanager/systemops_darwin.go @@ -8,7 +8,9 @@ import ( "net/netip" "os/exec" "strings" + "time" + "github.com/cenkalti/backoff/v4" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/internal/peer" @@ -55,11 +57,33 @@ func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf strin args = append(args, "-interface", intf) } - out, err := exec.Command("route", args...).CombinedOutput() - log.Tracef("route %s: %s", strings.Join(args, " "), out) - - if err != nil { + if err := retryRouteCmd(args); err != nil { return fmt.Errorf("failed to %s route for %s: %w", action, prefix, err) } return nil } + +func retryRouteCmd(args []string) error { + operation := func() error { + out, err := exec.Command("route", args...).CombinedOutput() + log.Tracef("route %s: %s", strings.Join(args, " "), out) + // https://github.com/golang/go/issues/45736 + if err != nil && strings.Contains(string(out), "sysctl: cannot allocate memory") { + return err + } else if err != nil { + return backoff.Permanent(err) + } + return nil + } + + expBackOff := backoff.NewExponentialBackOff() + expBackOff.InitialInterval = 50 * time.Millisecond + expBackOff.MaxInterval = 500 * time.Millisecond + expBackOff.MaxElapsedTime = 1 * time.Second + + err := backoff.Retry(operation, expBackOff) + if err != nil { + return fmt.Errorf("route cmd retry failed: %w", err) + } + return nil +} diff --git a/client/internal/routemanager/systemops_darwin_test.go b/client/internal/routemanager/systemops_darwin_test.go index 5c5aaa24f..cc9bb9db5 100644 --- a/client/internal/routemanager/systemops_darwin_test.go +++ b/client/internal/routemanager/systemops_darwin_test.go @@ -5,8 +5,10 @@ package routemanager import ( "fmt" "net" + "net/netip" "os/exec" "regexp" + "sync" "testing" "github.com/stretchr/testify/assert" @@ -29,6 +31,42 @@ func init() { }...) } +func TestConcurrentRoutes(t *testing.T) { + baseIP := netip.MustParseAddr("192.0.2.0") + intf := "lo0" + + var wg sync.WaitGroup + for i := 0; i < 1024; i++ { + wg.Add(1) + go func(ip netip.Addr) { + defer wg.Done() + prefix := netip.PrefixFrom(ip, 32) + if err := addToRouteTable(prefix, netip.Addr{}, intf); err != nil { + t.Errorf("Failed to add route for %s: %v", prefix, err) + } + }(baseIP) + baseIP = baseIP.Next() + } + + wg.Wait() + + baseIP = netip.MustParseAddr("192.0.2.0") + + for i := 0; i < 1024; i++ { + wg.Add(1) + go func(ip netip.Addr) { + defer wg.Done() + prefix := netip.PrefixFrom(ip, 32) + if err := removeFromRouteTable(prefix, netip.Addr{}, intf); err != nil { + t.Errorf("Failed to remove route for %s: %v", prefix, err) + } + }(baseIP) + baseIP = baseIP.Next() + } + + wg.Wait() +} + func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string { t.Helper() From 22b2caffc60ead6828d59d495b5670475cf95f3c Mon Sep 17 00:00:00 2001 From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Tue, 9 Apr 2024 19:01:31 +0200 Subject: [PATCH 24/28] Remove dns based cloud detection (#1812) * remove dns based cloud checks * remove dns based cloud checks --- client/system/detect_cloud/detect.go | 2 - client/system/detect_cloud/gcp.go | 2 +- client/system/detect_cloud/ibmcloud.go | 54 ------------------------- client/system/detect_cloud/softlayer.go | 25 ------------ 4 files changed, 1 insertion(+), 82 deletions(-) delete mode 100644 client/system/detect_cloud/ibmcloud.go delete mode 100644 client/system/detect_cloud/softlayer.go diff --git a/client/system/detect_cloud/detect.go b/client/system/detect_cloud/detect.go index 3bbff4345..8a8de763e 100644 --- a/client/system/detect_cloud/detect.go +++ b/client/system/detect_cloud/detect.go @@ -25,8 +25,6 @@ func Detect(ctx context.Context) string { detectDigitalOcean, detectGCP, detectOracle, - detectIBMCloud, - detectSoftlayer, detectVultr, } diff --git a/client/system/detect_cloud/gcp.go b/client/system/detect_cloud/gcp.go index c673f8937..a24c38c0c 100644 --- a/client/system/detect_cloud/gcp.go +++ b/client/system/detect_cloud/gcp.go @@ -6,7 +6,7 @@ import ( ) func detectGCP(ctx context.Context) string { - req, err := http.NewRequestWithContext(ctx, "GET", "http://metadata.google.internal", nil) + req, err := http.NewRequestWithContext(ctx, "GET", "http://169.254.169.254", nil) if err != nil { return "" } diff --git a/client/system/detect_cloud/ibmcloud.go b/client/system/detect_cloud/ibmcloud.go deleted file mode 100644 index 07de6a2ee..000000000 --- a/client/system/detect_cloud/ibmcloud.go +++ /dev/null @@ -1,54 +0,0 @@ -package detect_cloud - -import ( - "context" - "net/http" -) - -func detectIBMCloud(ctx context.Context) string { - v1ResultChan := make(chan bool, 1) - v2ResultChan := make(chan bool, 1) - - go func() { - v1ResultChan <- detectIBMSecure(ctx) - }() - - go func() { - v2ResultChan <- detectIBM(ctx) - }() - - v1Result, v2Result := <-v1ResultChan, <-v2ResultChan - - if v1Result || v2Result { - return "IBM Cloud" - } - return "" -} - -func detectIBMSecure(ctx context.Context) bool { - req, err := http.NewRequestWithContext(ctx, "PUT", "https://api.metadata.cloud.ibm.com/instance_identity/v1/token", nil) - if err != nil { - return false - } - - resp, err := hc.Do(req) - if err != nil { - return false - } - defer resp.Body.Close() - return resp.StatusCode == http.StatusOK -} - -func detectIBM(ctx context.Context) bool { - req, err := http.NewRequestWithContext(ctx, "PUT", "http://api.metadata.cloud.ibm.com/instance_identity/v1/token", nil) - if err != nil { - return false - } - - resp, err := hc.Do(req) - if err != nil { - return false - } - defer resp.Body.Close() - return resp.StatusCode == http.StatusOK -} diff --git a/client/system/detect_cloud/softlayer.go b/client/system/detect_cloud/softlayer.go deleted file mode 100644 index a09b522c4..000000000 --- a/client/system/detect_cloud/softlayer.go +++ /dev/null @@ -1,25 +0,0 @@ -package detect_cloud - -import ( - "context" - "net/http" -) - -func detectSoftlayer(ctx context.Context) string { - req, err := http.NewRequestWithContext(ctx, "GET", "https://api.service.softlayer.com/rest/v3/SoftLayer_Resource_Metadata/UserMetadata.txt", nil) - if err != nil { - return "" - } - - resp, err := hc.Do(req) - if err != nil { - return "" - } - defer resp.Body.Close() - - if resp.StatusCode == http.StatusOK { - // Since SoftLayer was acquired by IBM, we should return "IBM Cloud" - return "IBM Cloud" - } - return "" -} From dd0cf4114713069c11e2a829e0ee90769e399fcf Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 10 Apr 2024 03:10:59 +0900 Subject: [PATCH 25/28] Auto restart Windows agent daemon service (#1819) This enables auto restart of the windows agent daemon service on event of failure --- client/cmd/service_installer.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/client/cmd/service_installer.go b/client/cmd/service_installer.go index da6beef4f..5e147262b 100644 --- a/client/cmd/service_installer.go +++ b/client/cmd/service_installer.go @@ -64,6 +64,10 @@ var installCmd = &cobra.Command{ } } + if runtime.GOOS == "windows" { + svcConfig.Option["OnFailure"] = "restart" + } + ctx, cancel := context.WithCancel(cmd.Context()) s, err := newSVC(newProgram(ctx, cancel), svcConfig) From 90bd39c74017cb4974e7b5d9de56180eb2e2b66b Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Tue, 9 Apr 2024 20:27:27 +0200 Subject: [PATCH 26/28] Log panics (#1818) --- client/internal/connect.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/client/internal/connect.go b/client/internal/connect.go index b50b3a629..6b888c9cc 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "runtime" + "runtime/debug" "strings" "time" @@ -94,6 +95,12 @@ func runClient( relayProbe *Probe, wgProbe *Probe, ) error { + defer func() { + if r := recover(); r != nil { + log.Panicf("Panic occurred: %v, stack trace: %s", r, string(debug.Stack())) + } + }() + log.Infof("starting NetBird client version %s on %s/%s", version.NetbirdVersion(), runtime.GOOS, runtime.GOARCH) // Check if client was not shut down in a clean way and restore DNS config if required. From 4c83408f27259b7fecfc0fa933a2ad68cbcfffaa Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 10 Apr 2024 04:00:43 +0900 Subject: [PATCH 27/28] Add log-level to the management's docker service command (#1820) --- infrastructure_files/docker-compose.yml.tmpl | 1 + 1 file changed, 1 insertion(+) diff --git a/infrastructure_files/docker-compose.yml.tmpl b/infrastructure_files/docker-compose.yml.tmpl index 47cadb93c..747eebd53 100644 --- a/infrastructure_files/docker-compose.yml.tmpl +++ b/infrastructure_files/docker-compose.yml.tmpl @@ -58,6 +58,7 @@ services: command: [ "--port", "443", "--log-file", "console", + "--log-level", "info", "--disable-anonymous-metrics=$NETBIRD_DISABLE_ANONYMOUS_METRICS", "--single-account-mode-domain=$NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN", "--dns-domain=$NETBIRD_MGMT_DNS_DOMAIN" From 3ed2f08f3c5dd930a598a26f24cf028807816486 Mon Sep 17 00:00:00 2001 From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Tue, 9 Apr 2024 21:20:02 +0200 Subject: [PATCH 28/28] Add latency based routing (#1732) Now that we have the latency between peers available we can use this data to consider when choosing the best route. This way the route with the routing peer with the lower latency will be preferred over others with the same target network. --- client/internal/routemanager/client.go | 38 ++++-- client/internal/routemanager/client_test.go | 142 ++++++++++++++++++-- 2 files changed, 163 insertions(+), 17 deletions(-) diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go index 38cf4bf65..370ad5cf4 100644 --- a/client/internal/routemanager/client.go +++ b/client/internal/routemanager/client.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net/netip" + "time" log "github.com/sirupsen/logrus" @@ -18,6 +19,7 @@ type routerPeerStatus struct { connected bool relayed bool direct bool + latency time.Duration } type routesUpdate struct { @@ -68,6 +70,7 @@ func (c *clientNetwork) getRouterPeerStatuses() map[string]routerPeerStatus { connected: peerStatus.ConnStatus == peer.StatusConnected, relayed: peerStatus.Relayed, direct: peerStatus.Direct, + latency: peerStatus.Latency, } } return routePeerStatuses @@ -83,11 +86,13 @@ func (c *clientNetwork) getRouterPeerStatuses() map[string]routerPeerStatus { // * Non-relayed: Routes without relays are preferred. // * Direct connections: Routes with direct peer connections are favored. // * Stability: In case of equal scores, the currently active route (if any) is maintained. +// * Latency: Routes with lower latency are prioritized. // // It returns the ID of the selected optimal route. func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]routerPeerStatus) string { chosen := "" - chosenScore := 0 + chosenScore := float64(0) + currScore := float64(0) currID := "" if c.chosenRoute != nil { @@ -95,7 +100,7 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]ro } for _, r := range c.routes { - tempScore := 0 + tempScore := float64(0) peerStatus, found := routePeerStatuses[r.ID] if !found || !peerStatus.connected { continue @@ -103,9 +108,18 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]ro if r.Metric < route.MaxMetric { metricDiff := route.MaxMetric - r.Metric - tempScore = metricDiff * 10 + tempScore = float64(metricDiff) * 10 } + // in some temporal cases, latency can be 0, so we set it to 1s to not block but try to avoid this route + latency := time.Second + if peerStatus.latency != 0 { + latency = peerStatus.latency + } else { + log.Warnf("peer %s has 0 latency", r.Peer) + } + tempScore += 1 - latency.Seconds() + if !peerStatus.relayed { tempScore++ } @@ -114,7 +128,7 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]ro tempScore++ } - if tempScore > chosenScore || (tempScore == chosenScore && r.ID == currID) { + if tempScore > chosenScore || (tempScore == chosenScore && chosen == "") { chosen = r.ID chosenScore = tempScore } @@ -123,18 +137,26 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]ro chosen = r.ID chosenScore = tempScore } + + if r.ID == currID { + currScore = tempScore + } } - if chosen == "" { + switch { + case chosen == "": var peers []string for _, r := range c.routes { peers = append(peers, r.Peer) } log.Warnf("the network %s has not been assigned a routing peer as no peers from the list %s are currently connected", c.network, peers) - - } else if chosen != currID { - log.Infof("new chosen route is %s with peer %s with score %d for network %s", chosen, c.routes[chosen].Peer, chosenScore, c.network) + case chosen != currID: + if currScore != 0 && currScore < chosenScore+0.1 { + return currID + } else { + log.Infof("new chosen route is %s with peer %s with score %f for network %s", chosen, c.routes[chosen].Peer, chosenScore, c.network) + } } return chosen diff --git a/client/internal/routemanager/client_test.go b/client/internal/routemanager/client_test.go index 3700d72ec..d24d42b8e 100644 --- a/client/internal/routemanager/client_test.go +++ b/client/internal/routemanager/client_test.go @@ -3,6 +3,7 @@ package routemanager import ( "net/netip" "testing" + "time" "github.com/netbirdio/netbird/route" ) @@ -13,7 +14,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) { name string statuses map[string]routerPeerStatus expectedRouteID string - currentRoute *route.Route + currentRoute string existingRoutes map[string]*route.Route }{ { @@ -32,7 +33,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) { Peer: "peer1", }, }, - currentRoute: nil, + currentRoute: "", expectedRouteID: "route1", }, { @@ -51,7 +52,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) { Peer: "peer1", }, }, - currentRoute: nil, + currentRoute: "", expectedRouteID: "route1", }, { @@ -70,7 +71,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) { Peer: "peer1", }, }, - currentRoute: nil, + currentRoute: "", expectedRouteID: "route1", }, { @@ -89,7 +90,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) { Peer: "peer1", }, }, - currentRoute: nil, + currentRoute: "", expectedRouteID: "", }, { @@ -118,7 +119,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) { Peer: "peer2", }, }, - currentRoute: nil, + currentRoute: "", expectedRouteID: "route1", }, { @@ -147,7 +148,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) { Peer: "peer2", }, }, - currentRoute: nil, + currentRoute: "", expectedRouteID: "route1", }, { @@ -176,18 +177,141 @@ func TestGetBestrouteFromStatuses(t *testing.T) { Peer: "peer2", }, }, - currentRoute: nil, + currentRoute: "", expectedRouteID: "route1", }, + { + name: "multiple connected peers with different latencies", + statuses: map[string]routerPeerStatus{ + "route1": { + connected: true, + latency: 300 * time.Millisecond, + }, + "route2": { + connected: true, + latency: 10 * time.Millisecond, + }, + }, + existingRoutes: map[string]*route.Route{ + "route1": { + ID: "route1", + Metric: route.MaxMetric, + Peer: "peer1", + }, + "route2": { + ID: "route2", + Metric: route.MaxMetric, + Peer: "peer2", + }, + }, + currentRoute: "", + expectedRouteID: "route2", + }, + { + name: "should ignore routes with latency 0", + statuses: map[string]routerPeerStatus{ + "route1": { + connected: true, + latency: 0 * time.Millisecond, + }, + "route2": { + connected: true, + latency: 10 * time.Millisecond, + }, + }, + existingRoutes: map[string]*route.Route{ + "route1": { + ID: "route1", + Metric: route.MaxMetric, + Peer: "peer1", + }, + "route2": { + ID: "route2", + Metric: route.MaxMetric, + Peer: "peer2", + }, + }, + currentRoute: "", + expectedRouteID: "route2", + }, + { + name: "current route with similar score and similar but slightly worse latency should not change", + statuses: map[string]routerPeerStatus{ + "route1": { + connected: true, + relayed: false, + direct: true, + latency: 12 * time.Millisecond, + }, + "route2": { + connected: true, + relayed: false, + direct: true, + latency: 10 * time.Millisecond, + }, + }, + existingRoutes: map[string]*route.Route{ + "route1": { + ID: "route1", + Metric: route.MaxMetric, + Peer: "peer1", + }, + "route2": { + ID: "route2", + Metric: route.MaxMetric, + Peer: "peer2", + }, + }, + currentRoute: "route1", + expectedRouteID: "route1", + }, + { + name: "current chosen route doesn't exist anymore", + statuses: map[string]routerPeerStatus{ + "route1": { + connected: true, + relayed: false, + direct: true, + latency: 20 * time.Millisecond, + }, + "route2": { + connected: true, + relayed: false, + direct: true, + latency: 10 * time.Millisecond, + }, + }, + existingRoutes: map[string]*route.Route{ + "route1": { + ID: "route1", + Metric: route.MaxMetric, + Peer: "peer1", + }, + "route2": { + ID: "route2", + Metric: route.MaxMetric, + Peer: "peer2", + }, + }, + currentRoute: "routeDoesntExistAnymore", + expectedRouteID: "route2", + }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { + currentRoute := &route.Route{ + ID: "routeDoesntExistAnymore", + } + if tc.currentRoute != "" { + currentRoute = tc.existingRoutes[tc.currentRoute] + } + // create new clientNetwork client := &clientNetwork{ network: netip.MustParsePrefix("192.168.0.0/24"), routes: tc.existingRoutes, - chosenRoute: tc.currentRoute, + chosenRoute: currentRoute, } chosenRoute := client.getBestRouteFromStatuses(tc.statuses)