mirror of
https://github.com/bolkedebruin/rdpgw.git
synced 2026-03-29 07:06:34 +00:00
Fix checking host from list
This commit is contained in:
@@ -24,7 +24,7 @@ type RedirectFlags struct {
|
|||||||
|
|
||||||
type SessionInfo struct {
|
type SessionInfo struct {
|
||||||
// The connection-id (RDG-ConnID) as reported by the client
|
// The connection-id (RDG-ConnID) as reported by the client
|
||||||
ConnId string
|
ConnId string
|
||||||
// The underlying incoming transport being either websocket or legacy http
|
// The underlying incoming transport being either websocket or legacy http
|
||||||
// in case of websocket TransportOut will equal TransportIn
|
// in case of websocket TransportOut will equal TransportIn
|
||||||
TransportIn transport.Transport
|
TransportIn transport.Transport
|
||||||
@@ -32,9 +32,11 @@ type SessionInfo struct {
|
|||||||
// in case of websocket TransportOut will equal TransportOut
|
// in case of websocket TransportOut will equal TransportOut
|
||||||
TransportOut transport.Transport
|
TransportOut transport.Transport
|
||||||
// The remote desktop server (rdp, vnc etc) the clients intends to connect to
|
// The remote desktop server (rdp, vnc etc) the clients intends to connect to
|
||||||
RemoteServer string
|
RemoteServer string
|
||||||
// The obtained client ip address
|
// The obtained client ip address
|
||||||
ClientIp string
|
ClientIp string
|
||||||
|
// User
|
||||||
|
UserName string
|
||||||
}
|
}
|
||||||
|
|
||||||
// readMessage parses and defragments a packet from a Transport. It returns
|
// readMessage parses and defragments a packet from a Transport. It returns
|
||||||
|
|||||||
@@ -22,10 +22,13 @@ func CheckHost(ctx context.Context, host string) (bool, error) {
|
|||||||
return false, errors.New("cannot verify host in 'signed' mode as token data is missing")
|
return false, errors.New("cannot verify host in 'signed' mode as token data is missing")
|
||||||
case "roundrobin", "unsigned":
|
case "roundrobin", "unsigned":
|
||||||
log.Printf("Checking host")
|
log.Printf("Checking host")
|
||||||
username := ctx.Value("preferred_username").(string)
|
s := getSessionInfo(ctx)
|
||||||
|
if s == nil {
|
||||||
|
return false, errors.New("no valid session info found in context")
|
||||||
|
}
|
||||||
for _, h := range Hosts {
|
for _, h := range Hosts {
|
||||||
if username != "" {
|
if s.UserName != "" {
|
||||||
h = strings.Replace(h, "{{ preferred_username }}", username, 1)
|
h = strings.Replace(h, "{{ preferred_username }}", s.UserName, 1)
|
||||||
}
|
}
|
||||||
if h == host {
|
if h == host {
|
||||||
return true, nil
|
return true, nil
|
||||||
|
|||||||
@@ -95,19 +95,18 @@ func VerifyPAAToken(ctx context.Context, tokenString string) (bool, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// validate the access token
|
// validate the access token
|
||||||
if custom.AccessToken != "EMPTY" {
|
tokenSource := Oauth2Config.TokenSource(ctx, &oauth2.Token{AccessToken: custom.AccessToken})
|
||||||
tokenSource := Oauth2Config.TokenSource(ctx, &oauth2.Token{AccessToken: custom.AccessToken})
|
user, err := OIDCProvider.UserInfo(ctx, tokenSource)
|
||||||
_, err = OIDCProvider.UserInfo(ctx, tokenSource)
|
if err != nil {
|
||||||
if err != nil {
|
log.Printf("Cannot get user info for access token: %s", err)
|
||||||
log.Printf("Cannot get user info for access token: %s", err)
|
return false, err
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
s := getSessionInfo(ctx)
|
s := getSessionInfo(ctx)
|
||||||
|
|
||||||
s.RemoteServer = custom.RemoteServer
|
s.RemoteServer = custom.RemoteServer
|
||||||
s.ClientIp = custom.ClientIP
|
s.ClientIp = custom.ClientIP
|
||||||
|
s.UserName = user.Subject
|
||||||
|
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user