mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
fix minor potential security issues with OIDC
This commit is contained in:
@@ -9,6 +9,7 @@ import (
|
|||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@@ -407,6 +408,29 @@ func protoStatusToInternal(protoStatus proto.ProxyStatus) reverseproxy.ProxyStat
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCURLRequest) (*proto.GetOIDCURLResponse, error) {
|
func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCURLRequest) (*proto.GetOIDCURLResponse, error) {
|
||||||
|
redirectURL, err := url.Parse(req.GetRedirectUrl())
|
||||||
|
if err != nil {
|
||||||
|
// TODO: log
|
||||||
|
return nil, status.Errorf(codes.InvalidArgument, "failed to parse redirect url: %v", err)
|
||||||
|
}
|
||||||
|
// Validate redirectURL against known proxy endpoints to avoid abuse of OIDC redirection.
|
||||||
|
proxies, err := s.reverseProxyStore.GetAccountReverseProxies(ctx, store.LockingStrengthNone, req.GetAccountId())
|
||||||
|
if err != nil {
|
||||||
|
// TODO: log
|
||||||
|
return nil, status.Errorf(codes.FailedPrecondition, "failed to get reverse proxy from store: %v", err)
|
||||||
|
}
|
||||||
|
var found bool
|
||||||
|
for _, proxy := range proxies {
|
||||||
|
if proxy.Domain == redirectURL.Hostname() {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
// TODO: log
|
||||||
|
return nil, status.Errorf(codes.FailedPrecondition, "reverse proxy not found in store")
|
||||||
|
}
|
||||||
|
|
||||||
provider, err := oidc.NewProvider(ctx, s.oidcConfig.Issuer)
|
provider, err := oidc.NewProvider(ctx, s.oidcConfig.Issuer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// TODO: log
|
// TODO: log
|
||||||
@@ -420,9 +444,8 @@ func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCU
|
|||||||
|
|
||||||
// Using an HMAC here to avoid redirection state being modified.
|
// Using an HMAC here to avoid redirection state being modified.
|
||||||
// State format: base64(redirectURL)|hmac
|
// State format: base64(redirectURL)|hmac
|
||||||
redirectURL := req.GetRedirectUrl()
|
hmacSum := s.generateHMAC(redirectURL.String())
|
||||||
hmacSum := s.generateHMAC(redirectURL)
|
state := fmt.Sprintf("%s|%s", base64.URLEncoding.EncodeToString([]byte(redirectURL.String())), hmacSum)
|
||||||
state := fmt.Sprintf("%s|%s", base64.URLEncoding.EncodeToString([]byte(redirectURL)), hmacSum)
|
|
||||||
|
|
||||||
codeVerifier := oauth2.GenerateVerifier()
|
codeVerifier := oauth2.GenerateVerifier()
|
||||||
s.pkceVerifiers.Store(state, codeVerifier)
|
s.pkceVerifiers.Store(state, codeVerifier)
|
||||||
|
|||||||
@@ -72,6 +72,9 @@ func (h *AuthCallbackHandler) handleCallback(w http.ResponseWriter, r *http.Requ
|
|||||||
}
|
}
|
||||||
redirectURL.RawQuery = redirectQuery.Encode()
|
redirectURL.RawQuery = redirectQuery.Encode()
|
||||||
|
|
||||||
log.WithField("redirect", redirectURL).Debug("OAuth callback: redirecting user with token")
|
// Redirect must be HTTPS, regardless of what was originally intended (which should always be HTTPS but better to double-check here).
|
||||||
|
redirectURL.Scheme = "https"
|
||||||
|
|
||||||
|
log.WithField("redirect", redirectURL.String()).Debug("OAuth callback: redirecting user with token")
|
||||||
http.Redirect(w, r, redirectURL.String(), http.StatusFound)
|
http.Redirect(w, r, redirectURL.String(), http.StatusFound)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,8 +4,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
gojwt "github.com/golang-jwt/jwt/v5"
|
gojwt "github.com/golang-jwt/jwt/v5"
|
||||||
@@ -15,8 +13,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/shared/auth/jwt"
|
"github.com/netbirdio/netbird/shared/auth/jwt"
|
||||||
)
|
)
|
||||||
|
|
||||||
const stateExpiration = 10 * time.Minute
|
|
||||||
|
|
||||||
type urlGenerator interface {
|
type urlGenerator interface {
|
||||||
GetOIDCURL(context.Context, *proto.GetOIDCURLRequest, ...grpc.CallOption) (*proto.GetOIDCURLResponse, error)
|
GetOIDCURL(context.Context, *proto.GetOIDCURLRequest, ...grpc.CallOption) (*proto.GetOIDCURLResponse, error)
|
||||||
}
|
}
|
||||||
@@ -41,13 +37,11 @@ type OIDC struct {
|
|||||||
validator *jwt.Validator
|
validator *jwt.Validator
|
||||||
maxTokenAgeSeconds int64
|
maxTokenAgeSeconds int64
|
||||||
client urlGenerator
|
client urlGenerator
|
||||||
states map[string]*oidcState
|
|
||||||
statesMux sync.RWMutex
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewOIDC creates a new OIDC authentication scheme
|
// NewOIDC creates a new OIDC authentication scheme
|
||||||
func NewOIDC(client urlGenerator, id, accountId string, cfg OIDCConfig) *OIDC {
|
func NewOIDC(client urlGenerator, id, accountId string, cfg OIDCConfig) *OIDC {
|
||||||
o := &OIDC{
|
return &OIDC{
|
||||||
id: id,
|
id: id,
|
||||||
accountId: accountId,
|
accountId: accountId,
|
||||||
validator: jwt.NewValidator(
|
validator: jwt.NewValidator(
|
||||||
@@ -58,12 +52,7 @@ func NewOIDC(client urlGenerator, id, accountId string, cfg OIDCConfig) *OIDC {
|
|||||||
),
|
),
|
||||||
maxTokenAgeSeconds: cfg.MaxTokenAgeSeconds,
|
maxTokenAgeSeconds: cfg.MaxTokenAgeSeconds,
|
||||||
client: client,
|
client: client,
|
||||||
states: make(map[string]*oidcState),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
go o.cleanupStates()
|
|
||||||
|
|
||||||
return o
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*OIDC) Type() Method {
|
func (*OIDC) Type() Method {
|
||||||
@@ -71,9 +60,8 @@ func (*OIDC) Type() Method {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (o *OIDC) Authenticate(r *http.Request) (string, string) {
|
func (o *OIDC) Authenticate(r *http.Request) (string, string) {
|
||||||
// Try Authorization: Bearer <token> header
|
if token := r.URL.Query().Get("access_token"); token != "" {
|
||||||
if auth := r.Header.Get("Authorization"); strings.HasPrefix(auth, "Bearer ") {
|
if userID := o.validateToken(r.Context(), token); userID != "" {
|
||||||
if userID := o.validateToken(r.Context(), strings.TrimPrefix(auth, "Bearer ")); userID != "" {
|
|
||||||
return userID, ""
|
return userID, ""
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -116,7 +104,8 @@ func (o *OIDC) validateToken(ctx context.Context, token string) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
if time.Since(iat.Time).Seconds() > float64(o.maxTokenAgeSeconds) {
|
// If max token age is 0 skip this check.
|
||||||
|
if o.maxTokenAgeSeconds > 0 && time.Since(iat.Time).Seconds() > float64(o.maxTokenAgeSeconds) {
|
||||||
// TODO: log or return?
|
// TODO: log or return?
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
@@ -147,17 +136,3 @@ func getUserIDFromClaims(claims gojwt.MapClaims) string {
|
|||||||
}
|
}
|
||||||
return "unknown"
|
return "unknown"
|
||||||
}
|
}
|
||||||
|
|
||||||
// cleanupStates periodically removes expired states
|
|
||||||
func (o *OIDC) cleanupStates() {
|
|
||||||
for range time.Tick(time.Minute) {
|
|
||||||
cutoff := time.Now().Add(-stateExpiration)
|
|
||||||
o.statesMux.Lock()
|
|
||||||
for k, v := range o.states {
|
|
||||||
if v.CreatedAt.Before(cutoff) {
|
|
||||||
delete(o.states, k)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
o.statesMux.Unlock()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
Reference in New Issue
Block a user