diff --git a/.github/workflows/install-test-darwin.yml b/.github/workflows/install-test-darwin.yml new file mode 100644 index 000000000..cdf0cae5a --- /dev/null +++ b/.github/workflows/install-test-darwin.yml @@ -0,0 +1,58 @@ +name: Test installation Darwin + +on: + push: + branches: + - main + pull_request: + paths: + - "release_files/install.sh" + +jobs: + install-cli-only: + runs-on: macos-latest + steps: + - name: Checkout code + uses: actions/checkout@v2 + + - name: Rename brew package + if: ${{ matrix.check_bin_install }} + run: mv /opt/homebrew/bin/brew /opt/homebrew/bin/brew.bak + + - name: Run install script + run: | + sh ./release_files/install.sh + env: + SKIP_UI_APP: true + + - name: Run tests + run: | + if ! command -v netbird &> /dev/null; then + echo "Error: netbird is not installed" + exit 1 + fi + install-all: + runs-on: macos-latest + steps: + - name: Checkout code + uses: actions/checkout@v2 + + - name: Rename brew package + if: ${{ matrix.check_bin_install }} + run: mv /opt/homebrew/bin/brew /opt/homebrew/bin/brew.bak + + - name: Run install script + run: | + sh ./release_files/install.sh + + - name: Run tests + run: | + if ! command -v netbird &> /dev/null; then + echo "Error: netbird is not installed" + exit 1 + fi + + if [[ $(mdfind "kMDItemContentType == 'com.apple.application-bundle' && kMDItemFSName == '*NetBird UI.app'") ]]; then + echo "Error: NetBird UI is not installed" + exit 1 + fi diff --git a/.github/workflows/install-test-linux.yml b/.github/workflows/install-test-linux.yml new file mode 100644 index 000000000..d4246881c --- /dev/null +++ b/.github/workflows/install-test-linux.yml @@ -0,0 +1,36 @@ +name: Test installation Linux + +on: + push: + branches: + - main + pull_request: + paths: + - "release_files/install.sh" + +jobs: + install-cli-only: + runs-on: ubuntu-latest + strategy: + matrix: + check_bin_install: [true, false] + steps: + - name: Checkout code + uses: actions/checkout@v2 + + - name: Rename apt package + if: ${{ matrix.check_bin_install }} + run: | + sudo mv /usr/bin/apt /usr/bin/apt.bak + sudo mv /usr/bin/apt-get /usr/bin/apt-get.bak + + - name: Run install script + run: | + sh ./release_files/install.sh + + - name: Run tests + run: | + if ! command -v netbird &> /dev/null; then + echo "Error: netbird is not installed" + exit 1 + fi diff --git a/.github/workflows/test-docker-compose-linux.yml b/.github/workflows/test-docker-compose-linux.yml index d681dd89c..c28e94a4f 100644 --- a/.github/workflows/test-docker-compose-linux.yml +++ b/.github/workflows/test-docker-compose-linux.yml @@ -59,6 +59,10 @@ jobs: CI_NETBIRD_AUTH_TOKEN_ENDPOINT: https://example.eu.auth0.com/oauth/token CI_NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT: https://example.eu.auth0.com/oauth/device/code CI_NETBIRD_AUTH_REDIRECT_URI: "/peers" + CI_NETBIRD_TOKEN_SOURCE: "idToken" + CI_NETBIRD_AUTH_USER_ID_CLAIM: "email" + CI_NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE: "super" + run: | grep AUTH_CLIENT_ID docker-compose.yml | grep $CI_NETBIRD_AUTH_CLIENT_ID grep AUTH_AUTHORITY docker-compose.yml | grep $CI_NETBIRD_AUTH_AUTHORITY @@ -68,6 +72,10 @@ jobs: grep NETBIRD_MGMT_API_ENDPOINT docker-compose.yml | grep "$CI_NETBIRD_DOMAIN:33073" grep AUTH_REDIRECT_URI docker-compose.yml | grep $CI_NETBIRD_AUTH_REDIRECT_URI grep AUTH_SILENT_REDIRECT_URI docker-compose.yml | egrep 'AUTH_SILENT_REDIRECT_URI=$' + grep LETSENCRYPT_DOMAIN docker-compose.yml | egrep 'LETSENCRYPT_DOMAIN=$' + grep NETBIRD_TOKEN_SOURCE docker-compose.yml | grep $CI_NETBIRD_TOKEN_SOURCE + grep AuthUserIDClaim management.json | grep $CI_NETBIRD_AUTH_USER_ID_CLAIM + grep -A 1 ProviderConfig management.json | grep Audience | grep $CI_NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE - name: run docker compose up working-directory: infrastructure_files diff --git a/client/android/client.go b/client/android/client.go index ac16316ed..3a7c2c8dc 100644 --- a/client/android/client.go +++ b/client/android/client.go @@ -78,7 +78,7 @@ func (c *Client) Run(urlOpener URLOpener) error { c.ctxCancelLock.Unlock() auth := NewAuthWithConfig(ctx, cfg) - err = auth.Login(urlOpener) + err = auth.login(urlOpener) if err != nil { return err } @@ -118,12 +118,12 @@ func (c *Client) PeersList() *PeerInfoArray { return &PeerInfoArray{items: peerInfos} } -// AddConnectionListener add new network connection listener -func (c *Client) AddConnectionListener(listener ConnectionListener) { - c.recorder.AddConnectionListener(listener) +// SetConnectionListener set the network connection listener +func (c *Client) SetConnectionListener(listener ConnectionListener) { + c.recorder.SetConnectionListener(listener) } // RemoveConnectionListener remove connection listener -func (c *Client) RemoveConnectionListener(listener ConnectionListener) { - c.recorder.RemoveConnectionListener(listener) +func (c *Client) RemoveConnectionListener() { + c.recorder.RemoveConnectionListener() } diff --git a/client/android/login.go b/client/android/login.go index 4e2f1ab30..518942cb6 100644 --- a/client/android/login.go +++ b/client/android/login.go @@ -17,6 +17,18 @@ import ( "github.com/netbirdio/netbird/client/internal" ) +// SSOListener is async listener for mobile framework +type SSOListener interface { + OnSuccess(bool) + OnError(error) +} + +// ErrListener is async listener for mobile framework +type ErrListener interface { + OnSuccess() + OnError(error) +} + // URLOpener it is a callback interface. The Open function will be triggered if // the backend want to show an url for the user type URLOpener interface { @@ -59,7 +71,18 @@ func NewAuthWithConfig(ctx context.Context, config *internal.Config) *Auth { // SaveConfigIfSSOSupported test the connectivity with the management server by retrieving the server device flow info. // If it returns a flow info than save the configuration and return true. If it gets a codes.NotFound, it means that SSO // is not supported and returns false without saving the configuration. For other errors return false. -func (a *Auth) SaveConfigIfSSOSupported() (bool, error) { +func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) { + go func() { + sso, err := a.saveConfigIfSSOSupported() + if err != nil { + listener.OnError(err) + } else { + listener.OnSuccess(sso) + } + }() +} + +func (a *Auth) saveConfigIfSSOSupported() (bool, error) { supportsSSO := true err := a.withBackOff(a.ctx, func() (err error) { _, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL) @@ -83,7 +106,18 @@ func (a *Auth) SaveConfigIfSSOSupported() (bool, error) { } // LoginWithSetupKeyAndSaveConfig test the connectivity with the management server with the setup key. -func (a *Auth) LoginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error { +func (a *Auth) LoginWithSetupKeyAndSaveConfig(resultListener ErrListener, setupKey string, deviceName string) { + go func() { + err := a.loginWithSetupKeyAndSaveConfig(setupKey, deviceName) + if err != nil { + resultListener.OnError(err) + } else { + resultListener.OnSuccess() + } + }() +} + +func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error { //nolint ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName) @@ -103,7 +137,18 @@ func (a *Auth) LoginWithSetupKeyAndSaveConfig(setupKey string, deviceName string } // Login try register the client on the server -func (a *Auth) Login(urlOpener URLOpener) error { +func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener) { + go func() { + err := a.login(urlOpener) + if err != nil { + resultListener.OnError(err) + } else { + resultListener.OnSuccess() + } + }() +} + +func (a *Auth) login(urlOpener URLOpener) error { var needsLogin bool // check if we need to generate JWT token @@ -121,7 +166,7 @@ func (a *Auth) Login(urlOpener URLOpener) error { if err != nil { return fmt.Errorf("interactive sso login failed: %v", err) } - jwtToken = tokenInfo.AccessToken + jwtToken = tokenInfo.GetTokenToUse() } err = a.withBackOff(a.ctx, func() error { @@ -154,12 +199,7 @@ func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*internal.TokenInfo, } } - hostedClient := internal.NewHostedDeviceFlow( - providerConfig.ProviderConfig.Audience, - providerConfig.ProviderConfig.ClientID, - providerConfig.ProviderConfig.TokenEndpoint, - providerConfig.ProviderConfig.DeviceAuthEndpoint, - ) + hostedClient := internal.NewHostedDeviceFlow(providerConfig.ProviderConfig) flowInfo, err := hostedClient.RequestDeviceCode(context.TODO()) if err != nil { diff --git a/client/cmd/login.go b/client/cmd/login.go index 13b4b335c..92d69b6ee 100644 --- a/client/cmd/login.go +++ b/client/cmd/login.go @@ -135,7 +135,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.C if err != nil { return fmt.Errorf("interactive sso login failed: %v", err) } - jwtToken = tokenInfo.AccessToken + jwtToken = tokenInfo.GetTokenToUse() } err = WithBackOff(func() error { @@ -172,12 +172,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int } } - hostedClient := internal.NewHostedDeviceFlow( - providerConfig.ProviderConfig.Audience, - providerConfig.ProviderConfig.ClientID, - providerConfig.ProviderConfig.TokenEndpoint, - providerConfig.ProviderConfig.DeviceAuthEndpoint, - ) + hostedClient := internal.NewHostedDeviceFlow(providerConfig.ProviderConfig) flowInfo, err := hostedClient.RequestDeviceCode(context.TODO()) if err != nil { diff --git a/client/internal/connect.go b/client/internal/connect.go index 62b899d96..44f046191 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -164,6 +164,7 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status, state.Set(StatusConnected) <-engineCtx.Done() + statusRecorder.ClientTeardown() backOff.Reset() diff --git a/client/internal/device_auth.go b/client/internal/device_auth.go index d2396242b..0273bb8e4 100644 --- a/client/internal/device_auth.go +++ b/client/internal/device_auth.go @@ -34,6 +34,10 @@ type ProviderConfig struct { TokenEndpoint string // DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code DeviceAuthEndpoint string + // Scopes provides the scopes to be included in the token request + Scope string + // UseIDToken indicates if the id token should be used for authentication + UseIDToken bool } // GetDeviceAuthorizationFlowInfo initialize a DeviceAuthorizationFlow instance and return with it @@ -91,9 +95,16 @@ func GetDeviceAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmU Domain: protoDeviceAuthorizationFlow.GetProviderConfig().Domain, TokenEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(), DeviceAuthEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetDeviceAuthEndpoint(), + Scope: protoDeviceAuthorizationFlow.GetProviderConfig().GetScope(), + UseIDToken: protoDeviceAuthorizationFlow.GetProviderConfig().GetUseIDToken(), }, } + // keep compatibility with older management versions + if deviceAuthorizationFlow.ProviderConfig.Scope == "" { + deviceAuthorizationFlow.ProviderConfig.Scope = "openid" + } + err = isProviderConfigValid(deviceAuthorizationFlow.ProviderConfig) if err != nil { return DeviceAuthorizationFlow{}, err @@ -116,5 +127,8 @@ func isProviderConfigValid(config ProviderConfig) error { if config.DeviceAuthEndpoint == "" { return fmt.Errorf(errorMSGFormat, "Device Auth Endpoint") } + if config.Scope == "" { + return fmt.Errorf(errorMSGFormat, "Device Auth Scopes") + } return nil } diff --git a/client/internal/oauth.go b/client/internal/oauth.go index ae327a620..2d237925d 100644 --- a/client/internal/oauth.go +++ b/client/internal/oauth.go @@ -35,15 +35,6 @@ type DeviceAuthInfo struct { Interval int `json:"interval"` } -// TokenInfo holds information of issued access token -type TokenInfo struct { - AccessToken string `json:"access_token"` - RefreshToken string `json:"refresh_token"` - IDToken string `json:"id_token"` - TokenType string `json:"token_type"` - ExpiresIn int `json:"expires_in"` -} - // HostedGrantType grant type for device flow on Hosted const ( HostedGrantType = "urn:ietf:params:oauth:grant-type:device_code" @@ -52,16 +43,7 @@ const ( // Hosted client type Hosted struct { - // Hosted API Audience for validation - Audience string - // Hosted Native application client id - ClientID string - // Hosted Native application request scope - Scope string - // TokenEndpoint to request access token - TokenEndpoint string - // DeviceAuthEndpoint to request device authorization code - DeviceAuthEndpoint string + providerConfig ProviderConfig HTTPClient HTTPClient } @@ -70,7 +52,7 @@ type Hosted struct { type RequestDeviceCodePayload struct { Audience string `json:"audience"` ClientID string `json:"client_id"` - Scope string `json:"scope"` + Scope string `json:"scope"` } // TokenRequestPayload used for requesting the auth0 token @@ -93,8 +75,26 @@ type Claims struct { Audience interface{} `json:"aud"` } +// TokenInfo holds information of issued access token +type TokenInfo struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + IDToken string `json:"id_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + UseIDToken bool `json:"-"` +} + +// GetTokenToUse returns either the access or id token based on UseIDToken field +func (t TokenInfo) GetTokenToUse() string { + if t.UseIDToken { + return t.IDToken + } + return t.AccessToken +} + // NewHostedDeviceFlow returns an Hosted OAuth client -func NewHostedDeviceFlow(audience string, clientID string, tokenEndpoint string, deviceAuthEndpoint string) *Hosted { +func NewHostedDeviceFlow(config ProviderConfig) *Hosted { httpTransport := http.DefaultTransport.(*http.Transport).Clone() httpTransport.MaxIdleConns = 5 @@ -104,27 +104,23 @@ func NewHostedDeviceFlow(audience string, clientID string, tokenEndpoint string, } return &Hosted{ - Audience: audience, - ClientID: clientID, - Scope: "openid", - TokenEndpoint: tokenEndpoint, - HTTPClient: httpClient, - DeviceAuthEndpoint: deviceAuthEndpoint, + providerConfig: config, + HTTPClient: httpClient, } } // GetClientID returns the provider client id func (h *Hosted) GetClientID(ctx context.Context) string { - return h.ClientID + return h.providerConfig.ClientID } // RequestDeviceCode requests a device code login flow information from Hosted func (h *Hosted) RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error) { form := url.Values{} - form.Add("client_id", h.ClientID) - form.Add("audience", h.Audience) - form.Add("scope", h.Scope) - req, err := http.NewRequest("POST", h.DeviceAuthEndpoint, + form.Add("client_id", h.providerConfig.ClientID) + form.Add("audience", h.providerConfig.Audience) + form.Add("scope", h.providerConfig.Scope) + req, err := http.NewRequest("POST", h.providerConfig.DeviceAuthEndpoint, strings.NewReader(form.Encode())) if err != nil { return DeviceAuthInfo{}, fmt.Errorf("creating request failed with error: %v", err) @@ -157,10 +153,10 @@ func (h *Hosted) RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error) func (h *Hosted) requestToken(info DeviceAuthInfo) (TokenRequestResponse, error) { form := url.Values{} - form.Add("client_id", h.ClientID) + form.Add("client_id", h.providerConfig.ClientID) form.Add("grant_type", HostedGrantType) form.Add("device_code", info.DeviceCode) - req, err := http.NewRequest("POST", h.TokenEndpoint, strings.NewReader(form.Encode())) + req, err := http.NewRequest("POST", h.providerConfig.TokenEndpoint, strings.NewReader(form.Encode())) if err != nil { return TokenRequestResponse{}, fmt.Errorf("failed to create request access token: %v", err) } @@ -225,18 +221,20 @@ func (h *Hosted) WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo, return TokenInfo{}, fmt.Errorf(tokenResponse.ErrorDescription) } - err = isValidAccessToken(tokenResponse.AccessToken, h.Audience) - if err != nil { - return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err) - } - tokenInfo := TokenInfo{ AccessToken: tokenResponse.AccessToken, TokenType: tokenResponse.TokenType, RefreshToken: tokenResponse.RefreshToken, IDToken: tokenResponse.IDToken, ExpiresIn: tokenResponse.ExpiresIn, + UseIDToken: h.providerConfig.UseIDToken, } + + err = isValidAccessToken(tokenInfo.GetTokenToUse(), h.providerConfig.Audience) + if err != nil { + return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err) + } + return tokenInfo, err } } diff --git a/client/internal/oauth_test.go b/client/internal/oauth_test.go index 3a9e2a0c2..aa71fa0eb 100644 --- a/client/internal/oauth_test.go +++ b/client/internal/oauth_test.go @@ -3,14 +3,15 @@ package internal import ( "context" "fmt" - "github.com/golang-jwt/jwt" - "github.com/stretchr/testify/require" "io" "net/http" "net/url" "strings" "testing" "time" + + "github.com/golang-jwt/jwt" + "github.com/stretchr/testify/require" ) type mockHTTPClient struct { @@ -113,12 +114,15 @@ func TestHosted_RequestDeviceCode(t *testing.T) { } hosted := Hosted{ - Audience: expectedAudience, - ClientID: expectedClientID, - Scope: expectedScope, - TokenEndpoint: "test.hosted.com/token", - DeviceAuthEndpoint: "test.hosted.com/device/auth", - HTTPClient: &httpClient, + providerConfig: ProviderConfig{ + Audience: expectedAudience, + ClientID: expectedClientID, + Scope: expectedScope, + TokenEndpoint: "test.hosted.com/token", + DeviceAuthEndpoint: "test.hosted.com/device/auth", + UseIDToken: false, + }, + HTTPClient: &httpClient, } authInfo, err := hosted.RequestDeviceCode(context.TODO()) @@ -275,12 +279,15 @@ func TestHosted_WaitToken(t *testing.T) { } hosted := Hosted{ - Audience: testCase.inputAudience, - ClientID: clientID, - TokenEndpoint: "test.hosted.com/token", - DeviceAuthEndpoint: "test.hosted.com/device/auth", - HTTPClient: &httpClient, - } + providerConfig: ProviderConfig{ + Audience: testCase.inputAudience, + ClientID: clientID, + TokenEndpoint: "test.hosted.com/token", + DeviceAuthEndpoint: "test.hosted.com/device/auth", + Scope: "openid", + UseIDToken: false, + }, + HTTPClient: &httpClient} ctx, cancel := context.WithTimeout(context.TODO(), testCase.inputTimeout) defer cancel() diff --git a/client/internal/peer/listener.go b/client/internal/peer/listener.go index c8dc0fe70..c601fe534 100644 --- a/client/internal/peer/listener.go +++ b/client/internal/peer/listener.go @@ -5,6 +5,7 @@ type Listener interface { OnConnected() OnDisconnected() OnConnecting() + OnDisconnecting() OnAddressChanged(string, string) OnPeersListChanged(int) } diff --git a/client/internal/peer/notifier.go b/client/internal/peer/notifier.go index efc9e47ad..b2d324c6c 100644 --- a/client/internal/peer/notifier.go +++ b/client/internal/peer/notifier.go @@ -8,37 +8,37 @@ const ( stateDisconnected = iota stateConnected stateConnecting + stateDisconnecting ) type notifier struct { serverStateLock sync.Mutex listenersLock sync.Mutex - listeners map[Listener]struct{} + listener Listener currentServerState bool currentClientState bool lastNotification int } func newNotifier() *notifier { - return ¬ifier{ - listeners: make(map[Listener]struct{}), - } + return ¬ifier{} } -func (n *notifier) addListener(listener Listener) { +func (n *notifier) setListener(listener Listener) { n.listenersLock.Lock() defer n.listenersLock.Unlock() n.serverStateLock.Lock() - go n.notifyListener(listener, n.lastNotification) + n.notifyListener(listener, n.lastNotification) n.serverStateLock.Unlock() - n.listeners[listener] = struct{}{} + + n.listener = listener } -func (n *notifier) removeListener(listener Listener) { +func (n *notifier) removeListener() { n.listenersLock.Lock() defer n.listenersLock.Unlock() - delete(n.listeners, listener) + n.listener = nil } func (n *notifier) updateServerStates(mgmState bool, signalState bool) { @@ -57,9 +57,13 @@ func (n *notifier) updateServerStates(mgmState bool, signalState bool) { } n.currentServerState = newState - n.lastNotification = n.calculateState(newState, n.currentClientState) - go n.notifyAll(n.lastNotification) + if n.lastNotification == stateDisconnecting { + return + } + + n.lastNotification = n.calculateState(newState, n.currentClientState) + n.notify(n.lastNotification) } func (n *notifier) clientStart() { @@ -67,7 +71,7 @@ func (n *notifier) clientStart() { defer n.serverStateLock.Unlock() n.currentClientState = true n.lastNotification = n.calculateState(n.currentServerState, true) - go n.notifyAll(n.lastNotification) + n.notify(n.lastNotification) } func (n *notifier) clientStop() { @@ -75,31 +79,43 @@ func (n *notifier) clientStop() { defer n.serverStateLock.Unlock() n.currentClientState = false n.lastNotification = n.calculateState(n.currentServerState, false) - go n.notifyAll(n.lastNotification) + n.notify(n.lastNotification) +} + +func (n *notifier) clientTearDown() { + n.serverStateLock.Lock() + defer n.serverStateLock.Unlock() + n.currentClientState = false + n.lastNotification = stateDisconnecting + n.notify(n.lastNotification) } func (n *notifier) isServerStateChanged(newState bool) bool { return n.currentServerState != newState } -func (n *notifier) notifyAll(state int) { +func (n *notifier) notify(state int) { n.listenersLock.Lock() defer n.listenersLock.Unlock() - - for l := range n.listeners { - n.notifyListener(l, state) + if n.listener == nil { + return } + n.notifyListener(n.listener, state) } func (n *notifier) notifyListener(l Listener, state int) { - switch state { - case stateDisconnected: - l.OnDisconnected() - case stateConnected: - l.OnConnected() - case stateConnecting: - l.OnConnecting() - } + go func() { + switch state { + case stateDisconnected: + l.OnDisconnected() + case stateConnected: + l.OnConnected() + case stateConnecting: + l.OnConnecting() + case stateDisconnecting: + l.OnDisconnecting() + } + }() } func (n *notifier) calculateState(serverState bool, clientState bool) int { @@ -117,17 +133,17 @@ func (n *notifier) calculateState(serverState bool, clientState bool) int { func (n *notifier) peerListChanged(numOfPeers int) { n.listenersLock.Lock() defer n.listenersLock.Unlock() - - for l := range n.listeners { - l.OnPeersListChanged(numOfPeers) + if n.listener == nil { + return } + n.listener.OnPeersListChanged(numOfPeers) } func (n *notifier) localAddressChanged(fqdn, address string) { n.listenersLock.Lock() defer n.listenersLock.Unlock() - - for l := range n.listeners { - l.OnAddressChanged(fqdn, address) + if n.listener == nil { + return } + n.listener.OnAddressChanged(fqdn, address) } diff --git a/client/internal/peer/notifier_test.go b/client/internal/peer/notifier_test.go index f21193e06..a9045ac34 100644 --- a/client/internal/peer/notifier_test.go +++ b/client/internal/peer/notifier_test.go @@ -1,9 +1,48 @@ package peer import ( + "sync" "testing" ) +type mocListener struct { + lastState int + wg sync.WaitGroup + peers int +} + +func (l *mocListener) OnConnected() { + l.lastState = stateConnected + l.wg.Done() +} +func (l *mocListener) OnDisconnected() { + l.lastState = stateDisconnected + l.wg.Done() +} +func (l *mocListener) OnConnecting() { + l.lastState = stateConnecting + l.wg.Done() +} +func (l *mocListener) OnDisconnecting() { + l.lastState = stateDisconnecting + l.wg.Done() +} + +func (l *mocListener) OnAddressChanged(host, addr string) { + +} +func (l *mocListener) OnPeersListChanged(size int) { + l.peers = size +} + +func (l *mocListener) setWaiter() { + l.wg.Add(1) +} + +func (l *mocListener) wait() { + l.wg.Wait() +} + func Test_notifier_serverState(t *testing.T) { type scenario struct { @@ -30,3 +69,30 @@ func Test_notifier_serverState(t *testing.T) { }) } } + +func Test_notifier_SetListener(t *testing.T) { + listener := &mocListener{} + listener.setWaiter() + + n := newNotifier() + n.lastNotification = stateConnecting + n.setListener(listener) + listener.wait() + if listener.lastState != n.lastNotification { + t.Errorf("invalid state: %d, expected: %d", listener.lastState, n.lastNotification) + } +} + +func Test_notifier_RemoveListener(t *testing.T) { + listener := &mocListener{} + listener.setWaiter() + n := newNotifier() + n.lastNotification = stateConnecting + n.setListener(listener) + n.removeListener() + n.peerListChanged(1) + + if listener.peers != 0 { + t.Errorf("invalid state: %d", listener.peers) + } +} diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index 1ecdff301..508131816 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -288,14 +288,19 @@ func (d *Status) ClientStop() { d.notifier.clientStop() } -// AddConnectionListener add a listener to the notifier -func (d *Status) AddConnectionListener(listener Listener) { - d.notifier.addListener(listener) +// ClientTeardown will notify all listeners about the service is under teardown +func (d *Status) ClientTeardown() { + d.notifier.clientTearDown() } -// RemoveConnectionListener remove a listener from the notifier -func (d *Status) RemoveConnectionListener(listener Listener) { - d.notifier.removeListener(listener) +// SetConnectionListener set a listener to the notifier +func (d *Status) SetConnectionListener(listener Listener) { + d.notifier.setListener(listener) +} + +// RemoveConnectionListener remove the listener from the notifier +func (d *Status) RemoveConnectionListener() { + d.notifier.removeListener() } func (d *Status) onConnectionChanged() { diff --git a/client/server/server.go b/client/server/server.go index 6d5a08c59..fba82c7e4 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -223,12 +223,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro } } - hostedClient := internal.NewHostedDeviceFlow( - providerConfig.ProviderConfig.Audience, - providerConfig.ProviderConfig.ClientID, - providerConfig.ProviderConfig.TokenEndpoint, - providerConfig.ProviderConfig.DeviceAuthEndpoint, - ) + hostedClient := internal.NewHostedDeviceFlow(providerConfig.ProviderConfig) if s.oauthAuthFlow.client != nil && s.oauthAuthFlow.client.GetClientID(ctx) == hostedClient.GetClientID(context.TODO()) { if s.oauthAuthFlow.expiresAt.After(time.Now().Add(90 * time.Second)) { @@ -344,7 +339,7 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin s.oauthAuthFlow.expiresAt = time.Now() s.mutex.Unlock() - if loginStatus, err := s.loginAttempt(ctx, "", tokenInfo.AccessToken); err != nil { + if loginStatus, err := s.loginAttempt(ctx, "", tokenInfo.GetTokenToUse()); err != nil { state.Set(loginStatus) return nil, err } diff --git a/encryption/encryption.go b/encryption/encryption.go index 196c42106..1c6ec7806 100644 --- a/encryption/encryption.go +++ b/encryption/encryption.go @@ -3,10 +3,13 @@ package encryption import ( "crypto/rand" "fmt" + "golang.org/x/crypto/nacl/box" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) +const nonceSize = 24 + // A set of tools to encrypt/decrypt messages being sent through the Signal Exchange Service or Management Service // These tools use Golang crypto package (Curve25519, XSalsa20 and Poly1305 to encrypt and authenticate) // Wireguard keys are used for encryption @@ -26,8 +29,11 @@ func Decrypt(encryptedMsg []byte, peerPublicKey wgtypes.Key, privateKey wgtypes. if err != nil { return nil, err } - copy(nonce[:], encryptedMsg[:24]) - opened, ok := box.Open(nil, encryptedMsg[24:], nonce, toByte32(peerPublicKey), toByte32(privateKey)) + if len(encryptedMsg) < nonceSize { + return nil, fmt.Errorf("invalid encrypted message lenght") + } + copy(nonce[:], encryptedMsg[:nonceSize]) + opened, ok := box.Open(nil, encryptedMsg[nonceSize:], nonce, toByte32(peerPublicKey), toByte32(privateKey)) if !ok { return nil, fmt.Errorf("failed to decrypt message from peer %s", peerPublicKey.String()) } @@ -36,8 +42,8 @@ func Decrypt(encryptedMsg []byte, peerPublicKey wgtypes.Key, privateKey wgtypes. } // Generates nonce of size 24 -func genNonce() (*[24]byte, error) { - var nonce [24]byte +func genNonce() (*[nonceSize]byte, error) { + var nonce [nonceSize]byte if _, err := rand.Read(nonce[:]); err != nil { return nil, err } diff --git a/go.mod b/go.mod index 64e0fbdc9..f34004c4e 100644 --- a/go.mod +++ b/go.mod @@ -37,6 +37,7 @@ require ( github.com/getlantern/systray v1.2.1 github.com/gliderlabs/ssh v0.3.4 github.com/godbus/dbus/v5 v5.1.0 + github.com/google/go-cmp v0.5.9 github.com/google/nftables v0.0.0-20220808154552-2eca00135732 github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 github.com/hashicorp/go-version v1.6.0 @@ -57,6 +58,7 @@ require ( go.opentelemetry.io/otel/exporters/prometheus v0.33.0 go.opentelemetry.io/otel/metric v0.33.0 go.opentelemetry.io/otel/sdk/metric v0.33.0 + golang.org/x/exp v0.0.0-20220518171630-0b5c67f07fdf golang.org/x/net v0.8.0 golang.org/x/sync v0.1.0 golang.org/x/term v0.6.0 @@ -91,7 +93,6 @@ require ( github.com/go-stack/stack v1.8.0 // indirect github.com/gobwas/glob v0.2.3 // indirect github.com/goki/freetype v0.0.0-20181231101311-fa8a33aabaff // indirect - github.com/google/go-cmp v0.5.9 // indirect github.com/google/gopacket v1.1.19 // indirect github.com/hashicorp/go-uuid v1.0.2 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect @@ -126,7 +127,6 @@ require ( go.opentelemetry.io/otel v1.11.1 // indirect go.opentelemetry.io/otel/sdk v1.11.1 // indirect go.opentelemetry.io/otel/trace v1.11.1 // indirect - golang.org/x/exp v0.0.0-20220518171630-0b5c67f07fdf // indirect golang.org/x/image v0.0.0-20200430140353-33d19683fad8 // indirect golang.org/x/mod v0.8.0 // indirect golang.org/x/text v0.8.0 // indirect diff --git a/iface/ipc_parser_android.go b/iface/ipc_parser_android.go index ef757a638..e1dd66856 100644 --- a/iface/ipc_parser_android.go +++ b/iface/ipc_parser_android.go @@ -33,7 +33,7 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string { if p.PresharedKey != nil { preSharedHexKey := hex.EncodeToString(p.PresharedKey[:]) - sb.WriteString(fmt.Sprintf("public_key=%s\n", preSharedHexKey)) + sb.WriteString(fmt.Sprintf("preshared_key=%s\n", preSharedHexKey)) } if p.Remove { diff --git a/infrastructure_files/base.setup.env b/infrastructure_files/base.setup.env index 2d74e3a66..8fa58ffc3 100644 --- a/infrastructure_files/base.setup.env +++ b/infrastructure_files/base.setup.env @@ -3,26 +3,30 @@ # Management API # Management API port -NETBIRD_MGMT_API_PORT=33073 +NETBIRD_MGMT_API_PORT=${NETBIRD_MGMT_API_PORT:-33073} # Management API endpoint address, used by the Dashboard NETBIRD_MGMT_API_ENDPOINT=https://$NETBIRD_DOMAIN:$NETBIRD_MGMT_API_PORT # Management Certficate file path. These are generated by the Dashboard container -NETBIRD_MGMT_API_CERT_FILE="/etc/letsencrypt/live/$NETBIRD_DOMAIN/fullchain.pem" +NETBIRD_MGMT_API_CERT_FILE="/etc/letsencrypt/live/$NETBIRD_LETSENCRYPT_DOMAIN/fullchain.pem" # Management Certficate key file path. -NETBIRD_MGMT_API_CERT_KEY_FILE="/etc/letsencrypt/live/$NETBIRD_DOMAIN/privkey.pem" +NETBIRD_MGMT_API_CERT_KEY_FILE="/etc/letsencrypt/live/$NETBIRD_LETSENCRYPT_DOMAIN/privkey.pem" # By default Management single account mode is enabled and domain set to $NETBIRD_DOMAIN, you may want to set this to your user's email domain NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN=$NETBIRD_DOMAIN NETBIRD_MGMT_DNS_DOMAIN=${NETBIRD_MGMT_DNS_DOMAIN:-netbird.selfhosted} -# Turn credentials +# Signal +NETBIRD_SIGNAL_PROTOCOL="http" +NETBIRD_SIGNAL_PORT=${NETBIRD_SIGNAL_PORT:-10000} + +# Turn credentials # User TURN_USER=self # Password. If empty, the configure.sh will generate one with openssl TURN_PASSWORD= # Min port -TURN_MIN_PORT=49152 +TURN_MIN_PORT=${TURN_MIN_PORT:-49152} # Max port -TURN_MAX_PORT=65535 +TURN_MAX_PORT=${TURN_MAX_PORT:-65535} VOLUME_PREFIX="netbird-" MGMT_VOLUMESUFFIX="mgmt" @@ -32,6 +36,8 @@ LETSENCRYPT_VOLUMESUFFIX="letsencrypt" NETBIRD_AUTH_DEVICE_AUTH_PROVIDER="none" NETBIRD_DISABLE_ANONYMOUS_METRICS=${NETBIRD_DISABLE_ANONYMOUS_METRICS:-false} +NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE=${NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE:-$NETBIRD_AUTH_AUDIENCE} +NETBIRD_TOKEN_SOURCE=${NETBIRD_TOKEN_SOURCE:-accessToken} # exports export NETBIRD_DOMAIN @@ -61,4 +67,9 @@ export SIGNAL_VOLUMESUFFIX export LETSENCRYPT_VOLUMESUFFIX export NETBIRD_DISABLE_ANONYMOUS_METRICS export NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN -export NETBIRD_MGMT_DNS_DOMAIN \ No newline at end of file +export NETBIRD_MGMT_DNS_DOMAIN +export NETBIRD_SIGNAL_PROTOCOL +export NETBIRD_SIGNAL_PORT +export NETBIRD_AUTH_USER_ID_CLAIM +export NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE +export NETBIRD_TOKEN_SOURCE \ No newline at end of file diff --git a/infrastructure_files/configure.sh b/infrastructure_files/configure.sh index ed6367171..501098a57 100755 --- a/infrastructure_files/configure.sh +++ b/infrastructure_files/configure.sh @@ -121,6 +121,32 @@ if [[ ! -z "${NETBIRD_AUTH_DEVICE_AUTH_CLIENT_ID}" ]]; then export NETBIRD_AUTH_DEVICE_AUTH_PROVIDER="hosted" fi +# Check if letsencrypt was disabled +if [[ "$NETBIRD_DISABLE_LETSENCRYPT" == "true" ]] +then + export NETBIRD_DASHBOARD_ENDPOINT="https://$NETBIRD_DOMAIN:443" + export NETBIRD_SIGNAL_ENDPOINT="https://$NETBIRD_DOMAIN:$NETBIRD_SIGNAL_PORT" + + echo "Letsencrypt was disabled, the Https-endpoints cannot be used anymore" + echo " and a reverse-proxy with Https needs to be placed in front of netbird!" + echo "The following forwards have to be setup:" + echo "- $NETBIRD_DASHBOARD_ENDPOINT -http-> dashboard:80" + echo "- $NETBIRD_MGMT_API_ENDPOINT/api -http-> management:$NETBIRD_MGMT_API_PORT" + echo "- $NETBIRD_MGMT_API_ENDPOINT/management.ManagementService/ -grpc-> management:$NETBIRD_MGMT_API_PORT" + echo "- $NETBIRD_SIGNAL_ENDPOINT/signalexchange.SignalExchange/ -grpc-> signal:80" + echo "You most likely also have to change NETBIRD_MGMT_API_ENDPOINT in base.setup.env and port-mappings in docker-compose.yml.tmpl and rerun this script." + echo " The target of the forwards depends on your setup. Beware of the gRPC protocol instead of http for management and signal!" + echo "You are also free to remove any occurences of the Letsencrypt-volume $LETSENCRYPT_VOLUMENAME" + echo "" + + export NETBIRD_SIGNAL_PROTOCOL="https" + unset NETBIRD_LETSENCRYPT_DOMAIN + unset NETBIRD_MGMT_API_CERT_FILE + unset NETBIRD_MGMT_API_CERT_KEY_FILE +else + export NETBIRD_LETSENCRYPT_DOMAIN="$NETBIRD_DOMAIN" +fi + env | grep NETBIRD envsubst < docker-compose.yml.tmpl > docker-compose.yml diff --git a/infrastructure_files/docker-compose.yml.tmpl b/infrastructure_files/docker-compose.yml.tmpl index 296201710..af7f1af00 100644 --- a/infrastructure_files/docker-compose.yml.tmpl +++ b/infrastructure_files/docker-compose.yml.tmpl @@ -8,20 +8,26 @@ services: - 80:80 - 443:443 environment: + # Endpoints + - NETBIRD_MGMT_API_ENDPOINT=$NETBIRD_MGMT_API_ENDPOINT + - NETBIRD_MGMT_GRPC_API_ENDPOINT=$NETBIRD_MGMT_API_ENDPOINT + # OIDC - AUTH_AUDIENCE=$NETBIRD_AUTH_AUDIENCE - AUTH_CLIENT_ID=$NETBIRD_AUTH_CLIENT_ID - AUTH_AUTHORITY=$NETBIRD_AUTH_AUTHORITY - USE_AUTH0=$NETBIRD_USE_AUTH0 - AUTH_SUPPORTED_SCOPES=$NETBIRD_AUTH_SUPPORTED_SCOPES - - NETBIRD_MGMT_API_ENDPOINT=$NETBIRD_MGMT_API_ENDPOINT - - NETBIRD_MGMT_GRPC_API_ENDPOINT=$NETBIRD_MGMT_API_ENDPOINT - - NGINX_SSL_PORT=443 - - LETSENCRYPT_DOMAIN=$NETBIRD_DOMAIN - - LETSENCRYPT_EMAIL=$NETBIRD_LETSENCRYPT_EMAIL - AUTH_REDIRECT_URI=$NETBIRD_AUTH_REDIRECT_URI - AUTH_SILENT_REDIRECT_URI=$NETBIRD_AUTH_SILENT_REDIRECT_URI + - NETBIRD_TOKEN_SOURCE=$NETBIRD_TOKEN_SOURCE + # SSL + - NGINX_SSL_PORT=443 + # Letsencrypt + - LETSENCRYPT_DOMAIN=$NETBIRD_LETSENCRYPT_DOMAIN + - LETSENCRYPT_EMAIL=$NETBIRD_LETSENCRYPT_EMAIL volumes: - $LETSENCRYPT_VOLUMENAME:/etc/letsencrypt/ + # Signal signal: image: netbirdio/signal:latest @@ -32,7 +38,8 @@ services: - 10000:80 # # port and command for Let's Encrypt validation # - 443:443 - # command: ["--letsencrypt-domain", "$NETBIRD_DOMAIN", "--log-file", "console"] + # command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"] + # Management management: image: netbirdio/management:latest @@ -46,8 +53,15 @@ services: ports: - $NETBIRD_MGMT_API_PORT:443 #API port # # command for Let's Encrypt validation without dashboard container - # command: ["--letsencrypt-domain", "$NETBIRD_DOMAIN", "--log-file", "console"] - command: ["--port", "443", "--log-file", "console", "--disable-anonymous-metrics=$NETBIRD_DISABLE_ANONYMOUS_METRICS", "--single-account-mode-domain=$NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN", "--dns-domain=$NETBIRD_MGMT_DNS_DOMAIN"] + # command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"] + command: [ + "--port", "443", + "--log-file", "console", + "--disable-anonymous-metrics=$NETBIRD_DISABLE_ANONYMOUS_METRICS", + "--single-account-mode-domain=$NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN", + "--dns-domain=$NETBIRD_MGMT_DNS_DOMAIN" + ] + # Coturn coturn: image: coturn/coturn @@ -60,6 +74,7 @@ services: network_mode: host command: - -c /etc/turnserver.conf + volumes: $MGMT_VOLUMENAME: $SIGNAL_VOLUMENAME: diff --git a/infrastructure_files/docker-compose.yml.tmpl.traefik b/infrastructure_files/docker-compose.yml.tmpl.traefik new file mode 100644 index 000000000..9c1e0fd03 --- /dev/null +++ b/infrastructure_files/docker-compose.yml.tmpl.traefik @@ -0,0 +1,99 @@ +version: "3" +services: + #UI dashboard + dashboard: + image: wiretrustee/dashboard:latest + restart: unless-stopped + #ports: + # - 80:80 + # - 443:443 + environment: + # Endpoints + - NETBIRD_MGMT_API_ENDPOINT=$NETBIRD_MGMT_API_ENDPOINT + - NETBIRD_MGMT_GRPC_API_ENDPOINT=$NETBIRD_MGMT_API_ENDPOINT + # OIDC + - AUTH_AUDIENCE=$NETBIRD_AUTH_AUDIENCE + - AUTH_CLIENT_ID=$NETBIRD_AUTH_CLIENT_ID + - AUTH_AUTHORITY=$NETBIRD_AUTH_AUTHORITY + - USE_AUTH0=$NETBIRD_USE_AUTH0 + - AUTH_SUPPORTED_SCOPES=$NETBIRD_AUTH_SUPPORTED_SCOPES + - AUTH_REDIRECT_URI=$NETBIRD_AUTH_REDIRECT_URI + - AUTH_SILENT_REDIRECT_URI=$NETBIRD_AUTH_SILENT_REDIRECT_URI + # SSL + - NGINX_SSL_PORT=443 + # Letsencrypt + - LETSENCRYPT_DOMAIN=$NETBIRD_LETSENCRYPT_DOMAIN + - LETSENCRYPT_EMAIL=$NETBIRD_LETSENCRYPT_EMAIL + volumes: + - $LETSENCRYPT_VOLUMENAME:/etc/letsencrypt/ + labels: + - traefik.enable=true + - traefik.http.routers.netbird-dashboard.rule=Host(`$NETBIRD_DOMAIN`) + - traefik.http.services.netbird-dashboard.loadbalancer.server.port=80 + + # Signal + signal: + image: netbirdio/signal:latest + restart: unless-stopped + volumes: + - $SIGNAL_VOLUMENAME:/var/lib/netbird + #ports: + # - 10000:80 + # # port and command for Let's Encrypt validation + # - 443:443 + # command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"] + labels: + - traefik.enable=true + - traefik.http.routers.netbird-signal.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/signalexchange.SignalExchange/`) + - traefik.http.services.netbird-signal.loadbalancer.server.port=80 + - traefik.http.services.netbird-signal.loadbalancer.server.scheme=h2c + + # Management + management: + image: netbirdio/management:latest + restart: unless-stopped + depends_on: + - dashboard + volumes: + - $MGMT_VOLUMENAME:/var/lib/netbird + - $LETSENCRYPT_VOLUMENAME:/etc/letsencrypt:ro + - ./management.json:/etc/netbird/management.json + #ports: + # - $NETBIRD_MGMT_API_PORT:443 #API port + # # command for Let's Encrypt validation without dashboard container + # command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"] + command: [ + "--port", "443", + "--log-file", "console", + "--disable-anonymous-metrics=$NETBIRD_DISABLE_ANONYMOUS_METRICS", + "--single-account-mode-domain=$NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN", + "--dns-domain=$NETBIRD_MGMT_DNS_DOMAIN" + ] + labels: + - traefik.enable=true + - traefik.http.routers.netbird-api.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/api`) + - traefik.http.routers.netbird-api.service=netbird-api + - traefik.http.services.netbird-api.loadbalancer.server.port=443 + + - traefik.http.routers.netbird-management.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/management.ManagementService/`) + - traefik.http.routers.netbird-management.service=netbird-management + - traefik.http.services.netbird-management.loadbalancer.server.port=443 + - traefik.http.services.netbird-management.loadbalancer.server.scheme=h2c + + # Coturn + coturn: + image: coturn/coturn + restart: unless-stopped + domainname: $NETBIRD_DOMAIN + volumes: + - ./turnserver.conf:/etc/turnserver.conf:ro + # - ./privkey.pem:/etc/coturn/private/privkey.pem:ro + # - ./cert.pem:/etc/coturn/certs/cert.pem:ro + network_mode: host + command: + - -c /etc/turnserver.conf + +volumes: + $MGMT_VOLUMENAME: + $SIGNAL_VOLUMENAME: + $LETSENCRYPT_VOLUMENAME: diff --git a/infrastructure_files/management.json.tmpl b/infrastructure_files/management.json.tmpl index f3b08101c..19dcff898 100644 --- a/infrastructure_files/management.json.tmpl +++ b/infrastructure_files/management.json.tmpl @@ -21,8 +21,8 @@ "TimeBasedCredentials": false }, "Signal": { - "Proto": "http", - "URI": "$NETBIRD_DOMAIN:10000", + "Proto": "$NETBIRD_SIGNAL_PROTOCOL", + "URI": "$NETBIRD_DOMAIN:$NETBIRD_SIGNAL_PORT", "Username": "", "Password": null }, @@ -43,7 +43,7 @@ "DeviceAuthorizationFlow": { "Provider": "$NETBIRD_AUTH_DEVICE_AUTH_PROVIDER", "ProviderConfig": { - "Audience": "$NETBIRD_AUTH_AUDIENCE", + "Audience": "$NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE", "Domain": "$NETBIRD_AUTH0_DOMAIN", "ClientID": "$NETBIRD_AUTH_DEVICE_AUTH_CLIENT_ID", "TokenEndpoint": "$NETBIRD_AUTH_TOKEN_ENDPOINT", diff --git a/infrastructure_files/setup.env.example b/infrastructure_files/setup.env.example index 09f407225..324174757 100644 --- a/infrastructure_files/setup.env.example +++ b/infrastructure_files/setup.env.example @@ -2,7 +2,11 @@ ## # Dashboard domain. e.g. app.mydomain.com NETBIRD_DOMAIN="" -# OIDC configuration e.g., https://example.eu.auth0.com/.well-known/openid-configuration + +# ------------------------------------------- +# OIDC +# e.g., https://example.eu.auth0.com/.well-known/openid-configuration +# ------------------------------------------- NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT="" NETBIRD_AUTH_AUDIENCE="" # e.g. netbird-client @@ -13,14 +17,27 @@ NETBIRD_AUTH_CLIENT_ID="" NETBIRD_USE_AUTH0="false" NETBIRD_AUTH_DEVICE_AUTH_PROVIDER="none" NETBIRD_AUTH_DEVICE_AUTH_CLIENT_ID="" -# e.g. hello@mydomain.com -NETBIRD_LETSENCRYPT_EMAIL="" +# Some IDPs requires different audience for device authorization flow, you can customize here +NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE=$NETBIRD_AUTH_AUDIENCE + # if your IDP provider doesn't support fragmented URIs, configure custom # redirect and silent redirect URIs, these will be concatenated into your NETBIRD_DOMAIN domain. # NETBIRD_AUTH_REDIRECT_URI="/peers" # NETBIRD_AUTH_SILENT_REDIRECT_URI="/add-peers" +# Updates the preference to use id tokens instead of access token on dashboard +# Okta and Gitlab IDPs can benefit from this +# NETBIRD_TOKEN_SOURCE="idToken" + +# ------------------------------------------- +# Letsencrypt +# ------------------------------------------- +# Disable letsencrypt +# if disabled, cannot use HTTPS anymore and requires setting up a reverse-proxy to do it instead +NETBIRD_DISABLE_LETSENCRYPT=false +# e.g. hello@mydomain.com +NETBIRD_LETSENCRYPT_EMAIL="" # Disable anonymous metrics collection, see more information at https://netbird.io/docs/FAQ/metrics-collection NETBIRD_DISABLE_ANONYMOUS_METRICS=false # DNS DOMAIN configures the domain name used for peer resolution. By default it is netbird.selfhosted -NETBIRD_MGMT_DNS_DOMAIN=netbird.selfhosted +NETBIRD_MGMT_DNS_DOMAIN=netbird.selfhosted \ No newline at end of file diff --git a/infrastructure_files/tests/setup.env b/infrastructure_files/tests/setup.env index cdb5e5c6b..09164a135 100644 --- a/infrastructure_files/tests/setup.env +++ b/infrastructure_files/tests/setup.env @@ -11,4 +11,8 @@ NETBIRD_USE_AUTH0=$CI_NETBIRD_USE_AUTH0 NETBIRD_AUTH_AUDIENCE=$CI_NETBIRD_AUTH_AUDIENCE # e.g. hello@mydomain.com NETBIRD_LETSENCRYPT_EMAIL="" -NETBIRD_AUTH_REDIRECT_URI="/peers" \ No newline at end of file +NETBIRD_AUTH_REDIRECT_URI="/peers" +NETBIRD_DISABLE_LETSENCRYPT=true +NETBIRD_TOKEN_SOURCE="idToken" +NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE="super" +NETBIRD_AUTH_USER_ID_CLAIM="email" \ No newline at end of file diff --git a/management/cmd/management.go b/management/cmd/management.go index f3210d88e..d956fcff5 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -19,25 +19,28 @@ import ( "github.com/google/uuid" "github.com/miekg/dns" - "github.com/netbirdio/netbird/management/server/activity/sqlite" - httpapi "github.com/netbirdio/netbird/management/server/http" - "github.com/netbirdio/netbird/management/server/metrics" - "github.com/netbirdio/netbird/management/server/telemetry" "golang.org/x/crypto/acme/autocert" "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" + "github.com/netbirdio/netbird/management/server/activity/sqlite" + httpapi "github.com/netbirdio/netbird/management/server/http" + "github.com/netbirdio/netbird/management/server/jwtclaims" + "github.com/netbirdio/netbird/management/server/metrics" + "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/util" - "github.com/netbirdio/netbird/encryption" - mgmtProto "github.com/netbirdio/netbird/management/proto" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" "google.golang.org/grpc" "google.golang.org/grpc/credentials" "google.golang.org/grpc/keepalive" + + "github.com/netbirdio/netbird/encryption" + mgmtProto "github.com/netbirdio/netbird/management/proto" ) // ManagementLegacyPort is the port that was used before by the Management gRPC server. @@ -179,13 +182,22 @@ var ( tlsEnabled = true } + jwtValidator, err := jwtclaims.NewJWTValidator( + config.HttpConfig.AuthIssuer, + config.GetAuthAudiences(), + config.HttpConfig.AuthKeysLocation, + ) + if err != nil { + return fmt.Errorf("failed creating JWT validator: %v", err) + } + httpAPIAuthCfg := httpapi.AuthCfg{ Issuer: config.HttpConfig.AuthIssuer, Audience: config.HttpConfig.AuthAudience, UserIDClaim: config.HttpConfig.AuthUserIDClaim, KeysLocation: config.HttpConfig.AuthKeysLocation, } - httpAPIHandler, err := httpapi.APIHandler(accountManager, appMetrics, httpAPIAuthCfg) + httpAPIHandler, err := httpapi.APIHandler(accountManager, *jwtValidator, appMetrics, httpAPIAuthCfg) if err != nil { return fmt.Errorf("failed creating HTTP API handler: %v", err) } @@ -405,6 +417,10 @@ func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) { log.Infof("overriding DeviceAuthorizationFlow.ProviderConfig.Domain with a new value: %s, previously configured value: %s", u.Host, config.DeviceAuthorizationFlow.ProviderConfig.Domain) config.DeviceAuthorizationFlow.ProviderConfig.Domain = u.Host + + if config.DeviceAuthorizationFlow.ProviderConfig.Scope == "" { + config.DeviceAuthorizationFlow.ProviderConfig.Scope = server.DefaultDeviceAuthFlowScope + } } } diff --git a/management/proto/management.pb.go b/management/proto/management.pb.go index 022cc1408..ff2133526 100644 --- a/management/proto/management.pb.go +++ b/management/proto/management.pb.go @@ -1,15 +1,15 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v3.12.4 +// protoc v3.21.9 // source: management.proto package proto import ( - timestamp "github.com/golang/protobuf/ptypes/timestamp" protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" + timestamppb "google.golang.org/protobuf/types/known/timestamppb" reflect "reflect" sync "sync" ) @@ -611,7 +611,7 @@ type ServerKeyResponse struct { // Server's Wireguard public key Key string `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"` // Key expiration timestamp after which the key should be fetched again by the client - ExpiresAt *timestamp.Timestamp `protobuf:"bytes,2,opt,name=expiresAt,proto3" json:"expiresAt,omitempty"` + ExpiresAt *timestamppb.Timestamp `protobuf:"bytes,2,opt,name=expiresAt,proto3" json:"expiresAt,omitempty"` // Version of the Wiretrustee Management Service protocol Version int32 `protobuf:"varint,3,opt,name=version,proto3" json:"version,omitempty"` } @@ -655,7 +655,7 @@ func (x *ServerKeyResponse) GetKey() string { return "" } -func (x *ServerKeyResponse) GetExpiresAt() *timestamp.Timestamp { +func (x *ServerKeyResponse) GetExpiresAt() *timestamppb.Timestamp { if x != nil { return x.ExpiresAt } @@ -1331,6 +1331,10 @@ type ProviderConfig struct { DeviceAuthEndpoint string `protobuf:"bytes,5,opt,name=DeviceAuthEndpoint,proto3" json:"DeviceAuthEndpoint,omitempty"` // TokenEndpoint is an endpoint to request auth token. TokenEndpoint string `protobuf:"bytes,6,opt,name=TokenEndpoint,proto3" json:"TokenEndpoint,omitempty"` + // Scopes provides the scopes to be included in the token request + Scope string `protobuf:"bytes,7,opt,name=Scope,proto3" json:"Scope,omitempty"` + // UseIDToken indicates if the id token should be used for authentication + UseIDToken bool `protobuf:"varint,8,opt,name=UseIDToken,proto3" json:"UseIDToken,omitempty"` } func (x *ProviderConfig) Reset() { @@ -1407,6 +1411,20 @@ func (x *ProviderConfig) GetTokenEndpoint() string { return "" } +func (x *ProviderConfig) GetScope() string { + if x != nil { + return x.Scope + } + return "" +} + +func (x *ProviderConfig) GetUseIDToken() bool { + if x != nil { + return x.UseIDToken + } + return false +} + // Route represents a route.Route object type Route struct { state protoimpl.MessageState @@ -2000,7 +2018,7 @@ var file_management_proto_rawDesc = []byte{ 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0x16, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x0a, 0x0a, 0x06, 0x48, 0x4f, 0x53, 0x54, 0x45, 0x44, - 0x10, 0x00, 0x22, 0xda, 0x01, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, + 0x10, 0x00, 0x22, 0x90, 0x02, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44, 0x12, 0x22, 0x0a, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, @@ -2013,81 +2031,84 @@ var file_management_proto_rawDesc = []byte{ 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x24, 0x0a, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x22, - 0xb5, 0x01, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x44, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x4e, 0x65, 0x74, - 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x4e, 0x65, 0x74, 0x77, - 0x6f, 0x72, 0x6b, 0x12, 0x20, 0x0a, 0x0b, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79, - 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, - 0x6b, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x65, 0x65, 0x72, 0x18, 0x04, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x65, 0x65, 0x72, 0x12, 0x16, 0x0a, 0x06, 0x4d, 0x65, 0x74, - 0x72, 0x69, 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4d, 0x65, 0x74, 0x72, 0x69, - 0x63, 0x12, 0x1e, 0x0a, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, 0x61, 0x64, 0x65, 0x18, - 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, 0x61, 0x64, - 0x65, 0x12, 0x14, 0x0a, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x22, 0xb4, 0x01, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, - 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x24, 0x0a, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, - 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x53, 0x65, - 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x47, 0x0a, 0x10, 0x4e, - 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18, - 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, - 0x75, 0x70, 0x52, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, - 0x6f, 0x75, 0x70, 0x73, 0x12, 0x38, 0x0a, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, - 0x6e, 0x65, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, - 0x65, 0x52, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x22, 0x58, - 0x0a, 0x0a, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x12, 0x16, 0x0a, 0x06, - 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, - 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x32, 0x0a, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x18, - 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, - 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x22, 0x74, 0x0a, 0x0c, 0x53, 0x69, 0x6d, 0x70, - 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, - 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x54, 0x79, 0x70, 0x65, - 0x12, 0x14, 0x0a, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x54, 0x54, 0x4c, 0x18, 0x04, 0x20, - 0x01, 0x28, 0x03, 0x52, 0x03, 0x54, 0x54, 0x4c, 0x12, 0x14, 0x0a, 0x05, 0x52, 0x44, 0x61, 0x74, - 0x61, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x22, 0x7f, - 0x0a, 0x0f, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, - 0x70, 0x12, 0x38, 0x0a, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, - 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x52, 0x0b, - 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x50, - 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x50, 0x72, - 0x69, 0x6d, 0x61, 0x72, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, - 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x22, - 0x48, 0x0a, 0x0a, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x0e, 0x0a, - 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, - 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4e, - 0x53, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, - 0x01, 0x28, 0x03, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x32, 0xf7, 0x02, 0x0a, 0x11, 0x4d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, - 0x45, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, - 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, - 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x46, 0x0a, 0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, + 0x52, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, + 0x14, 0x0a, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, + 0x53, 0x63, 0x6f, 0x70, 0x65, 0x12, 0x1e, 0x0a, 0x0a, 0x55, 0x73, 0x65, 0x49, 0x44, 0x54, 0x6f, + 0x6b, 0x65, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x55, 0x73, 0x65, 0x49, 0x44, + 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x22, 0xb5, 0x01, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, + 0x0e, 0x0a, 0x02, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, + 0x18, 0x0a, 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x20, 0x0a, 0x0b, 0x4e, 0x65, 0x74, + 0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, + 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, + 0x65, 0x65, 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x65, 0x65, 0x72, 0x12, + 0x16, 0x0a, 0x06, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, + 0x06, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x12, 0x1e, 0x0a, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, + 0x65, 0x72, 0x61, 0x64, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x4d, 0x61, 0x73, + 0x71, 0x75, 0x65, 0x72, 0x61, 0x64, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, + 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x22, 0xb4, 0x01, + 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x24, 0x0a, 0x0d, 0x53, + 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x08, 0x52, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, + 0x65, 0x12, 0x47, 0x0a, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, + 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, + 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x52, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, + 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x12, 0x38, 0x0a, 0x0b, 0x43, 0x75, + 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, + 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x75, 0x73, + 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x52, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, + 0x6f, 0x6e, 0x65, 0x73, 0x22, 0x58, 0x0a, 0x0a, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, + 0x6e, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x32, 0x0a, 0x07, 0x52, 0x65, + 0x63, 0x6f, 0x72, 0x64, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, + 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x22, 0x74, + 0x0a, 0x0c, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x12, 0x12, + 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, + 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, + 0x52, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, + 0x54, 0x54, 0x4c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x03, 0x54, 0x54, 0x4c, 0x12, 0x14, + 0x0a, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x52, + 0x44, 0x61, 0x74, 0x61, 0x22, 0x7f, 0x0a, 0x0f, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, + 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x38, 0x0a, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, + 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, + 0x72, 0x76, 0x65, 0x72, 0x52, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, + 0x73, 0x12, 0x18, 0x0a, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x08, 0x52, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x44, + 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, + 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x22, 0x48, 0x0a, 0x0a, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, + 0x76, 0x65, 0x72, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x03, 0x52, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, + 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x32, + 0xf7, 0x02, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x65, + 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, - 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, - 0x0a, 0x0c, 0x47, 0x65, 0x74, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, - 0x79, 0x1a, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, - 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x22, 0x00, 0x12, 0x33, 0x0a, 0x09, 0x69, 0x73, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, - 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, - 0x74, 0x79, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x5a, 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, - 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, - 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, - 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, - 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, - 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x46, 0x0a, 0x04, + 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, + 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, + 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, 0x0c, 0x47, 0x65, 0x74, 0x53, 0x65, 0x72, 0x76, 0x65, + 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x52, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x09, 0x69, 0x73, 0x48, 0x65, + 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x5a, 0x0a, + 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, + 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, + 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, + 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -2132,7 +2153,7 @@ var file_management_proto_goTypes = []interface{}{ (*SimpleRecord)(nil), // 24: management.SimpleRecord (*NameServerGroup)(nil), // 25: management.NameServerGroup (*NameServer)(nil), // 26: management.NameServer - (*timestamp.Timestamp)(nil), // 27: google.protobuf.Timestamp + (*timestamppb.Timestamp)(nil), // 27: google.protobuf.Timestamp } var file_management_proto_depIdxs = []int32{ 11, // 0: management.SyncResponse.wiretrusteeConfig:type_name -> management.WiretrusteeConfig diff --git a/management/proto/management.proto b/management/proto/management.proto index 2c3c18c97..5447a9ee6 100644 --- a/management/proto/management.proto +++ b/management/proto/management.proto @@ -246,6 +246,10 @@ message ProviderConfig { string DeviceAuthEndpoint = 5; // TokenEndpoint is an endpoint to request auth token. string TokenEndpoint = 6; + // Scopes provides the scopes to be included in the token request + string Scope = 7; + // UseIDToken indicates if the id token should be used for authentication + bool UseIDToken = 8; } // Route represents a route.Route object diff --git a/management/server/account.go b/management/server/account.go index 01cae2e64..78c9237b8 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -3,6 +3,7 @@ package server import ( "context" "crypto/sha256" + b64 "encoding/base64" "fmt" "hash/crc32" "math/rand" @@ -54,7 +55,8 @@ type AccountManager interface { GetSetupKey(accountID, userID, keyID string) (*SetupKey, error) GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error) - GetAccountFromPAT(pat string) (*Account, *User, error) + GetAccountFromPAT(pat string) (*Account, *User, *PersonalAccessToken, error) + MarkPATUsed(tokenID string) error IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) AccountExists(accountId string) (*bool, error) GetPeerByKey(peerKey string) (*Peer, error) @@ -66,8 +68,10 @@ type AccountManager interface { GetNetworkMap(peerID string) (*NetworkMap, error) GetPeerNetwork(peerID string) (*Network, error) AddPeer(setupKey, userID string, peer *Peer) (*Peer, *NetworkMap, error) - AddPATToUser(accountID string, userID string, pat *PersonalAccessToken) error - DeletePAT(accountID string, userID string, tokenID string) error + CreatePAT(accountID string, executingUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) + DeletePAT(accountID string, executingUserID string, targetUserID string, tokenID string) error + GetPAT(accountID string, executingUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) + GetAllPATs(accountID string, executingUserID string, targetUserID string) ([]*PersonalAccessToken, error) UpdatePeerSSHKey(peerID string, sshKey string) error GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error) GetGroup(accountId, groupID string) (*Group, error) @@ -1119,45 +1123,85 @@ func (am *DefaultAccountManager) redeemInvite(account *Account, userID string) e return nil } +// MarkPATUsed marks a personal access token as used +func (am *DefaultAccountManager) MarkPATUsed(tokenID string) error { + unlock := am.Store.AcquireGlobalLock() + + user, err := am.Store.GetUserByTokenID(tokenID) + if err != nil { + return err + } + + account, err := am.Store.GetAccountByUser(user.Id) + if err != nil { + return err + } + + unlock() + unlock = am.Store.AcquireAccountLock(account.Id) + defer unlock() + + account, err = am.Store.GetAccountByUser(user.Id) + if err != nil { + return err + } + + pat, ok := account.Users[user.Id].PATs[tokenID] + if !ok { + return fmt.Errorf("token not found") + } + + pat.LastUsed = time.Now() + + return am.Store.SaveAccount(account) +} + // GetAccountFromPAT returns Account and User associated with a personal access token -func (am *DefaultAccountManager) GetAccountFromPAT(token string) (*Account, *User, error) { +func (am *DefaultAccountManager) GetAccountFromPAT(token string) (*Account, *User, *PersonalAccessToken, error) { if len(token) != PATLength { - return nil, nil, fmt.Errorf("token has wrong length") + return nil, nil, nil, fmt.Errorf("token has wrong length") } prefix := token[:len(PATPrefix)] if prefix != PATPrefix { - return nil, nil, fmt.Errorf("token has wrong prefix") + return nil, nil, nil, fmt.Errorf("token has wrong prefix") } secret := token[len(PATPrefix) : len(PATPrefix)+PATSecretLength] encodedChecksum := token[len(PATPrefix)+PATSecretLength : len(PATPrefix)+PATSecretLength+PATChecksumLength] verificationChecksum, err := base62.Decode(encodedChecksum) if err != nil { - return nil, nil, fmt.Errorf("token checksum decoding failed: %w", err) + return nil, nil, nil, fmt.Errorf("token checksum decoding failed: %w", err) } secretChecksum := crc32.ChecksumIEEE([]byte(secret)) if secretChecksum != verificationChecksum { - return nil, nil, fmt.Errorf("token checksum does not match") + return nil, nil, nil, fmt.Errorf("token checksum does not match") } hashedToken := sha256.Sum256([]byte(token)) - tokenID, err := am.Store.GetTokenIDByHashedToken(string(hashedToken[:])) + encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:]) + tokenID, err := am.Store.GetTokenIDByHashedToken(encodedHashedToken) if err != nil { - return nil, nil, err + return nil, nil, nil, err } user, err := am.Store.GetUserByTokenID(tokenID) if err != nil { - return nil, nil, err + return nil, nil, nil, err } account, err := am.Store.GetAccountByUser(user.Id) if err != nil { - return nil, nil, err + return nil, nil, nil, err } - return account, user, nil + + pat := user.PATs[tokenID] + if pat == nil { + return nil, nil, nil, fmt.Errorf("personal access token not found") + } + + return account, user, pat, nil } // GetAccountFromToken returns an account associated with this token diff --git a/management/server/account_test.go b/management/server/account_test.go index af894817b..f21c93f0e 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -2,6 +2,7 @@ package server import ( "crypto/sha256" + b64 "encoding/base64" "fmt" "net" "reflect" @@ -465,12 +466,13 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) { token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W" hashedToken := sha256.Sum256([]byte(token)) + encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:]) account.Users["someUser"] = &User{ Id: "someUser", PATs: map[string]*PersonalAccessToken{ - "pat1": { + "tokenId": { ID: "tokenId", - HashedToken: string(hashedToken[:]), + HashedToken: encodedHashedToken, }, }, } @@ -483,13 +485,52 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) { Store: store, } - account, user, err := am.GetAccountFromPAT(token) + account, user, pat, err := am.GetAccountFromPAT(token) if err != nil { t.Fatalf("Error when getting Account from PAT: %s", err) } assert.Equal(t, "account_id", account.Id) assert.Equal(t, "someUser", user.Id) + assert.Equal(t, account.Users["someUser"].PATs["tokenId"], pat) +} + +func TestDefaultAccountManager_MarkPATUsed(t *testing.T) { + store := newStore(t) + account := newAccountWithId("account_id", "testuser", "") + + token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W" + hashedToken := sha256.Sum256([]byte(token)) + encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:]) + account.Users["someUser"] = &User{ + Id: "someUser", + PATs: map[string]*PersonalAccessToken{ + "tokenId": { + ID: "tokenId", + HashedToken: encodedHashedToken, + LastUsed: time.Time{}, + }, + }, + } + err := store.SaveAccount(account) + if err != nil { + t.Fatalf("Error when saving account: %s", err) + } + + am := DefaultAccountManager{ + Store: store, + } + + err = am.MarkPATUsed("tokenId") + if err != nil { + t.Fatalf("Error when marking PAT used: %s", err) + } + + account, err = am.Store.GetAccount("account_id") + if err != nil { + t.Fatalf("Error when getting account: %s", err) + } + assert.True(t, !account.Users["someUser"].PATs["tokenId"].LastUsed.IsZero()) } func TestAccountManager_PrivateAccount(t *testing.T) { @@ -1245,7 +1286,7 @@ func TestAccount_Copy(t *testing.T) { PATs: map[string]*PersonalAccessToken{ "pat1": { ID: "pat1", - Description: "First PAT", + Name: "First PAT", HashedToken: "SoMeHaShEdToKeN", ExpirationDate: time.Now().AddDate(0, 0, 7), CreatedBy: "user1", diff --git a/management/server/activity/codes.go b/management/server/activity/codes.go index a4a46439d..dacac4129 100644 --- a/management/server/activity/codes.go +++ b/management/server/activity/codes.go @@ -83,6 +83,10 @@ const ( AccountPeerLoginExpirationDisabled // AccountPeerLoginExpirationDurationUpdated indicates that a user updated peer login expiration duration for the account AccountPeerLoginExpirationDurationUpdated + // PersonalAccessTokenCreated indicates that a user created a personal access token + PersonalAccessTokenCreated + // PersonalAccessTokenDeleted indicates that a user deleted a personal access token + PersonalAccessTokenDeleted ) const ( @@ -168,6 +172,10 @@ const ( AccountPeerLoginExpirationDisabledMessage string = "Peer login expiration disabled for the account" // AccountPeerLoginExpirationDurationUpdatedMessage is a human-readable text message of the AccountPeerLoginExpirationDurationUpdated activity AccountPeerLoginExpirationDurationUpdatedMessage string = "Peer login expiration duration updated" + // PersonalAccessTokenCreatedMessage is a human-readable text message of the PersonalAccessTokenCreated activity + PersonalAccessTokenCreatedMessage string = "Personal access token created" + // PersonalAccessTokenDeletedMessage is a human-readable text message of the PersonalAccessTokenDeleted activity + PersonalAccessTokenDeletedMessage string = "Personal access token deleted" ) // Activity that triggered an Event @@ -258,6 +266,10 @@ func (a Activity) Message() string { return AccountPeerLoginExpirationDisabledMessage case AccountPeerLoginExpirationDurationUpdated: return AccountPeerLoginExpirationDurationUpdatedMessage + case PersonalAccessTokenCreated: + return PersonalAccessTokenCreatedMessage + case PersonalAccessTokenDeleted: + return PersonalAccessTokenDeletedMessage default: return "UNKNOWN_ACTIVITY" } @@ -348,6 +360,10 @@ func (a Activity) StringCode() string { return "account.setting.peer.login.expiration.enable" case AccountPeerLoginExpirationDisabled: return "account.setting.peer.login.expiration.disable" + case PersonalAccessTokenCreated: + return "personal.access.token.create" + case PersonalAccessTokenDeleted: + return "personal.access.token.delete" default: return "UNKNOWN_ACTIVITY" } diff --git a/management/server/config.go b/management/server/config.go index 6a428c83b..9ec16b3e8 100644 --- a/management/server/config.go +++ b/management/server/config.go @@ -24,6 +24,11 @@ const ( NONE Provider = "none" ) +const ( + // DefaultDeviceAuthFlowScope defines the bare minimum scope to request in the device authorization flow + DefaultDeviceAuthFlowScope string = "openid" +) + // Config of the Management service type Config struct { Stuns []*Host @@ -39,6 +44,17 @@ type Config struct { DeviceAuthorizationFlow *DeviceAuthorizationFlow } +// GetAuthAudiences returns the audience from the http config and device authorization flow config +func (c Config) GetAuthAudiences() []string { + audiences := []string{c.HttpConfig.AuthAudience} + + if c.DeviceAuthorizationFlow != nil && c.DeviceAuthorizationFlow.ProviderConfig.Audience != "" { + audiences = append(audiences, c.DeviceAuthorizationFlow.ProviderConfig.Audience) + } + + return audiences +} + // TURNConfig is a config of the TURNCredentialsManager type TURNConfig struct { TimeBasedCredentials bool @@ -98,6 +114,10 @@ type ProviderConfig struct { TokenEndpoint string // DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code DeviceAuthEndpoint string + // Scopes provides the scopes to be included in the token request + Scope string + // UseIDToken indicates if the id token should be used for authentication + UseIDToken bool } // validateURL validates input http url diff --git a/management/server/file_store.go b/management/server/file_store.go index 4f8092cfb..f79179841 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -112,7 +112,7 @@ func restore(file string) (*FileStore, error) { store.UserID2AccountID[user.Id] = accountID for _, pat := range user.PATs { store.TokenID2UserID[pat.ID] = user.Id - store.HashedPAT2TokenID[pat.HashedToken[:]] = pat.ID + store.HashedPAT2TokenID[pat.HashedToken] = pat.ID } } @@ -121,15 +121,25 @@ func restore(file string) (*FileStore, error) { store.PrivateDomain2AccountID[account.Domain] = accountID } - // if no policies are defined, that means we need to migrate Rules to policies - if len(account.Policies) == 0 { + // TODO: policy query generated from the Go template and rule object. + // We need to refactor this part to avoid using templating for policies queries building + // and drop this migration part. + policies := make(map[string]int, len(account.Policies)) + for i, policy := range account.Policies { + policies[policy.ID] = i + } + if account.Policies == nil { account.Policies = make([]*Policy, 0) - for _, rule := range account.Rules { - policy, err := RuleToPolicy(rule) - if err != nil { - log.Errorf("unable to migrate rule to policy: %v", err) - continue - } + } + for _, rule := range account.Rules { + policy, err := RuleToPolicy(rule) + if err != nil { + log.Errorf("unable to migrate rule to policy: %v", err) + continue + } + if i, ok := policies[policy.ID]; ok { + account.Policies[i] = policy + } else { account.Policies = append(account.Policies, policy) } } @@ -268,7 +278,7 @@ func (s *FileStore) SaveAccount(account *Account) error { s.UserID2AccountID[user.Id] = accountCopy.Id for _, pat := range user.PATs { s.TokenID2UserID[pat.ID] = user.Id - s.HashedPAT2TokenID[pat.HashedToken[:]] = pat.ID + s.HashedPAT2TokenID[pat.HashedToken] = pat.ID } } diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 45be9815c..f63a55d65 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -6,11 +6,10 @@ import ( "strings" "time" - pb "github.com/golang/protobuf/proto" //nolint + pb "github.com/golang/protobuf/proto" // nolint "github.com/netbirdio/netbird/management/server/telemetry" - "github.com/netbirdio/netbird/management/server/http/middleware" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/golang/protobuf/ptypes/timestamp" @@ -33,7 +32,7 @@ type GRPCServer struct { peersUpdateManager *PeersUpdateManager config *Config turnCredentialsManager TURNCredentialsManager - jwtMiddleware *middleware.JWTMiddleware + jwtValidator *jwtclaims.JWTValidator jwtClaimsExtractor *jwtclaims.ClaimsExtractor appMetrics telemetry.AppMetrics } @@ -47,12 +46,12 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager return nil, err } - var jwtMiddleware *middleware.JWTMiddleware + var jwtValidator *jwtclaims.JWTValidator if config.HttpConfig != nil && config.HttpConfig.AuthIssuer != "" && config.HttpConfig.AuthAudience != "" && validateURL(config.HttpConfig.AuthKeysLocation) { - jwtMiddleware, err = middleware.NewJwtMiddleware( + jwtValidator, err = jwtclaims.NewJWTValidator( config.HttpConfig.AuthIssuer, - config.HttpConfig.AuthAudience, + config.GetAuthAudiences(), config.HttpConfig.AuthKeysLocation) if err != nil { return nil, status.Errorf(codes.Internal, "unable to create new jwt middleware, err: %v", err) @@ -88,7 +87,7 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager accountManager: accountManager, config: config, turnCredentialsManager: turnCredentialsManager, - jwtMiddleware: jwtMiddleware, + jwtValidator: jwtValidator, jwtClaimsExtractor: jwtClaimsExtractor, appMetrics: appMetrics, }, nil @@ -189,11 +188,11 @@ func (s *GRPCServer) cancelPeerRoutines(peer *Peer) { } func (s *GRPCServer) validateToken(jwtToken string) (string, error) { - if s.jwtMiddleware == nil { - return "", status.Error(codes.Internal, "no jwt middleware set") + if s.jwtValidator == nil { + return "", status.Error(codes.Internal, "no jwt validator set") } - token, err := s.jwtMiddleware.ValidateAndParse(jwtToken) + token, err := s.jwtValidator.ValidateAndParse(jwtToken) if err != nil { return "", status.Errorf(codes.InvalidArgument, "invalid jwt token, err: %v", err) } @@ -511,6 +510,8 @@ func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto. Audience: s.config.DeviceAuthorizationFlow.ProviderConfig.Audience, DeviceAuthEndpoint: s.config.DeviceAuthorizationFlow.ProviderConfig.DeviceAuthEndpoint, TokenEndpoint: s.config.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint, + Scope: s.config.DeviceAuthorizationFlow.ProviderConfig.Scope, + UseIDToken: s.config.DeviceAuthorizationFlow.ProviderConfig.UseIDToken, }, } diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index b3d954a4d..eaeb5693c 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -6,6 +6,8 @@ info: tags: - name: Users description: Interact with and view information about users. + - name: Tokens + description: Interact with and view information about tokens. - name: Peers description: Interact with and view information about peers. - name: Setup Keys @@ -284,6 +286,59 @@ components: - revoked - auto_groups - usage_limit + PersonalAccessToken: + type: object + properties: + id: + description: ID of a token + type: string + name: + description: Name of the token + type: string + expiration_date: + description: Date the token expires + type: string + format: date-time + created_by: + description: User ID of the user who created the token + type: string + created_at: + description: Date the token was created + type: string + format: date-time + last_used: + description: Date the token was last used + type: string + format: date-time + required: + - id + - name + - expiration_date + - created_by + - created_at + PersonalAccessTokenGenerated: + type: object + properties: + plain_token: + description: Plain text representation of the generated token + type: string + personal_access_token: + $ref: '#/components/schemas/PersonalAccessToken' + required: + - plain_token + - personal_access_token + PersonalAccessTokenRequest: + type: object + properties: + name: + description: Name of the token + type: string + expires_in: + description: Expiration in days + type: integer + required: + - name + - expires_in GroupMinimum: type: object properties: @@ -848,6 +903,133 @@ paths: "$ref": "#/components/responses/forbidden" '500': "$ref": "#/components/responses/internal_error" + /api/users/{userId}/tokens: + get: + summary: Returns a list of all tokens for a user + tags: [ Tokens ] + security: + - BearerAuth: [] + parameters: + - in: path + name: userId + required: true + schema: + type: string + description: The User ID + responses: + '200': + description: A JSON Array of PersonalAccessTokens + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/PersonalAccessToken' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + post: + summary: Create a new token + tags: [ Tokens ] + security: + - BearerAuth: [ ] + parameters: + - in: path + name: userId + required: true + schema: + type: string + description: The User ID + requestBody: + description: PersonalAccessToken create parameters + content: + application/json: + schema: + $ref: '#/components/schemas/PersonalAccessTokenRequest' + responses: + '200': + description: The token in plain text + content: + application/json: + schema: + $ref: '#/components/schemas/PersonalAccessTokenGenerated' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + /api/users/{userId}/tokens/{tokenId}: + get: + summary: Returns a specific token + tags: [ Tokens ] + security: + - BearerAuth: [ ] + parameters: + - in: path + name: userId + required: true + schema: + type: string + description: The User ID + - in: path + name: tokenId + required: true + schema: + type: string + description: The Token ID + responses: + '200': + description: A PersonalAccessTokens Object + content: + application/json: + schema: + $ref: '#/components/schemas/PersonalAccessToken' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + delete: + summary: Delete a token + tags: [ Tokens ] + security: + - BearerAuth: [ ] + parameters: + - in: path + name: userId + required: true + schema: + type: string + description: The User ID + - in: path + name: tokenId + required: true + schema: + type: string + description: The Token ID + responses: + '200': + description: Delete status code + content: { } + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" /api/peers: get: summary: Returns a list of all peers diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index 372ecd1a7..930a9df54 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -379,6 +379,44 @@ type PeerMinimum struct { Name string `json:"name"` } +// PersonalAccessToken defines model for PersonalAccessToken. +type PersonalAccessToken struct { + // CreatedAt Date the token was created + CreatedAt time.Time `json:"created_at"` + + // CreatedBy User ID of the user who created the token + CreatedBy string `json:"created_by"` + + // ExpirationDate Date the token expires + ExpirationDate time.Time `json:"expiration_date"` + + // Id ID of a token + Id string `json:"id"` + + // LastUsed Date the token was last used + LastUsed *time.Time `json:"last_used,omitempty"` + + // Name Name of the token + Name string `json:"name"` +} + +// PersonalAccessTokenGenerated defines model for PersonalAccessTokenGenerated. +type PersonalAccessTokenGenerated struct { + PersonalAccessToken PersonalAccessToken `json:"personal_access_token"` + + // PlainToken Plain text representation of the generated token + PlainToken string `json:"plain_token"` +} + +// PersonalAccessTokenRequest defines model for PersonalAccessTokenRequest. +type PersonalAccessTokenRequest struct { + // ExpiresIn Expiration in days + ExpiresIn int `json:"expires_in"` + + // Name Name of the token + Name string `json:"name"` +} + // Policy defines model for Policy. type Policy struct { // Description Policy friendly description @@ -808,3 +846,6 @@ type PostApiUsersJSONRequestBody = UserCreateRequest // PutApiUsersIdJSONRequestBody defines body for PutApiUsersId for application/json ContentType. type PutApiUsersIdJSONRequestBody = UserRequest + +// PostApiUsersUserIdTokensJSONRequestBody defines body for PostApiUsersUserIdTokens for application/json ContentType. +type PostApiUsersUserIdTokensJSONRequestBody = PersonalAccessTokenRequest diff --git a/management/server/http/groups_handler.go b/management/server/http/groups_handler.go index 9712d2e75..2464f47ef 100644 --- a/management/server/http/groups_handler.go +++ b/management/server/http/groups_handler.go @@ -300,7 +300,7 @@ func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) { return } - util.WriteJSONObject(w, "") + util.WriteJSONObject(w, emptyObject{}) } // GetGroup returns a group diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 90f62e700..d8117a436 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -8,6 +8,7 @@ import ( s "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/middleware" + "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/telemetry" ) @@ -25,15 +26,17 @@ type apiHandler struct { AuthCfg AuthCfg } +// EmptyObject is an empty struct used to return empty JSON object +type emptyObject struct { +} + // APIHandler creates the Management service HTTP API handler registering all the available endpoints. -func APIHandler(accountManager s.AccountManager, appMetrics telemetry.AppMetrics, authCfg AuthCfg) (http.Handler, error) { - jwtMiddleware, err := middleware.NewJwtMiddleware( - authCfg.Issuer, - authCfg.Audience, - authCfg.KeysLocation) - if err != nil { - return nil, err - } +func APIHandler(accountManager s.AccountManager, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg) (http.Handler, error) { + authMiddleware := middleware.NewAuthMiddleware( + accountManager.GetAccountFromPAT, + jwtValidator.ValidateAndParse, + accountManager.MarkPATUsed, + authCfg.Audience) corsMiddleware := cors.AllowAll() @@ -46,7 +49,7 @@ func APIHandler(accountManager s.AccountManager, appMetrics telemetry.AppMetrics metricsMiddleware := appMetrics.HTTPMiddleware() router := rootRouter.PathPrefix("/api").Subrouter() - router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, jwtMiddleware.Handler, acMiddleware.Handler) + router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware.Handler, acMiddleware.Handler) api := apiHandler{ Router: router, @@ -57,6 +60,7 @@ func APIHandler(accountManager s.AccountManager, appMetrics telemetry.AppMetrics api.addAccountsEndpoint() api.addPeersEndpoint() api.addUsersEndpoint() + api.addUsersTokensEndpoint() api.addSetupKeysEndpoint() api.addRulesEndpoint() api.addPoliciesEndpoint() @@ -66,7 +70,7 @@ func APIHandler(accountManager s.AccountManager, appMetrics telemetry.AppMetrics api.addDNSSettingEndpoint() api.addEventsEndpoint() - err = api.Router.Walk(func(route *mux.Route, _ *mux.Router, _ []*mux.Route) error { + err := api.Router.Walk(func(route *mux.Route, _ *mux.Router, _ []*mux.Route) error { methods, err := route.GetMethods() if err != nil { return err @@ -110,6 +114,14 @@ func (apiHandler *apiHandler) addUsersEndpoint() { apiHandler.Router.HandleFunc("/users", userHandler.CreateUser).Methods("POST", "OPTIONS") } +func (apiHandler *apiHandler) addUsersTokensEndpoint() { + tokenHandler := NewPATsHandler(apiHandler.AccountManager, apiHandler.AuthCfg) + apiHandler.Router.HandleFunc("/users/{userId}/tokens", tokenHandler.GetAllTokens).Methods("GET", "OPTIONS") + apiHandler.Router.HandleFunc("/users/{userId}/tokens", tokenHandler.CreateToken).Methods("POST", "OPTIONS") + apiHandler.Router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.GetToken).Methods("GET", "OPTIONS") + apiHandler.Router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.DeleteToken).Methods("DELETE", "OPTIONS") +} + func (apiHandler *apiHandler) addSetupKeysEndpoint() { keysHandler := NewSetupKeysHandler(apiHandler.AccountManager, apiHandler.AuthCfg) apiHandler.Router.HandleFunc("/setup-keys", keysHandler.GetAllSetupKeys).Methods("GET", "OPTIONS") diff --git a/management/server/http/middleware/access_control.go b/management/server/http/middleware/access_control.go index 5e56f75ab..5f8389dfa 100644 --- a/management/server/http/middleware/access_control.go +++ b/management/server/http/middleware/access_control.go @@ -2,6 +2,9 @@ package middleware import ( "net/http" + "regexp" + + log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/status" @@ -39,10 +42,22 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler { util.WriteError(status.Errorf(status.Unauthorized, "invalid JWT"), w) return } - if !ok { switch r.Method { case http.MethodDelete, http.MethodPost, http.MethodPatch, http.MethodPut: + + ok, err := regexp.MatchString(`^.*/api/users/.*/tokens.*$`, r.URL.Path) + if err != nil { + log.Debugf("Regex failed") + util.WriteError(status.Errorf(status.Internal, ""), w) + return + } + if ok { + log.Debugf("Valid Path") + h.ServeHTTP(w, r) + return + } + util.WriteError(status.Errorf(status.PermissionDenied, "only admin can perform this operation"), w) return } diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go new file mode 100644 index 000000000..a8c81012a --- /dev/null +++ b/management/server/http/middleware/auth_middleware.go @@ -0,0 +1,173 @@ +package middleware + +import ( + "context" + "errors" + "fmt" + "net/http" + "strings" + "time" + + "github.com/golang-jwt/jwt" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/http/util" + "github.com/netbirdio/netbird/management/server/jwtclaims" + "github.com/netbirdio/netbird/management/server/status" +) + +// GetAccountFromPATFunc function +type GetAccountFromPATFunc func(token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) + +// ValidateAndParseTokenFunc function +type ValidateAndParseTokenFunc func(token string) (*jwt.Token, error) + +// MarkPATUsedFunc function +type MarkPATUsedFunc func(token string) error + +// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens +type AuthMiddleware struct { + getAccountFromPAT GetAccountFromPATFunc + validateAndParseToken ValidateAndParseTokenFunc + markPATUsed MarkPATUsedFunc + audience string +} + +const ( + userProperty = "user" +) + +// NewAuthMiddleware instance constructor +func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc, markPATUsed MarkPATUsedFunc, audience string) *AuthMiddleware { + return &AuthMiddleware{ + getAccountFromPAT: getAccountFromPAT, + validateAndParseToken: validateAndParseToken, + markPATUsed: markPATUsed, + audience: audience, + } +} + +// Handler method of the middleware which authenticates a user either by JWT claims or by PAT +func (m *AuthMiddleware) Handler(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + auth := strings.Split(r.Header.Get("Authorization"), " ") + authType := auth[0] + switch strings.ToLower(authType) { + case "bearer": + err := m.CheckJWTFromRequest(w, r) + if err != nil { + log.Debugf("Error when validating JWT claims: %s", err.Error()) + util.WriteError(status.Errorf(status.Unauthorized, "token invalid"), w) + return + } + h.ServeHTTP(w, r) + case "token": + err := m.CheckPATFromRequest(w, r) + if err != nil { + log.Debugf("Error when validating PAT claims: %s", err.Error()) + util.WriteError(status.Errorf(status.Unauthorized, "token invalid"), w) + return + } + h.ServeHTTP(w, r) + default: + util.WriteError(status.Errorf(status.Unauthorized, "no valid authentication provided"), w) + return + } + }) +} + +// CheckJWTFromRequest checks if the JWT is valid +func (m *AuthMiddleware) CheckJWTFromRequest(w http.ResponseWriter, r *http.Request) error { + + token, err := getTokenFromJWTRequest(r) + + // If an error occurs, call the error handler and return an error + if err != nil { + return fmt.Errorf("Error extracting token: %w", err) + } + + validatedToken, err := m.validateAndParseToken(token) + if err != nil { + return err + } + + if validatedToken == nil { + return nil + } + + // If we get here, everything worked and we can set the + // user property in context. + newRequest := r.WithContext(context.WithValue(r.Context(), userProperty, validatedToken)) //nolint + // Update the current request with the new context information. + *r = *newRequest + return nil +} + +// CheckPATFromRequest checks if the PAT is valid +func (m *AuthMiddleware) CheckPATFromRequest(w http.ResponseWriter, r *http.Request) error { + token, err := getTokenFromPATRequest(r) + + // If an error occurs, call the error handler and return an error + if err != nil { + return fmt.Errorf("Error extracting token: %w", err) + } + + account, user, pat, err := m.getAccountFromPAT(token) + if err != nil { + return fmt.Errorf("invalid Token: %w", err) + } + if time.Now().After(pat.ExpirationDate) { + return fmt.Errorf("token expired") + } + + err = m.markPATUsed(pat.ID) + if err != nil { + return err + } + + claimMaps := jwt.MapClaims{} + claimMaps[jwtclaims.UserIDClaim] = user.Id + claimMaps[m.audience+jwtclaims.AccountIDSuffix] = account.Id + claimMaps[m.audience+jwtclaims.DomainIDSuffix] = account.Domain + claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = account.DomainCategory + jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps) + newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) //nolint + // Update the current request with the new context information. + *r = *newRequest + return nil +} + +// getTokenFromJWTRequest is a "TokenExtractor" that takes a give request and extracts +// the JWT token from the Authorization header. +func getTokenFromJWTRequest(r *http.Request) (string, error) { + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + return "", nil // No error, just no token + } + + // TODO: Make this a bit more robust, parsing-wise + authHeaderParts := strings.Fields(authHeader) + if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" { + return "", errors.New("Authorization header format must be Bearer {token}") + } + + return authHeaderParts[1], nil +} + +// getTokenFromPATRequest is a "TokenExtractor" that takes a give request and extracts +// the PAT token from the Authorization header. +func getTokenFromPATRequest(r *http.Request) (string, error) { + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + return "", nil // No error, just no token + } + + // TODO: Make this a bit more robust, parsing-wise + authHeaderParts := strings.Fields(authHeader) + if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "token" { + return "", errors.New("Authorization header format must be Token {token}") + } + + return authHeaderParts[1], nil +} diff --git a/management/server/http/middleware/auth_middleware_test.go b/management/server/http/middleware/auth_middleware_test.go new file mode 100644 index 000000000..5a5558fa5 --- /dev/null +++ b/management/server/http/middleware/auth_middleware_test.go @@ -0,0 +1,123 @@ +package middleware + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/golang-jwt/jwt" + + "github.com/netbirdio/netbird/management/server" +) + +const ( + audience = "audience" + accountID = "accountID" + domain = "domain" + userID = "userID" + tokenID = "tokenID" + PAT = "PAT" + JWT = "JWT" + wrongToken = "wrongToken" +) + +var testAccount = &server.Account{ + Id: accountID, + Domain: domain, + Users: map[string]*server.User{ + userID: { + Id: userID, + PATs: map[string]*server.PersonalAccessToken{ + tokenID: { + ID: tokenID, + Name: "My first token", + HashedToken: "someHash", + ExpirationDate: time.Now().AddDate(0, 0, 7), + CreatedBy: userID, + CreatedAt: time.Now(), + LastUsed: time.Now(), + }, + }, + }, + }, +} + +func mockGetAccountFromPAT(token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) { + if token == PAT { + return testAccount, testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], nil + } + return nil, nil, nil, fmt.Errorf("PAT invalid") +} + +func mockValidateAndParseToken(token string) (*jwt.Token, error) { + if token == JWT { + return &jwt.Token{}, nil + } + return nil, fmt.Errorf("JWT invalid") +} + +func mockMarkPATUsed(token string) error { + if token == tokenID { + return nil + } + return fmt.Errorf("Should never get reached") +} + +func TestAuthMiddleware_Handler(t *testing.T) { + tt := []struct { + name string + authHeader string + expectedStatusCode int + }{ + { + name: "Valid PAT Token", + authHeader: "Token " + PAT, + expectedStatusCode: 200, + }, + { + name: "Invalid PAT Token", + authHeader: "Token " + wrongToken, + expectedStatusCode: 401, + }, + { + name: "Valid JWT Token", + authHeader: "Bearer " + JWT, + expectedStatusCode: 200, + }, + { + name: "Invalid JWT Token", + authHeader: "Bearer " + wrongToken, + expectedStatusCode: 401, + }, + { + name: "Basic Auth", + authHeader: "Basic " + PAT, + expectedStatusCode: 401, + }, + } + + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // do nothing + }) + + authMiddleware := NewAuthMiddleware(mockGetAccountFromPAT, mockValidateAndParseToken, mockMarkPATUsed, audience) + + handlerToTest := authMiddleware.Handler(nextHandler) + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest("GET", "http://testing", nil) + req.Header.Set("Authorization", tc.authHeader) + rec := httptest.NewRecorder() + + handlerToTest.ServeHTTP(rec, req) + + if rec.Result().StatusCode != tc.expectedStatusCode { + t.Errorf("expected status code %d, got %d", tc.expectedStatusCode, rec.Result().StatusCode) + } + }) + } + +} diff --git a/management/server/http/middleware/jwt.go b/management/server/http/middleware/jwt.go deleted file mode 100644 index 1ac6d3948..000000000 --- a/management/server/http/middleware/jwt.go +++ /dev/null @@ -1,254 +0,0 @@ -package middleware - -import ( - "context" - "errors" - "fmt" - "net/http" - "strings" - - "github.com/golang-jwt/jwt" - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/management/server/http/util" - "github.com/netbirdio/netbird/management/server/status" -) - -// A function called whenever an error is encountered -type errorHandler func(w http.ResponseWriter, r *http.Request, err string) - -// TokenExtractor is a function that takes a request as input and returns -// either a token or an error. An error should only be returned if an attempt -// to specify a token was found, but the information was somehow incorrectly -// formed. In the case where a token is simply not present, this should not -// be treated as an error. An empty string should be returned in that case. -type TokenExtractor func(r *http.Request) (string, error) - -// Options is a struct for specifying configuration options for the middleware. -type Options struct { - // The function that will return the Key to validate the JWT. - // It can be either a shared secret or a public key. - // Default value: nil - ValidationKeyGetter jwt.Keyfunc - // The name of the property in the request where the user information - // from the JWT will be stored. - // Default value: "user" - UserProperty string - // The function that will be called when there's an error validating the token - // Default value: - ErrorHandler errorHandler - // A boolean indicating if the credentials are required or not - // Default value: false - CredentialsOptional bool - // A function that extracts the token from the request - // Default: FromAuthHeader (i.e., from Authorization header as bearer token) - Extractor TokenExtractor - // Debug flag turns on debugging output - // Default: false - Debug bool - // When set, all requests with the OPTIONS method will use authentication - // Default: false - EnableAuthOnOptions bool - // When set, the middelware verifies that tokens are signed with the specific signing algorithm - // If the signing method is not constant the ValidationKeyGetter callback can be used to implement additional checks - // Important to avoid security issues described here: https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/ - // Default: nil - SigningMethod jwt.SigningMethod -} - -type JWTMiddleware struct { - Options Options -} - -func OnError(w http.ResponseWriter, r *http.Request, err string) { - util.WriteError(status.Errorf(status.Unauthorized, ""), w) -} - -// New constructs a new Secure instance with supplied options. -func New(options ...Options) *JWTMiddleware { - - var opts Options - if len(options) == 0 { - opts = Options{} - } else { - opts = options[0] - } - - if opts.UserProperty == "" { - opts.UserProperty = "user" - } - - if opts.ErrorHandler == nil { - opts.ErrorHandler = OnError - } - - if opts.Extractor == nil { - opts.Extractor = FromAuthHeader - } - - return &JWTMiddleware{ - Options: opts, - } -} - -func (m *JWTMiddleware) logf(format string, args ...interface{}) { - if m.Options.Debug { - log.Printf(format, args...) - } -} - -// HandlerWithNext is a special implementation for Negroni, but could be used elsewhere. -func (m *JWTMiddleware) HandlerWithNext(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { - err := m.CheckJWTFromRequest(w, r) - - // If there was an error, do not call next. - if err == nil && next != nil { - next(w, r) - } -} - -func (m *JWTMiddleware) Handler(h http.Handler) http.Handler { - - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Let secure process the request. If it returns an error, - // that indicates the request should not continue. - err := m.CheckJWTFromRequest(w, r) - - // If there was an error, do not continue. - if err != nil { - log.Errorf("received an error while validating the JWT token: %s. "+ - "Review your IDP configuration and ensure that "+ - "settings are in sync between dashboard and management", err) - return - } - - h.ServeHTTP(w, r) - }) -} - -// FromAuthHeader is a "TokenExtractor" that takes a give request and extracts -// the JWT token from the Authorization header. -func FromAuthHeader(r *http.Request) (string, error) { - authHeader := r.Header.Get("Authorization") - if authHeader == "" { - return "", nil // No error, just no token - } - - // TODO: Make this a bit more robust, parsing-wise - authHeaderParts := strings.Fields(authHeader) - if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" { - return "", errors.New("Authorization header format must be Bearer {token}") - } - - return authHeaderParts[1], nil -} - -// FromParameter returns a function that extracts the token from the specified -// query string parameter -func FromParameter(param string) TokenExtractor { - return func(r *http.Request) (string, error) { - return r.URL.Query().Get(param), nil - } -} - -// FromFirst returns a function that runs multiple token extractors and takes the -// first token it finds -func FromFirst(extractors ...TokenExtractor) TokenExtractor { - return func(r *http.Request) (string, error) { - for _, ex := range extractors { - token, err := ex(r) - if err != nil { - return "", err - } - if token != "" { - return token, nil - } - } - return "", nil - } -} - -func (m *JWTMiddleware) CheckJWTFromRequest(w http.ResponseWriter, r *http.Request) error { - if !m.Options.EnableAuthOnOptions { - if r.Method == "OPTIONS" { - return nil - } - } - - // Use the specified token extractor to extract a token from the request - token, err := m.Options.Extractor(r) - - // If debugging is turned on, log the outcome - if err != nil { - m.logf("Error extracting JWT: %v", err) - } else { - m.logf("Token extracted: %s", token) - } - - // If an error occurs, call the error handler and return an error - if err != nil { - m.Options.ErrorHandler(w, r, err.Error()) - return fmt.Errorf("Error extracting token: %w", err) - } - - validatedToken, err := m.ValidateAndParse(token) - if err != nil { - m.Options.ErrorHandler(w, r, err.Error()) - return err - } - - if validatedToken == nil { - return nil - } - - // If we get here, everything worked and we can set the - // user property in context. - newRequest := r.WithContext(context.WithValue(r.Context(), m.Options.UserProperty, validatedToken)) //nolint - // Update the current request with the new context information. - *r = *newRequest - return nil -} - -// ValidateAndParse validates and parses a given access token against jwt standards and signing methods -func (m *JWTMiddleware) ValidateAndParse(token string) (*jwt.Token, error) { - // If the token is empty... - if token == "" { - // Check if it was required - if m.Options.CredentialsOptional { - m.logf("no credentials found (CredentialsOptional=true)") - // No error, just no token (and that is ok given that CredentialsOptional is true) - return nil, nil - } - - // If we get here, the required token is missing - errorMsg := "required authorization token not found" - m.logf(" Error: No credentials found (CredentialsOptional=false)") - return nil, fmt.Errorf(errorMsg) - } - - // Now parse the token - parsedToken, err := jwt.Parse(token, m.Options.ValidationKeyGetter) - - // Check if there was an error in parsing... - if err != nil { - m.logf("error parsing token: %v", err) - return nil, fmt.Errorf("Error parsing token: %w", err) - } - - if m.Options.SigningMethod != nil && m.Options.SigningMethod.Alg() != parsedToken.Header["alg"] { - errorMsg := fmt.Sprintf("Expected %s signing method but token specified %s", - m.Options.SigningMethod.Alg(), - parsedToken.Header["alg"]) - m.logf("error validating token algorithm: %s", errorMsg) - return nil, fmt.Errorf("error validating token algorithm: %s", errorMsg) - } - - // Check if the parsed token is valid... - if !parsedToken.Valid { - errorMsg := "token is invalid" - m.logf(errorMsg) - return nil, errors.New(errorMsg) - } - - return parsedToken, nil -} diff --git a/management/server/http/nameservers_handler.go b/management/server/http/nameservers_handler.go index e7be617e2..5ad52a426 100644 --- a/management/server/http/nameservers_handler.go +++ b/management/server/http/nameservers_handler.go @@ -243,7 +243,7 @@ func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *htt return } - util.WriteJSONObject(w, "") + util.WriteJSONObject(w, emptyObject{}) } // GetNameserverGroup handles a nameserver group Get request identified by ID diff --git a/management/server/http/pat_handler.go b/management/server/http/pat_handler.go new file mode 100644 index 000000000..d2398a7e1 --- /dev/null +++ b/management/server/http/pat_handler.go @@ -0,0 +1,178 @@ +package http + +import ( + "encoding/json" + "net/http" + "time" + + "github.com/gorilla/mux" + + "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/util" + "github.com/netbirdio/netbird/management/server/jwtclaims" + "github.com/netbirdio/netbird/management/server/status" +) + +// PATHandler is the nameserver group handler of the account +type PATHandler struct { + accountManager server.AccountManager + claimsExtractor *jwtclaims.ClaimsExtractor +} + +// NewPATsHandler creates a new PATHandler HTTP handler +func NewPATsHandler(accountManager server.AccountManager, authCfg AuthCfg) *PATHandler { + return &PATHandler{ + accountManager: accountManager, + claimsExtractor: jwtclaims.NewClaimsExtractor( + jwtclaims.WithAudience(authCfg.Audience), + jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), + ), + } +} + +// GetAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user +func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + account, user, err := h.accountManager.GetAccountFromToken(claims) + if err != nil { + util.WriteError(err, w) + return + } + + vars := mux.Vars(r) + userID := vars["userId"] + if len(userID) == 0 { + util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w) + return + } + + pats, err := h.accountManager.GetAllPATs(account.Id, user.Id, userID) + if err != nil { + util.WriteError(err, w) + return + } + + var patResponse []*api.PersonalAccessToken + for _, pat := range pats { + patResponse = append(patResponse, toPATResponse(pat)) + } + + util.WriteJSONObject(w, patResponse) +} + +// GetToken is HTTP GET handler that returns a personal access token for the given user +func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + account, user, err := h.accountManager.GetAccountFromToken(claims) + if err != nil { + util.WriteError(err, w) + return + } + + vars := mux.Vars(r) + targetUserID := vars["userId"] + if len(targetUserID) == 0 { + util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w) + return + } + + tokenID := vars["tokenId"] + if len(tokenID) == 0 { + util.WriteError(status.Errorf(status.InvalidArgument, "invalid token ID"), w) + return + } + + pat, err := h.accountManager.GetPAT(account.Id, user.Id, targetUserID, tokenID) + if err != nil { + util.WriteError(err, w) + return + } + + util.WriteJSONObject(w, toPATResponse(pat)) +} + +// CreateToken is HTTP POST handler that creates a personal access token for the given user +func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + account, user, err := h.accountManager.GetAccountFromToken(claims) + if err != nil { + util.WriteError(err, w) + return + } + + vars := mux.Vars(r) + targetUserID := vars["userId"] + if len(targetUserID) == 0 { + util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w) + return + } + + var req api.PostApiUsersUserIdTokensJSONRequestBody + err = json.NewDecoder(r.Body).Decode(&req) + if err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + + pat, err := h.accountManager.CreatePAT(account.Id, user.Id, targetUserID, req.Name, req.ExpiresIn) + if err != nil { + util.WriteError(err, w) + return + } + + util.WriteJSONObject(w, toPATGeneratedResponse(pat)) +} + +// DeleteToken is HTTP DELETE handler that deletes a personal access token for the given user +func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + account, user, err := h.accountManager.GetAccountFromToken(claims) + if err != nil { + util.WriteError(err, w) + return + } + + vars := mux.Vars(r) + targetUserID := vars["userId"] + if len(targetUserID) == 0 { + util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w) + return + } + + tokenID := vars["tokenId"] + if len(tokenID) == 0 { + util.WriteError(status.Errorf(status.InvalidArgument, "invalid token ID"), w) + return + } + + err = h.accountManager.DeletePAT(account.Id, user.Id, targetUserID, tokenID) + if err != nil { + util.WriteError(err, w) + return + } + + util.WriteJSONObject(w, emptyObject{}) +} + +func toPATResponse(pat *server.PersonalAccessToken) *api.PersonalAccessToken { + var lastUsed *time.Time + if !pat.LastUsed.IsZero() { + lastUsed = &pat.LastUsed + } + return &api.PersonalAccessToken{ + CreatedAt: pat.CreatedAt, + CreatedBy: pat.CreatedBy, + Name: pat.Name, + ExpirationDate: pat.ExpirationDate, + Id: pat.ID, + LastUsed: lastUsed, + } +} + +func toPATGeneratedResponse(pat *server.PersonalAccessTokenGenerated) *api.PersonalAccessTokenGenerated { + return &api.PersonalAccessTokenGenerated{ + PlainToken: pat.PlainToken, + PersonalAccessToken: *toPATResponse(&pat.PersonalAccessToken), + } +} diff --git a/management/server/http/pat_handler_test.go b/management/server/http/pat_handler_test.go new file mode 100644 index 000000000..de79f1006 --- /dev/null +++ b/management/server/http/pat_handler_test.go @@ -0,0 +1,254 @@ +package http + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/gorilla/mux" + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/jwtclaims" + "github.com/netbirdio/netbird/management/server/mock_server" + "github.com/netbirdio/netbird/management/server/status" +) + +const ( + existingAccountID = "existingAccountID" + notFoundAccountID = "notFoundAccountID" + existingUserID = "existingUserID" + notFoundUserID = "notFoundUserID" + existingTokenID = "existingTokenID" + notFoundTokenID = "notFoundTokenID" + domain = "hotmail.com" +) + +var testAccount = &server.Account{ + Id: existingAccountID, + Domain: domain, + Users: map[string]*server.User{ + existingUserID: { + Id: existingUserID, + PATs: map[string]*server.PersonalAccessToken{ + existingTokenID: { + ID: existingTokenID, + Name: "My first token", + HashedToken: "someHash", + ExpirationDate: time.Now().AddDate(0, 0, 7), + CreatedBy: existingUserID, + CreatedAt: time.Now(), + LastUsed: time.Now(), + }, + "token2": { + ID: "token2", + Name: "My second token", + HashedToken: "someOtherHash", + ExpirationDate: time.Now().AddDate(0, 0, 7), + CreatedBy: existingUserID, + CreatedAt: time.Now(), + LastUsed: time.Now(), + }, + }, + }, + }, +} + +func initPATTestData() *PATHandler { + return &PATHandler{ + accountManager: &mock_server.MockAccountManager{ + CreatePATFunc: func(accountID string, executingUserID string, targetUserID string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) { + if accountID != existingAccountID { + return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID) + } + if targetUserID != existingUserID { + return nil, status.Errorf(status.NotFound, "user with ID %s not found", targetUserID) + } + return &server.PersonalAccessTokenGenerated{ + PlainToken: "nbp_z1pvsg2wP3EzmEou4S679KyTNhov632eyrXe", + PersonalAccessToken: server.PersonalAccessToken{}, + }, nil + }, + + GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { + return testAccount, testAccount.Users[existingUserID], nil + }, + DeletePATFunc: func(accountID string, executingUserID string, targetUserID string, tokenID string) error { + if accountID != existingAccountID { + return status.Errorf(status.NotFound, "account with ID %s not found", accountID) + } + if targetUserID != existingUserID { + return status.Errorf(status.NotFound, "user with ID %s not found", targetUserID) + } + if tokenID != existingTokenID { + return status.Errorf(status.NotFound, "token with ID %s not found", tokenID) + } + return nil + }, + GetPATFunc: func(accountID string, executingUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) { + if accountID != existingAccountID { + return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID) + } + if targetUserID != existingUserID { + return nil, status.Errorf(status.NotFound, "user with ID %s not found", targetUserID) + } + if tokenID != existingTokenID { + return nil, status.Errorf(status.NotFound, "token with ID %s not found", tokenID) + } + return testAccount.Users[existingUserID].PATs[existingTokenID], nil + }, + GetAllPATsFunc: func(accountID string, executingUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) { + if accountID != existingAccountID { + return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID) + } + if targetUserID != existingUserID { + return nil, status.Errorf(status.NotFound, "user with ID %s not found", targetUserID) + } + return []*server.PersonalAccessToken{testAccount.Users[existingUserID].PATs[existingTokenID], testAccount.Users[existingUserID].PATs["token2"]}, nil + }, + }, + claimsExtractor: jwtclaims.NewClaimsExtractor( + jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { + return jwtclaims.AuthorizationClaims{ + UserId: existingUserID, + Domain: domain, + AccountId: testNSGroupAccountID, + } + }), + ), + } +} + +func TestTokenHandlers(t *testing.T) { + tt := []struct { + name string + expectedStatus int + expectedBody bool + requestType string + requestPath string + requestBody io.Reader + }{ + { + name: "Get All Tokens", + requestType: http.MethodGet, + requestPath: "/api/users/" + existingUserID + "/tokens", + expectedStatus: http.StatusOK, + expectedBody: true, + }, + { + name: "Get Existing Token", + requestType: http.MethodGet, + requestPath: "/api/users/" + existingUserID + "/tokens/" + existingTokenID, + expectedStatus: http.StatusOK, + expectedBody: true, + }, + { + name: "Get Not Existing Token", + requestType: http.MethodGet, + requestPath: "/api/users/" + existingUserID + "/tokens/" + notFoundTokenID, + expectedStatus: http.StatusNotFound, + }, + { + name: "Delete Existing Token", + requestType: http.MethodDelete, + requestPath: "/api/users/" + existingUserID + "/tokens/" + existingTokenID, + expectedStatus: http.StatusOK, + }, + { + name: "Delete Not Existing Token", + requestType: http.MethodDelete, + requestPath: "/api/users/" + existingUserID + "/tokens/" + notFoundTokenID, + expectedStatus: http.StatusNotFound, + }, + { + name: "POST OK", + requestType: http.MethodPost, + requestPath: "/api/users/" + existingUserID + "/tokens", + requestBody: bytes.NewBuffer( + []byte("{\"name\":\"name\",\"expires_in\":7}")), + expectedStatus: http.StatusOK, + expectedBody: true, + }, + } + + p := initPATTestData() + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + recorder := httptest.NewRecorder() + req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) + + router := mux.NewRouter() + router.HandleFunc("/api/users/{userId}/tokens", p.GetAllTokens).Methods("GET") + router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.GetToken).Methods("GET") + router.HandleFunc("/api/users/{userId}/tokens", p.CreateToken).Methods("POST") + router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.DeleteToken).Methods("DELETE") + router.ServeHTTP(recorder, req) + + res := recorder.Result() + defer res.Body.Close() + + content, err := io.ReadAll(res.Body) + if err != nil { + t.Fatalf("I don't know what I expected; %v", err) + } + + if status := recorder.Code; status != tc.expectedStatus { + t.Errorf("handler returned wrong status code: got %v want %v, content: %s", + status, tc.expectedStatus, string(content)) + return + } + + if !tc.expectedBody { + return + } + + switch tc.name { + case "POST OK": + got := &api.PersonalAccessTokenGenerated{} + if err = json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + assert.NotEmpty(t, got.PlainToken) + assert.Equal(t, server.PATLength, len(got.PlainToken)) + case "Get All Tokens": + expectedTokens := []api.PersonalAccessToken{ + toTokenResponse(*testAccount.Users[existingUserID].PATs[existingTokenID]), + toTokenResponse(*testAccount.Users[existingUserID].PATs["token2"]), + } + + var got []api.PersonalAccessToken + if err = json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + assert.True(t, cmp.Equal(got, expectedTokens)) + case "Get Existing Token": + expectedToken := toTokenResponse(*testAccount.Users[existingUserID].PATs[existingTokenID]) + got := &api.PersonalAccessToken{} + if err = json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + assert.True(t, cmp.Equal(*got, expectedToken)) + } + + }) + } +} + +func toTokenResponse(serverToken server.PersonalAccessToken) api.PersonalAccessToken { + return api.PersonalAccessToken{ + Id: serverToken.ID, + Name: serverToken.Name, + CreatedAt: serverToken.CreatedAt, + LastUsed: &serverToken.LastUsed, + CreatedBy: serverToken.CreatedBy, + ExpirationDate: serverToken.ExpirationDate, + } +} diff --git a/management/server/http/peers_handler.go b/management/server/http/peers_handler.go index 76c4f7502..7379277af 100644 --- a/management/server/http/peers_handler.go +++ b/management/server/http/peers_handler.go @@ -66,7 +66,7 @@ func (h *PeersHandler) deletePeer(accountID, userID string, peerID string, w htt util.WriteError(err, w) return } - util.WriteJSONObject(w, "") + util.WriteJSONObject(w, emptyObject{}) } // HandlePeer handles all peer requests for GET, PUT and DELETE operations diff --git a/management/server/http/policies.go b/management/server/http/policies.go index 275992d14..a0fe3b1e2 100644 --- a/management/server/http/policies.go +++ b/management/server/http/policies.go @@ -225,7 +225,7 @@ func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) { return } - util.WriteJSONObject(w, "") + util.WriteJSONObject(w, emptyObject{}) } // GetPolicy handles a group Get request identified by ID diff --git a/management/server/http/routes_handler.go b/management/server/http/routes_handler.go index b29e5c261..aaaaaa854 100644 --- a/management/server/http/routes_handler.go +++ b/management/server/http/routes_handler.go @@ -321,7 +321,7 @@ func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) { return } - util.WriteJSONObject(w, "") + util.WriteJSONObject(w, emptyObject{}) } // GetRoute handles a route Get request identified by ID diff --git a/management/server/http/rules_handler.go b/management/server/http/rules_handler.go index 8925c3763..f8bb5f0cb 100644 --- a/management/server/http/rules_handler.go +++ b/management/server/http/rules_handler.go @@ -222,7 +222,7 @@ func (h *RulesHandler) DeleteRule(w http.ResponseWriter, r *http.Request) { return } - util.WriteJSONObject(w, "") + util.WriteJSONObject(w, emptyObject{}) } // GetRule handles a group Get request identified by ID diff --git a/management/server/http/util/util.go b/management/server/http/util/util.go index 0055511a2..407443251 100644 --- a/management/server/http/util/util.go +++ b/management/server/http/util/util.go @@ -4,10 +4,13 @@ import ( "encoding/json" "errors" "fmt" - "github.com/netbirdio/netbird/management/server/status" - log "github.com/sirupsen/logrus" "net/http" + "strings" "time" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/management/server/status" ) // WriteJSONObject simply writes object to the HTTP reponse in JSON format @@ -93,9 +96,11 @@ func WriteError(err error, w http.ResponseWriter) { httpStatus = http.StatusInternalServerError case status.InvalidArgument: httpStatus = http.StatusUnprocessableEntity + case status.Unauthorized: + httpStatus = http.StatusUnauthorized default: } - msg = err.Error() + msg = strings.ToLower(err.Error()) } else { unhandledMSG := fmt.Sprintf("got unhandled error code, error: %s", err.Error()) log.Error(unhandledMSG) diff --git a/management/server/jwtclaims/extractor.go b/management/server/jwtclaims/extractor.go index 9d60da335..9aa00a004 100644 --- a/management/server/jwtclaims/extractor.go +++ b/management/server/jwtclaims/extractor.go @@ -7,14 +7,19 @@ import ( ) const ( - TokenUserProperty = "user" - AccountIDSuffix = "wt_account_id" - DomainIDSuffix = "wt_account_domain" + // TokenUserProperty key for the user property in the request context + TokenUserProperty = "user" + // AccountIDSuffix suffix for the account id claim + AccountIDSuffix = "wt_account_id" + // DomainIDSuffix suffix for the domain id claim + DomainIDSuffix = "wt_account_domain" + // DomainCategorySuffix suffix for the domain category claim DomainCategorySuffix = "wt_account_domain_category" - UserIDClaim = "sub" + // UserIDClaim claim for the user id + UserIDClaim = "sub" ) -// Extract function type +// ExtractClaims Extract function type type ExtractClaims func(r *http.Request) AuthorizationClaims // ClaimsExtractor struct that holds the extract function diff --git a/management/server/jwtclaims/extractor_test.go b/management/server/jwtclaims/extractor_test.go index d8acd79b6..53f8818b1 100644 --- a/management/server/jwtclaims/extractor_test.go +++ b/management/server/jwtclaims/extractor_test.go @@ -26,7 +26,7 @@ func newTestRequestWithJWT(t *testing.T, claims AuthorizationClaims, audiance st token := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps) r, err := http.NewRequest(http.MethodGet, "http://localhost", nil) require.NoError(t, err, "creating testing request failed") - testRequest := r.WithContext(context.WithValue(r.Context(), TokenUserProperty, token)) //nolint + testRequest := r.WithContext(context.WithValue(r.Context(), TokenUserProperty, token)) // nolint return testRequest } diff --git a/management/server/http/middleware/handler.go b/management/server/jwtclaims/jwtValidator.go similarity index 52% rename from management/server/http/middleware/handler.go rename to management/server/jwtclaims/jwtValidator.go index c647506bc..147f8f2eb 100644 --- a/management/server/http/middleware/handler.go +++ b/management/server/jwtclaims/jwtValidator.go @@ -1,4 +1,4 @@ -package middleware +package jwtclaims import ( "bytes" @@ -17,6 +17,32 @@ import ( log "github.com/sirupsen/logrus" ) +// Options is a struct for specifying configuration options for the middleware. +type Options struct { + // The function that will return the Key to validate the JWT. + // It can be either a shared secret or a public key. + // Default value: nil + ValidationKeyGetter jwt.Keyfunc + // The name of the property in the request where the user information + // from the JWT will be stored. + // Default value: "user" + UserProperty string + // The function that will be called when there's an error validating the token + // Default value: + CredentialsOptional bool + // A function that extracts the token from the request + // Default: FromAuthHeader (i.e., from Authorization header as bearer token) + Debug bool + // When set, all requests with the OPTIONS method will use authentication + // Default: false + EnableAuthOnOptions bool + // When set, the middelware verifies that tokens are signed with the specific signing algorithm + // If the signing method is not constant the ValidationKeyGetter callback can be used to implement additional checks + // Important to avoid security issues described here: https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/ + // Default: nil + SigningMethod jwt.SigningMethod +} + // Jwks is a collection of JSONWebKey obtained from Config.HttpServerConfig.AuthKeysLocation type Jwks struct { Keys []JSONWebKey `json:"keys"` @@ -32,17 +58,28 @@ type JSONWebKey struct { X5c []string `json:"x5c"` } -// NewJwtMiddleware creates new middleware to verify the JWT token sent via Authorization header -func NewJwtMiddleware(issuer string, audience string, keysLocation string) (*JWTMiddleware, error) { +// JWTValidator struct to handle token validation and parsing +type JWTValidator struct { + options Options +} + +// NewJWTValidator constructor +func NewJWTValidator(issuer string, audienceList []string, keysLocation string) (*JWTValidator, error) { keys, err := getPemKeys(keysLocation) if err != nil { return nil, err } - return New(Options{ + options := Options{ ValidationKeyGetter: func(token *jwt.Token) (interface{}, error) { // Verify 'aud' claim - checkAud := token.Claims.(jwt.MapClaims).VerifyAudience(audience, false) + var checkAud bool + for _, audience := range audienceList { + checkAud = token.Claims.(jwt.MapClaims).VerifyAudience(audience, false) + if checkAud { + break + } + } if !checkAud { return token, errors.New("invalid audience") } @@ -62,7 +99,59 @@ func NewJwtMiddleware(issuer string, audience string, keysLocation string) (*JWT }, SigningMethod: jwt.SigningMethodRS256, EnableAuthOnOptions: false, - }), nil + } + + if options.UserProperty == "" { + options.UserProperty = "user" + } + + return &JWTValidator{ + options: options, + }, nil +} + +// ValidateAndParse validates the token and returns the parsed token +func (m *JWTValidator) ValidateAndParse(token string) (*jwt.Token, error) { + // If the token is empty... + if token == "" { + // Check if it was required + if m.options.CredentialsOptional { + log.Debugf("no credentials found (CredentialsOptional=true)") + // No error, just no token (and that is ok given that CredentialsOptional is true) + return nil, nil + } + + // If we get here, the required token is missing + errorMsg := "required authorization token not found" + log.Debugf(" Error: No credentials found (CredentialsOptional=false)") + return nil, fmt.Errorf(errorMsg) + } + + // Now parse the token + parsedToken, err := jwt.Parse(token, m.options.ValidationKeyGetter) + + // Check if there was an error in parsing... + if err != nil { + log.Debugf("error parsing token: %v", err) + return nil, fmt.Errorf("Error parsing token: %w", err) + } + + if m.options.SigningMethod != nil && m.options.SigningMethod.Alg() != parsedToken.Header["alg"] { + errorMsg := fmt.Sprintf("Expected %s signing method but token specified %s", + m.options.SigningMethod.Alg(), + parsedToken.Header["alg"]) + log.Debugf("error validating token algorithm: %s", errorMsg) + return nil, fmt.Errorf("error validating token algorithm: %s", errorMsg) + } + + // Check if the parsed token is valid... + if !parsedToken.Valid { + errorMsg := "token is invalid" + log.Debugf(errorMsg) + return nil, errors.New(errorMsg) + } + + return parsedToken, nil } func getPemKeys(keysLocation string) (*Jwks, error) { diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 2ae71d1a9..eb473d03c 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -47,7 +47,8 @@ type MockAccountManager struct { DeletePolicyFunc func(accountID, policyID, userID string) error ListPoliciesFunc func(accountID, userID string) ([]*server.Policy, error) GetUsersFromAccountFunc func(accountID, userID string) ([]*server.UserInfo, error) - GetAccountFromPATFunc func(pat string) (*server.Account, *server.User, error) + GetAccountFromPATFunc func(pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) + MarkPATUsedFunc func(pat string) error UpdatePeerMetaFunc func(peerID string, meta server.PeerSystemMeta) error UpdatePeerSSHKeyFunc func(peerID string, sshKey string) error UpdatePeerFunc func(accountID, userID string, peer *server.Peer) (*server.Peer, error) @@ -60,8 +61,10 @@ type MockAccountManager struct { SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error) ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error) SaveUserFunc func(accountID, userID string, user *server.User) (*server.UserInfo, error) - AddPATToUserFunc func(accountID string, userID string, pat *server.PersonalAccessToken) error - DeletePATFunc func(accountID string, userID string, tokenID string) error + CreatePATFunc func(accountID string, executingUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) + DeletePATFunc func(accountID string, executingUserID string, targetUserId string, tokenID string) error + GetPATFunc func(accountID string, executingUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error) + GetAllPATsFunc func(accountID string, executingUserID string, targetUserId string) ([]*server.PersonalAccessToken, error) GetNameServerGroupFunc func(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string) (*nbdns.NameServerGroup, error) SaveNameServerGroupFunc func(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error @@ -179,29 +182,53 @@ func (am *MockAccountManager) GetPeerByIP(accountId string, peerIP string) (*ser } // GetAccountFromPAT mock implementation of GetAccountFromPAT from server.AccountManager interface -func (am *MockAccountManager) GetAccountFromPAT(pat string) (*server.Account, *server.User, error) { +func (am *MockAccountManager) GetAccountFromPAT(pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) { if am.GetAccountFromPATFunc != nil { return am.GetAccountFromPATFunc(pat) } - return nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromPAT is not implemented") + return nil, nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromPAT is not implemented") } -// AddPATToUser mock implementation of AddPATToUser from server.AccountManager interface -func (am *MockAccountManager) AddPATToUser(accountID string, userID string, pat *server.PersonalAccessToken) error { - if am.AddPATToUserFunc != nil { - return am.AddPATToUserFunc(accountID, userID, pat) +// MarkPATUsed mock implementation of MarkPATUsed from server.AccountManager interface +func (am *MockAccountManager) MarkPATUsed(pat string) error { + if am.MarkPATUsedFunc != nil { + return am.MarkPATUsedFunc(pat) } - return status.Errorf(codes.Unimplemented, "method AddPATToUser is not implemented") + return status.Errorf(codes.Unimplemented, "method MarkPATUsed is not implemented") +} + +// CreatePAT mock implementation of GetPAT from server.AccountManager interface +func (am *MockAccountManager) CreatePAT(accountID string, executingUserID string, targetUserID string, name string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) { + if am.CreatePATFunc != nil { + return am.CreatePATFunc(accountID, executingUserID, targetUserID, name, expiresIn) + } + return nil, status.Errorf(codes.Unimplemented, "method CreatePAT is not implemented") } // DeletePAT mock implementation of DeletePAT from server.AccountManager interface -func (am *MockAccountManager) DeletePAT(accountID string, userID string, tokenID string) error { +func (am *MockAccountManager) DeletePAT(accountID string, executingUserID string, targetUserID string, tokenID string) error { if am.DeletePATFunc != nil { - return am.DeletePATFunc(accountID, userID, tokenID) + return am.DeletePATFunc(accountID, executingUserID, targetUserID, tokenID) } return status.Errorf(codes.Unimplemented, "method DeletePAT is not implemented") } +// GetPAT mock implementation of GetPAT from server.AccountManager interface +func (am *MockAccountManager) GetPAT(accountID string, executingUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) { + if am.GetPATFunc != nil { + return am.GetPATFunc(accountID, executingUserID, targetUserID, tokenID) + } + return nil, status.Errorf(codes.Unimplemented, "method GetPAT is not implemented") +} + +// GetAllPATs mock implementation of GetAllPATs from server.AccountManager interface +func (am *MockAccountManager) GetAllPATs(accountID string, executingUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) { + if am.GetAllPATsFunc != nil { + return am.GetAllPATsFunc(accountID, executingUserID, targetUserID) + } + return nil, status.Errorf(codes.Unimplemented, "method GetAllPATs is not implemented") +} + // GetNetworkMap mock implementation of GetNetworkMap from server.AccountManager interface func (am *MockAccountManager) GetNetworkMap(peerKey string) (*server.NetworkMap, error) { if am.GetNetworkMapFunc != nil { diff --git a/management/server/personal_access_token.go b/management/server/personal_access_token.go index 7416a9e0b..bdf34e9fd 100644 --- a/management/server/personal_access_token.go +++ b/management/server/personal_access_token.go @@ -2,6 +2,7 @@ package server import ( "crypto/sha256" + b64 "encoding/base64" "fmt" "hash/crc32" "time" @@ -25,7 +26,7 @@ const ( // PersonalAccessToken holds all information about a PAT including a hashed version of it for verification type PersonalAccessToken struct { ID string - Description string + Name string HashedToken string ExpirationDate time.Time // scope could be added in future @@ -34,23 +35,33 @@ type PersonalAccessToken struct { LastUsed time.Time } +// PersonalAccessTokenGenerated holds the new PersonalAccessToken and the plain text version of it +type PersonalAccessTokenGenerated struct { + PlainToken string + PersonalAccessToken +} + // CreateNewPAT will generate a new PersonalAccessToken that can be assigned to a User. // Additionally, it will return the token in plain text once, to give to the user and only save a hashed version -func CreateNewPAT(description string, expirationInDays int, createdBy string) (*PersonalAccessToken, string, error) { +func CreateNewPAT(name string, expirationInDays int, createdBy string) (*PersonalAccessTokenGenerated, error) { hashedToken, plainToken, err := generateNewToken() if err != nil { - return nil, "", err + return nil, err } currentTime := time.Now().UTC() - return &PersonalAccessToken{ - ID: xid.New().String(), - Description: description, - HashedToken: hashedToken, - ExpirationDate: currentTime.AddDate(0, 0, expirationInDays), - CreatedBy: createdBy, - CreatedAt: currentTime, - LastUsed: currentTime, - }, plainToken, nil + return &PersonalAccessTokenGenerated{ + PersonalAccessToken: PersonalAccessToken{ + ID: xid.New().String(), + Name: name, + HashedToken: hashedToken, + ExpirationDate: currentTime.AddDate(0, 0, expirationInDays), + CreatedBy: createdBy, + CreatedAt: currentTime, + LastUsed: time.Time{}, + }, + PlainToken: plainToken, + }, nil + } func generateNewToken() (string, string, error) { @@ -64,5 +75,6 @@ func generateNewToken() (string, string, error) { paddedChecksum := fmt.Sprintf("%06s", encodedChecksum) plainToken := PATPrefix + secret + paddedChecksum hashedToken := sha256.Sum256([]byte(plainToken)) - return string(hashedToken[:]), plainToken, nil + encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:]) + return encodedHashedToken, plainToken, nil } diff --git a/management/server/personal_access_token_test.go b/management/server/personal_access_token_test.go index a4e02f750..03dd2ef4e 100644 --- a/management/server/personal_access_token_test.go +++ b/management/server/personal_access_token_test.go @@ -2,6 +2,7 @@ package server import ( "crypto/sha256" + b64 "encoding/base64" "hash/crc32" "strings" "testing" @@ -13,7 +14,8 @@ import ( func TestPAT_GenerateToken_Hashing(t *testing.T) { hashedToken, plainToken, _ := generateNewToken() expectedToken := sha256.Sum256([]byte(plainToken)) - assert.Equal(t, hashedToken, string(expectedToken[:])) + encodedExpectedToken := b64.StdEncoding.EncodeToString(expectedToken[:]) + assert.Equal(t, hashedToken, encodedExpectedToken) } func TestPAT_GenerateToken_Prefix(t *testing.T) { diff --git a/management/server/policy.go b/management/server/policy.go index 31f6bb655..8a166c25c 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -178,6 +178,9 @@ type FirewallRule struct { // Port of the traffic Port string + + // id for internal purposes + id string } // parseFromRegoResult parses the Rego result to a FirewallRule. @@ -218,39 +221,35 @@ func (f *FirewallRule) parseFromRegoResult(value interface{}) error { f.Action = action f.Port = port + // NOTE: update this id each time when new field added + f.id = peerID + peerIP + direction + action + port + return nil } -// getRegoQuery returns a initialized Rego object with default rule. -func (a *Account) getRegoQuery() (rego.PreparedEvalQuery, error) { - queries := []func(*rego.Rego){ - rego.Query("data.netbird.all"), - rego.Module("netbird", defaultPolicyModule), - } - for i, p := range a.Policies { - if !p.Enabled { - continue - } - queries = append(queries, rego.Module(fmt.Sprintf("netbird-%d", i), p.Query)) - } - return rego.New(queries...).PrepareForEval(context.TODO()) -} - -// getPeersByPolicy returns all peers that given peer has access to. -func (a *Account) getPeersByPolicy(peerID string) ([]*Peer, []*FirewallRule) { +// queryPeersAndFwRulesByRego returns a list associated Peers and firewall rules list for this peer. +func (a *Account) queryPeersAndFwRulesByRego( + peerID string, + queryNumber int, + query string, +) ([]*Peer, []*FirewallRule) { input := map[string]interface{}{ "peer_id": peerID, "peers": a.Peers, "groups": a.Groups, } - query, err := a.getRegoQuery() + stmt, err := rego.New( + rego.Query("data.netbird.all"), + rego.Module("netbird", defaultPolicyModule), + rego.Module(fmt.Sprintf("netbird-%d", queryNumber), query), + ).PrepareForEval(context.TODO()) if err != nil { log.WithError(err).Error("get Rego query") return nil, nil } - evalResult, err := query.Eval( + evalResult, err := stmt.Eval( context.TODO(), rego.EvalInput(input), ) @@ -318,6 +317,33 @@ func (a *Account) getPeersByPolicy(peerID string) ([]*Peer, []*FirewallRule) { return peers, rules } +// getPeersByPolicy returns all peers that given peer has access to. +func (a *Account) getPeersByPolicy(peerID string) (peers []*Peer, rules []*FirewallRule) { + peersSeen := make(map[string]struct{}) + ruleSeen := make(map[string]struct{}) + for i, policy := range a.Policies { + if !policy.Enabled { + continue + } + p, r := a.queryPeersAndFwRulesByRego(peerID, i, policy.Query) + for _, peer := range p { + if _, ok := peersSeen[peer.ID]; ok { + continue + } + peers = append(peers, peer) + peersSeen[peer.ID] = struct{}{} + } + for _, rule := range r { + if _, ok := ruleSeen[rule.id]; ok { + continue + } + rules = append(rules, rule) + ruleSeen[rule.id] = struct{}{} + } + } + return +} + // GetPolicy from the store func (am *DefaultAccountManager) GetPolicy(accountID, policyID, userID string) (*Policy, error) { unlock := am.Store.AcquireAccountLock(accountID) diff --git a/management/server/policy_test.go b/management/server/policy_test.go index 73663a8fd..39ac44843 100644 --- a/management/server/policy_test.go +++ b/management/server/policy_test.go @@ -5,63 +5,268 @@ import ( "testing" "github.com/stretchr/testify/assert" + "golang.org/x/exp/slices" ) func TestAccount_getPeersByPolicy(t *testing.T) { account := &Account{ Peers: map[string]*Peer{ - "peer1": { - ID: "peer1", - IP: net.IPv4(10, 20, 0, 1), + "cfif97at2r9s73au3q00": { + ID: "cfif97at2r9s73au3q00", + IP: net.ParseIP("100.65.14.88"), }, - "peer2": { - ID: "peer2", - IP: net.IPv4(10, 20, 0, 2), + "cfif97at2r9s73au3q0g": { + ID: "cfif97at2r9s73au3q0g", + IP: net.ParseIP("100.65.80.39"), }, - "peer3": { - ID: "peer3", - IP: net.IPv4(10, 20, 0, 3), + "cfif97at2r9s73au3q10": { + ID: "cfif97at2r9s73au3q10", + IP: net.ParseIP("100.65.254.139"), + }, + "cfif97at2r9s73au3q20": { + ID: "cfif97at2r9s73au3q20", + IP: net.ParseIP("100.65.62.5"), + }, + "cfj4tiqt2r9s73dmeun0": { + ID: "cfj4tiqt2r9s73dmeun0", + IP: net.ParseIP("100.65.32.206"), + }, + "cg7h032t2r9s73cg5fk0": { + ID: "cg7h032t2r9s73cg5fk0", + IP: net.ParseIP("100.65.250.202"), + }, + "cgcnkj2t2r9s73cg5vv0": { + ID: "cgcnkj2t2r9s73cg5vv0", + IP: net.ParseIP("100.65.13.186"), + }, + "cgcol4qt2r9s73cg601g": { + ID: "cgcol4qt2r9s73cg601g", + IP: net.ParseIP("100.65.29.55"), }, }, Groups: map[string]*Group{ - "gid1": { - ID: "gid1", - Name: "all", - Peers: []string{"peer1", "peer2", "peer3"}, + "cet9e92t2r9s7383ns20": { + ID: "cet9e92t2r9s7383ns20", + Name: "All", + Peers: []string{ + "cfif97at2r9s73au3q0g", + "cfif97at2r9s73au3q00", + "cfif97at2r9s73au3q20", + "cfif97at2r9s73au3q10", + "cfj4tiqt2r9s73dmeun0", + "cg7h032t2r9s73cg5fk0", + "cgcnkj2t2r9s73cg5vv0", + "cgcol4qt2r9s73cg601g", + }, + }, + "cev90bat2r9s7383o150": { + ID: "cev90bat2r9s7383o150", + Name: "swarm", + Peers: []string{ + "cfif97at2r9s73au3q0g", + "cfif97at2r9s73au3q00", + "cfif97at2r9s73au3q20", + "cfj4tiqt2r9s73dmeun0", + "cgcnkj2t2r9s73cg5vv0", + "cgcol4qt2r9s73cg601g", + }, }, }, Rules: map[string]*Rule{ - "default": { - ID: "default", - Name: "default", - Description: "default", - Disabled: false, - Source: []string{"gid1"}, - Destination: []string{"gid1"}, + "cet9e92t2r9s7383ns2g": { + ID: "cet9e92t2r9s7383ns2g", + Name: "Default", + Description: "This is a default rule that allows connections between all the resources", + Source: []string{ + "cet9e92t2r9s7383ns20", + }, + Destination: []string{ + "cet9e92t2r9s7383ns20", + }, + }, + "cev90bat2r9s7383o15g": { + ID: "cev90bat2r9s7383o15g", + Name: "Swarm", + Description: "", + Source: []string{ + "cev90bat2r9s7383o150", + "cet9e92t2r9s7383ns20", + }, + Destination: []string{ + "cev90bat2r9s7383o150", + }, }, }, } - rule, err := RuleToPolicy(account.Rules["default"]) + rule1, err := RuleToPolicy(account.Rules["cet9e92t2r9s7383ns2g"]) assert.NoError(t, err) - account.Policies = append(account.Policies, rule) + rule2, err := RuleToPolicy(account.Rules["cev90bat2r9s7383o15g"]) + assert.NoError(t, err) - peers, firewallRules := account.getPeersByPolicy("peer1") - assert.Len(t, peers, 2) - assert.Contains(t, peers, account.Peers["peer2"]) - assert.Contains(t, peers, account.Peers["peer3"]) + account.Policies = append(account.Policies, rule1, rule2) - epectedFirewallRules := []*FirewallRule{ - {PeerID: "peer1", PeerIP: "10.20.0.1", Direction: "dst", Action: "accept", Port: ""}, - {PeerID: "peer2", PeerIP: "10.20.0.2", Direction: "dst", Action: "accept", Port: ""}, - {PeerID: "peer3", PeerIP: "10.20.0.3", Direction: "dst", Action: "accept", Port: ""}, - {PeerID: "peer1", PeerIP: "10.20.0.1", Direction: "src", Action: "accept", Port: ""}, - {PeerID: "peer2", PeerIP: "10.20.0.2", Direction: "src", Action: "accept", Port: ""}, - {PeerID: "peer3", PeerIP: "10.20.0.3", Direction: "src", Action: "accept", Port: ""}, - } - assert.Len(t, firewallRules, len(epectedFirewallRules)) - for i := range firewallRules { - assert.Equal(t, firewallRules[i], epectedFirewallRules[i]) - } + t.Run("check that all peers get map", func(t *testing.T) { + for _, p := range account.Peers { + peers, firewallRules := account.getPeersByPolicy(p.ID) + assert.GreaterOrEqual(t, len(peers), 2, "mininum number peers should present") + assert.GreaterOrEqual(t, len(firewallRules), 2, "mininum number of firewall rules should present") + } + }) + + t.Run("check first peer map details", func(t *testing.T) { + peers, firewallRules := account.getPeersByPolicy("cfif97at2r9s73au3q0g") + assert.Len(t, peers, 7) + assert.Contains(t, peers, account.Peers["cfif97at2r9s73au3q00"]) + assert.Contains(t, peers, account.Peers["cfif97at2r9s73au3q10"]) + assert.Contains(t, peers, account.Peers["cfif97at2r9s73au3q20"]) + assert.Contains(t, peers, account.Peers["cfj4tiqt2r9s73dmeun0"]) + assert.Contains(t, peers, account.Peers["cg7h032t2r9s73cg5fk0"]) + + epectedFirewallRules := []*FirewallRule{ + { + PeerID: "cfif97at2r9s73au3q00", + PeerIP: "100.65.14.88", + Direction: "src", + Action: "accept", + Port: "", + id: "cfif97at2r9s73au3q00100.65.14.88srcaccept", + }, + { + PeerID: "cfif97at2r9s73au3q00", + PeerIP: "100.65.14.88", + Direction: "dst", + Action: "accept", + Port: "", + id: "cfif97at2r9s73au3q00100.65.14.88dstaccept", + }, + + { + PeerID: "cfif97at2r9s73au3q0g", + PeerIP: "100.65.80.39", + Direction: "dst", + Action: "accept", + Port: "", + id: "cfif97at2r9s73au3q0g100.65.80.39dstaccept", + }, + { + PeerID: "cfif97at2r9s73au3q0g", + PeerIP: "100.65.80.39", + Direction: "src", + Action: "accept", + Port: "", + id: "cfif97at2r9s73au3q0g100.65.80.39srcaccept", + }, + + { + PeerID: "cfif97at2r9s73au3q10", + PeerIP: "100.65.254.139", + Direction: "dst", + Action: "accept", + Port: "", + id: "cfif97at2r9s73au3q10100.65.254.139dstaccept", + }, + { + PeerID: "cfif97at2r9s73au3q10", + PeerIP: "100.65.254.139", + Direction: "src", + Action: "accept", + Port: "", + id: "cfif97at2r9s73au3q10100.65.254.139srcaccept", + }, + + { + PeerID: "cfif97at2r9s73au3q20", + PeerIP: "100.65.62.5", + Direction: "dst", + Action: "accept", + Port: "", + id: "cfif97at2r9s73au3q20100.65.62.5dstaccept", + }, + { + PeerID: "cfif97at2r9s73au3q20", + PeerIP: "100.65.62.5", + Direction: "src", + Action: "accept", + Port: "", + id: "cfif97at2r9s73au3q20100.65.62.5srcaccept", + }, + + { + PeerID: "cfj4tiqt2r9s73dmeun0", + PeerIP: "100.65.32.206", + Direction: "dst", + Action: "accept", + Port: "", + id: "cfj4tiqt2r9s73dmeun0100.65.32.206dstaccept", + }, + { + PeerID: "cfj4tiqt2r9s73dmeun0", + PeerIP: "100.65.32.206", + Direction: "src", + Action: "accept", + Port: "", + id: "cfj4tiqt2r9s73dmeun0100.65.32.206srcaccept", + }, + + { + PeerID: "cg7h032t2r9s73cg5fk0", + PeerIP: "100.65.250.202", + Direction: "dst", + Action: "accept", + Port: "", + id: "cg7h032t2r9s73cg5fk0100.65.250.202dstaccept", + }, + { + PeerID: "cg7h032t2r9s73cg5fk0", + PeerIP: "100.65.250.202", + Direction: "src", + Action: "accept", + Port: "", + id: "cg7h032t2r9s73cg5fk0100.65.250.202srcaccept", + }, + + { + PeerID: "cgcnkj2t2r9s73cg5vv0", + PeerIP: "100.65.13.186", + Direction: "dst", + Action: "accept", + Port: "", + id: "cgcnkj2t2r9s73cg5vv0100.65.13.186dstaccept", + }, + { + PeerID: "cgcnkj2t2r9s73cg5vv0", + PeerIP: "100.65.13.186", + Direction: "src", + Action: "accept", + Port: "", + id: "cgcnkj2t2r9s73cg5vv0100.65.13.186srcaccept", + }, + + { + PeerID: "cgcol4qt2r9s73cg601g", + PeerIP: "100.65.29.55", + Direction: "dst", + Action: "accept", + Port: "", + id: "cgcol4qt2r9s73cg601g100.65.29.55dstaccept", + }, + { + PeerID: "cgcol4qt2r9s73cg601g", + PeerIP: "100.65.29.55", + Direction: "src", + Action: "accept", + Port: "", + id: "cgcol4qt2r9s73cg601g100.65.29.55srcaccept", + }, + } + assert.Len(t, firewallRules, len(epectedFirewallRules)) + slices.SortFunc(firewallRules, func(a, b *FirewallRule) bool { + return a.PeerID < b.PeerID + }) + for i := range firewallRules { + assert.Equal(t, epectedFirewallRules[i], firewallRules[i]) + } + }) } diff --git a/management/server/rego/default_policy.rego b/management/server/rego/default_policy.rego index 92e975a02..a1012ae76 100644 --- a/management/server/rego/default_policy.rego +++ b/management/server/rego/default_policy.rego @@ -1,9 +1,9 @@ package netbird all[rule] { - is_peer_in_any_group([{{range $i, $e := .All}}{{if $i}},{{end}}"{{$e}}"{{end}}]) - rule := array.concat( - rules_from_groups([{{range $i, $e := .Destination}}{{if $i}},{{end}}"{{$e}}"{{end}}], "dst", "accept", ""), - rules_from_groups([{{range $i, $e := .Source}}{{if $i}},{{end}}"{{$e}}"{{end}}], "src", "accept", ""), - )[_] + is_peer_in_any_group([{{range $i, $e := .All}}{{if $i}},{{end}}"{{$e}}"{{end}}]) + rule := { + {{range $i, $e := .Destination}}rules_from_group("{{$e}}", "dst", "accept", ""),{{end}} + {{range $i, $e := .Source}}rules_from_group("{{$e}}", "src", "accept", ""),{{end}} + }[_][_] } diff --git a/management/server/rego/default_policy_module.rego b/management/server/rego/default_policy_module.rego index 846e22e21..7411db36a 100644 --- a/management/server/rego/default_policy_module.rego +++ b/management/server/rego/default_policy_module.rego @@ -17,17 +17,11 @@ get_rule(peer_id, direction, action, port) := rule if { } } -# peers_from_group returns a list of peer ids for a given group id -peers_from_group(group_id) := peers if { +# netbird_rules_from_group returns a list of netbird rules for a given group_id +rules_from_group(group_id, direction, action, port) := rules if { group := input.groups[_] group.ID == group_id - peers := [peer | peer := group.Peers[_]] -} - -# netbird_rules_from_groups returns a list of netbird rules for a given list of group names -rules_from_groups(groups, direction, action, port) := rules if { - group_id := groups[_] - rules := [get_rule(peer, direction, action, port) | peer := peers_from_group(group_id)[_]] + rules := [get_rule(peer, direction, action, port) | peer := group.Peers[_]] } # is_peer_in_any_group checks that input peer present at least in one group diff --git a/management/server/user.go b/management/server/user.go index c3011c317..692a2833a 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -193,57 +193,148 @@ func (am *DefaultAccountManager) CreateUser(accountID, userID string, invite *Us } -// AddPATToUser takes the userID and the accountID the user belongs to and assigns a provided PersonalAccessToken to that user -func (am *DefaultAccountManager) AddPATToUser(accountID string, userID string, pat *PersonalAccessToken) error { +// CreatePAT creates a new PAT for the given user +func (am *DefaultAccountManager) CreatePAT(accountID string, executingUserID string, targetUserId string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) { unlock := am.Store.AcquireAccountLock(accountID) defer unlock() + if tokenName == "" { + return nil, status.Errorf(status.InvalidArgument, "token name can't be empty") + } + + if expiresIn < 1 || expiresIn > 365 { + return nil, status.Errorf(status.InvalidArgument, "expiration has to be between 1 and 365") + } + + if executingUserID != targetUserId { + return nil, status.Errorf(status.PermissionDenied, "no permission to create PAT for this user") + } + account, err := am.Store.GetAccount(accountID) if err != nil { - return err + return nil, err } - user := account.Users[userID] - if user == nil { - return status.Errorf(status.NotFound, "user not found") + targetUser := account.Users[targetUserId] + if targetUser == nil { + return nil, status.Errorf(status.NotFound, "targetUser not found") } - user.PATs[pat.ID] = pat + pat, err := CreateNewPAT(tokenName, expiresIn, targetUser.Id) + if err != nil { + return nil, status.Errorf(status.Internal, "failed to create PAT: %v", err) + } - return am.Store.SaveAccount(account) + targetUser.PATs[pat.ID] = &pat.PersonalAccessToken + + err = am.Store.SaveAccount(account) + if err != nil { + return nil, status.Errorf(status.Internal, "failed to save account: %v", err) + } + + meta := map[string]any{"name": pat.Name} + am.storeEvent(executingUserID, targetUserId, accountID, activity.PersonalAccessTokenCreated, meta) + + return pat, nil } // DeletePAT deletes a specific PAT from a user -func (am *DefaultAccountManager) DeletePAT(accountID string, userID string, tokenID string) error { +func (am *DefaultAccountManager) DeletePAT(accountID string, executingUserID string, targetUserID string, tokenID string) error { unlock := am.Store.AcquireAccountLock(accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) - if err != nil { - return err + if executingUserID != targetUserID { + return status.Errorf(status.PermissionDenied, "no permission to delete PAT for this user") } - user := account.Users[userID] + account, err := am.Store.GetAccount(accountID) + if err != nil { + return status.Errorf(status.NotFound, "account not found: %s", err) + } + + user := account.Users[targetUserID] if user == nil { return status.Errorf(status.NotFound, "user not found") } - pat := user.PATs["tokenID"] + pat := user.PATs[tokenID] if pat == nil { return status.Errorf(status.NotFound, "PAT not found") } err = am.Store.DeleteTokenID2UserIDIndex(pat.ID) if err != nil { - return err + return status.Errorf(status.Internal, "Failed to delete token id index: %s", err) } err = am.Store.DeleteHashedPAT2TokenIDIndex(pat.HashedToken) if err != nil { - return err + return status.Errorf(status.Internal, "Failed to delete hashed token index: %s", err) } + + meta := map[string]any{"name": pat.Name} + am.storeEvent(executingUserID, targetUserID, accountID, activity.PersonalAccessTokenDeleted, meta) + delete(user.PATs, tokenID) - return am.Store.SaveAccount(account) + err = am.Store.SaveAccount(account) + if err != nil { + return status.Errorf(status.Internal, "Failed to save account: %s", err) + } + return nil +} + +// GetPAT returns a specific PAT from a user +func (am *DefaultAccountManager) GetPAT(accountID string, executingUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) { + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() + + if executingUserID != targetUserID { + return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user") + } + + account, err := am.Store.GetAccount(accountID) + if err != nil { + return nil, status.Errorf(status.NotFound, "account not found: %s", err) + } + + user := account.Users[targetUserID] + if user == nil { + return nil, status.Errorf(status.NotFound, "user not found") + } + + pat := user.PATs[tokenID] + if pat == nil { + return nil, status.Errorf(status.NotFound, "PAT not found") + } + + return pat, nil +} + +// GetAllPATs returns all PATs for a user +func (am *DefaultAccountManager) GetAllPATs(accountID string, executingUserID string, targetUserID string) ([]*PersonalAccessToken, error) { + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() + + if executingUserID != targetUserID { + return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user") + } + + account, err := am.Store.GetAccount(accountID) + if err != nil { + return nil, status.Errorf(status.NotFound, "account not found: %s", err) + } + + user := account.Users[targetUserID] + if user == nil { + return nil, status.Errorf(status.NotFound, "user not found") + } + + var pats []*PersonalAccessToken + for _, pat := range user.PATs { + pats = append(pats, pat) + } + + return pats, nil } // SaveUser saves updates a given user. If the user doesn't exit it will throw status.NotFound error. diff --git a/management/server/user_test.go b/management/server/user_test.go index 20f2ca4f1..29e6bc2bc 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -4,16 +4,25 @@ import ( "testing" "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/management/server/activity" ) const ( - mockAccountID = "accountID" - mockUserID = "userID" - mockTokenID = "tokenID" - mockToken = "SoMeHaShEdToKeN" + mockAccountID = "accountID" + mockUserID = "userID" + mockTargetUserId = "targetUserID" + mockTokenID1 = "tokenID1" + mockToken1 = "SoMeHaShEdToKeN1" + mockTokenID2 = "tokenID2" + mockToken2 = "SoMeHaShEdToKeN2" + mockTokenName = "tokenName" + mockEmptyTokenName = "" + mockExpiresIn = 7 + mockWrongExpiresIn = 4506 ) -func TestUser_AddPATToUser(t *testing.T) { +func TestUser_CreatePAT_ForSameUser(t *testing.T) { store := newStore(t) account := newAccountWithId(mockAccountID, mockUserID, "") @@ -23,27 +32,23 @@ func TestUser_AddPATToUser(t *testing.T) { } am := DefaultAccountManager{ - Store: store, + Store: store, + eventStore: &activity.InMemoryEventStore{}, } - pat := PersonalAccessToken{ - ID: mockTokenID, - HashedToken: mockToken, - } - - err = am.AddPATToUser(mockAccountID, mockUserID, &pat) + pat, err := am.CreatePAT(mockAccountID, mockUserID, mockUserID, mockTokenName, mockExpiresIn) if err != nil { t.Fatalf("Error when adding PAT to user: %s", err) } fileStore := am.Store.(*FileStore) - tokenID := fileStore.HashedPAT2TokenID[mockToken[:]] + tokenID := fileStore.HashedPAT2TokenID[pat.HashedToken] if tokenID == "" { t.Fatal("GetTokenIDByHashedToken failed after adding PAT") } - assert.Equal(t, mockTokenID, tokenID) + assert.Equal(t, pat.ID, tokenID) userID := fileStore.TokenID2UserID[tokenID] if userID == "" { @@ -52,15 +57,69 @@ func TestUser_AddPATToUser(t *testing.T) { assert.Equal(t, mockUserID, userID) } +func TestUser_CreatePAT_ForDifferentUser(t *testing.T) { + store := newStore(t) + account := newAccountWithId(mockAccountID, mockUserID, "") + + err := store.SaveAccount(account) + if err != nil { + t.Fatalf("Error when saving account: %s", err) + } + + am := DefaultAccountManager{ + Store: store, + eventStore: &activity.InMemoryEventStore{}, + } + + _, err = am.CreatePAT(mockAccountID, mockUserID, mockTargetUserId, mockTokenName, mockExpiresIn) + assert.Errorf(t, err, "Creating PAT for different user should thorw error") +} + +func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) { + store := newStore(t) + account := newAccountWithId(mockAccountID, mockUserID, "") + + err := store.SaveAccount(account) + if err != nil { + t.Fatalf("Error when saving account: %s", err) + } + + am := DefaultAccountManager{ + Store: store, + eventStore: &activity.InMemoryEventStore{}, + } + + _, err = am.CreatePAT(mockAccountID, mockUserID, mockUserID, mockTokenName, mockWrongExpiresIn) + assert.Errorf(t, err, "Wrong expiration should thorw error") +} + +func TestUser_CreatePAT_WithEmptyName(t *testing.T) { + store := newStore(t) + account := newAccountWithId(mockAccountID, mockUserID, "") + + err := store.SaveAccount(account) + if err != nil { + t.Fatalf("Error when saving account: %s", err) + } + + am := DefaultAccountManager{ + Store: store, + eventStore: &activity.InMemoryEventStore{}, + } + + _, err = am.CreatePAT(mockAccountID, mockUserID, mockUserID, mockEmptyTokenName, mockExpiresIn) + assert.Errorf(t, err, "Wrong expiration should thorw error") +} + func TestUser_DeletePAT(t *testing.T) { store := newStore(t) account := newAccountWithId(mockAccountID, mockUserID, "") account.Users[mockUserID] = &User{ Id: mockUserID, PATs: map[string]*PersonalAccessToken{ - mockTokenID: { - ID: mockTokenID, - HashedToken: mockToken, + mockTokenID1: { + ID: mockTokenID1, + HashedToken: mockToken1, }, }, } @@ -70,15 +129,81 @@ func TestUser_DeletePAT(t *testing.T) { } am := DefaultAccountManager{ - Store: store, + Store: store, + eventStore: &activity.InMemoryEventStore{}, } - err = am.DeletePAT(mockAccountID, mockUserID, mockTokenID) + err = am.DeletePAT(mockAccountID, mockUserID, mockUserID, mockTokenID1) if err != nil { t.Fatalf("Error when adding PAT to user: %s", err) } - assert.Nil(t, store.Accounts[mockAccountID].Users[mockUserID].PATs[mockTokenID]) - assert.Empty(t, store.HashedPAT2TokenID[mockToken]) - assert.Empty(t, store.TokenID2UserID[mockTokenID]) + assert.Nil(t, store.Accounts[mockAccountID].Users[mockUserID].PATs[mockTokenID1]) + assert.Empty(t, store.HashedPAT2TokenID[mockToken1]) + assert.Empty(t, store.TokenID2UserID[mockTokenID1]) +} + +func TestUser_GetPAT(t *testing.T) { + store := newStore(t) + account := newAccountWithId(mockAccountID, mockUserID, "") + account.Users[mockUserID] = &User{ + Id: mockUserID, + PATs: map[string]*PersonalAccessToken{ + mockTokenID1: { + ID: mockTokenID1, + HashedToken: mockToken1, + }, + }, + } + err := store.SaveAccount(account) + if err != nil { + t.Fatalf("Error when saving account: %s", err) + } + + am := DefaultAccountManager{ + Store: store, + eventStore: &activity.InMemoryEventStore{}, + } + + pat, err := am.GetPAT(mockAccountID, mockUserID, mockUserID, mockTokenID1) + if err != nil { + t.Fatalf("Error when adding PAT to user: %s", err) + } + + assert.Equal(t, mockTokenID1, pat.ID) + assert.Equal(t, mockToken1, pat.HashedToken) +} + +func TestUser_GetAllPATs(t *testing.T) { + store := newStore(t) + account := newAccountWithId(mockAccountID, mockUserID, "") + account.Users[mockUserID] = &User{ + Id: mockUserID, + PATs: map[string]*PersonalAccessToken{ + mockTokenID1: { + ID: mockTokenID1, + HashedToken: mockToken1, + }, + mockTokenID2: { + ID: mockTokenID2, + HashedToken: mockToken2, + }, + }, + } + err := store.SaveAccount(account) + if err != nil { + t.Fatalf("Error when saving account: %s", err) + } + + am := DefaultAccountManager{ + Store: store, + eventStore: &activity.InMemoryEventStore{}, + } + + pats, err := am.GetAllPATs(mockAccountID, mockUserID, mockUserID) + if err != nil { + t.Fatalf("Error when adding PAT to user: %s", err) + } + + assert.Equal(t, 2, len(pats)) } diff --git a/release_files/install.sh b/release_files/install.sh new file mode 100644 index 000000000..fda7ea56e --- /dev/null +++ b/release_files/install.sh @@ -0,0 +1,284 @@ +#!/bin/sh +# This code is based on the netbird-installer contribution by physk on GitHub. +# Source: https://github.com/physk/netbird-installer +set -e + +OWNER="netbirdio" +REPO="netbird" +CLI_APP="netbird" +UI_APP="netbird-ui" + +# Set default variable +OS_NAME="" +OS_TYPE="" +ARCH="$(uname -m)" +PACKAGE_MANAGER="" +INSTALL_DIR="" + +get_latest_release() { + curl -s "https://api.github.com/repos/${OWNER}/${REPO}/releases/latest" \ + | grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/' +} + +download_release_binary() { + VERSION=$(get_latest_release) + BASE_URL="https://github.com/${OWNER}/${REPO}/releases/download" + BINARY_BASE_NAME="${VERSION#v}_${OS_TYPE}_${ARCH}.tar.gz" + + # for Darwin, download the signed Netbird-UI + if [ "$OS_TYPE" = "darwin" ] && [ "$1" = "$UI_APP" ]; then + BINARY_BASE_NAME="${VERSION#v}_${OS_TYPE}_${ARCH}_signed.zip" + fi + + BINARY_NAME="$1_${BINARY_BASE_NAME}" + DOWNLOAD_URL="${BASE_URL}/${VERSION}/${BINARY_NAME}" + + echo "Installing $1 from $DOWNLOAD_URL" + cd /tmp && curl -LO "$DOWNLOAD_URL" + + if [ "$OS_TYPE" = "darwin" ] && [ "$1" = "$UI_APP" ]; then + INSTALL_DIR="/Applications/NetBird UI.app" + + # Unzip the app and move to INSTALL_DIR + unzip -q -o "$BINARY_NAME" + mv "netbird_ui_${OS_TYPE}_${ARCH}" "$INSTALL_DIR" + else + tar -xzvf "$BINARY_NAME" + sudo mv "${1%_"${BINARY_BASE_NAME}"}" "$INSTALL_DIR" + fi +} + +add_apt_repo() { + sudo apt-get update + sudo apt-get install ca-certificates gnupg -y + + curl -sSL https://pkgs.wiretrustee.com/debian/public.key \ + | sudo gpg --dearmor --output /usr/share/keyrings/wiretrustee-archive-keyring.gpg + + APT_REPO="deb [signed-by=/usr/share/keyrings/wiretrustee-archive-keyring.gpg] https://pkgs.wiretrustee.com/debian stable main" + echo "$APT_REPO" | sudo tee /etc/apt/sources.list.d/wiretrustee.list + + sudo apt-get update +} + +add_rpm_repo() { +cat <<-EOF | sudo tee /etc/yum.repos.d/netbird.repo +[Netbird] +name=Netbird +baseurl=https://pkgs.netbird.io/yum/ +enabled=1 +gpgcheck=0 +gpgkey=https://pkgs.netbird.io/yum/repodata/repomd.xml.key +repo_gpgcheck=1 +EOF +} + +add_aur_repo() { + INSTALL_PKGS="git base-devel go" + REMOVE_PKGS="" + + # Check if dependencies are installed + for PKG in $INSTALL_PKGS; do + if ! pacman -Q "$PKG" > /dev/null 2>&1; then + # Install missing package(s) + sudo pacman -S "$PKG" --noconfirm + + # Add installed package for clean up later + REMOVE_PKGS="$REMOVE_PKGS $PKG" + fi + done + + # Build package from AUR + cd /tmp && git clone https://aur.archlinux.org/netbird.git + cd netbird && makepkg -sri --noconfirm + + if ! $SKIP_UI_APP; then + cd /tmp && git clone https://aur.archlinux.org/netbird-ui.git + cd netbird-ui && makepkg -sri --noconfirm + fi + + # Clean up the installed packages + sudo pacman -Rs "$REMOVE_PKGS" --noconfirm +} + +install_native_binaries() { + # Checks for supported architecture + case "$ARCH" in + x86_64|amd64) + ARCH="amd64" + ;; + i?86|x86) + ARCH="386" + ;; + aarch64|arm64) + ARCH="arm64" + ;; + *) + echo "Architecture ${ARCH} not supported" + exit 2 + ;; + esac + + # download and copy binaries to INSTALL_DIR + download_release_binary "$CLI_APP" + if ! $SKIP_UI_APP; then + download_release_binary "$UI_APP" + fi +} + +install_netbird() { + # Check if netbird CLI is installed + if [ -x "$(command -v netbird)" ]; then + if netbird status > /dev/null 2>&1; then + echo "Netbird service is running, please stop it before proceeding" + fi + + echo "Netbird seems to be installed already, please remove it before proceeding" + exit 1 + fi + + # Checks if SKIP_UI_APP env is set + if [ -z "$SKIP_UI_APP" ]; then + SKIP_UI_APP=false + else + if $SKIP_UI_APP; then + echo "SKIP_UI_APP has been set to true in the environment" + echo "Netbird UI installation will be omitted based on your preference" + fi + fi + + # Identify OS name and default package manager + if type uname >/dev/null 2>&1; then + case "$(uname)" in + Linux) + OS_NAME="$(. /etc/os-release && echo "$ID")" + OS_TYPE="linux" + INSTALL_DIR="/usr/bin" + + # Allow netbird UI installation for x64 arch only + if [ "$ARCH" != "amd64" ] && [ "$ARCH" != "arm64" ] \ + && [ "$ARCH" != "x86_64" ];then + SKIP_UI_APP=true + echo "Netbird UI installation will be omitted as $ARCH is not a compactible architecture" + fi + + # Allow netbird UI installation for linux running desktop enviroment + if [ -z "$XDG_CURRENT_DESKTOP" ];then + SKIP_UI_APP=true + echo "Netbird UI installation will be omitted as Linux does not run desktop environment" + fi + + # Check the availability of a compactible package manager + if [ -x "$(command -v apt)" ]; then + PACKAGE_MANAGER="apt" + echo "The installation will be performed using apt package manager" + elif [ -x "$(command -v dnf)" ]; then + PACKAGE_MANAGER="dnf" + echo "The installation will be performed using dnf package manager" + elif [ -x "$(command -v yum)" ]; then + PACKAGE_MANAGER="yum" + echo "The installation will be performed using yum package manager" + elif [ -x "$(command -v pacman)" ]; then + PACKAGE_MANAGER="pacman" + echo "The installation will be performed using pacman package manager" + fi + ;; + Darwin) + OS_NAME="macos" + OS_TYPE="darwin" + INSTALL_DIR="/usr/local/bin" + + # Check the availability of a compatible package manager + if [ -x "$(command -v brew)" ]; then + PACKAGE_MANAGER="brew" + echo "The installation will be performed using brew package manager" + fi + ;; + esac + fi + + # Run the installation, if a desktop environment is not detected + # only the CLI will be installed + case "$PACKAGE_MANAGER" in + apt) + add_apt_repo + sudo apt-get install netbird -y + + if ! $SKIP_UI_APP; then + sudo apt-get install netbird-ui -y + fi + ;; + yum) + add_rpm_repo + sudo yum -y install netbird + if ! $SKIP_UI_APP; then + sudo yum -y install netbird-ui + fi + ;; + dnf) + add_rpm_repo + sudo dnf -y install dnf-plugin-config-manager + sudo dnf config-manager --add-repo /etc/yum.repos.d/netbird.repo + sudo dnf -y install netbird + + if ! $SKIP_UI_APP; then + sudo dnf -y install netbird-ui + fi + ;; + pacman) + sudo pacman -Syy + add_aur_repo + ;; + brew) + # Remove Wiretrustee if it had been installed using Homebrew before + if brew ls --versions wiretrustee >/dev/null 2>&1; then + echo "Removing existing wiretrustee client" + + # Stop and uninstall daemon service: + wiretrustee service stop + wiretrustee service uninstall + + # Unlik the app + brew unlink wiretrustee + fi + + brew install netbirdio/tap/netbird + if ! $SKIP_UI_APP; then + brew install --cask netbirdio/tap/netbird-ui + fi + ;; + *) + if [ "$OS_NAME" = "nixos" ];then + echo "Please add Netbird to your NixOS configuration.nix directly:" + echo + echo "services.netbird.enable = true;" + + if ! $SKIP_UI_APP; then + echo "environment.systemPackages = [ pkgs.netbird-ui ];" + fi + + echo "Build and apply new configuration:" + echo + echo "sudo nixos-rebuild switch" + exit 0 + fi + + install_native_binaries + ;; + esac + + # Load and start netbird service + if ! sudo netbird service install 2>&1; then + echo "Netbird service has already been loaded" + fi + if ! sudo netbird service start 2>&1; then + echo "Netbird service has already been started" + fi + + + echo "Installation has been finished. To connect, you need to run NetBird by executing the following command:" + echo "" + echo "sudo netbird up" +} + +install_netbird \ No newline at end of file