mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
refactor layout and structure
This commit is contained in:
4
proxy/internal/auth/auth.gohtml
Normal file
4
proxy/internal/auth/auth.gohtml
Normal file
@@ -0,0 +1,4 @@
|
||||
<!doctype html>
|
||||
{{ range . }}
|
||||
<p>{{ . }}</p>
|
||||
{{ end }}
|
||||
42
proxy/internal/auth/basicauth.go
Normal file
42
proxy/internal/auth/basicauth.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type BasicAuth struct {
|
||||
username, password string
|
||||
}
|
||||
|
||||
func NewBasicAuth(username string, password string) BasicAuth {
|
||||
return BasicAuth{
|
||||
username: username,
|
||||
password: password,
|
||||
}
|
||||
}
|
||||
|
||||
func (BasicAuth) Type() Method {
|
||||
return MethodBasicAuth
|
||||
}
|
||||
|
||||
func (b BasicAuth) Authenticate(r *http.Request) (string, bool, any) {
|
||||
username, password, ok := r.BasicAuth()
|
||||
if !ok {
|
||||
return "", false, nil
|
||||
}
|
||||
|
||||
usernameMatch := subtle.ConstantTimeCompare([]byte(username), []byte(b.username)) == 1
|
||||
passwordMatch := subtle.ConstantTimeCompare([]byte(password), []byte(b.password)) == 1
|
||||
|
||||
// If authenticated, then return the username.
|
||||
if usernameMatch && passwordMatch {
|
||||
return username, false, nil
|
||||
}
|
||||
|
||||
return "", false, nil
|
||||
}
|
||||
|
||||
func (b BasicAuth) Middleware(next http.Handler) http.Handler {
|
||||
return next
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
package auth
|
||||
|
||||
import "github.com/netbirdio/netbird/proxy/internal/auth/methods"
|
||||
|
||||
// Config holds the authentication configuration for a route
|
||||
// Only ONE auth method should be configured per route
|
||||
type Config struct {
|
||||
// HTTP Basic authentication (username/password)
|
||||
BasicAuth *methods.BasicAuthConfig
|
||||
|
||||
// PIN authentication
|
||||
PIN *methods.PINConfig
|
||||
|
||||
// Bearer token with JWT validation and OAuth/OIDC flow
|
||||
// When enabled, uses the global OIDCConfig from proxy Config
|
||||
Bearer *methods.BearerConfig
|
||||
}
|
||||
|
||||
// IsEmpty returns true if no auth methods are configured
|
||||
func (c *Config) IsEmpty() bool {
|
||||
if c == nil {
|
||||
return true
|
||||
}
|
||||
return c.BasicAuth == nil && c.PIN == nil && c.Bearer == nil
|
||||
}
|
||||
38
proxy/internal/auth/context.go
Normal file
38
proxy/internal/auth/context.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
type requestContextKey string
|
||||
|
||||
const (
|
||||
authMethodKey requestContextKey = "authMethod"
|
||||
authUserKey requestContextKey = "authUser"
|
||||
)
|
||||
|
||||
func withAuthMethod(ctx context.Context, method Method) context.Context {
|
||||
return context.WithValue(ctx, authMethodKey, method)
|
||||
}
|
||||
|
||||
func MethodFromContext(ctx context.Context) Method {
|
||||
v := ctx.Value(authMethodKey)
|
||||
method, ok := v.(Method)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return method
|
||||
}
|
||||
|
||||
func withAuthUser(ctx context.Context, userId string) context.Context {
|
||||
return context.WithValue(ctx, authUserKey, userId)
|
||||
}
|
||||
|
||||
func UserFromContext(ctx context.Context) string {
|
||||
v := ctx.Value(authUserKey)
|
||||
userId, ok := v.(string)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return userId
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
package methods
|
||||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// BasicAuthConfig holds HTTP Basic authentication settings
|
||||
type BasicAuthConfig struct {
|
||||
Username string
|
||||
Password string
|
||||
}
|
||||
|
||||
// Validate checks Basic Auth credentials from the request
|
||||
func (c *BasicAuthConfig) Validate(r *http.Request) bool {
|
||||
username, password, ok := r.BasicAuth()
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
usernameMatch := subtle.ConstantTimeCompare([]byte(username), []byte(c.Username)) == 1
|
||||
passwordMatch := subtle.ConstantTimeCompare([]byte(password), []byte(c.Password)) == 1
|
||||
|
||||
return usernameMatch && passwordMatch
|
||||
}
|
||||
@@ -1,8 +0,0 @@
|
||||
package methods
|
||||
|
||||
// BearerConfig holds JWT/OAuth/OIDC bearer token authentication settings
|
||||
// The actual OIDC/JWT configuration comes from the global proxy Config.OIDCConfig
|
||||
// This just enables Bearer auth for a specific route
|
||||
type BearerConfig struct {
|
||||
Enabled bool
|
||||
}
|
||||
@@ -1,32 +0,0 @@
|
||||
package methods
|
||||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultPINHeader is the default header name for PIN authentication
|
||||
DefaultPINHeader = "X-PIN"
|
||||
)
|
||||
|
||||
// PINConfig holds PIN authentication settings
|
||||
type PINConfig struct {
|
||||
PIN string
|
||||
Header string
|
||||
}
|
||||
|
||||
// Validate checks PIN from the request header
|
||||
func (c *PINConfig) Validate(r *http.Request) bool {
|
||||
header := c.Header
|
||||
if header == "" {
|
||||
header = DefaultPINHeader
|
||||
}
|
||||
|
||||
providedPIN := r.Header.Get(header)
|
||||
if providedPIN == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
return subtle.ConstantTimeCompare([]byte(providedPIN), []byte(c.PIN)) == 1
|
||||
}
|
||||
@@ -1,298 +1,198 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"crypto/rand"
|
||||
_ "embed"
|
||||
"encoding/base64"
|
||||
"html/template"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/auth/oidc"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Middleware wraps an HTTP handler with authentication middleware
|
||||
//go:embed auth.gohtml
|
||||
var authTemplate string
|
||||
|
||||
type Method string
|
||||
|
||||
var (
|
||||
MethodBasicAuth Method = "basic"
|
||||
MethodPIN Method = "pin"
|
||||
MethodBearer Method = "bearer"
|
||||
)
|
||||
|
||||
func (m Method) String() string {
|
||||
return string(m)
|
||||
}
|
||||
|
||||
const (
|
||||
sessionCookieName = "nb_session"
|
||||
sessionExpiration = 24 * time.Hour
|
||||
)
|
||||
|
||||
type session struct {
|
||||
UserID string
|
||||
Method Method
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
type Scheme interface {
|
||||
Type() Method
|
||||
// Authenticate should check the passed request and determine whether
|
||||
// it represents an authenticated user request. If it does not, then
|
||||
// an empty string should indicate an unauthenticated request which
|
||||
// will be rejected; optionally, it can also return any data that should
|
||||
// be included in a UI template when prompting the user to authenticate.
|
||||
// If the request is authenticated, then a user id should be returned
|
||||
// along with a boolean indicating whether a redirect is needed to clean
|
||||
// up authentication artifacts from the URLs query.
|
||||
Authenticate(*http.Request) (userid string, needsRedirect bool, promptData any)
|
||||
// Middleware is applied within the outer auth middleware, but they will
|
||||
// be applied after authentication if no scheme has authenticated a
|
||||
// request.
|
||||
// If no scheme Middleware blocks the request processing, then the auth
|
||||
// middleware will then present the user with the auth UI.
|
||||
Middleware(http.Handler) http.Handler
|
||||
}
|
||||
|
||||
type Middleware struct {
|
||||
next http.Handler
|
||||
config *Config
|
||||
routeID string
|
||||
rejectResponse func(w http.ResponseWriter, r *http.Request)
|
||||
oidcHandler *oidc.Handler // OIDC handler for OAuth flow (contains config and JWT validator)
|
||||
domainsMux sync.RWMutex
|
||||
domains map[string][]Scheme
|
||||
sessionsMux sync.RWMutex
|
||||
sessions map[string]*session
|
||||
}
|
||||
|
||||
// authResult holds the result of an authentication attempt
|
||||
type authResult struct {
|
||||
authenticated bool
|
||||
method string
|
||||
userID string
|
||||
func NewMiddleware() *Middleware {
|
||||
mw := &Middleware{
|
||||
domains: make(map[string][]Scheme),
|
||||
sessions: make(map[string]*session),
|
||||
}
|
||||
// TODO: goroutine is leaked here.
|
||||
go mw.cleanupSessions()
|
||||
return mw
|
||||
}
|
||||
|
||||
// ServeHTTP implements the http.Handler interface
|
||||
func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if m.config.IsEmpty() {
|
||||
m.allowWithoutAuth(w, r)
|
||||
return
|
||||
}
|
||||
// Protect applies authentication middleware to the passed handler.
|
||||
// For each incoming request it will be checked against the middleware's
|
||||
// internal list of protected domains.
|
||||
// If the Host domain in the inbound request is not present, then it will
|
||||
// simply be passed through.
|
||||
// However, if the Host domain is present, then the specified authentication
|
||||
// schemes for that domain will be applied to the request.
|
||||
// In the event that no authentication schemes are defined for the domain,
|
||||
// then the request will also be simply passed through.
|
||||
func (mw *Middleware) Protect(next http.Handler) http.Handler {
|
||||
tmpl := template.Must(template.New("auth").Parse(authTemplate))
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mw.domainsMux.RLock()
|
||||
schemes, exists := mw.domains[r.Host]
|
||||
mw.domainsMux.RUnlock()
|
||||
|
||||
result := m.authenticate(w, r)
|
||||
if result == nil {
|
||||
// Authentication triggered a redirect (e.g., OIDC flow)
|
||||
return
|
||||
}
|
||||
// Domains that are not configured here or have no authentication schemes applied should simply pass through.
|
||||
if !exists || len(schemes) == 0 {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if !result.authenticated {
|
||||
m.rejectRequest(w, r)
|
||||
return
|
||||
}
|
||||
// Check for an existing session to avoid users having to authenticate for every request.
|
||||
// TODO: This does not work if you are load balancing across multiple proxy servers.
|
||||
if cookie, err := r.Cookie(sessionCookieName); err == nil {
|
||||
mw.sessionsMux.RLock()
|
||||
sess, ok := mw.sessions[cookie.Value]
|
||||
mw.sessionsMux.RUnlock()
|
||||
if ok {
|
||||
ctx := withAuthMethod(r.Context(), sess.Method)
|
||||
ctx = withAuthUser(ctx, sess.UserID)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
m.continueWithAuth(w, r, result)
|
||||
// Try to authenticate with each scheme.
|
||||
methods := make(map[Method]any)
|
||||
for _, s := range schemes {
|
||||
userid, needsRedirect, promptData := s.Authenticate(r)
|
||||
if userid != "" {
|
||||
mw.createSession(w, r, userid, s.Type())
|
||||
if needsRedirect {
|
||||
// Clean the path and redirect to the naked URL.
|
||||
// This is intended to prevent leaking potentially
|
||||
// sensitive query parameters for some authentication
|
||||
// methods such as OIDC.
|
||||
http.Redirect(w, r, r.URL.Path, http.StatusFound)
|
||||
return
|
||||
}
|
||||
ctx := withAuthMethod(r.Context(), s.Type())
|
||||
ctx = withAuthUser(ctx, userid)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
methods[s.Type()] = promptData
|
||||
}
|
||||
|
||||
// The handler is passed through the scheme middlewares,
|
||||
// if none of them intercept the request, then this handler will
|
||||
// be called and present the user with the authentication page.
|
||||
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := tmpl.Execute(w, methods); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadGateway)
|
||||
}
|
||||
}))
|
||||
|
||||
// No authentication succeeded. Apply the scheme handlers.
|
||||
for _, s := range schemes {
|
||||
handler = s.Middleware(handler)
|
||||
}
|
||||
|
||||
// Run the unauthenticated request against the scheme handlers and the final UI handler.
|
||||
handler.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// allowWithoutAuth allows requests when no authentication is configured
|
||||
func (m *Middleware) allowWithoutAuth(w http.ResponseWriter, r *http.Request) {
|
||||
log.WithFields(log.Fields{
|
||||
"route_id": m.routeID,
|
||||
"auth_method": "none",
|
||||
"path": r.URL.Path,
|
||||
}).Debug("No authentication configured, allowing request")
|
||||
r.Header.Set("X-Auth-Method", "none")
|
||||
m.next.ServeHTTP(w, r)
|
||||
func (mw *Middleware) AddDomain(domain string, schemes []Scheme) {
|
||||
mw.domainsMux.Lock()
|
||||
defer mw.domainsMux.Unlock()
|
||||
mw.domains[domain] = schemes
|
||||
}
|
||||
|
||||
// authenticate attempts to authenticate the request using configured methods
|
||||
// Returns nil if a redirect occurred (e.g., OIDC flow initiated)
|
||||
func (m *Middleware) authenticate(w http.ResponseWriter, r *http.Request) *authResult {
|
||||
if result := m.tryBasicAuth(r); result.authenticated {
|
||||
return result
|
||||
}
|
||||
|
||||
if result := m.tryPINAuth(r); result.authenticated {
|
||||
return result
|
||||
}
|
||||
|
||||
return m.tryBearerAuth(w, r)
|
||||
func (mw *Middleware) RemoveDomain(domain string) {
|
||||
mw.domainsMux.Lock()
|
||||
defer mw.domainsMux.Unlock()
|
||||
delete(mw.domains, domain)
|
||||
}
|
||||
|
||||
// tryBasicAuth attempts Basic authentication
|
||||
func (m *Middleware) tryBasicAuth(r *http.Request) *authResult {
|
||||
if m.config.BasicAuth == nil {
|
||||
return &authResult{}
|
||||
}
|
||||
func (mw *Middleware) createSession(w http.ResponseWriter, r *http.Request, userID string, method Method) {
|
||||
// Generate a random sessionID
|
||||
b := make([]byte, 32)
|
||||
_, _ = rand.Read(b)
|
||||
sessionID := base64.URLEncoding.EncodeToString(b)
|
||||
|
||||
if !m.config.BasicAuth.Validate(r) {
|
||||
return &authResult{}
|
||||
mw.sessionsMux.Lock()
|
||||
mw.sessions[sessionID] = &session{
|
||||
UserID: userID,
|
||||
Method: method,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
mw.sessionsMux.Unlock()
|
||||
|
||||
result := &authResult{
|
||||
authenticated: true,
|
||||
method: "basic",
|
||||
}
|
||||
|
||||
if username, _, ok := r.BasicAuth(); ok {
|
||||
result.userID = username
|
||||
}
|
||||
|
||||
return result
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: sessionCookieName,
|
||||
Value: sessionID,
|
||||
HttpOnly: true, // This cookie is only for proxy access, so no scripts should touch it.
|
||||
Secure: true, // The proxy only accepts TLS traffic regardless of the service proxied behind.
|
||||
SameSite: http.SameSiteLaxMode, // TODO: might this actually be strict mode?
|
||||
})
|
||||
}
|
||||
|
||||
// tryPINAuth attempts PIN authentication
|
||||
func (m *Middleware) tryPINAuth(r *http.Request) *authResult {
|
||||
if m.config.PIN == nil {
|
||||
return &authResult{}
|
||||
}
|
||||
|
||||
if !m.config.PIN.Validate(r) {
|
||||
return &authResult{}
|
||||
}
|
||||
|
||||
return &authResult{
|
||||
authenticated: true,
|
||||
method: "pin",
|
||||
userID: "pin_user",
|
||||
}
|
||||
}
|
||||
|
||||
// tryBearerAuth attempts Bearer token authentication with JWT validation
|
||||
// Returns nil if OIDC redirect occurred
|
||||
func (m *Middleware) tryBearerAuth(w http.ResponseWriter, r *http.Request) *authResult {
|
||||
if m.config.Bearer == nil || m.oidcHandler == nil {
|
||||
return &authResult{}
|
||||
}
|
||||
|
||||
cookieName := m.oidcHandler.SessionCookieName()
|
||||
|
||||
if m.handleAuthTokenParameter(w, r, cookieName) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if result := m.trySessionCookie(r, cookieName); result.authenticated {
|
||||
return result
|
||||
}
|
||||
|
||||
if result := m.tryAuthorizationHeader(r); result.authenticated {
|
||||
return result
|
||||
}
|
||||
|
||||
m.oidcHandler.RedirectToProvider(w, r, m.routeID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleAuthTokenParameter processes the _auth_token query parameter from OIDC callback
|
||||
// Returns true if a redirect occurred
|
||||
func (m *Middleware) handleAuthTokenParameter(w http.ResponseWriter, r *http.Request, cookieName string) bool {
|
||||
authToken := r.URL.Query().Get("_auth_token")
|
||||
if authToken == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"route_id": m.routeID,
|
||||
"host": r.Host,
|
||||
}).Info("Found auth token in query parameter, setting cookie and redirecting")
|
||||
|
||||
if !m.oidcHandler.ValidateJWT(authToken) {
|
||||
log.WithFields(log.Fields{
|
||||
"route_id": m.routeID,
|
||||
}).Warn("Invalid token in query parameter")
|
||||
return false
|
||||
}
|
||||
|
||||
cookie := &http.Cookie{
|
||||
Name: cookieName,
|
||||
Value: authToken,
|
||||
Path: "/",
|
||||
MaxAge: 3600, // 1 hour
|
||||
HttpOnly: true,
|
||||
Secure: false, // Set to false for HTTP testing, true for HTTPS in production
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
}
|
||||
http.SetCookie(w, cookie)
|
||||
|
||||
// Redirect to same URL without the token parameter
|
||||
redirectURL := m.buildCleanRedirectURL(r)
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"route_id": m.routeID,
|
||||
"redirect_url": redirectURL,
|
||||
}).Debug("Redirecting to clean URL after setting cookie")
|
||||
|
||||
http.Redirect(w, r, redirectURL, http.StatusFound)
|
||||
return true
|
||||
}
|
||||
|
||||
// buildCleanRedirectURL builds a redirect URL without the _auth_token parameter
|
||||
func (m *Middleware) buildCleanRedirectURL(r *http.Request) string {
|
||||
cleanURL := *r.URL
|
||||
q := cleanURL.Query()
|
||||
q.Del("_auth_token")
|
||||
cleanURL.RawQuery = q.Encode()
|
||||
|
||||
scheme := "http"
|
||||
if r.TLS != nil {
|
||||
scheme = "https"
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s://%s%s", scheme, r.Host, cleanURL.String())
|
||||
}
|
||||
|
||||
// trySessionCookie attempts authentication using a session cookie
|
||||
func (m *Middleware) trySessionCookie(r *http.Request, cookieName string) *authResult {
|
||||
log.WithFields(log.Fields{
|
||||
"route_id": m.routeID,
|
||||
"cookie_name": cookieName,
|
||||
"host": r.Host,
|
||||
"path": r.URL.Path,
|
||||
}).Debug("Checking for session cookie")
|
||||
|
||||
cookie, err := r.Cookie(cookieName)
|
||||
if err != nil || cookie.Value == "" {
|
||||
log.WithFields(log.Fields{
|
||||
"route_id": m.routeID,
|
||||
"error": err,
|
||||
}).Debug("No session cookie found")
|
||||
return &authResult{}
|
||||
}
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"route_id": m.routeID,
|
||||
"cookie_name": cookieName,
|
||||
}).Debug("Session cookie found, validating JWT")
|
||||
|
||||
if !m.oidcHandler.ValidateJWT(cookie.Value) {
|
||||
log.WithFields(log.Fields{
|
||||
"route_id": m.routeID,
|
||||
}).Debug("JWT validation failed for session cookie")
|
||||
return &authResult{}
|
||||
}
|
||||
|
||||
return &authResult{
|
||||
authenticated: true,
|
||||
method: "bearer_session",
|
||||
userID: m.oidcHandler.ExtractUserID(cookie.Value),
|
||||
}
|
||||
}
|
||||
|
||||
// tryAuthorizationHeader attempts authentication using the Authorization header
|
||||
func (m *Middleware) tryAuthorizationHeader(r *http.Request) *authResult {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") {
|
||||
return &authResult{}
|
||||
}
|
||||
|
||||
token := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if !m.oidcHandler.ValidateJWT(token) {
|
||||
return &authResult{}
|
||||
}
|
||||
|
||||
return &authResult{
|
||||
authenticated: true,
|
||||
method: "bearer",
|
||||
userID: m.oidcHandler.ExtractUserID(token),
|
||||
}
|
||||
}
|
||||
|
||||
// rejectRequest rejects an unauthenticated request
|
||||
func (m *Middleware) rejectRequest(w http.ResponseWriter, r *http.Request) {
|
||||
log.WithFields(log.Fields{
|
||||
"route_id": m.routeID,
|
||||
"path": r.URL.Path,
|
||||
}).Warn("Authentication failed")
|
||||
|
||||
if m.rejectResponse != nil {
|
||||
m.rejectResponse(w, r)
|
||||
} else {
|
||||
w.Header().Set("WWW-Authenticate", `Bearer realm="Restricted"`)
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
}
|
||||
}
|
||||
|
||||
// continueWithAuth continues the request with authenticated user info
|
||||
func (m *Middleware) continueWithAuth(w http.ResponseWriter, r *http.Request, result *authResult) {
|
||||
log.WithFields(log.Fields{
|
||||
"route_id": m.routeID,
|
||||
"auth_method": result.method,
|
||||
"user_id": result.userID,
|
||||
"path": r.URL.Path,
|
||||
}).Debug("Authentication successful")
|
||||
|
||||
// TODO: Find other means of auth logging than headers
|
||||
r.Header.Set("X-Auth-Method", result.method)
|
||||
r.Header.Set("X-Auth-User-ID", result.userID)
|
||||
|
||||
// Continue to next handler
|
||||
m.next.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// Wrap wraps an HTTP handler with authentication middleware
|
||||
func Wrap(next http.Handler, authConfig *Config, routeID string, rejectResponse func(w http.ResponseWriter, r *http.Request), oidcHandler *oidc.Handler) http.Handler {
|
||||
if authConfig == nil {
|
||||
authConfig = &Config{}
|
||||
}
|
||||
|
||||
return &Middleware{
|
||||
next: next,
|
||||
config: authConfig,
|
||||
routeID: routeID,
|
||||
rejectResponse: rejectResponse,
|
||||
oidcHandler: oidcHandler,
|
||||
func (mw *Middleware) cleanupSessions() {
|
||||
for range time.Tick(time.Minute) {
|
||||
cutoff := time.Now().Add(-sessionExpiration)
|
||||
mw.sessionsMux.Lock()
|
||||
for id, sess := range mw.sessions {
|
||||
if sess.CreatedAt.Before(cutoff) {
|
||||
delete(mw.sessions, id)
|
||||
}
|
||||
}
|
||||
mw.sessionsMux.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
202
proxy/internal/auth/oidc.go
Normal file
202
proxy/internal/auth/oidc.go
Normal file
@@ -0,0 +1,202 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
const stateExpiration = 10 * time.Minute
|
||||
|
||||
// OIDCConfig holds configuration for OIDC authentication
|
||||
type OIDCConfig struct {
|
||||
OIDCProviderURL string
|
||||
OIDCClientID string
|
||||
OIDCClientSecret string
|
||||
OIDCRedirectURL string
|
||||
OIDCScopes []string
|
||||
}
|
||||
|
||||
// oidcState stores CSRF state with expiration
|
||||
type oidcState struct {
|
||||
OriginalURL string
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// OIDC implements the Scheme interface for JWT/OIDC authentication
|
||||
type OIDC struct {
|
||||
verifier *oidc.IDTokenVerifier
|
||||
oauthConfig *oauth2.Config
|
||||
states map[string]*oidcState
|
||||
statesMux sync.RWMutex
|
||||
}
|
||||
|
||||
// NewOIDC creates a new OIDC authentication scheme
|
||||
func NewOIDC(ctx context.Context, cfg OIDCConfig) (*OIDC, error) {
|
||||
if cfg.OIDCProviderURL == "" || cfg.OIDCClientID == "" {
|
||||
return nil, fmt.Errorf("OIDC provider URL and client ID are required")
|
||||
}
|
||||
|
||||
scopes := cfg.OIDCScopes
|
||||
if len(scopes) == 0 {
|
||||
scopes = []string{oidc.ScopeOpenID, "profile", "email"}
|
||||
}
|
||||
|
||||
provider, err := oidc.NewProvider(ctx, cfg.OIDCProviderURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create OIDC provider: %w", err)
|
||||
}
|
||||
|
||||
o := &OIDC{
|
||||
verifier: provider.Verifier(&oidc.Config{
|
||||
ClientID: cfg.OIDCClientID,
|
||||
}),
|
||||
oauthConfig: &oauth2.Config{
|
||||
ClientID: cfg.OIDCClientID,
|
||||
ClientSecret: cfg.OIDCClientSecret,
|
||||
RedirectURL: cfg.OIDCRedirectURL,
|
||||
Scopes: scopes,
|
||||
Endpoint: provider.Endpoint(),
|
||||
},
|
||||
states: make(map[string]*oidcState),
|
||||
}
|
||||
|
||||
go o.cleanupStates()
|
||||
|
||||
return o, nil
|
||||
}
|
||||
|
||||
func (*OIDC) Type() Method {
|
||||
return MethodBearer
|
||||
}
|
||||
|
||||
func (o *OIDC) Authenticate(r *http.Request) (string, bool, any) {
|
||||
// Try Authorization: Bearer <token> header
|
||||
if auth := r.Header.Get("Authorization"); strings.HasPrefix(auth, "Bearer ") {
|
||||
if userID := o.validateToken(r.Context(), strings.TrimPrefix(auth, "Bearer ")); userID != "" {
|
||||
return userID, false, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Try _auth_token query parameter (from OIDC callback redirect)
|
||||
if token := r.URL.Query().Get("_auth_token"); token != "" {
|
||||
if userID := o.validateToken(r.Context(), token); userID != "" {
|
||||
return userID, true, nil // Redirect needed to clean up URL
|
||||
}
|
||||
}
|
||||
|
||||
// If the request is not authenticated, return a redirect URL for the UI to
|
||||
// route the user through if they select OIDC login.
|
||||
b := make([]byte, 32)
|
||||
_, _ = rand.Read(b)
|
||||
state := base64.URLEncoding.EncodeToString(b)
|
||||
|
||||
// TODO: this does not work if you are load balancing across multiple proxy servers.
|
||||
o.statesMux.Lock()
|
||||
o.states[state] = &oidcState{OriginalURL: fmt.Sprintf("https://%s%s", r.Host, r.URL), CreatedAt: time.Now()}
|
||||
o.statesMux.Unlock()
|
||||
|
||||
return "", false, o.oauthConfig.AuthCodeURL(state)
|
||||
}
|
||||
|
||||
// Middleware returns an http.Handler that handles OIDC callback and flow initiation.
|
||||
func (o *OIDC) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Handle OIDC callback
|
||||
if r.URL.Path == "/oauth/callback" {
|
||||
o.handleCallback(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// validateToken validates a JWT ID token and returns the user ID (subject)
|
||||
func (o *OIDC) validateToken(ctx context.Context, token string) string {
|
||||
if o.verifier == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
idToken, err := o.verifier.Verify(ctx, token)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return idToken.Subject
|
||||
}
|
||||
|
||||
// handleCallback processes the OIDC callback
|
||||
func (o *OIDC) handleCallback(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.URL.Query().Get("code")
|
||||
state := r.URL.Query().Get("state")
|
||||
|
||||
if code == "" || state == "" {
|
||||
http.Error(w, "Invalid callback parameters", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify and consume state
|
||||
o.statesMux.Lock()
|
||||
st, ok := o.states[state]
|
||||
if ok {
|
||||
delete(o.states, state)
|
||||
}
|
||||
o.statesMux.Unlock()
|
||||
|
||||
if !ok {
|
||||
http.Error(w, "Invalid or expired state", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Exchange code for token
|
||||
token, err := o.oauthConfig.Exchange(r.Context(), code)
|
||||
if err != nil {
|
||||
slog.Error("Token exchange failed", "error", err)
|
||||
http.Error(w, "Authentication failed", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Prefer ID token if available
|
||||
idToken := token.AccessToken
|
||||
if id, ok := token.Extra("id_token").(string); ok && id != "" {
|
||||
idToken = id
|
||||
}
|
||||
|
||||
// Redirect back to original URL with token
|
||||
origURL, err := url.Parse(st.OriginalURL)
|
||||
if err != nil {
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
q := origURL.Query()
|
||||
q.Set("_auth_token", idToken)
|
||||
origURL.RawQuery = q.Encode()
|
||||
|
||||
http.Redirect(w, r, origURL.String(), http.StatusFound)
|
||||
}
|
||||
|
||||
// cleanupStates periodically removes expired states
|
||||
func (o *OIDC) cleanupStates() {
|
||||
for range time.Tick(time.Minute) {
|
||||
cutoff := time.Now().Add(-stateExpiration)
|
||||
o.statesMux.Lock()
|
||||
for k, v := range o.states {
|
||||
if v.CreatedAt.Before(cutoff) {
|
||||
delete(o.states, k)
|
||||
}
|
||||
}
|
||||
o.statesMux.Unlock()
|
||||
}
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
package oidc
|
||||
|
||||
// Config holds the global OIDC/OAuth configuration
|
||||
type Config struct {
|
||||
ProviderURL string `env:"NB_OIDC_PROVIDER_URL" json:"provider_url"`
|
||||
ClientID string `env:"NB_OIDC_CLIENT_ID" json:"client_id"`
|
||||
ClientSecret string `env:"NB_OIDC_CLIENT_SECRET" json:"client_secret"`
|
||||
RedirectURL string `env:"NB_OIDC_REDIRECT_URL" json:"redirect_url"`
|
||||
Scopes []string `env:"NB_OIDC_SCOPES" json:"scopes"`
|
||||
|
||||
JWTKeysLocation string `env:"NB_OIDC_JWT_KEYS_LOCATION" json:"jwt_keys_location"`
|
||||
JWTIssuer string `env:"NB_OIDC_JWT_ISSUER" json:"jwt_issuer"`
|
||||
JWTAudience []string `env:"NB_OIDC_JWT_AUDIENCE" json:"jwt_audience"`
|
||||
JWTIdpSignkeyRefreshEnabled bool `env:"NB_OIDC_JWT_IDP_SIGNKEY_REFRESH_ENABLED" json:"jwt_idp_signkey_refresh_enabled"`
|
||||
|
||||
SessionCookieName string `env:"NB_OIDC_SESSION_COOKIE_NAME" json:"session_cookie_name"`
|
||||
}
|
||||
@@ -1,285 +0,0 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/auth/jwt"
|
||||
)
|
||||
|
||||
// Handler manages OIDC authentication flow
|
||||
type Handler struct {
|
||||
config *Config
|
||||
stateStore *StateStore
|
||||
jwtValidator *jwt.Validator
|
||||
}
|
||||
|
||||
// NewHandler creates a new OIDC handler
|
||||
func NewHandler(config *Config, stateStore *StateStore) *Handler {
|
||||
var jwtValidator *jwt.Validator
|
||||
if config.JWTKeysLocation != "" {
|
||||
jwtValidator = jwt.NewValidator(
|
||||
config.JWTIssuer,
|
||||
config.JWTAudience,
|
||||
config.JWTKeysLocation,
|
||||
config.JWTIdpSignkeyRefreshEnabled,
|
||||
)
|
||||
}
|
||||
|
||||
return &Handler{
|
||||
config: config,
|
||||
stateStore: stateStore,
|
||||
jwtValidator: jwtValidator,
|
||||
}
|
||||
}
|
||||
|
||||
// RedirectToProvider initiates the OAuth/OIDC authentication flow by redirecting to the provider
|
||||
func (h *Handler) RedirectToProvider(w http.ResponseWriter, r *http.Request, routeID string) {
|
||||
// Generate random state for CSRF protection
|
||||
state, err := generateRandomString(32)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to generate OIDC state")
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Store state with original URL for redirect after auth
|
||||
// Include the full URL with scheme and host
|
||||
scheme := "http"
|
||||
if r.TLS != nil {
|
||||
scheme = "https"
|
||||
}
|
||||
originalURL := fmt.Sprintf("%s://%s%s", scheme, r.Host, r.URL.String())
|
||||
h.stateStore.Store(state, originalURL, routeID)
|
||||
|
||||
// Default scopes if not configured
|
||||
scopes := h.config.Scopes
|
||||
if len(scopes) == 0 {
|
||||
scopes = []string{"openid", "profile", "email"}
|
||||
}
|
||||
|
||||
authURL, err := url.Parse(h.config.ProviderURL)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Invalid OIDC provider URL")
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Append /authorize if it doesn't exist (common OIDC endpoint)
|
||||
if !strings.HasSuffix(authURL.Path, "/authorize") && !strings.HasSuffix(authURL.Path, "/auth") {
|
||||
authURL.Path = strings.TrimSuffix(authURL.Path, "/") + "/authorize"
|
||||
}
|
||||
|
||||
params := url.Values{}
|
||||
params.Set("client_id", h.config.ClientID)
|
||||
params.Set("redirect_uri", h.config.RedirectURL)
|
||||
params.Set("response_type", "code")
|
||||
params.Set("scope", strings.Join(scopes, " "))
|
||||
params.Set("state", state)
|
||||
|
||||
if len(h.config.JWTAudience) > 0 && h.config.JWTAudience[0] != h.config.ClientID {
|
||||
params.Set("audience", h.config.JWTAudience[0])
|
||||
}
|
||||
|
||||
authURL.RawQuery = params.Encode()
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"route_id": routeID,
|
||||
"provider_url": authURL.String(),
|
||||
"redirect_url": h.config.RedirectURL,
|
||||
"state": state,
|
||||
}).Info("Redirecting to OIDC provider for authentication")
|
||||
|
||||
http.Redirect(w, r, authURL.String(), http.StatusFound)
|
||||
}
|
||||
|
||||
// HandleCallback creates an HTTP handler for the OIDC callback endpoint
|
||||
func (h *Handler) HandleCallback() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// Get authorization code and state from query parameters
|
||||
code := r.URL.Query().Get("code")
|
||||
state := r.URL.Query().Get("state")
|
||||
|
||||
if code == "" || state == "" {
|
||||
log.Error("Missing code or state in OIDC callback")
|
||||
http.Error(w, "Invalid callback parameters", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify state to prevent CSRF
|
||||
oidcSt, ok := h.stateStore.Get(state)
|
||||
if !ok {
|
||||
log.Error("Invalid or expired OIDC state")
|
||||
http.Error(w, "Invalid or expired state parameter", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Delete state to prevent reuse
|
||||
h.stateStore.Delete(state)
|
||||
|
||||
// Exchange authorization code for token
|
||||
token, err := h.exchangeCodeForToken(code)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to exchange code for token")
|
||||
http.Error(w, "Authentication failed", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse the original URL to add the token as a query parameter
|
||||
origURL, err := url.Parse(oidcSt.OriginalURL)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to parse original URL")
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Add token as query parameter so the original domain can set its own cookie
|
||||
// We use a special parameter name that the auth middleware will look for
|
||||
q := origURL.Query()
|
||||
q.Set("_auth_token", token)
|
||||
origURL.RawQuery = q.Encode()
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"route_id": oidcSt.RouteID,
|
||||
"original_url": oidcSt.OriginalURL,
|
||||
"redirect_url": origURL.String(),
|
||||
"callback_host": r.Host,
|
||||
}).Info("OIDC authentication successful, redirecting with token parameter")
|
||||
|
||||
// Redirect back to original URL with token parameter
|
||||
http.Redirect(w, r, origURL.String(), http.StatusFound)
|
||||
}
|
||||
}
|
||||
|
||||
// exchangeCodeForToken exchanges an authorization code for an access token
|
||||
func (h *Handler) exchangeCodeForToken(code string) (string, error) {
|
||||
// Build token endpoint URL
|
||||
tokenURL, err := url.Parse(h.config.ProviderURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid OIDC provider URL: %w", err)
|
||||
}
|
||||
|
||||
// Auth0 uses /oauth/token, standard OIDC uses /token
|
||||
// Check if path already contains token endpoint
|
||||
if !strings.Contains(tokenURL.Path, "/token") {
|
||||
tokenURL.Path = strings.TrimSuffix(tokenURL.Path, "/") + "/oauth/token"
|
||||
}
|
||||
|
||||
// Build request body
|
||||
data := url.Values{}
|
||||
data.Set("grant_type", "authorization_code")
|
||||
data.Set("code", code)
|
||||
data.Set("redirect_uri", h.config.RedirectURL)
|
||||
data.Set("client_id", h.config.ClientID)
|
||||
|
||||
// Only include client_secret if it's provided (not needed for public/SPA clients)
|
||||
if h.config.ClientSecret != "" {
|
||||
data.Set("client_secret", h.config.ClientSecret)
|
||||
}
|
||||
|
||||
// Make token exchange request
|
||||
resp, err := http.PostForm(tokenURL.String(), data)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("token exchange request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return "", fmt.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// Parse response
|
||||
var tokenResp struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
IDToken string `json:"id_token"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
||||
return "", fmt.Errorf("failed to decode token response: %w", err)
|
||||
}
|
||||
|
||||
if tokenResp.AccessToken == "" {
|
||||
return "", fmt.Errorf("no access token in response")
|
||||
}
|
||||
|
||||
// Return the ID token if available (contains user claims), otherwise access token
|
||||
if tokenResp.IDToken != "" {
|
||||
return tokenResp.IDToken, nil
|
||||
}
|
||||
|
||||
return tokenResp.AccessToken, nil
|
||||
}
|
||||
|
||||
// ValidateJWT validates a JWT token
|
||||
func (h *Handler) ValidateJWT(tokenString string) bool {
|
||||
if h.jwtValidator == nil {
|
||||
log.Error("JWT validation failed: JWT validator not initialized")
|
||||
return false
|
||||
}
|
||||
|
||||
// Validate the token
|
||||
ctx := context.Background()
|
||||
parsedToken, err := h.jwtValidator.ValidateAndParse(ctx, tokenString)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("JWT validation failed")
|
||||
// Try to parse token without validation to see what's in it
|
||||
parts := strings.Split(tokenString, ".")
|
||||
if len(parts) == 3 {
|
||||
// Decode payload (middle part)
|
||||
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err == nil {
|
||||
log.WithFields(log.Fields{
|
||||
"payload": string(payload),
|
||||
}).Debug("Token payload for debugging")
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Token is valid if parsedToken is not nil and Valid is true
|
||||
return parsedToken != nil && parsedToken.Valid
|
||||
}
|
||||
|
||||
// ExtractUserID extracts the user ID from a JWT token
|
||||
func (h *Handler) ExtractUserID(tokenString string) string {
|
||||
if h.jwtValidator == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Parse the token
|
||||
ctx := context.Background()
|
||||
parsedToken, err := h.jwtValidator.ValidateAndParse(ctx, tokenString)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// parsedToken is already *jwtgo.Token from ValidateAndParse
|
||||
// Create extractor to get user auth info
|
||||
extractor := jwt.NewClaimsExtractor()
|
||||
userAuth, err := extractor.ToUserAuth(parsedToken)
|
||||
if err != nil {
|
||||
log.WithError(err).Debug("Failed to extract user ID from JWT")
|
||||
return ""
|
||||
}
|
||||
|
||||
return userAuth.UserId
|
||||
}
|
||||
|
||||
// SessionCookieName returns the configured session cookie name or default
|
||||
func (h *Handler) SessionCookieName() string {
|
||||
if h.config.SessionCookieName != "" {
|
||||
return h.config.SessionCookieName
|
||||
}
|
||||
return "auth_session"
|
||||
}
|
||||
@@ -1,10 +0,0 @@
|
||||
package oidc
|
||||
|
||||
import "time"
|
||||
|
||||
// State represents stored OIDC state information for CSRF protection
|
||||
type State struct {
|
||||
OriginalURL string
|
||||
CreatedAt time.Time
|
||||
RouteID string
|
||||
}
|
||||
@@ -1,65 +0,0 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// StateExpiration is how long OIDC state tokens are valid
|
||||
StateExpiration = 10 * time.Minute
|
||||
)
|
||||
|
||||
// StateStore manages OIDC state tokens for CSRF protection
|
||||
type StateStore struct {
|
||||
mu sync.RWMutex
|
||||
states map[string]*State
|
||||
}
|
||||
|
||||
// NewStateStore creates a new OIDC state store
|
||||
func NewStateStore() *StateStore {
|
||||
return &StateStore{
|
||||
states: make(map[string]*State),
|
||||
}
|
||||
}
|
||||
|
||||
// Store saves a state token with associated metadata
|
||||
func (s *StateStore) Store(stateToken, originalURL, routeID string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.states[stateToken] = &State{
|
||||
OriginalURL: originalURL,
|
||||
CreatedAt: time.Now(),
|
||||
RouteID: routeID,
|
||||
}
|
||||
|
||||
s.cleanup()
|
||||
}
|
||||
|
||||
// Get retrieves a state by token
|
||||
func (s *StateStore) Get(stateToken string) (*State, bool) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
st, ok := s.states[stateToken]
|
||||
return st, ok
|
||||
}
|
||||
|
||||
// Delete removes a state token
|
||||
func (s *StateStore) Delete(stateToken string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
delete(s.states, stateToken)
|
||||
}
|
||||
|
||||
// cleanup removes expired state tokens (must be called with lock held)
|
||||
func (s *StateStore) cleanup() {
|
||||
cutoff := time.Now().Add(-StateExpiration)
|
||||
for k, v := range s.states {
|
||||
if v.CreatedAt.Before(cutoff) {
|
||||
delete(s.states, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
)
|
||||
|
||||
// generateRandomString generates a cryptographically secure random string of the specified length
|
||||
func generateRandomString(length int) (string, error) {
|
||||
bytes := make([]byte, length)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.URLEncoding.EncodeToString(bytes)[:length], nil
|
||||
}
|
||||
45
proxy/internal/auth/pin.go
Normal file
45
proxy/internal/auth/pin.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const (
|
||||
userId = "pin-user"
|
||||
formId = "pin"
|
||||
)
|
||||
|
||||
type Pin struct {
|
||||
pin string
|
||||
}
|
||||
|
||||
func NewPin(pin string) Pin {
|
||||
return Pin{
|
||||
pin: pin,
|
||||
}
|
||||
}
|
||||
|
||||
func (Pin) Type() Method {
|
||||
return MethodPIN
|
||||
}
|
||||
|
||||
// Authenticate attempts to authenticate the request using a form
|
||||
// value passed in the request.
|
||||
// If authentication fails, the required HTTP form ID is returned
|
||||
// so that it can be injected into a request from the UI so that
|
||||
// authentication may be successful.
|
||||
func (p Pin) Authenticate(r *http.Request) (string, bool, any) {
|
||||
pin := r.FormValue(formId)
|
||||
|
||||
// Compare the passed pin with the expected pin.
|
||||
if subtle.ConstantTimeCompare([]byte(pin), []byte(p.pin)) == 1 {
|
||||
return userId, false, nil
|
||||
}
|
||||
|
||||
return "", false, formId
|
||||
}
|
||||
|
||||
func (p Pin) Middleware(next http.Handler) http.Handler {
|
||||
return next
|
||||
}
|
||||
Reference in New Issue
Block a user