mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
[client,management] Rewrite the SSH feature (#4015)
This commit is contained in:
@@ -18,6 +18,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
@@ -236,7 +237,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: adminUser.Id,
|
||||
AccountId: accountID,
|
||||
Domain: "hotmail.com",
|
||||
|
||||
@@ -9,9 +9,9 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// dnsSettingsHandler is a handler that returns the DNS settings of the account
|
||||
|
||||
@@ -11,13 +11,14 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
)
|
||||
@@ -107,7 +108,7 @@ func TestDNSSettingsHandlers(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: testingDNSSettingsAccount.Users[testDNSSettingsUserID].Id,
|
||||
AccountId: testingDNSSettingsAccount.Id,
|
||||
Domain: testingDNSSettingsAccount.Domain,
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
)
|
||||
@@ -193,7 +194,7 @@ func TestNameserversHandlers(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: "test_user",
|
||||
AccountId: testNSGroupAccountID,
|
||||
Domain: "hotmail.com",
|
||||
|
||||
@@ -14,11 +14,12 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
func initEventsTestData(account string, events ...*activity.Event) *handler {
|
||||
@@ -188,7 +189,7 @@ func TestEvents_GetEvents(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: "test_user",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_account",
|
||||
|
||||
@@ -11,10 +11,10 @@ import (
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// handler is a handler that returns groups of the account
|
||||
|
||||
@@ -19,12 +19,13 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
var TestPeers = map[string]*nbpeer.Peer{
|
||||
@@ -122,7 +123,7 @@ func TestGetGroup(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: "test_user",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_id",
|
||||
@@ -248,7 +249,7 @@ func TestWriteGroup(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: "test_user",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_id",
|
||||
@@ -330,7 +331,7 @@ func TestDeleteGroup(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: "test_user",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_id",
|
||||
|
||||
@@ -12,15 +12,15 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/networks"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources"
|
||||
"github.com/netbirdio/netbird/management/server/networks/routers"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
"github.com/netbirdio/netbird/management/server/networks/types"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
nbtypes "github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
// handler is a handler that returns networks of the account
|
||||
|
||||
@@ -8,10 +8,10 @@ import (
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
)
|
||||
|
||||
type resourceHandler struct {
|
||||
|
||||
@@ -7,10 +7,10 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/networks/routers"
|
||||
"github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
)
|
||||
|
||||
type routersHandler struct {
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -296,7 +297,7 @@ func TestGetPeers(t *testing.T) {
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: "admin_user",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_id",
|
||||
@@ -444,7 +445,7 @@ func TestGetAccessiblePeers(t *testing.T) {
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/api/peers/%s/accessible-peers", tc.peerID), nil)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: tc.callerUserID,
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_id",
|
||||
@@ -527,7 +528,7 @@ func TestPeersHandlerUpdatePeerIP(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodPut, fmt.Sprintf("/peers/%s", tc.peerID), bytes.NewBuffer([]byte(tc.requestBody)))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: tc.callerUserID,
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_id",
|
||||
|
||||
@@ -16,12 +16,13 @@ import (
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
@@ -113,7 +114,7 @@ func TestGetCitiesByCountry(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: "test_user",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_id",
|
||||
@@ -206,7 +207,7 @@ func TestGetAllCountries(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: "test_user",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_id",
|
||||
|
||||
@@ -9,11 +9,11 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
|
||||
@@ -10,10 +10,10 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// handler is a handler that returns policy of the account
|
||||
|
||||
@@ -14,10 +14,11 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
func initPoliciesTestData(policies ...*types.Policy) *handler {
|
||||
@@ -103,7 +104,7 @@ func TestPoliciesGetPolicy(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: "test_user",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_id",
|
||||
@@ -267,7 +268,7 @@ func TestPoliciesWritePolicy(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: "test_user",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_id",
|
||||
|
||||
@@ -9,9 +9,9 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
|
||||
@@ -16,9 +16,10 @@ import (
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
@@ -175,7 +176,7 @@ func TestGetPostureCheck(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/posture-checks/"+tc.id, tc.requestBody)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: "test_user",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_id",
|
||||
@@ -828,7 +829,7 @@ func TestPostureCheckUpdate(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: "test_user",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_id",
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -493,7 +494,7 @@ func TestRoutesHandlers(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: "test_user",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: testAccountID,
|
||||
|
||||
@@ -10,10 +10,10 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// handler is a handler that returns a list of setup keys of the account
|
||||
|
||||
@@ -15,10 +15,11 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -163,7 +164,7 @@ func TestSetupKeysHandlers(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: adminUser.Id,
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "testAccountId",
|
||||
|
||||
@@ -8,10 +8,10 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// patHandler is the nameserver group handler of the account
|
||||
|
||||
@@ -17,10 +17,11 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -173,7 +174,7 @@ func TestTokenHandlers(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: existingUserID,
|
||||
Domain: testDomain,
|
||||
AccountId: existingAccountID,
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/permissions/roles"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/users"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
@@ -128,7 +129,7 @@ func initUsersTestData() *handler {
|
||||
|
||||
return nil
|
||||
},
|
||||
GetCurrentUserInfoFunc: func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) {
|
||||
GetCurrentUserInfoFunc: func(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error) {
|
||||
switch userAuth.UserId {
|
||||
case "not-found":
|
||||
return nil, status.NewUserNotFoundError("not-found")
|
||||
@@ -225,7 +226,7 @@ func TestGetUsers(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: existingUserID,
|
||||
Domain: testDomain,
|
||||
AccountId: existingAccountID,
|
||||
@@ -335,7 +336,7 @@ func TestUpdateUser(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: existingUserID,
|
||||
Domain: testDomain,
|
||||
AccountId: existingAccountID,
|
||||
@@ -432,7 +433,7 @@ func TestCreateUser(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
rr := httptest.NewRecorder()
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: existingUserID,
|
||||
Domain: testDomain,
|
||||
AccountId: existingAccountID,
|
||||
@@ -481,7 +482,7 @@ func TestInviteUser(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
||||
req = mux.SetURLVars(req, tc.requestVars)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: existingUserID,
|
||||
Domain: testDomain,
|
||||
AccountId: existingAccountID,
|
||||
@@ -540,7 +541,7 @@ func TestDeleteUser(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
||||
req = mux.SetURLVars(req, tc.requestVars)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: existingUserID,
|
||||
Domain: testDomain,
|
||||
AccountId: existingAccountID,
|
||||
@@ -565,7 +566,7 @@ func TestCurrentUser(t *testing.T) {
|
||||
tt := []struct {
|
||||
name string
|
||||
expectedStatus int
|
||||
requestAuth nbcontext.UserAuth
|
||||
requestAuth auth.UserAuth
|
||||
expectedResult *api.User
|
||||
}{
|
||||
{
|
||||
@@ -574,27 +575,27 @@ func TestCurrentUser(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "user not found",
|
||||
requestAuth: nbcontext.UserAuth{UserId: "not-found"},
|
||||
requestAuth: auth.UserAuth{UserId: "not-found"},
|
||||
expectedStatus: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
name: "not of account",
|
||||
requestAuth: nbcontext.UserAuth{UserId: "not-of-account"},
|
||||
requestAuth: auth.UserAuth{UserId: "not-of-account"},
|
||||
expectedStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "blocked user",
|
||||
requestAuth: nbcontext.UserAuth{UserId: "blocked-user"},
|
||||
requestAuth: auth.UserAuth{UserId: "blocked-user"},
|
||||
expectedStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "service user",
|
||||
requestAuth: nbcontext.UserAuth{UserId: "service-user"},
|
||||
requestAuth: auth.UserAuth{UserId: "service-user"},
|
||||
expectedStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "owner",
|
||||
requestAuth: nbcontext.UserAuth{UserId: "owner"},
|
||||
requestAuth: auth.UserAuth{UserId: "owner"},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedResult: &api.User{
|
||||
Id: "owner",
|
||||
@@ -613,7 +614,7 @@ func TestCurrentUser(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "regular user",
|
||||
requestAuth: nbcontext.UserAuth{UserId: "regular-user"},
|
||||
requestAuth: auth.UserAuth{UserId: "regular-user"},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedResult: &api.User{
|
||||
Id: "regular-user",
|
||||
@@ -632,7 +633,7 @@ func TestCurrentUser(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "admin user",
|
||||
requestAuth: nbcontext.UserAuth{UserId: "admin-user"},
|
||||
requestAuth: auth.UserAuth{UserId: "admin-user"},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedResult: &api.User{
|
||||
Id: "admin-user",
|
||||
@@ -651,7 +652,7 @@ func TestCurrentUser(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "restricted user",
|
||||
requestAuth: nbcontext.UserAuth{UserId: "restricted-user"},
|
||||
requestAuth: auth.UserAuth{UserId: "restricted-user"},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedResult: &api.User{
|
||||
Id: "restricted-user",
|
||||
@@ -783,7 +784,7 @@ func TestApproveUserEndpoint(t *testing.T) {
|
||||
req, err := http.NewRequest("POST", "/users/pending-user/approve", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
userAuth := nbcontext.UserAuth{
|
||||
userAuth := auth.UserAuth{
|
||||
AccountId: existingAccountID,
|
||||
UserId: tc.requestingUser.Id,
|
||||
}
|
||||
@@ -841,7 +842,7 @@ func TestRejectUserEndpoint(t *testing.T) {
|
||||
req, err := http.NewRequest("DELETE", "/users/pending-user/reject", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
userAuth := nbcontext.UserAuth{
|
||||
userAuth := auth.UserAuth{
|
||||
AccountId: existingAccountID,
|
||||
UserId: tc.requestingUser.Id,
|
||||
}
|
||||
|
||||
@@ -10,22 +10,23 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/auth"
|
||||
serverauth "github.com/netbirdio/netbird/management/server/auth"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
type EnsureAccountFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error)
|
||||
type SyncUserJWTGroupsFunc func(ctx context.Context, userAuth nbcontext.UserAuth) error
|
||||
type EnsureAccountFunc func(ctx context.Context, userAuth auth.UserAuth) (string, string, error)
|
||||
type SyncUserJWTGroupsFunc func(ctx context.Context, userAuth auth.UserAuth) error
|
||||
|
||||
type GetUserFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error)
|
||||
type GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
|
||||
|
||||
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
|
||||
type AuthMiddleware struct {
|
||||
authManager auth.Manager
|
||||
authManager serverauth.Manager
|
||||
ensureAccount EnsureAccountFunc
|
||||
getUserFromUserAuth GetUserFromUserAuthFunc
|
||||
syncUserJWTGroups SyncUserJWTGroupsFunc
|
||||
@@ -34,7 +35,7 @@ type AuthMiddleware struct {
|
||||
|
||||
// NewAuthMiddleware instance constructor
|
||||
func NewAuthMiddleware(
|
||||
authManager auth.Manager,
|
||||
authManager serverauth.Manager,
|
||||
ensureAccount EnsureAccountFunc,
|
||||
syncUserJWTGroups SyncUserJWTGroupsFunc,
|
||||
getUserFromUserAuth GetUserFromUserAuthFunc,
|
||||
@@ -61,18 +62,18 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
||||
return
|
||||
}
|
||||
|
||||
auth := strings.Split(r.Header.Get("Authorization"), " ")
|
||||
authType := strings.ToLower(auth[0])
|
||||
authHeader := strings.Split(r.Header.Get("Authorization"), " ")
|
||||
authType := strings.ToLower(authHeader[0])
|
||||
|
||||
// fallback to token when receive pat as bearer
|
||||
if len(auth) >= 2 && authType == "bearer" && strings.HasPrefix(auth[1], "nbp_") {
|
||||
if len(authHeader) >= 2 && authType == "bearer" && strings.HasPrefix(authHeader[1], "nbp_") {
|
||||
authType = "token"
|
||||
auth[0] = authType
|
||||
authHeader[0] = authType
|
||||
}
|
||||
|
||||
switch authType {
|
||||
case "bearer":
|
||||
request, err := m.checkJWTFromRequest(r, auth)
|
||||
request, err := m.checkJWTFromRequest(r, authHeader)
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Errorf("Error when validating JWT: %s", err.Error())
|
||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
|
||||
@@ -81,7 +82,7 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
||||
|
||||
h.ServeHTTP(w, request)
|
||||
case "token":
|
||||
request, err := m.checkPATFromRequest(r, auth)
|
||||
request, err := m.checkPATFromRequest(r, authHeader)
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error())
|
||||
// Check if it's a status error, otherwise default to Unauthorized
|
||||
@@ -100,8 +101,8 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
||||
}
|
||||
|
||||
// CheckJWTFromRequest checks if the JWT is valid
|
||||
func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, auth []string) (*http.Request, error) {
|
||||
token, err := getTokenFromJWTRequest(auth)
|
||||
func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) {
|
||||
token, err := getTokenFromJWTRequest(authHeaderParts)
|
||||
|
||||
// If an error occurs, call the error handler and return an error
|
||||
if err != nil {
|
||||
@@ -151,8 +152,8 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, auth []string) (*h
|
||||
}
|
||||
|
||||
// CheckPATFromRequest checks if the PAT is valid
|
||||
func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, auth []string) (*http.Request, error) {
|
||||
token, err := getTokenFromPATRequest(auth)
|
||||
func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) {
|
||||
token, err := getTokenFromPATRequest(authHeaderParts)
|
||||
if err != nil {
|
||||
return r, fmt.Errorf("error extracting token: %w", err)
|
||||
}
|
||||
@@ -177,7 +178,7 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, auth []string) (*h
|
||||
return r, err
|
||||
}
|
||||
|
||||
userAuth := nbcontext.UserAuth{
|
||||
userAuth := auth.UserAuth{
|
||||
UserId: user.Id,
|
||||
AccountId: user.AccountID,
|
||||
Domain: accDomain,
|
||||
|
||||
@@ -12,11 +12,12 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/auth"
|
||||
nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
nbauth "github.com/netbirdio/netbird/shared/auth"
|
||||
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -75,9 +76,9 @@ func mockGetAccountInfoFromPAT(_ context.Context, token string) (user *types.Use
|
||||
return nil, nil, "", "", fmt.Errorf("PAT invalid")
|
||||
}
|
||||
|
||||
func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserAuth, *jwt.Token, error) {
|
||||
func mockValidateAndParseToken(_ context.Context, token string) (nbauth.UserAuth, *jwt.Token, error) {
|
||||
if token == JWT {
|
||||
return nbcontext.UserAuth{
|
||||
return nbauth.UserAuth{
|
||||
UserId: userID,
|
||||
AccountId: accountID,
|
||||
Domain: testAccount.Domain,
|
||||
@@ -91,7 +92,7 @@ func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserA
|
||||
Valid: true,
|
||||
}, nil
|
||||
}
|
||||
return nbcontext.UserAuth{}, nil, fmt.Errorf("JWT invalid")
|
||||
return nbauth.UserAuth{}, nil, fmt.Errorf("JWT invalid")
|
||||
}
|
||||
|
||||
func mockMarkPATUsed(_ context.Context, token string) error {
|
||||
@@ -101,7 +102,7 @@ func mockMarkPATUsed(_ context.Context, token string) error {
|
||||
return fmt.Errorf("Should never get reached")
|
||||
}
|
||||
|
||||
func mockEnsureUserAccessByJWTGroups(_ context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) {
|
||||
func mockEnsureUserAccessByJWTGroups(_ context.Context, userAuth nbauth.UserAuth, token *jwt.Token) (nbauth.UserAuth, error) {
|
||||
if userAuth.IsChild || userAuth.IsPAT {
|
||||
return userAuth, nil
|
||||
}
|
||||
@@ -197,13 +198,13 @@ func TestAuthMiddleware_Handler(t *testing.T) {
|
||||
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockAuth,
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
|
||||
return userAuth.AccountId, userAuth.UserId, nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) error {
|
||||
return nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
nil,
|
||||
@@ -255,13 +256,13 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockAuth,
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
|
||||
return userAuth.AccountId, userAuth.UserId, nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) error {
|
||||
return nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
@@ -306,13 +307,13 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockAuth,
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
|
||||
return userAuth.AccountId, userAuth.UserId, nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) error {
|
||||
return nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
@@ -348,13 +349,13 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockAuth,
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
|
||||
return userAuth.AccountId, userAuth.UserId, nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) error {
|
||||
return nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
@@ -391,13 +392,13 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockAuth,
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
|
||||
return userAuth.AccountId, userAuth.UserId, nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) error {
|
||||
return nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
@@ -454,13 +455,13 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockAuth,
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
|
||||
return userAuth.AccountId, userAuth.UserId, nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) error {
|
||||
return nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
@@ -508,13 +509,13 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
||||
name string
|
||||
path string
|
||||
authHeader string
|
||||
expectedUserAuth *nbcontext.UserAuth // nil expects 401 response status
|
||||
expectedUserAuth *nbauth.UserAuth // nil expects 401 response status
|
||||
}{
|
||||
{
|
||||
name: "Valid PAT Token",
|
||||
path: "/test",
|
||||
authHeader: "Token " + PAT,
|
||||
expectedUserAuth: &nbcontext.UserAuth{
|
||||
expectedUserAuth: &nbauth.UserAuth{
|
||||
AccountId: accountID,
|
||||
UserId: userID,
|
||||
Domain: testAccount.Domain,
|
||||
@@ -526,7 +527,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
||||
name: "Valid PAT Token accesses child",
|
||||
path: "/test?account=xyz",
|
||||
authHeader: "Token " + PAT,
|
||||
expectedUserAuth: &nbcontext.UserAuth{
|
||||
expectedUserAuth: &nbauth.UserAuth{
|
||||
AccountId: "xyz",
|
||||
UserId: userID,
|
||||
Domain: testAccount.Domain,
|
||||
@@ -539,7 +540,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
||||
name: "Valid JWT Token",
|
||||
path: "/test",
|
||||
authHeader: "Bearer " + JWT,
|
||||
expectedUserAuth: &nbcontext.UserAuth{
|
||||
expectedUserAuth: &nbauth.UserAuth{
|
||||
AccountId: accountID,
|
||||
UserId: userID,
|
||||
Domain: testAccount.Domain,
|
||||
@@ -551,7 +552,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
||||
name: "Valid JWT Token with child",
|
||||
path: "/test?account=xyz",
|
||||
authHeader: "Bearer " + JWT,
|
||||
expectedUserAuth: &nbcontext.UserAuth{
|
||||
expectedUserAuth: &nbauth.UserAuth{
|
||||
AccountId: "xyz",
|
||||
UserId: userID,
|
||||
Domain: testAccount.Domain,
|
||||
@@ -570,13 +571,13 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
||||
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockAuth,
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
|
||||
return userAuth.AccountId, userAuth.UserId, nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) error {
|
||||
return nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
nil,
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
@@ -18,8 +19,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/auth"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
serverauth "github.com/netbirdio/netbird/management/server/auth"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
http2 "github.com/netbirdio/netbird/management/server/http"
|
||||
@@ -33,6 +33,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/users"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
)
|
||||
|
||||
func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPeerUpdate *network_map.UpdateMessage, validateUpdate bool) (http.Handler, account.Manager, chan struct{}) {
|
||||
@@ -71,14 +72,14 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
||||
ctx := context.Background()
|
||||
requestBuffer := server.NewAccountRequestBuffer(ctx, store)
|
||||
networkMapController := controller.NewController(ctx, store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsManager, "", port_forwarding.NewControllerMock())
|
||||
am, err := server.BuildManager(ctx, store, networkMapController, nil, "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false)
|
||||
am, err := server.BuildManager(ctx, nil, store, networkMapController, nil, "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create manager: %v", err)
|
||||
}
|
||||
|
||||
// @note this is required so that PAT's validate from store, but JWT's are mocked
|
||||
authManager := auth.NewManager(store, "", "", "", "", []string{}, false)
|
||||
authManagerMock := &auth.MockManager{
|
||||
authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false)
|
||||
authManagerMock := &serverauth.MockManager{
|
||||
ValidateAndParseTokenFunc: mockValidateAndParseToken,
|
||||
EnsureUserAccessByJWTGroupsFunc: authManager.EnsureUserAccessByJWTGroups,
|
||||
MarkPATUsedFunc: authManager.MarkPATUsed,
|
||||
@@ -123,8 +124,8 @@ func peerShouldReceiveUpdate(t testing_tools.TB, updateMessage <-chan *network_m
|
||||
}
|
||||
}
|
||||
|
||||
func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserAuth, *jwt.Token, error) {
|
||||
userAuth := nbcontext.UserAuth{}
|
||||
func mockValidateAndParseToken(_ context.Context, token string) (auth.UserAuth, *jwt.Token, error) {
|
||||
userAuth := auth.UserAuth{}
|
||||
|
||||
switch token {
|
||||
case "testUserId", "testAdminId", "testOwnerId", "testServiceUserId", "testServiceAdminId", "blockedUserId":
|
||||
|
||||
Reference in New Issue
Block a user