Check hostname specified by client against the token

This commit is contained in:
Bolke de Bruin
2020-07-25 19:37:33 +02:00
parent 39c73fc8fc
commit 5f3c7d07e2
5 changed files with 53 additions and 32 deletions

View File

@@ -1,6 +1,7 @@
package security
import (
"context"
"errors"
"fmt"
"github.com/bolkedebruin/rdpgw/protocol"
@@ -17,7 +18,7 @@ type customClaims struct {
jwt.StandardClaims
}
func VerifyPAAToken(s *protocol.SessionInfo, tokenString string) (bool, error) {
func VerifyPAAToken(ctx context.Context, tokenString string) (bool, error) {
token, err := jwt.ParseWithClaims(tokenString, &customClaims{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
@@ -30,7 +31,9 @@ func VerifyPAAToken(s *protocol.SessionInfo, tokenString string) (bool, error) {
return false, err
}
if _, ok := token.Claims.(*customClaims); ok && token.Valid {
if c, ok := token.Claims.(*customClaims); ok && token.Valid {
s := getSessionInfo(ctx)
s.RemoteServer = c.RemoteServer
return true, nil
}
@@ -38,7 +41,21 @@ func VerifyPAAToken(s *protocol.SessionInfo, tokenString string) (bool, error) {
return false, err
}
func GeneratePAAToken(username string, server string) (string, error) {
func VerifyServerFunc(ctx context.Context, host string) (bool, error) {
s := getSessionInfo(ctx)
if s == nil {
return false, errors.New("no valid session info found in context")
}
if s.RemoteServer != host {
log.Printf("Client host %s does not match token host %s", host, s.RemoteServer)
return false, nil
}
return true, nil
}
func GeneratePAAToken(ctx context.Context, username string, server string) (string, error) {
if len(SigningKey) < 32 {
return "", errors.New("token signing key not long enough or not specified")
}
@@ -67,4 +84,13 @@ func GeneratePAAToken(username string, server string) (string, error) {
} else {
return ss, nil
}
}
func getSessionInfo(ctx context.Context) *protocol.SessionInfo {
s, ok := ctx.Value("SessionInfo").(*protocol.SessionInfo)
if !ok {
log.Printf("cannot get session info from context")
return nil
}
return s
}