mirror of
https://github.com/bolkedebruin/rdpgw.git
synced 2026-03-30 15:36:36 +00:00
Use session store for state
This commit is contained in:
@@ -4,24 +4,25 @@ import (
|
|||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
|
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
|
||||||
"github.com/coreos/go-oidc/v3/oidc"
|
"github.com/coreos/go-oidc/v3/oidc"
|
||||||
"github.com/patrickmn/go-cache"
|
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
CacheExpiration = time.Minute * 2
|
CacheExpiration = time.Minute * 2
|
||||||
CleanupInterval = time.Minute * 5
|
CleanupInterval = time.Minute * 5
|
||||||
|
oidcStateKey = "OIDCSTATE"
|
||||||
)
|
)
|
||||||
|
|
||||||
type OIDC struct {
|
type OIDC struct {
|
||||||
oAuth2Config *oauth2.Config
|
oAuth2Config *oauth2.Config
|
||||||
oidcTokenVerifier *oidc.IDTokenVerifier
|
oidcTokenVerifier *oidc.IDTokenVerifier
|
||||||
stateStore *cache.Cache
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type OIDCConfig struct {
|
type OIDCConfig struct {
|
||||||
@@ -33,18 +34,63 @@ func (c *OIDCConfig) New() *OIDC {
|
|||||||
return &OIDC{
|
return &OIDC{
|
||||||
oAuth2Config: c.OAuth2Config,
|
oAuth2Config: c.OAuth2Config,
|
||||||
oidcTokenVerifier: c.OIDCTokenVerifier,
|
oidcTokenVerifier: c.OIDCTokenVerifier,
|
||||||
stateStore: cache.New(CacheExpiration, CleanupInterval),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// storeOIDCState stores the OIDC state and redirect URL in the session
|
||||||
|
func storeOIDCState(w http.ResponseWriter, r *http.Request, state string, redirectURL string) error {
|
||||||
|
session, err := GetSession(r)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store state data directly as a concatenated string: state + "|" + redirectURL
|
||||||
|
stateValue := state + "|" + redirectURL
|
||||||
|
session.Values[oidcStateKey] = stateValue
|
||||||
|
session.Options.MaxAge = int(CacheExpiration.Seconds())
|
||||||
|
|
||||||
|
return sessionStore.Save(r, w, session)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getOIDCState retrieves the redirect URL for the given state from the session
|
||||||
|
func getOIDCState(r *http.Request, state string) (string, bool) {
|
||||||
|
session, err := GetSession(r)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Error getting session for OIDC state: %v", err)
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
stateData, exists := session.Values[oidcStateKey]
|
||||||
|
if !exists {
|
||||||
|
log.Printf("No OIDC state data found in session")
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
stateValue, ok := stateData.(string)
|
||||||
|
if !ok {
|
||||||
|
log.Printf("Invalid OIDC state data format in session")
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse state data: state + "|" + redirectURL
|
||||||
|
expectedPrefix := state + "|"
|
||||||
|
if !strings.HasPrefix(stateValue, expectedPrefix) {
|
||||||
|
log.Printf("OIDC state '%s' not found in session", state)
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
redirectURL := stateValue[len(expectedPrefix):]
|
||||||
|
return redirectURL, true
|
||||||
|
}
|
||||||
|
|
||||||
func (h *OIDC) HandleCallback(w http.ResponseWriter, r *http.Request) {
|
func (h *OIDC) HandleCallback(w http.ResponseWriter, r *http.Request) {
|
||||||
state := r.URL.Query().Get("state")
|
state := r.URL.Query().Get("state")
|
||||||
s, found := h.stateStore.Get(state)
|
url, found := getOIDCState(r, state)
|
||||||
if !found {
|
if !found {
|
||||||
|
log.Printf("OIDC HandleCallback: unknown state '%s'", state)
|
||||||
http.Error(w, "unknown state", http.StatusBadRequest)
|
http.Error(w, "unknown state", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
url := s.(string)
|
|
||||||
|
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
oauth2Token, err := h.oAuth2Config.Exchange(ctx, r.URL.Query().Get("code"))
|
oauth2Token, err := h.oAuth2Config.Exchange(ctx, r.URL.Query().Get("code"))
|
||||||
@@ -123,7 +169,15 @@ func (h *OIDC) Authenticated(next http.Handler) http.Handler {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
state := hex.EncodeToString(seed)
|
state := hex.EncodeToString(seed)
|
||||||
h.stateStore.Set(state, r.RequestURI, cache.DefaultExpiration)
|
|
||||||
|
log.Printf("OIDC Authenticated: storing state '%s' for redirect to '%s'", state, r.RequestURI)
|
||||||
|
err = storeOIDCState(w, r, state, r.RequestURI)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("OIDC Authenticated: failed to store state: %v", err)
|
||||||
|
http.Error(w, "Failed to store OIDC state", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
http.Redirect(w, r, h.oAuth2Config.AuthCodeURL(state), http.StatusFound)
|
http.Redirect(w, r, h.oAuth2Config.AuthCodeURL(state), http.StatusFound)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,9 @@
|
|||||||
package web
|
package web
|
||||||
|
|
||||||
import "testing"
|
import (
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
func TestFindUserNameInClaims(t *testing.T) {
|
func TestFindUserNameInClaims(t *testing.T) {
|
||||||
cases := []struct {
|
cases := []struct {
|
||||||
@@ -47,3 +50,66 @@ func TestFindUserNameInClaims(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOIDCStateManagement(t *testing.T) {
|
||||||
|
// Initialize session store for testing
|
||||||
|
sessionKey := []byte("testsessionkeytestsessionkey1234") // 32 bytes
|
||||||
|
encryptionKey := []byte("testencryptionkeytestencrypt1234") // 32 bytes
|
||||||
|
InitStore(sessionKey, encryptionKey, "cookie", 8192)
|
||||||
|
|
||||||
|
// Test storing and retrieving OIDC state
|
||||||
|
t.Run("StoreAndRetrieveState", func(t *testing.T) {
|
||||||
|
// Create test request and response
|
||||||
|
req := httptest.NewRequest("GET", "/connect", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
state := "test-state-12345"
|
||||||
|
redirectURL := "/original-request"
|
||||||
|
|
||||||
|
// Store state
|
||||||
|
err := storeOIDCState(w, req, state, redirectURL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to store OIDC state: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new request with the same session cookie
|
||||||
|
callbackReq := httptest.NewRequest("GET", "/callback", nil)
|
||||||
|
|
||||||
|
// Copy session cookie from response to new request
|
||||||
|
for _, cookie := range w.Result().Cookies() {
|
||||||
|
callbackReq.AddCookie(cookie)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retrieve state
|
||||||
|
retrievedURL, found := getOIDCState(callbackReq, state)
|
||||||
|
if !found {
|
||||||
|
t.Fatal("Expected to find OIDC state, but it was not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
if retrievedURL != redirectURL {
|
||||||
|
t.Fatalf("Expected redirect URL '%s', got '%s'", redirectURL, retrievedURL)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("StateNotFound", func(t *testing.T) {
|
||||||
|
// Create test request
|
||||||
|
req := httptest.NewRequest("GET", "/callback", nil)
|
||||||
|
|
||||||
|
// Try to retrieve non-existent state
|
||||||
|
_, found := getOIDCState(req, "non-existent-state")
|
||||||
|
if found {
|
||||||
|
t.Fatal("Expected state not to be found, but it was found")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("EmptySession", func(t *testing.T) {
|
||||||
|
// Create fresh request with no session
|
||||||
|
req := httptest.NewRequest("GET", "/callback", nil)
|
||||||
|
|
||||||
|
// Try to retrieve state from empty session
|
||||||
|
_, found := getOIDCState(req, "any-state")
|
||||||
|
if found {
|
||||||
|
t.Fatal("Expected state not to be found in empty session, but it was found")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user