[client,management] Rewrite the SSH feature (#4015)

This commit is contained in:
Viktor Liu
2025-11-17 17:10:41 +01:00
committed by GitHub
parent 0d79301141
commit d71a82769c
170 changed files with 18744 additions and 2853 deletions

View File

@@ -18,6 +18,7 @@ import (
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/status"
)
@@ -236,7 +237,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: adminUser.Id,
AccountId: accountID,
Domain: "hotmail.com",

View File

@@ -9,9 +9,9 @@ import (
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/management/server/types"
)
// dnsSettingsHandler is a handler that returns the DNS settings of the account

View File

@@ -11,13 +11,14 @@ import (
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/types"
"github.com/gorilla/mux"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/management/server/mock_server"
)
@@ -107,7 +108,7 @@ func TestDNSSettingsHandlers(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: testingDNSSettingsAccount.Users[testDNSSettingsUserID].Id,
AccountId: testingDNSSettingsAccount.Id,
Domain: testingDNSSettingsAccount.Domain,

View File

@@ -19,6 +19,7 @@ import (
"github.com/gorilla/mux"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/management/server/mock_server"
)
@@ -193,7 +194,7 @@ func TestNameserversHandlers(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: "test_user",
AccountId: testNSGroupAccountID,
Domain: "hotmail.com",

View File

@@ -14,11 +14,12 @@ import (
"github.com/stretchr/testify/assert"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
)
func initEventsTestData(account string, events ...*activity.Event) *handler {
@@ -188,7 +189,7 @@ func TestEvents_GetEvents(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_account",

View File

@@ -11,10 +11,10 @@ import (
nbcontext "github.com/netbirdio/netbird/management/server/context"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/types"
)
// handler is a handler that returns groups of the account

View File

@@ -19,12 +19,13 @@ import (
"github.com/netbirdio/netbird/management/server"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/management/server/mock_server"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
)
var TestPeers = map[string]*nbpeer.Peer{
@@ -122,7 +123,7 @@ func TestGetGroup(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_id",
@@ -248,7 +249,7 @@ func TestWriteGroup(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_id",
@@ -330,7 +331,7 @@ func TestDeleteGroup(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_id",

View File

@@ -12,15 +12,15 @@ import (
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/management/server/networks"
"github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/routers"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
"github.com/netbirdio/netbird/management/server/networks/types"
"github.com/netbirdio/netbird/shared/management/status"
nbtypes "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
)
// handler is a handler that returns networks of the account

View File

@@ -8,10 +8,10 @@ import (
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/resources/types"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
)
type resourceHandler struct {

View File

@@ -7,10 +7,10 @@ import (
"github.com/gorilla/mux"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/management/server/networks/routers"
"github.com/netbirdio/netbird/management/server/networks/routers/types"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
)
type routersHandler struct {

View File

@@ -21,6 +21,7 @@ import (
nbcontext "github.com/netbirdio/netbird/management/server/context"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/stretchr/testify/assert"
@@ -296,7 +297,7 @@ func TestGetPeers(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: "admin_user",
Domain: "hotmail.com",
AccountId: "test_id",
@@ -444,7 +445,7 @@ func TestGetAccessiblePeers(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/api/peers/%s/accessible-peers", tc.peerID), nil)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: tc.callerUserID,
Domain: "hotmail.com",
AccountId: "test_id",
@@ -527,7 +528,7 @@ func TestPeersHandlerUpdatePeerIP(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodPut, fmt.Sprintf("/peers/%s", tc.peerID), bytes.NewBuffer([]byte(tc.requestBody)))
req.Header.Set("Content-Type", "application/json")
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: tc.callerUserID,
Domain: "hotmail.com",
AccountId: "test_id",

View File

@@ -16,12 +16,13 @@ import (
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/util"
)
@@ -113,7 +114,7 @@ func TestGetCitiesByCountry(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_id",
@@ -206,7 +207,7 @@ func TestGetAllCountries(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_id",

View File

@@ -9,11 +9,11 @@ import (
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
)

View File

@@ -10,10 +10,10 @@ import (
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/types"
)
// handler is a handler that returns policy of the account

View File

@@ -14,10 +14,11 @@ import (
"github.com/stretchr/testify/assert"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/status"
)
func initPoliciesTestData(policies ...*types.Policy) *handler {
@@ -103,7 +104,7 @@ func TestPoliciesGetPolicy(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_id",
@@ -267,7 +268,7 @@ func TestPoliciesWritePolicy(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_id",

View File

@@ -9,9 +9,9 @@ import (
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/shared/management/status"
)

View File

@@ -16,9 +16,10 @@ import (
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/status"
)
@@ -175,7 +176,7 @@ func TestGetPostureCheck(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/posture-checks/"+tc.id, tc.requestBody)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_id",
@@ -828,7 +829,7 @@ func TestPostureCheckUpdate(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_id",

View File

@@ -19,6 +19,7 @@ import (
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/status"
@@ -493,7 +494,7 @@ func TestRoutesHandlers(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: testAccountID,

View File

@@ -10,10 +10,10 @@ import (
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/types"
)
// handler is a handler that returns a list of setup keys of the account

View File

@@ -15,10 +15,11 @@ import (
"github.com/stretchr/testify/assert"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/status"
)
const (
@@ -163,7 +164,7 @@ func TestSetupKeysHandlers(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: adminUser.Id,
Domain: "hotmail.com",
AccountId: "testAccountId",

View File

@@ -8,10 +8,10 @@ import (
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/types"
)
// patHandler is the nameserver group handler of the account

View File

@@ -17,10 +17,11 @@ import (
"github.com/netbirdio/netbird/management/server/util"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/status"
)
const (
@@ -173,7 +174,7 @@ func TestTokenHandlers(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: existingUserID,
Domain: testDomain,
AccountId: existingAccountID,

View File

@@ -21,6 +21,7 @@ import (
"github.com/netbirdio/netbird/management/server/permissions/roles"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/status"
)
@@ -128,7 +129,7 @@ func initUsersTestData() *handler {
return nil
},
GetCurrentUserInfoFunc: func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) {
GetCurrentUserInfoFunc: func(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error) {
switch userAuth.UserId {
case "not-found":
return nil, status.NewUserNotFoundError("not-found")
@@ -225,7 +226,7 @@ func TestGetUsers(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: existingUserID,
Domain: testDomain,
AccountId: existingAccountID,
@@ -335,7 +336,7 @@ func TestUpdateUser(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: existingUserID,
Domain: testDomain,
AccountId: existingAccountID,
@@ -432,7 +433,7 @@ func TestCreateUser(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
rr := httptest.NewRecorder()
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: existingUserID,
Domain: testDomain,
AccountId: existingAccountID,
@@ -481,7 +482,7 @@ func TestInviteUser(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
req = mux.SetURLVars(req, tc.requestVars)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: existingUserID,
Domain: testDomain,
AccountId: existingAccountID,
@@ -540,7 +541,7 @@ func TestDeleteUser(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
req = mux.SetURLVars(req, tc.requestVars)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: existingUserID,
Domain: testDomain,
AccountId: existingAccountID,
@@ -565,7 +566,7 @@ func TestCurrentUser(t *testing.T) {
tt := []struct {
name string
expectedStatus int
requestAuth nbcontext.UserAuth
requestAuth auth.UserAuth
expectedResult *api.User
}{
{
@@ -574,27 +575,27 @@ func TestCurrentUser(t *testing.T) {
},
{
name: "user not found",
requestAuth: nbcontext.UserAuth{UserId: "not-found"},
requestAuth: auth.UserAuth{UserId: "not-found"},
expectedStatus: http.StatusNotFound,
},
{
name: "not of account",
requestAuth: nbcontext.UserAuth{UserId: "not-of-account"},
requestAuth: auth.UserAuth{UserId: "not-of-account"},
expectedStatus: http.StatusForbidden,
},
{
name: "blocked user",
requestAuth: nbcontext.UserAuth{UserId: "blocked-user"},
requestAuth: auth.UserAuth{UserId: "blocked-user"},
expectedStatus: http.StatusForbidden,
},
{
name: "service user",
requestAuth: nbcontext.UserAuth{UserId: "service-user"},
requestAuth: auth.UserAuth{UserId: "service-user"},
expectedStatus: http.StatusForbidden,
},
{
name: "owner",
requestAuth: nbcontext.UserAuth{UserId: "owner"},
requestAuth: auth.UserAuth{UserId: "owner"},
expectedStatus: http.StatusOK,
expectedResult: &api.User{
Id: "owner",
@@ -613,7 +614,7 @@ func TestCurrentUser(t *testing.T) {
},
{
name: "regular user",
requestAuth: nbcontext.UserAuth{UserId: "regular-user"},
requestAuth: auth.UserAuth{UserId: "regular-user"},
expectedStatus: http.StatusOK,
expectedResult: &api.User{
Id: "regular-user",
@@ -632,7 +633,7 @@ func TestCurrentUser(t *testing.T) {
},
{
name: "admin user",
requestAuth: nbcontext.UserAuth{UserId: "admin-user"},
requestAuth: auth.UserAuth{UserId: "admin-user"},
expectedStatus: http.StatusOK,
expectedResult: &api.User{
Id: "admin-user",
@@ -651,7 +652,7 @@ func TestCurrentUser(t *testing.T) {
},
{
name: "restricted user",
requestAuth: nbcontext.UserAuth{UserId: "restricted-user"},
requestAuth: auth.UserAuth{UserId: "restricted-user"},
expectedStatus: http.StatusOK,
expectedResult: &api.User{
Id: "restricted-user",
@@ -783,7 +784,7 @@ func TestApproveUserEndpoint(t *testing.T) {
req, err := http.NewRequest("POST", "/users/pending-user/approve", nil)
require.NoError(t, err)
userAuth := nbcontext.UserAuth{
userAuth := auth.UserAuth{
AccountId: existingAccountID,
UserId: tc.requestingUser.Id,
}
@@ -841,7 +842,7 @@ func TestRejectUserEndpoint(t *testing.T) {
req, err := http.NewRequest("DELETE", "/users/pending-user/reject", nil)
require.NoError(t, err)
userAuth := nbcontext.UserAuth{
userAuth := auth.UserAuth{
AccountId: existingAccountID,
UserId: tc.requestingUser.Id,
}

View File

@@ -10,22 +10,23 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/auth"
serverauth "github.com/netbirdio/netbird/management/server/auth"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
)
type EnsureAccountFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error)
type SyncUserJWTGroupsFunc func(ctx context.Context, userAuth nbcontext.UserAuth) error
type EnsureAccountFunc func(ctx context.Context, userAuth auth.UserAuth) (string, string, error)
type SyncUserJWTGroupsFunc func(ctx context.Context, userAuth auth.UserAuth) error
type GetUserFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error)
type GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
type AuthMiddleware struct {
authManager auth.Manager
authManager serverauth.Manager
ensureAccount EnsureAccountFunc
getUserFromUserAuth GetUserFromUserAuthFunc
syncUserJWTGroups SyncUserJWTGroupsFunc
@@ -34,7 +35,7 @@ type AuthMiddleware struct {
// NewAuthMiddleware instance constructor
func NewAuthMiddleware(
authManager auth.Manager,
authManager serverauth.Manager,
ensureAccount EnsureAccountFunc,
syncUserJWTGroups SyncUserJWTGroupsFunc,
getUserFromUserAuth GetUserFromUserAuthFunc,
@@ -61,18 +62,18 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
return
}
auth := strings.Split(r.Header.Get("Authorization"), " ")
authType := strings.ToLower(auth[0])
authHeader := strings.Split(r.Header.Get("Authorization"), " ")
authType := strings.ToLower(authHeader[0])
// fallback to token when receive pat as bearer
if len(auth) >= 2 && authType == "bearer" && strings.HasPrefix(auth[1], "nbp_") {
if len(authHeader) >= 2 && authType == "bearer" && strings.HasPrefix(authHeader[1], "nbp_") {
authType = "token"
auth[0] = authType
authHeader[0] = authType
}
switch authType {
case "bearer":
request, err := m.checkJWTFromRequest(r, auth)
request, err := m.checkJWTFromRequest(r, authHeader)
if err != nil {
log.WithContext(r.Context()).Errorf("Error when validating JWT: %s", err.Error())
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
@@ -81,7 +82,7 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
h.ServeHTTP(w, request)
case "token":
request, err := m.checkPATFromRequest(r, auth)
request, err := m.checkPATFromRequest(r, authHeader)
if err != nil {
log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error())
// Check if it's a status error, otherwise default to Unauthorized
@@ -100,8 +101,8 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
}
// CheckJWTFromRequest checks if the JWT is valid
func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, auth []string) (*http.Request, error) {
token, err := getTokenFromJWTRequest(auth)
func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) {
token, err := getTokenFromJWTRequest(authHeaderParts)
// If an error occurs, call the error handler and return an error
if err != nil {
@@ -151,8 +152,8 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, auth []string) (*h
}
// CheckPATFromRequest checks if the PAT is valid
func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, auth []string) (*http.Request, error) {
token, err := getTokenFromPATRequest(auth)
func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) {
token, err := getTokenFromPATRequest(authHeaderParts)
if err != nil {
return r, fmt.Errorf("error extracting token: %w", err)
}
@@ -177,7 +178,7 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, auth []string) (*h
return r, err
}
userAuth := nbcontext.UserAuth{
userAuth := auth.UserAuth{
UserId: user.Id,
AccountId: user.AccountID,
Domain: accDomain,

View File

@@ -12,11 +12,12 @@ import (
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/auth"
nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util"
nbauth "github.com/netbirdio/netbird/shared/auth"
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
)
const (
@@ -75,9 +76,9 @@ func mockGetAccountInfoFromPAT(_ context.Context, token string) (user *types.Use
return nil, nil, "", "", fmt.Errorf("PAT invalid")
}
func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserAuth, *jwt.Token, error) {
func mockValidateAndParseToken(_ context.Context, token string) (nbauth.UserAuth, *jwt.Token, error) {
if token == JWT {
return nbcontext.UserAuth{
return nbauth.UserAuth{
UserId: userID,
AccountId: accountID,
Domain: testAccount.Domain,
@@ -91,7 +92,7 @@ func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserA
Valid: true,
}, nil
}
return nbcontext.UserAuth{}, nil, fmt.Errorf("JWT invalid")
return nbauth.UserAuth{}, nil, fmt.Errorf("JWT invalid")
}
func mockMarkPATUsed(_ context.Context, token string) error {
@@ -101,7 +102,7 @@ func mockMarkPATUsed(_ context.Context, token string) error {
return fmt.Errorf("Should never get reached")
}
func mockEnsureUserAccessByJWTGroups(_ context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) {
func mockEnsureUserAccessByJWTGroups(_ context.Context, userAuth nbauth.UserAuth, token *jwt.Token) (nbauth.UserAuth, error) {
if userAuth.IsChild || userAuth.IsPAT {
return userAuth, nil
}
@@ -197,13 +198,13 @@ func TestAuthMiddleware_Handler(t *testing.T) {
authMiddleware := NewAuthMiddleware(
mockAuth,
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
return userAuth.AccountId, userAuth.UserId, nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
func(ctx context.Context, userAuth nbauth.UserAuth) error {
return nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
nil,
@@ -255,13 +256,13 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
authMiddleware := NewAuthMiddleware(
mockAuth,
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
return userAuth.AccountId, userAuth.UserId, nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
func(ctx context.Context, userAuth nbauth.UserAuth) error {
return nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
@@ -306,13 +307,13 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
authMiddleware := NewAuthMiddleware(
mockAuth,
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
return userAuth.AccountId, userAuth.UserId, nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
func(ctx context.Context, userAuth nbauth.UserAuth) error {
return nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
@@ -348,13 +349,13 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
authMiddleware := NewAuthMiddleware(
mockAuth,
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
return userAuth.AccountId, userAuth.UserId, nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
func(ctx context.Context, userAuth nbauth.UserAuth) error {
return nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
@@ -391,13 +392,13 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
authMiddleware := NewAuthMiddleware(
mockAuth,
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
return userAuth.AccountId, userAuth.UserId, nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
func(ctx context.Context, userAuth nbauth.UserAuth) error {
return nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
@@ -454,13 +455,13 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
authMiddleware := NewAuthMiddleware(
mockAuth,
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
return userAuth.AccountId, userAuth.UserId, nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
func(ctx context.Context, userAuth nbauth.UserAuth) error {
return nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
@@ -508,13 +509,13 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
name string
path string
authHeader string
expectedUserAuth *nbcontext.UserAuth // nil expects 401 response status
expectedUserAuth *nbauth.UserAuth // nil expects 401 response status
}{
{
name: "Valid PAT Token",
path: "/test",
authHeader: "Token " + PAT,
expectedUserAuth: &nbcontext.UserAuth{
expectedUserAuth: &nbauth.UserAuth{
AccountId: accountID,
UserId: userID,
Domain: testAccount.Domain,
@@ -526,7 +527,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
name: "Valid PAT Token accesses child",
path: "/test?account=xyz",
authHeader: "Token " + PAT,
expectedUserAuth: &nbcontext.UserAuth{
expectedUserAuth: &nbauth.UserAuth{
AccountId: "xyz",
UserId: userID,
Domain: testAccount.Domain,
@@ -539,7 +540,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
name: "Valid JWT Token",
path: "/test",
authHeader: "Bearer " + JWT,
expectedUserAuth: &nbcontext.UserAuth{
expectedUserAuth: &nbauth.UserAuth{
AccountId: accountID,
UserId: userID,
Domain: testAccount.Domain,
@@ -551,7 +552,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
name: "Valid JWT Token with child",
path: "/test?account=xyz",
authHeader: "Bearer " + JWT,
expectedUserAuth: &nbcontext.UserAuth{
expectedUserAuth: &nbauth.UserAuth{
AccountId: "xyz",
UserId: userID,
Domain: testAccount.Domain,
@@ -570,13 +571,13 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
authMiddleware := NewAuthMiddleware(
mockAuth,
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
return userAuth.AccountId, userAuth.UserId, nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
func(ctx context.Context, userAuth nbauth.UserAuth) error {
return nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
nil,

View File

@@ -10,6 +10,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
@@ -18,8 +19,7 @@ import (
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/auth"
nbcontext "github.com/netbirdio/netbird/management/server/context"
serverauth "github.com/netbirdio/netbird/management/server/auth"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/groups"
http2 "github.com/netbirdio/netbird/management/server/http"
@@ -33,6 +33,7 @@ import (
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/users"
"github.com/netbirdio/netbird/shared/auth"
)
func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPeerUpdate *network_map.UpdateMessage, validateUpdate bool) (http.Handler, account.Manager, chan struct{}) {
@@ -71,14 +72,14 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
ctx := context.Background()
requestBuffer := server.NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsManager, "", port_forwarding.NewControllerMock())
am, err := server.BuildManager(ctx, store, networkMapController, nil, "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false)
am, err := server.BuildManager(ctx, nil, store, networkMapController, nil, "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false)
if err != nil {
t.Fatalf("Failed to create manager: %v", err)
}
// @note this is required so that PAT's validate from store, but JWT's are mocked
authManager := auth.NewManager(store, "", "", "", "", []string{}, false)
authManagerMock := &auth.MockManager{
authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false)
authManagerMock := &serverauth.MockManager{
ValidateAndParseTokenFunc: mockValidateAndParseToken,
EnsureUserAccessByJWTGroupsFunc: authManager.EnsureUserAccessByJWTGroups,
MarkPATUsedFunc: authManager.MarkPATUsed,
@@ -123,8 +124,8 @@ func peerShouldReceiveUpdate(t testing_tools.TB, updateMessage <-chan *network_m
}
}
func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserAuth, *jwt.Token, error) {
userAuth := nbcontext.UserAuth{}
func mockValidateAndParseToken(_ context.Context, token string) (auth.UserAuth, *jwt.Token, error) {
userAuth := auth.UserAuth{}
switch token {
case "testUserId", "testAdminId", "testOwnerId", "testServiceUserId", "testServiceAdminId", "blockedUserId":