mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-22 18:26:41 +00:00
fix minor potential security issues with OIDC
This commit is contained in:
@@ -4,8 +4,6 @@ import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
gojwt "github.com/golang-jwt/jwt/v5"
|
||||
@@ -15,8 +13,6 @@ import (
|
||||
"github.com/netbirdio/netbird/shared/auth/jwt"
|
||||
)
|
||||
|
||||
const stateExpiration = 10 * time.Minute
|
||||
|
||||
type urlGenerator interface {
|
||||
GetOIDCURL(context.Context, *proto.GetOIDCURLRequest, ...grpc.CallOption) (*proto.GetOIDCURLResponse, error)
|
||||
}
|
||||
@@ -41,13 +37,11 @@ type OIDC struct {
|
||||
validator *jwt.Validator
|
||||
maxTokenAgeSeconds int64
|
||||
client urlGenerator
|
||||
states map[string]*oidcState
|
||||
statesMux sync.RWMutex
|
||||
}
|
||||
|
||||
// NewOIDC creates a new OIDC authentication scheme
|
||||
func NewOIDC(client urlGenerator, id, accountId string, cfg OIDCConfig) *OIDC {
|
||||
o := &OIDC{
|
||||
return &OIDC{
|
||||
id: id,
|
||||
accountId: accountId,
|
||||
validator: jwt.NewValidator(
|
||||
@@ -58,12 +52,7 @@ func NewOIDC(client urlGenerator, id, accountId string, cfg OIDCConfig) *OIDC {
|
||||
),
|
||||
maxTokenAgeSeconds: cfg.MaxTokenAgeSeconds,
|
||||
client: client,
|
||||
states: make(map[string]*oidcState),
|
||||
}
|
||||
|
||||
go o.cleanupStates()
|
||||
|
||||
return o
|
||||
}
|
||||
|
||||
func (*OIDC) Type() Method {
|
||||
@@ -71,9 +60,8 @@ func (*OIDC) Type() Method {
|
||||
}
|
||||
|
||||
func (o *OIDC) Authenticate(r *http.Request) (string, string) {
|
||||
// Try Authorization: Bearer <token> header
|
||||
if auth := r.Header.Get("Authorization"); strings.HasPrefix(auth, "Bearer ") {
|
||||
if userID := o.validateToken(r.Context(), strings.TrimPrefix(auth, "Bearer ")); userID != "" {
|
||||
if token := r.URL.Query().Get("access_token"); token != "" {
|
||||
if userID := o.validateToken(r.Context(), token); userID != "" {
|
||||
return userID, ""
|
||||
}
|
||||
}
|
||||
@@ -116,7 +104,8 @@ func (o *OIDC) validateToken(ctx context.Context, token string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
if time.Since(iat.Time).Seconds() > float64(o.maxTokenAgeSeconds) {
|
||||
// If max token age is 0 skip this check.
|
||||
if o.maxTokenAgeSeconds > 0 && time.Since(iat.Time).Seconds() > float64(o.maxTokenAgeSeconds) {
|
||||
// TODO: log or return?
|
||||
return ""
|
||||
}
|
||||
@@ -147,17 +136,3 @@ func getUserIDFromClaims(claims gojwt.MapClaims) string {
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
// cleanupStates periodically removes expired states
|
||||
func (o *OIDC) cleanupStates() {
|
||||
for range time.Tick(time.Minute) {
|
||||
cutoff := time.Now().Add(-stateExpiration)
|
||||
o.statesMux.Lock()
|
||||
for k, v := range o.states {
|
||||
if v.CreatedAt.Before(cutoff) {
|
||||
delete(o.states, k)
|
||||
}
|
||||
}
|
||||
o.statesMux.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user