fix minor potential security issues with OIDC

This commit is contained in:
Alisdair MacLeod
2026-02-04 12:25:19 +00:00
parent a89bb807a6
commit a0005a604e
3 changed files with 35 additions and 34 deletions

View File

@@ -9,6 +9,7 @@ import (
"encoding/hex"
"errors"
"fmt"
"net/url"
"strings"
"sync"
"time"
@@ -407,6 +408,29 @@ func protoStatusToInternal(protoStatus proto.ProxyStatus) reverseproxy.ProxyStat
}
func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCURLRequest) (*proto.GetOIDCURLResponse, error) {
redirectURL, err := url.Parse(req.GetRedirectUrl())
if err != nil {
// TODO: log
return nil, status.Errorf(codes.InvalidArgument, "failed to parse redirect url: %v", err)
}
// Validate redirectURL against known proxy endpoints to avoid abuse of OIDC redirection.
proxies, err := s.reverseProxyStore.GetAccountReverseProxies(ctx, store.LockingStrengthNone, req.GetAccountId())
if err != nil {
// TODO: log
return nil, status.Errorf(codes.FailedPrecondition, "failed to get reverse proxy from store: %v", err)
}
var found bool
for _, proxy := range proxies {
if proxy.Domain == redirectURL.Hostname() {
found = true
break
}
}
if !found {
// TODO: log
return nil, status.Errorf(codes.FailedPrecondition, "reverse proxy not found in store")
}
provider, err := oidc.NewProvider(ctx, s.oidcConfig.Issuer)
if err != nil {
// TODO: log
@@ -420,9 +444,8 @@ func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCU
// Using an HMAC here to avoid redirection state being modified.
// State format: base64(redirectURL)|hmac
redirectURL := req.GetRedirectUrl()
hmacSum := s.generateHMAC(redirectURL)
state := fmt.Sprintf("%s|%s", base64.URLEncoding.EncodeToString([]byte(redirectURL)), hmacSum)
hmacSum := s.generateHMAC(redirectURL.String())
state := fmt.Sprintf("%s|%s", base64.URLEncoding.EncodeToString([]byte(redirectURL.String())), hmacSum)
codeVerifier := oauth2.GenerateVerifier()
s.pkceVerifiers.Store(state, codeVerifier)

View File

@@ -72,6 +72,9 @@ func (h *AuthCallbackHandler) handleCallback(w http.ResponseWriter, r *http.Requ
}
redirectURL.RawQuery = redirectQuery.Encode()
log.WithField("redirect", redirectURL).Debug("OAuth callback: redirecting user with token")
// Redirect must be HTTPS, regardless of what was originally intended (which should always be HTTPS but better to double-check here).
redirectURL.Scheme = "https"
log.WithField("redirect", redirectURL.String()).Debug("OAuth callback: redirecting user with token")
http.Redirect(w, r, redirectURL.String(), http.StatusFound)
}

View File

@@ -4,8 +4,6 @@ import (
"context"
"net/http"
"net/url"
"strings"
"sync"
"time"
gojwt "github.com/golang-jwt/jwt/v5"
@@ -15,8 +13,6 @@ import (
"github.com/netbirdio/netbird/shared/auth/jwt"
)
const stateExpiration = 10 * time.Minute
type urlGenerator interface {
GetOIDCURL(context.Context, *proto.GetOIDCURLRequest, ...grpc.CallOption) (*proto.GetOIDCURLResponse, error)
}
@@ -41,13 +37,11 @@ type OIDC struct {
validator *jwt.Validator
maxTokenAgeSeconds int64
client urlGenerator
states map[string]*oidcState
statesMux sync.RWMutex
}
// NewOIDC creates a new OIDC authentication scheme
func NewOIDC(client urlGenerator, id, accountId string, cfg OIDCConfig) *OIDC {
o := &OIDC{
return &OIDC{
id: id,
accountId: accountId,
validator: jwt.NewValidator(
@@ -58,12 +52,7 @@ func NewOIDC(client urlGenerator, id, accountId string, cfg OIDCConfig) *OIDC {
),
maxTokenAgeSeconds: cfg.MaxTokenAgeSeconds,
client: client,
states: make(map[string]*oidcState),
}
go o.cleanupStates()
return o
}
func (*OIDC) Type() Method {
@@ -71,9 +60,8 @@ func (*OIDC) Type() Method {
}
func (o *OIDC) Authenticate(r *http.Request) (string, string) {
// Try Authorization: Bearer <token> header
if auth := r.Header.Get("Authorization"); strings.HasPrefix(auth, "Bearer ") {
if userID := o.validateToken(r.Context(), strings.TrimPrefix(auth, "Bearer ")); userID != "" {
if token := r.URL.Query().Get("access_token"); token != "" {
if userID := o.validateToken(r.Context(), token); userID != "" {
return userID, ""
}
}
@@ -116,7 +104,8 @@ func (o *OIDC) validateToken(ctx context.Context, token string) string {
return ""
}
if time.Since(iat.Time).Seconds() > float64(o.maxTokenAgeSeconds) {
// If max token age is 0 skip this check.
if o.maxTokenAgeSeconds > 0 && time.Since(iat.Time).Seconds() > float64(o.maxTokenAgeSeconds) {
// TODO: log or return?
return ""
}
@@ -147,17 +136,3 @@ func getUserIDFromClaims(claims gojwt.MapClaims) string {
}
return "unknown"
}
// cleanupStates periodically removes expired states
func (o *OIDC) cleanupStates() {
for range time.Tick(time.Minute) {
cutoff := time.Now().Add(-stateExpiration)
o.statesMux.Lock()
for k, v := range o.states {
if v.CreatedAt.Before(cutoff) {
delete(o.states, k)
}
}
o.statesMux.Unlock()
}
}