mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 15:26:40 +00:00
fix minor potential security issues with OIDC
This commit is contained in:
@@ -9,6 +9,7 @@ import (
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -407,6 +408,29 @@ func protoStatusToInternal(protoStatus proto.ProxyStatus) reverseproxy.ProxyStat
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCURLRequest) (*proto.GetOIDCURLResponse, error) {
|
||||
redirectURL, err := url.Parse(req.GetRedirectUrl())
|
||||
if err != nil {
|
||||
// TODO: log
|
||||
return nil, status.Errorf(codes.InvalidArgument, "failed to parse redirect url: %v", err)
|
||||
}
|
||||
// Validate redirectURL against known proxy endpoints to avoid abuse of OIDC redirection.
|
||||
proxies, err := s.reverseProxyStore.GetAccountReverseProxies(ctx, store.LockingStrengthNone, req.GetAccountId())
|
||||
if err != nil {
|
||||
// TODO: log
|
||||
return nil, status.Errorf(codes.FailedPrecondition, "failed to get reverse proxy from store: %v", err)
|
||||
}
|
||||
var found bool
|
||||
for _, proxy := range proxies {
|
||||
if proxy.Domain == redirectURL.Hostname() {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
// TODO: log
|
||||
return nil, status.Errorf(codes.FailedPrecondition, "reverse proxy not found in store")
|
||||
}
|
||||
|
||||
provider, err := oidc.NewProvider(ctx, s.oidcConfig.Issuer)
|
||||
if err != nil {
|
||||
// TODO: log
|
||||
@@ -420,9 +444,8 @@ func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCU
|
||||
|
||||
// Using an HMAC here to avoid redirection state being modified.
|
||||
// State format: base64(redirectURL)|hmac
|
||||
redirectURL := req.GetRedirectUrl()
|
||||
hmacSum := s.generateHMAC(redirectURL)
|
||||
state := fmt.Sprintf("%s|%s", base64.URLEncoding.EncodeToString([]byte(redirectURL)), hmacSum)
|
||||
hmacSum := s.generateHMAC(redirectURL.String())
|
||||
state := fmt.Sprintf("%s|%s", base64.URLEncoding.EncodeToString([]byte(redirectURL.String())), hmacSum)
|
||||
|
||||
codeVerifier := oauth2.GenerateVerifier()
|
||||
s.pkceVerifiers.Store(state, codeVerifier)
|
||||
|
||||
@@ -72,6 +72,9 @@ func (h *AuthCallbackHandler) handleCallback(w http.ResponseWriter, r *http.Requ
|
||||
}
|
||||
redirectURL.RawQuery = redirectQuery.Encode()
|
||||
|
||||
log.WithField("redirect", redirectURL).Debug("OAuth callback: redirecting user with token")
|
||||
// Redirect must be HTTPS, regardless of what was originally intended (which should always be HTTPS but better to double-check here).
|
||||
redirectURL.Scheme = "https"
|
||||
|
||||
log.WithField("redirect", redirectURL.String()).Debug("OAuth callback: redirecting user with token")
|
||||
http.Redirect(w, r, redirectURL.String(), http.StatusFound)
|
||||
}
|
||||
|
||||
@@ -4,8 +4,6 @@ import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
gojwt "github.com/golang-jwt/jwt/v5"
|
||||
@@ -15,8 +13,6 @@ import (
|
||||
"github.com/netbirdio/netbird/shared/auth/jwt"
|
||||
)
|
||||
|
||||
const stateExpiration = 10 * time.Minute
|
||||
|
||||
type urlGenerator interface {
|
||||
GetOIDCURL(context.Context, *proto.GetOIDCURLRequest, ...grpc.CallOption) (*proto.GetOIDCURLResponse, error)
|
||||
}
|
||||
@@ -41,13 +37,11 @@ type OIDC struct {
|
||||
validator *jwt.Validator
|
||||
maxTokenAgeSeconds int64
|
||||
client urlGenerator
|
||||
states map[string]*oidcState
|
||||
statesMux sync.RWMutex
|
||||
}
|
||||
|
||||
// NewOIDC creates a new OIDC authentication scheme
|
||||
func NewOIDC(client urlGenerator, id, accountId string, cfg OIDCConfig) *OIDC {
|
||||
o := &OIDC{
|
||||
return &OIDC{
|
||||
id: id,
|
||||
accountId: accountId,
|
||||
validator: jwt.NewValidator(
|
||||
@@ -58,12 +52,7 @@ func NewOIDC(client urlGenerator, id, accountId string, cfg OIDCConfig) *OIDC {
|
||||
),
|
||||
maxTokenAgeSeconds: cfg.MaxTokenAgeSeconds,
|
||||
client: client,
|
||||
states: make(map[string]*oidcState),
|
||||
}
|
||||
|
||||
go o.cleanupStates()
|
||||
|
||||
return o
|
||||
}
|
||||
|
||||
func (*OIDC) Type() Method {
|
||||
@@ -71,9 +60,8 @@ func (*OIDC) Type() Method {
|
||||
}
|
||||
|
||||
func (o *OIDC) Authenticate(r *http.Request) (string, string) {
|
||||
// Try Authorization: Bearer <token> header
|
||||
if auth := r.Header.Get("Authorization"); strings.HasPrefix(auth, "Bearer ") {
|
||||
if userID := o.validateToken(r.Context(), strings.TrimPrefix(auth, "Bearer ")); userID != "" {
|
||||
if token := r.URL.Query().Get("access_token"); token != "" {
|
||||
if userID := o.validateToken(r.Context(), token); userID != "" {
|
||||
return userID, ""
|
||||
}
|
||||
}
|
||||
@@ -116,7 +104,8 @@ func (o *OIDC) validateToken(ctx context.Context, token string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
if time.Since(iat.Time).Seconds() > float64(o.maxTokenAgeSeconds) {
|
||||
// If max token age is 0 skip this check.
|
||||
if o.maxTokenAgeSeconds > 0 && time.Since(iat.Time).Seconds() > float64(o.maxTokenAgeSeconds) {
|
||||
// TODO: log or return?
|
||||
return ""
|
||||
}
|
||||
@@ -147,17 +136,3 @@ func getUserIDFromClaims(claims gojwt.MapClaims) string {
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
// cleanupStates periodically removes expired states
|
||||
func (o *OIDC) cleanupStates() {
|
||||
for range time.Tick(time.Minute) {
|
||||
cutoff := time.Now().Add(-stateExpiration)
|
||||
o.statesMux.Lock()
|
||||
for k, v := range o.states {
|
||||
if v.CreatedAt.Before(cutoff) {
|
||||
delete(o.states, k)
|
||||
}
|
||||
}
|
||||
o.statesMux.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user