diff --git a/cmd/rdpgw/web/oidc.go b/cmd/rdpgw/web/oidc.go index cdc8c90..37f79ac 100644 --- a/cmd/rdpgw/web/oidc.go +++ b/cmd/rdpgw/web/oidc.go @@ -4,24 +4,25 @@ import ( "crypto/rand" "encoding/hex" "encoding/json" + "log" "net/http" + "strings" "time" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity" "github.com/coreos/go-oidc/v3/oidc" - "github.com/patrickmn/go-cache" "golang.org/x/oauth2" ) const ( CacheExpiration = time.Minute * 2 CleanupInterval = time.Minute * 5 + oidcStateKey = "OIDCSTATE" ) type OIDC struct { oAuth2Config *oauth2.Config oidcTokenVerifier *oidc.IDTokenVerifier - stateStore *cache.Cache } type OIDCConfig struct { @@ -33,18 +34,63 @@ func (c *OIDCConfig) New() *OIDC { return &OIDC{ oAuth2Config: c.OAuth2Config, oidcTokenVerifier: c.OIDCTokenVerifier, - stateStore: cache.New(CacheExpiration, CleanupInterval), } } +// storeOIDCState stores the OIDC state and redirect URL in the session +func storeOIDCState(w http.ResponseWriter, r *http.Request, state string, redirectURL string) error { + session, err := GetSession(r) + if err != nil { + return err + } + + // Store state data directly as a concatenated string: state + "|" + redirectURL + stateValue := state + "|" + redirectURL + session.Values[oidcStateKey] = stateValue + session.Options.MaxAge = int(CacheExpiration.Seconds()) + + return sessionStore.Save(r, w, session) +} + +// getOIDCState retrieves the redirect URL for the given state from the session +func getOIDCState(r *http.Request, state string) (string, bool) { + session, err := GetSession(r) + if err != nil { + log.Printf("Error getting session for OIDC state: %v", err) + return "", false + } + + stateData, exists := session.Values[oidcStateKey] + if !exists { + log.Printf("No OIDC state data found in session") + return "", false + } + + stateValue, ok := stateData.(string) + if !ok { + log.Printf("Invalid OIDC state data format in session") + return "", false + } + + // Parse state data: state + "|" + redirectURL + expectedPrefix := state + "|" + if !strings.HasPrefix(stateValue, expectedPrefix) { + log.Printf("OIDC state '%s' not found in session", state) + return "", false + } + + redirectURL := stateValue[len(expectedPrefix):] + return redirectURL, true +} + func (h *OIDC) HandleCallback(w http.ResponseWriter, r *http.Request) { state := r.URL.Query().Get("state") - s, found := h.stateStore.Get(state) + url, found := getOIDCState(r, state) if !found { + log.Printf("OIDC HandleCallback: unknown state '%s'", state) http.Error(w, "unknown state", http.StatusBadRequest) return } - url := s.(string) ctx := r.Context() oauth2Token, err := h.oAuth2Config.Exchange(ctx, r.URL.Query().Get("code")) @@ -123,7 +169,15 @@ func (h *OIDC) Authenticated(next http.Handler) http.Handler { return } state := hex.EncodeToString(seed) - h.stateStore.Set(state, r.RequestURI, cache.DefaultExpiration) + + log.Printf("OIDC Authenticated: storing state '%s' for redirect to '%s'", state, r.RequestURI) + err = storeOIDCState(w, r, state, r.RequestURI) + if err != nil { + log.Printf("OIDC Authenticated: failed to store state: %v", err) + http.Error(w, "Failed to store OIDC state", http.StatusInternalServerError) + return + } + http.Redirect(w, r, h.oAuth2Config.AuthCodeURL(state), http.StatusFound) return } diff --git a/cmd/rdpgw/web/oidc_test.go b/cmd/rdpgw/web/oidc_test.go index 37eb908..4e8d787 100644 --- a/cmd/rdpgw/web/oidc_test.go +++ b/cmd/rdpgw/web/oidc_test.go @@ -1,6 +1,9 @@ package web -import "testing" +import ( + "net/http/httptest" + "testing" +) func TestFindUserNameInClaims(t *testing.T) { cases := []struct { @@ -47,3 +50,66 @@ func TestFindUserNameInClaims(t *testing.T) { }) } } + +func TestOIDCStateManagement(t *testing.T) { + // Initialize session store for testing + sessionKey := []byte("testsessionkeytestsessionkey1234") // 32 bytes + encryptionKey := []byte("testencryptionkeytestencrypt1234") // 32 bytes + InitStore(sessionKey, encryptionKey, "cookie", 8192) + + // Test storing and retrieving OIDC state + t.Run("StoreAndRetrieveState", func(t *testing.T) { + // Create test request and response + req := httptest.NewRequest("GET", "/connect", nil) + w := httptest.NewRecorder() + + state := "test-state-12345" + redirectURL := "/original-request" + + // Store state + err := storeOIDCState(w, req, state, redirectURL) + if err != nil { + t.Fatalf("Failed to store OIDC state: %v", err) + } + + // Create a new request with the same session cookie + callbackReq := httptest.NewRequest("GET", "/callback", nil) + + // Copy session cookie from response to new request + for _, cookie := range w.Result().Cookies() { + callbackReq.AddCookie(cookie) + } + + // Retrieve state + retrievedURL, found := getOIDCState(callbackReq, state) + if !found { + t.Fatal("Expected to find OIDC state, but it was not found") + } + + if retrievedURL != redirectURL { + t.Fatalf("Expected redirect URL '%s', got '%s'", redirectURL, retrievedURL) + } + }) + + t.Run("StateNotFound", func(t *testing.T) { + // Create test request + req := httptest.NewRequest("GET", "/callback", nil) + + // Try to retrieve non-existent state + _, found := getOIDCState(req, "non-existent-state") + if found { + t.Fatal("Expected state not to be found, but it was found") + } + }) + + t.Run("EmptySession", func(t *testing.T) { + // Create fresh request with no session + req := httptest.NewRequest("GET", "/callback", nil) + + // Try to retrieve state from empty session + _, found := getOIDCState(req, "any-state") + if found { + t.Fatal("Expected state not to be found in empty session, but it was found") + } + }) +}