Fix checking host from list

This commit is contained in:
Bolke de Bruin
2022-08-26 11:59:46 +02:00
parent 19e9e3269d
commit 184ff320b8
3 changed files with 18 additions and 14 deletions

View File

@@ -24,7 +24,7 @@ type RedirectFlags struct {
type SessionInfo struct { type SessionInfo struct {
// The connection-id (RDG-ConnID) as reported by the client // The connection-id (RDG-ConnID) as reported by the client
ConnId string ConnId string
// The underlying incoming transport being either websocket or legacy http // The underlying incoming transport being either websocket or legacy http
// in case of websocket TransportOut will equal TransportIn // in case of websocket TransportOut will equal TransportIn
TransportIn transport.Transport TransportIn transport.Transport
@@ -32,9 +32,11 @@ type SessionInfo struct {
// in case of websocket TransportOut will equal TransportOut // in case of websocket TransportOut will equal TransportOut
TransportOut transport.Transport TransportOut transport.Transport
// The remote desktop server (rdp, vnc etc) the clients intends to connect to // The remote desktop server (rdp, vnc etc) the clients intends to connect to
RemoteServer string RemoteServer string
// The obtained client ip address // The obtained client ip address
ClientIp string ClientIp string
// User
UserName string
} }
// readMessage parses and defragments a packet from a Transport. It returns // readMessage parses and defragments a packet from a Transport. It returns

View File

@@ -22,10 +22,13 @@ func CheckHost(ctx context.Context, host string) (bool, error) {
return false, errors.New("cannot verify host in 'signed' mode as token data is missing") return false, errors.New("cannot verify host in 'signed' mode as token data is missing")
case "roundrobin", "unsigned": case "roundrobin", "unsigned":
log.Printf("Checking host") log.Printf("Checking host")
username := ctx.Value("preferred_username").(string) s := getSessionInfo(ctx)
if s == nil {
return false, errors.New("no valid session info found in context")
}
for _, h := range Hosts { for _, h := range Hosts {
if username != "" { if s.UserName != "" {
h = strings.Replace(h, "{{ preferred_username }}", username, 1) h = strings.Replace(h, "{{ preferred_username }}", s.UserName, 1)
} }
if h == host { if h == host {
return true, nil return true, nil

View File

@@ -95,19 +95,18 @@ func VerifyPAAToken(ctx context.Context, tokenString string) (bool, error) {
} }
// validate the access token // validate the access token
if custom.AccessToken != "EMPTY" { tokenSource := Oauth2Config.TokenSource(ctx, &oauth2.Token{AccessToken: custom.AccessToken})
tokenSource := Oauth2Config.TokenSource(ctx, &oauth2.Token{AccessToken: custom.AccessToken}) user, err := OIDCProvider.UserInfo(ctx, tokenSource)
_, err = OIDCProvider.UserInfo(ctx, tokenSource) if err != nil {
if err != nil { log.Printf("Cannot get user info for access token: %s", err)
log.Printf("Cannot get user info for access token: %s", err) return false, err
return false, err
}
} }
s := getSessionInfo(ctx) s := getSessionInfo(ctx)
s.RemoteServer = custom.RemoteServer s.RemoteServer = custom.RemoteServer
s.ClientIp = custom.ClientIP s.ClientIp = custom.ClientIP
s.UserName = user.Subject
return true, nil return true, nil
} }