mirror of
https://github.com/bolkedebruin/rdpgw.git
synced 2026-03-30 07:26:35 +00:00
Enable signed hosts provied in query parameters
This commit is contained in:
@@ -27,6 +27,7 @@ const (
|
|||||||
|
|
||||||
type TokenGeneratorFunc func(context.Context, string, string) (string, error)
|
type TokenGeneratorFunc func(context.Context, string, string) (string, error)
|
||||||
type UserTokenGeneratorFunc func(context.Context, string) (string, error)
|
type UserTokenGeneratorFunc func(context.Context, string) (string, error)
|
||||||
|
type QueryInfoFunc func(context.Context, string, string) (string, error)
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
SessionKey []byte
|
SessionKey []byte
|
||||||
@@ -34,6 +35,8 @@ type Config struct {
|
|||||||
SessionStore string
|
SessionStore string
|
||||||
PAATokenGenerator TokenGeneratorFunc
|
PAATokenGenerator TokenGeneratorFunc
|
||||||
UserTokenGenerator UserTokenGeneratorFunc
|
UserTokenGenerator UserTokenGeneratorFunc
|
||||||
|
QueryInfo QueryInfoFunc
|
||||||
|
QueryTokenIssuer string
|
||||||
EnableUserToken bool
|
EnableUserToken bool
|
||||||
OAuth2Config *oauth2.Config
|
OAuth2Config *oauth2.Config
|
||||||
store sessions.Store
|
store sessions.Store
|
||||||
@@ -159,39 +162,53 @@ func (c *Config) selectRandomHost() string {
|
|||||||
return host
|
return host
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Config) getHost(u *url.URL) (string, error) {
|
func (c *Config) getHost(ctx context.Context, u *url.URL) (string, error) {
|
||||||
var host string
|
|
||||||
switch c.HostSelection {
|
switch c.HostSelection {
|
||||||
case "roundrobin":
|
case "roundrobin":
|
||||||
host = c.selectRandomHost()
|
return c.selectRandomHost(), nil
|
||||||
case "signed":
|
case "signed":
|
||||||
case "unsigned":
|
|
||||||
hosts, ok := u.Query()["host"]
|
hosts, ok := u.Query()["host"]
|
||||||
if !ok {
|
if !ok {
|
||||||
return "", errors.New("invalid query parameter")
|
return "", errors.New("invalid query parameter")
|
||||||
}
|
}
|
||||||
|
host, err := c.QueryInfo(ctx, hosts[0], c.QueryTokenIssuer)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
found := false
|
found := false
|
||||||
for _, check := range c.Hosts {
|
for _, check := range c.Hosts {
|
||||||
if check == hosts[0] {
|
if check == host {
|
||||||
host = hosts[0]
|
|
||||||
found = true
|
found = true
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !found {
|
if !found {
|
||||||
log.Printf("Invalid host %s specified in client request", hosts[0])
|
log.Printf("Invalid host %s specified in token", hosts[0])
|
||||||
return "", errors.New("invalid host specified in query parameter")
|
return "", errors.New("invalid host specified in query token")
|
||||||
}
|
}
|
||||||
|
return host, nil
|
||||||
|
case "unsigned":
|
||||||
|
hosts, ok := u.Query()["host"]
|
||||||
|
if !ok {
|
||||||
|
return "", errors.New("invalid query parameter")
|
||||||
|
}
|
||||||
|
for _, check := range c.Hosts {
|
||||||
|
if check == hosts[0] {
|
||||||
|
return hosts[0], nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// not found
|
||||||
|
log.Printf("Invalid host %s specified in client request", hosts[0])
|
||||||
|
return "", errors.New("invalid host specified in query parameter")
|
||||||
case "any":
|
case "any":
|
||||||
hosts, ok := u.Query()["host"]
|
hosts, ok := u.Query()["host"]
|
||||||
if !ok {
|
if !ok {
|
||||||
return "", errors.New("invalid query parameter")
|
return "", errors.New("invalid query parameter")
|
||||||
}
|
}
|
||||||
host = hosts[0]
|
return hosts[0], nil
|
||||||
default:
|
default:
|
||||||
host = c.selectRandomHost()
|
return c.selectRandomHost(), nil
|
||||||
}
|
}
|
||||||
return host, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Config) HandleDownload(w http.ResponseWriter, r *http.Request) {
|
func (c *Config) HandleDownload(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -205,7 +222,7 @@ func (c *Config) HandleDownload(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// determine host to connect to
|
// determine host to connect to
|
||||||
host, err := c.getHost(r.URL)
|
host, err := c.getHost(ctx, r.URL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -1,12 +1,15 @@
|
|||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/security"
|
||||||
"net/url"
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
hosts = []string{"10.0.0.1:3389", "10.1.1.1:3000", "32.32.11.1", "remote.host.com"}
|
hosts = []string{"10.0.0.1:3389", "10.1.1.1:3000", "32.32.11.1", "remote.host.com"}
|
||||||
|
key = []byte("thisisasessionkeyreplacethisjetzt")
|
||||||
)
|
)
|
||||||
|
|
||||||
func contains(needle string, haystack []string) bool {
|
func contains(needle string, haystack []string) bool {
|
||||||
@@ -19,6 +22,7 @@ func contains(needle string, haystack []string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestGetHost(t *testing.T) {
|
func TestGetHost(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
c := Config{
|
c := Config{
|
||||||
HostSelection: "roundrobin",
|
HostSelection: "roundrobin",
|
||||||
Hosts: hosts,
|
Hosts: hosts,
|
||||||
@@ -28,7 +32,7 @@ func TestGetHost(t *testing.T) {
|
|||||||
}
|
}
|
||||||
vals := u.Query()
|
vals := u.Query()
|
||||||
|
|
||||||
host, err := c.getHost(u)
|
host, err := c.getHost(ctx, u)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("#{err}")
|
t.Fatalf("#{err}")
|
||||||
}
|
}
|
||||||
@@ -40,14 +44,14 @@ func TestGetHost(t *testing.T) {
|
|||||||
c.HostSelection = "unsigned"
|
c.HostSelection = "unsigned"
|
||||||
vals.Set("host", "in.valid.host")
|
vals.Set("host", "in.valid.host")
|
||||||
u.RawQuery = vals.Encode()
|
u.RawQuery = vals.Encode()
|
||||||
host, err = c.getHost(u)
|
host, err = c.getHost(ctx, u)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatalf("Accepted host %s is not in hosts list", host)
|
t.Fatalf("Accepted host %s is not in hosts list", host)
|
||||||
}
|
}
|
||||||
|
|
||||||
vals.Set("host", hosts[0])
|
vals.Set("host", hosts[0])
|
||||||
u.RawQuery = vals.Encode()
|
u.RawQuery = vals.Encode()
|
||||||
host, err = c.getHost(u)
|
host, err = c.getHost(ctx, u)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Not accepted host %s is in hosts list (err: %s)", hosts[0], err)
|
t.Fatalf("Not accepted host %s is in hosts list (err: %s)", hosts[0], err)
|
||||||
}
|
}
|
||||||
@@ -55,4 +59,35 @@ func TestGetHost(t *testing.T) {
|
|||||||
t.Fatalf("host %s is not equal to input %s", host, hosts[0])
|
t.Fatalf("host %s is not equal to input %s", host, hosts[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// check any
|
||||||
|
c.HostSelection = "any"
|
||||||
|
test := "bla.bla.com"
|
||||||
|
vals.Set("host", test)
|
||||||
|
u.RawQuery = vals.Encode()
|
||||||
|
host, err = c.getHost(ctx, u)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("%s is not accepted", host)
|
||||||
|
}
|
||||||
|
if test != host {
|
||||||
|
t.Fatalf("Returned host %s is not equal to input host %s", host, test)
|
||||||
|
}
|
||||||
|
|
||||||
|
// check signed
|
||||||
|
c.HostSelection = "signed"
|
||||||
|
c.QueryInfo = security.QueryInfo
|
||||||
|
issuer := "rdpgwtest"
|
||||||
|
security.QuerySigningKey = key
|
||||||
|
queryToken, err := security.GenerateQueryToken(ctx, hosts[0], issuer)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("cannot generate token")
|
||||||
|
}
|
||||||
|
vals.Set("host", queryToken)
|
||||||
|
u.RawQuery = vals.Encode()
|
||||||
|
host, err = c.getHost(ctx, u)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Not accepted host %s is in hosts list (err: %s)", hosts[0], err)
|
||||||
|
}
|
||||||
|
if host != hosts[0] {
|
||||||
|
t.Fatalf("%s does not equal %s", host, hosts[0])
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -58,6 +58,8 @@ type SecurityConfig struct {
|
|||||||
PAATokenSigningKey string `koanf:"paatokensigningkey"`
|
PAATokenSigningKey string `koanf:"paatokensigningkey"`
|
||||||
UserTokenEncryptionKey string `koanf:"usertokenencryptionkey"`
|
UserTokenEncryptionKey string `koanf:"usertokenencryptionkey"`
|
||||||
UserTokenSigningKey string `koanf:"usertokensigningkey"`
|
UserTokenSigningKey string `koanf:"usertokensigningkey"`
|
||||||
|
QueryTokenSigningKey string `koanf:"querytokensigningkey"`
|
||||||
|
QueryTokenIssuer string `koanf:"querytokenissuer"`
|
||||||
VerifyClientIp bool `koanf:"verifyclientip"`
|
VerifyClientIp bool `koanf:"verifyclientip"`
|
||||||
EnableUserToken bool `koanf:"enableusertoken"`
|
EnableUserToken bool `koanf:"enableusertoken"`
|
||||||
}
|
}
|
||||||
@@ -176,6 +178,10 @@ func Load(configFile string) Configuration {
|
|||||||
log.Printf("No valid `server.sessionencryptionkey` specified (empty or not 32 characters). Setting to random")
|
log.Printf("No valid `server.sessionencryptionkey` specified (empty or not 32 characters). Setting to random")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if Conf.Server.HostSelection == "signed" && len(Conf.Security.QueryTokenSigningKey) == 0 {
|
||||||
|
log.Fatalf("host selection is set to `signed` but `querytokensigningkey` is not set")
|
||||||
|
}
|
||||||
|
|
||||||
return Conf
|
return Conf
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ func main() {
|
|||||||
security.EncryptionKey = []byte(conf.Security.PAATokenEncryptionKey)
|
security.EncryptionKey = []byte(conf.Security.PAATokenEncryptionKey)
|
||||||
security.UserEncryptionKey = []byte(conf.Security.UserTokenEncryptionKey)
|
security.UserEncryptionKey = []byte(conf.Security.UserTokenEncryptionKey)
|
||||||
security.UserSigningKey = []byte(conf.Security.UserTokenSigningKey)
|
security.UserSigningKey = []byte(conf.Security.UserTokenSigningKey)
|
||||||
|
security.QuerySigningKey = []byte(conf.Security.QueryTokenSigningKey)
|
||||||
|
|
||||||
// set oidc config
|
// set oidc config
|
||||||
provider, err := oidc.NewProvider(context.Background(), conf.OpenId.ProviderUrl)
|
provider, err := oidc.NewProvider(context.Background(), conf.OpenId.ProviderUrl)
|
||||||
@@ -74,6 +75,8 @@ func main() {
|
|||||||
OIDCTokenVerifier: verifier,
|
OIDCTokenVerifier: verifier,
|
||||||
PAATokenGenerator: security.GeneratePAAToken,
|
PAATokenGenerator: security.GeneratePAAToken,
|
||||||
UserTokenGenerator: security.GenerateUserToken,
|
UserTokenGenerator: security.GenerateUserToken,
|
||||||
|
QueryInfo: security.QueryInfo,
|
||||||
|
QueryTokenIssuer: conf.Security.QueryTokenIssuer,
|
||||||
EnableUserToken: conf.Security.EnableUserToken,
|
EnableUserToken: conf.Security.EnableUserToken,
|
||||||
SessionKey: []byte(conf.Server.SessionKey),
|
SessionKey: []byte(conf.Server.SessionKey),
|
||||||
SessionEncryptionKey: []byte(conf.Server.SessionEncryptionKey),
|
SessionEncryptionKey: []byte(conf.Server.SessionEncryptionKey),
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ var (
|
|||||||
EncryptionKey []byte
|
EncryptionKey []byte
|
||||||
UserSigningKey []byte
|
UserSigningKey []byte
|
||||||
UserEncryptionKey []byte
|
UserEncryptionKey []byte
|
||||||
|
QuerySigningKey []byte
|
||||||
OIDCProvider *oidc.Provider
|
OIDCProvider *oidc.Provider
|
||||||
Oauth2Config oauth2.Config
|
Oauth2Config oauth2.Config
|
||||||
)
|
)
|
||||||
@@ -221,6 +222,61 @@ func UserInfo(ctx context.Context, token string) (jwt.Claims, error) {
|
|||||||
return standard, nil
|
return standard, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func QueryInfo(ctx context.Context, tokenString string, issuer string) (string, error) {
|
||||||
|
standard := jwt.Claims{}
|
||||||
|
token, err := jwt.ParseSigned(tokenString)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Cannot get token %s", err)
|
||||||
|
return "", errors.New("cannot get token")
|
||||||
|
}
|
||||||
|
if _, err := verifyAlg(token.Headers, string(jose.HS256)); err != nil {
|
||||||
|
log.Printf("signature validation failure: %s", err)
|
||||||
|
return "", errors.New("signature validation failure")
|
||||||
|
}
|
||||||
|
err = token.Claims(QuerySigningKey, &standard)
|
||||||
|
if err = token.Claims(QuerySigningKey, &standard); err != nil {
|
||||||
|
log.Printf("cannot verify signature %s", err)
|
||||||
|
return "", errors.New("cannot verify signature")
|
||||||
|
}
|
||||||
|
|
||||||
|
// go-jose doesnt verify the expiry
|
||||||
|
err = standard.Validate(jwt.Expected{
|
||||||
|
Issuer: issuer,
|
||||||
|
Time: time.Now(),
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("token validation failed due to %s", err)
|
||||||
|
return "", fmt.Errorf("token validation failed due to %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return standard.Subject, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateQueryToken this is a helper function for testing
|
||||||
|
func GenerateQueryToken(ctx context.Context, query string, issuer string) (string, error) {
|
||||||
|
if len(QuerySigningKey) < 32 {
|
||||||
|
return "", errors.New("query token encryption key not long enough or not specified")
|
||||||
|
}
|
||||||
|
|
||||||
|
claims := jwt.Claims{
|
||||||
|
Subject: query,
|
||||||
|
Expiry: jwt.NewNumericDate(time.Now().Add(time.Minute * 5)),
|
||||||
|
Issuer: issuer,
|
||||||
|
}
|
||||||
|
|
||||||
|
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: QuerySigningKey},
|
||||||
|
(&jose.SignerOptions{}).WithBase64(true))
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Cannot encrypt user token due to %s", err)
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := jwt.Signed(sig).Claims(claims).CompactSerialize()
|
||||||
|
return token, err
|
||||||
|
}
|
||||||
|
|
||||||
func getSessionInfo(ctx context.Context) *protocol.SessionInfo {
|
func getSessionInfo(ctx context.Context) *protocol.SessionInfo {
|
||||||
s, ok := ctx.Value("SessionInfo").(*protocol.SessionInfo)
|
s, ok := ctx.Value("SessionInfo").(*protocol.SessionInfo)
|
||||||
if !ok {
|
if !ok {
|
||||||
|
|||||||
Reference in New Issue
Block a user