Enable signed hosts provied in query parameters

This commit is contained in:
Bolke de Bruin
2022-08-17 19:12:28 +02:00
parent 8bc3e25f83
commit cb8b269478
5 changed files with 132 additions and 15 deletions

View File

@@ -27,6 +27,7 @@ const (
type TokenGeneratorFunc func(context.Context, string, string) (string, error) type TokenGeneratorFunc func(context.Context, string, string) (string, error)
type UserTokenGeneratorFunc func(context.Context, string) (string, error) type UserTokenGeneratorFunc func(context.Context, string) (string, error)
type QueryInfoFunc func(context.Context, string, string) (string, error)
type Config struct { type Config struct {
SessionKey []byte SessionKey []byte
@@ -34,6 +35,8 @@ type Config struct {
SessionStore string SessionStore string
PAATokenGenerator TokenGeneratorFunc PAATokenGenerator TokenGeneratorFunc
UserTokenGenerator UserTokenGeneratorFunc UserTokenGenerator UserTokenGeneratorFunc
QueryInfo QueryInfoFunc
QueryTokenIssuer string
EnableUserToken bool EnableUserToken bool
OAuth2Config *oauth2.Config OAuth2Config *oauth2.Config
store sessions.Store store sessions.Store
@@ -159,39 +162,53 @@ func (c *Config) selectRandomHost() string {
return host return host
} }
func (c *Config) getHost(u *url.URL) (string, error) { func (c *Config) getHost(ctx context.Context, u *url.URL) (string, error) {
var host string
switch c.HostSelection { switch c.HostSelection {
case "roundrobin": case "roundrobin":
host = c.selectRandomHost() return c.selectRandomHost(), nil
case "signed": case "signed":
case "unsigned":
hosts, ok := u.Query()["host"] hosts, ok := u.Query()["host"]
if !ok { if !ok {
return "", errors.New("invalid query parameter") return "", errors.New("invalid query parameter")
} }
host, err := c.QueryInfo(ctx, hosts[0], c.QueryTokenIssuer)
if err != nil {
return "", err
}
found := false found := false
for _, check := range c.Hosts { for _, check := range c.Hosts {
if check == hosts[0] { if check == host {
host = hosts[0]
found = true found = true
break break
} }
} }
if !found { if !found {
log.Printf("Invalid host %s specified in client request", hosts[0]) log.Printf("Invalid host %s specified in token", hosts[0])
return "", errors.New("invalid host specified in query parameter") return "", errors.New("invalid host specified in query token")
} }
return host, nil
case "unsigned":
hosts, ok := u.Query()["host"]
if !ok {
return "", errors.New("invalid query parameter")
}
for _, check := range c.Hosts {
if check == hosts[0] {
return hosts[0], nil
}
}
// not found
log.Printf("Invalid host %s specified in client request", hosts[0])
return "", errors.New("invalid host specified in query parameter")
case "any": case "any":
hosts, ok := u.Query()["host"] hosts, ok := u.Query()["host"]
if !ok { if !ok {
return "", errors.New("invalid query parameter") return "", errors.New("invalid query parameter")
} }
host = hosts[0] return hosts[0], nil
default: default:
host = c.selectRandomHost() return c.selectRandomHost(), nil
} }
return host, nil
} }
func (c *Config) HandleDownload(w http.ResponseWriter, r *http.Request) { func (c *Config) HandleDownload(w http.ResponseWriter, r *http.Request) {
@@ -205,7 +222,7 @@ func (c *Config) HandleDownload(w http.ResponseWriter, r *http.Request) {
} }
// determine host to connect to // determine host to connect to
host, err := c.getHost(r.URL) host, err := c.getHost(ctx, r.URL)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest) http.Error(w, err.Error(), http.StatusBadRequest)
return return

View File

@@ -1,12 +1,15 @@
package api package api
import ( import (
"context"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/security"
"net/url" "net/url"
"testing" "testing"
) )
var ( var (
hosts = []string{"10.0.0.1:3389", "10.1.1.1:3000", "32.32.11.1", "remote.host.com"} hosts = []string{"10.0.0.1:3389", "10.1.1.1:3000", "32.32.11.1", "remote.host.com"}
key = []byte("thisisasessionkeyreplacethisjetzt")
) )
func contains(needle string, haystack []string) bool { func contains(needle string, haystack []string) bool {
@@ -19,6 +22,7 @@ func contains(needle string, haystack []string) bool {
} }
func TestGetHost(t *testing.T) { func TestGetHost(t *testing.T) {
ctx := context.Background()
c := Config{ c := Config{
HostSelection: "roundrobin", HostSelection: "roundrobin",
Hosts: hosts, Hosts: hosts,
@@ -28,7 +32,7 @@ func TestGetHost(t *testing.T) {
} }
vals := u.Query() vals := u.Query()
host, err := c.getHost(u) host, err := c.getHost(ctx, u)
if err != nil { if err != nil {
t.Fatalf("#{err}") t.Fatalf("#{err}")
} }
@@ -40,14 +44,14 @@ func TestGetHost(t *testing.T) {
c.HostSelection = "unsigned" c.HostSelection = "unsigned"
vals.Set("host", "in.valid.host") vals.Set("host", "in.valid.host")
u.RawQuery = vals.Encode() u.RawQuery = vals.Encode()
host, err = c.getHost(u) host, err = c.getHost(ctx, u)
if err == nil { if err == nil {
t.Fatalf("Accepted host %s is not in hosts list", host) t.Fatalf("Accepted host %s is not in hosts list", host)
} }
vals.Set("host", hosts[0]) vals.Set("host", hosts[0])
u.RawQuery = vals.Encode() u.RawQuery = vals.Encode()
host, err = c.getHost(u) host, err = c.getHost(ctx, u)
if err != nil { if err != nil {
t.Fatalf("Not accepted host %s is in hosts list (err: %s)", hosts[0], err) t.Fatalf("Not accepted host %s is in hosts list (err: %s)", hosts[0], err)
} }
@@ -55,4 +59,35 @@ func TestGetHost(t *testing.T) {
t.Fatalf("host %s is not equal to input %s", host, hosts[0]) t.Fatalf("host %s is not equal to input %s", host, hosts[0])
} }
// check any
c.HostSelection = "any"
test := "bla.bla.com"
vals.Set("host", test)
u.RawQuery = vals.Encode()
host, err = c.getHost(ctx, u)
if err != nil {
t.Fatalf("%s is not accepted", host)
}
if test != host {
t.Fatalf("Returned host %s is not equal to input host %s", host, test)
}
// check signed
c.HostSelection = "signed"
c.QueryInfo = security.QueryInfo
issuer := "rdpgwtest"
security.QuerySigningKey = key
queryToken, err := security.GenerateQueryToken(ctx, hosts[0], issuer)
if err != nil {
t.Fatalf("cannot generate token")
}
vals.Set("host", queryToken)
u.RawQuery = vals.Encode()
host, err = c.getHost(ctx, u)
if err != nil {
t.Fatalf("Not accepted host %s is in hosts list (err: %s)", hosts[0], err)
}
if host != hosts[0] {
t.Fatalf("%s does not equal %s", host, hosts[0])
}
} }

View File

@@ -58,6 +58,8 @@ type SecurityConfig struct {
PAATokenSigningKey string `koanf:"paatokensigningkey"` PAATokenSigningKey string `koanf:"paatokensigningkey"`
UserTokenEncryptionKey string `koanf:"usertokenencryptionkey"` UserTokenEncryptionKey string `koanf:"usertokenencryptionkey"`
UserTokenSigningKey string `koanf:"usertokensigningkey"` UserTokenSigningKey string `koanf:"usertokensigningkey"`
QueryTokenSigningKey string `koanf:"querytokensigningkey"`
QueryTokenIssuer string `koanf:"querytokenissuer"`
VerifyClientIp bool `koanf:"verifyclientip"` VerifyClientIp bool `koanf:"verifyclientip"`
EnableUserToken bool `koanf:"enableusertoken"` EnableUserToken bool `koanf:"enableusertoken"`
} }
@@ -176,6 +178,10 @@ func Load(configFile string) Configuration {
log.Printf("No valid `server.sessionencryptionkey` specified (empty or not 32 characters). Setting to random") log.Printf("No valid `server.sessionencryptionkey` specified (empty or not 32 characters). Setting to random")
} }
if Conf.Server.HostSelection == "signed" && len(Conf.Security.QueryTokenSigningKey) == 0 {
log.Fatalf("host selection is set to `signed` but `querytokensigningkey` is not set")
}
return Conf return Conf
} }

View File

@@ -40,6 +40,7 @@ func main() {
security.EncryptionKey = []byte(conf.Security.PAATokenEncryptionKey) security.EncryptionKey = []byte(conf.Security.PAATokenEncryptionKey)
security.UserEncryptionKey = []byte(conf.Security.UserTokenEncryptionKey) security.UserEncryptionKey = []byte(conf.Security.UserTokenEncryptionKey)
security.UserSigningKey = []byte(conf.Security.UserTokenSigningKey) security.UserSigningKey = []byte(conf.Security.UserTokenSigningKey)
security.QuerySigningKey = []byte(conf.Security.QueryTokenSigningKey)
// set oidc config // set oidc config
provider, err := oidc.NewProvider(context.Background(), conf.OpenId.ProviderUrl) provider, err := oidc.NewProvider(context.Background(), conf.OpenId.ProviderUrl)
@@ -74,6 +75,8 @@ func main() {
OIDCTokenVerifier: verifier, OIDCTokenVerifier: verifier,
PAATokenGenerator: security.GeneratePAAToken, PAATokenGenerator: security.GeneratePAAToken,
UserTokenGenerator: security.GenerateUserToken, UserTokenGenerator: security.GenerateUserToken,
QueryInfo: security.QueryInfo,
QueryTokenIssuer: conf.Security.QueryTokenIssuer,
EnableUserToken: conf.Security.EnableUserToken, EnableUserToken: conf.Security.EnableUserToken,
SessionKey: []byte(conf.Server.SessionKey), SessionKey: []byte(conf.Server.SessionKey),
SessionEncryptionKey: []byte(conf.Server.SessionEncryptionKey), SessionEncryptionKey: []byte(conf.Server.SessionEncryptionKey),

View File

@@ -19,6 +19,7 @@ var (
EncryptionKey []byte EncryptionKey []byte
UserSigningKey []byte UserSigningKey []byte
UserEncryptionKey []byte UserEncryptionKey []byte
QuerySigningKey []byte
OIDCProvider *oidc.Provider OIDCProvider *oidc.Provider
Oauth2Config oauth2.Config Oauth2Config oauth2.Config
) )
@@ -221,6 +222,61 @@ func UserInfo(ctx context.Context, token string) (jwt.Claims, error) {
return standard, nil return standard, nil
} }
func QueryInfo(ctx context.Context, tokenString string, issuer string) (string, error) {
standard := jwt.Claims{}
token, err := jwt.ParseSigned(tokenString)
if err != nil {
log.Printf("Cannot get token %s", err)
return "", errors.New("cannot get token")
}
if _, err := verifyAlg(token.Headers, string(jose.HS256)); err != nil {
log.Printf("signature validation failure: %s", err)
return "", errors.New("signature validation failure")
}
err = token.Claims(QuerySigningKey, &standard)
if err = token.Claims(QuerySigningKey, &standard); err != nil {
log.Printf("cannot verify signature %s", err)
return "", errors.New("cannot verify signature")
}
// go-jose doesnt verify the expiry
err = standard.Validate(jwt.Expected{
Issuer: issuer,
Time: time.Now(),
})
if err != nil {
log.Printf("token validation failed due to %s", err)
return "", fmt.Errorf("token validation failed due to %s", err)
}
return standard.Subject, nil
}
// GenerateQueryToken this is a helper function for testing
func GenerateQueryToken(ctx context.Context, query string, issuer string) (string, error) {
if len(QuerySigningKey) < 32 {
return "", errors.New("query token encryption key not long enough or not specified")
}
claims := jwt.Claims{
Subject: query,
Expiry: jwt.NewNumericDate(time.Now().Add(time.Minute * 5)),
Issuer: issuer,
}
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: QuerySigningKey},
(&jose.SignerOptions{}).WithBase64(true))
if err != nil {
log.Printf("Cannot encrypt user token due to %s", err)
return "", err
}
token, err := jwt.Signed(sig).Claims(claims).CompactSerialize()
return token, err
}
func getSessionInfo(ctx context.Context) *protocol.SessionInfo { func getSessionInfo(ctx context.Context) *protocol.SessionInfo {
s, ok := ctx.Value("SessionInfo").(*protocol.SessionInfo) s, ok := ctx.Value("SessionInfo").(*protocol.SessionInfo)
if !ok { if !ok {