Refactor identity and http routing

This commit is contained in:
Bolke de Bruin
2022-10-18 09:36:41 +02:00
parent b42c3cd3cc
commit db98550455
18 changed files with 402 additions and 199 deletions

View File

@@ -48,7 +48,7 @@ type ServerConfig struct {
SendBuf int `koanf:"sendbuf"` SendBuf int `koanf:"sendbuf"`
ReceiveBuf int `koanf:"receivebuf"` ReceiveBuf int `koanf:"receivebuf"`
Tls string `koanf:"tls"` Tls string `koanf:"tls"`
Authentication string `koanf:"authentication"` Authentication []string `koanf:"authentication"`
AuthSocket string `koanf:"authsocket"` AuthSocket string `koanf:"authsocket"`
} }
@@ -206,15 +206,15 @@ func Load(configFile string) Configuration {
log.Fatalf("host selection is set to `signed` but `querytokensigningkey` is not set") log.Fatalf("host selection is set to `signed` but `querytokensigningkey` is not set")
} }
if Conf.Server.Authentication == "local" && Conf.Server.Tls == "disable" { if Conf.Server.BasicAuthEnabled() && Conf.Server.Tls == "disable" {
log.Fatalf("basicauth=local and tls=disable are mutually exclusive") log.Fatalf("basicauth=local and tls=disable are mutually exclusive")
} }
if !Conf.Caps.TokenAuth && Conf.Server.Authentication == "openid" { if !Conf.Caps.TokenAuth && Conf.Server.OpenIDEnabled() {
log.Fatalf("openid is configured but tokenauth disabled") log.Fatalf("openid is configured but tokenauth disabled")
} }
if Conf.Server.Authentication == AuthenticationKerberos && Conf.Kerberos.Keytab == "" { if Conf.Server.KerberosEnabled() && Conf.Kerberos.Keytab == "" {
log.Fatalf("kerberos is configured but no keytab was specified") log.Fatalf("kerberos is configured but no keytab was specified")
} }
@@ -226,3 +226,24 @@ func Load(configFile string) Configuration {
return Conf return Conf
} }
func (s *ServerConfig) OpenIDEnabled() bool {
return s.matchAuth("openid")
}
func (s *ServerConfig) KerberosEnabled() bool {
return s.matchAuth("kerberos")
}
func (s *ServerConfig) BasicAuthEnabled() bool {
return s.matchAuth("local")
}
func (s *ServerConfig) matchAuth(needle string) bool {
for _, q := range s.Authentication {
if q == needle {
return true
}
}
return false
}

View File

@@ -0,0 +1,57 @@
package identity
import (
"context"
"net/http"
"time"
)
const (
CTXKey = "github.com/bolkedebruin/rdpgw/common/identity"
AttrRemoteAddr = "remoteAddr"
AttrClientIp = "clientIp"
AttrProxies = "proxyAddresses"
AttrAccessToken = "accessToken" // todo remove for security reasons
)
type Identity interface {
UserName() string
SetUserName(string)
DisplayName() string
SetDisplayName(string)
Domain() string
SetDomain(string)
Authenticated() bool
SetAuthenticated(bool)
AuthTime() time.Time
SetAuthTime(time2 time.Time)
SessionId() string
SetAttribute(string, interface{})
GetAttribute(string) interface{}
Attributes() map[string]interface{}
DelAttribute(string)
Email() string
SetEmail(string)
Expiry() time.Time
SetExpiry(time.Time)
Marshal() ([]byte, error)
Unmarshal([]byte) error
}
func AddToRequestCtx(id Identity, r *http.Request) *http.Request {
ctx := r.Context()
ctx = context.WithValue(ctx, CTXKey, id)
return r.WithContext(ctx)
}
func FromRequestCtx(r *http.Request) Identity {
return FromCtx(r.Context())
}
func FromCtx(ctx context.Context) Identity {
if id, ok := ctx.Value(CTXKey).(Identity); ok {
return id
}
return nil
}

View File

@@ -0,0 +1,28 @@
package identity
import (
"log"
"testing"
)
func TestMarshalling(t *testing.T) {
u := NewUser()
u.SetUserName("ANAME")
u.SetAuthenticated(true)
u.SetDomain("DOMAIN")
c := NewUser()
data, err := u.Marshal()
if err != nil {
log.Fatalf("Cannot marshal %s", err)
}
err = c.Unmarshal(data)
if err != nil {
t.Fatalf("Error while unmarshalling: %s", err)
}
if u.UserName() != c.UserName() || u.Authenticated() != c.Authenticated() || u.Domain() != c.Domain() {
t.Fatalf("identities not equal: %+v != %+v", u, c)
}
}

View File

@@ -1,60 +1,12 @@
package common package identity
import ( import (
"context" "bytes"
"encoding/gob"
"github.com/google/uuid" "github.com/google/uuid"
"net/http"
"time" "time"
) )
const (
CTXKey = "github.com/bolkedebruin/rdpgw/common/identity"
AttrRemoteAddr = "remoteAddr"
AttrClientIp = "clientIp"
AttrProxies = "proxyAddresses"
AttrAccessToken = "accessToken" // todo remove for security reasons
)
type Identity interface {
UserName() string
SetUserName(string)
DisplayName() string
SetDisplayName(string)
Domain() string
SetDomain(string)
Authenticated() bool
SetAuthenticated(bool)
AuthTime() time.Time
SetAuthTime(time2 time.Time)
SessionId() string
SetAttribute(string, interface{})
GetAttribute(string) interface{}
Attributes() map[string]interface{}
DelAttribute(string)
Email() string
SetEmail(string)
Expiry() time.Time
SetExpiry(time.Time)
}
func AddToRequestCtx(id Identity, r *http.Request) *http.Request {
ctx := r.Context()
ctx = context.WithValue(ctx, CTXKey, id)
return r.WithContext(ctx)
}
func FromRequestCtx(r *http.Request) Identity {
return FromCtx(r.Context())
}
func FromCtx(ctx context.Context) Identity {
if id, ok := ctx.Value(CTXKey).(Identity); ok {
return id
}
return nil
}
type User struct { type User struct {
authenticated bool authenticated bool
domain string domain string
@@ -68,6 +20,19 @@ type User struct {
groupMembership map[string]bool groupMembership map[string]bool
} }
type user struct {
Authenticated bool
UserName string
Domain string
DisplayName string
Email string
AuthTime time.Time
SessionId string
Expiry time.Time
Attributes map[string]interface{}
GroupMembership map[string]bool
}
func NewUser() *User { func NewUser() *User {
uuid := uuid.New().String() uuid := uuid.New().String()
return &User{ return &User{
@@ -158,3 +123,48 @@ func (u *User) Expiry() time.Time {
func (u *User) SetExpiry(t time.Time) { func (u *User) SetExpiry(t time.Time) {
u.expiry = t u.expiry = t
} }
func (u *User) Marshal() ([]byte, error) {
buf := new(bytes.Buffer)
enc := gob.NewEncoder(buf)
uu := user{
Authenticated: u.authenticated,
UserName: u.userName,
Domain: u.domain,
DisplayName: u.displayName,
Email: u.email,
AuthTime: u.authTime,
SessionId: u.sessionId,
Expiry: u.expiry,
Attributes: u.attributes,
GroupMembership: u.groupMembership,
}
err := enc.Encode(uu)
if err != nil {
return []byte{}, err
}
return buf.Bytes(), nil
}
func (u *User) Unmarshal(b []byte) error {
buf := bytes.NewBuffer(b)
dec := gob.NewDecoder(buf)
var uu user
err := dec.Decode(&uu)
if err != nil {
return err
}
u.sessionId = uu.SessionId
u.userName = uu.UserName
u.domain = uu.Domain
u.displayName = uu.DisplayName
u.email = uu.Email
u.authenticated = uu.Authenticated
u.authTime = uu.AuthTime
u.expiry = uu.Expiry
u.attributes = uu.Attributes
u.groupMembership = uu.GroupMembership
return nil
}

View File

@@ -7,14 +7,13 @@ import (
"github.com/bolkedebruin/gokrb5/v8/keytab" "github.com/bolkedebruin/gokrb5/v8/keytab"
"github.com/bolkedebruin/gokrb5/v8/service" "github.com/bolkedebruin/gokrb5/v8/service"
"github.com/bolkedebruin/gokrb5/v8/spnego" "github.com/bolkedebruin/gokrb5/v8/spnego"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/config" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/config"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/kdcproxy" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/kdcproxy"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/protocol" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/protocol"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/security" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/security"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/web" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/web"
"github.com/coreos/go-oidc/v3/oidc" "github.com/coreos/go-oidc/v3/oidc"
"github.com/gorilla/sessions" "github.com/gorilla/mux"
"github.com/prometheus/client_golang/prometheus/promhttp" "github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/thought-machine/go-flags" "github.com/thought-machine/go-flags"
"golang.org/x/crypto/acme/autocert" "golang.org/x/crypto/acme/autocert"
@@ -26,13 +25,18 @@ import (
"strconv" "strconv"
) )
const (
gatewayEndPoint = "/remoteDesktopGateway/"
kdcProxyEndPoint = "/KdcProxy"
)
var opts struct { var opts struct {
ConfigFile string `short:"c" long:"conf" default:"rdpgw.yaml" description:"config file (yaml)"` ConfigFile string `short:"c" long:"conf" default:"rdpgw.yaml" description:"config file (yaml)"`
} }
var conf config.Configuration var conf config.Configuration
func initOIDC(callbackUrl *url.URL, store sessions.Store) *web.OIDC { func initOIDC(callbackUrl *url.URL) *web.OIDC {
// set oidc config // set oidc config
provider, err := oidc.NewProvider(context.Background(), conf.OpenId.ProviderUrl) provider, err := oidc.NewProvider(context.Background(), conf.OpenId.ProviderUrl)
if err != nil { if err != nil {
@@ -56,7 +60,6 @@ func initOIDC(callbackUrl *url.URL, store sessions.Store) *web.OIDC {
o := web.OIDCConfig{ o := web.OIDCConfig{
OAuth2Config: &oauthConfig, OAuth2Config: &oauthConfig,
OIDCTokenVerifier: verifier, OIDCTokenVerifier: verifier,
SessionStore: store,
} }
return o.New() return o.New()
@@ -91,19 +94,13 @@ func main() {
security.Hosts = conf.Server.Hosts security.Hosts = conf.Server.Hosts
// init session store // init session store
sessionConf := web.SessionManagerConf{ web.InitStore([]byte(conf.Server.SessionKey), []byte(conf.Server.SessionEncryptionKey), conf.Server.SessionStore)
SessionKey: []byte(conf.Server.SessionKey),
SessionEncryptionKey: []byte(conf.Server.SessionEncryptionKey),
StoreType: conf.Server.SessionStore,
}
store := sessionConf.Init()
// configure web backend // configure web backend
w := &web.Config{ w := &web.Config{
QueryInfo: security.QueryInfo, QueryInfo: security.QueryInfo,
QueryTokenIssuer: conf.Security.QueryTokenIssuer, QueryTokenIssuer: conf.Security.QueryTokenIssuer,
EnableUserToken: conf.Security.EnableUserToken, EnableUserToken: conf.Security.EnableUserToken,
SessionStore: store,
Hosts: conf.Server.Hosts, Hosts: conf.Server.Hosts,
HostSelection: conf.Server.HostSelection, HostSelection: conf.Server.HostSelection,
RdpOpts: web.RdpOpts{ RdpOpts: web.RdpOpts{
@@ -128,6 +125,7 @@ func main() {
log.Printf("Starting remote desktop gateway server") log.Printf("Starting remote desktop gateway server")
cfg := &tls.Config{} cfg := &tls.Config{}
// configure tls security
if conf.Server.Tls == config.TlsDisable { if conf.Server.Tls == config.TlsDisable {
log.Printf("TLS disabled - rdp gw connections require tls, make sure to have a terminator") log.Printf("TLS disabled - rdp gw connections require tls, make sure to have a terminator")
} else { } else {
@@ -174,13 +172,7 @@ func main() {
} }
} }
server := http.Server{ // gateway confg
Addr: ":" + strconv.Itoa(conf.Server.Port),
TLSConfig: cfg,
TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), // disable http2
}
// create the gateway
gw := protocol.Gateway{ gw := protocol.Gateway{
RedirectFlags: protocol.RedirectFlags{ RedirectFlags: protocol.RedirectFlags{
Clipboard: conf.Caps.EnableClipboard, Clipboard: conf.Caps.EnableClipboard,
@@ -205,31 +197,72 @@ func main() {
gw.CheckHost = security.CheckHost gw.CheckHost = security.CheckHost
} }
if conf.Server.Authentication == config.AuthenticationBasic { r := mux.NewRouter()
h := web.BasicAuthHandler{SocketAddress: conf.Server.AuthSocket}
http.Handle("/remoteDesktopGateway/", common.EnrichContext(h.BasicAuth(gw.HandleGatewayProtocol))) // ensure identity is set in context and get some extra info
} else if conf.Server.Authentication == config.AuthenticationKerberos { r.Use(web.EnrichContext)
// prometheus metrics
r.Handle("/metrics", promhttp.Handler())
// for sso callbacks
r.HandleFunc("/tokeninfo", web.TokenInfo)
// gateway endpoint
rdp := r.PathPrefix(gatewayEndPoint).Subrouter()
// openid
if conf.Server.OpenIDEnabled() {
log.Printf("enabling openid extended authentication")
o := initOIDC(url)
r.Handle("/connect", o.Authenticated(http.HandlerFunc(h.HandleDownload)))
r.HandleFunc("/callback", o.HandleCallback)
// only enable un-auth endpoint for openid only config
if !conf.Server.KerberosEnabled() || !conf.Server.BasicAuthEnabled() {
rdp.Name("gw").HandlerFunc(gw.HandleGatewayProtocol)
}
}
// for stacking of authentication
auth := web.NewAuthMux()
// basic auth
if conf.Server.BasicAuthEnabled() {
log.Printf("enabling basic authentication")
q := web.BasicAuthHandler{SocketAddress: conf.Server.AuthSocket}
rdp.Headers("Authorization", "Basic*").HandlerFunc(q.BasicAuth(gw.HandleGatewayProtocol))
auth.Register(`Basic realm="restricted", charset="UTF-8"`)
}
// spnego / kerberos
if conf.Server.KerberosEnabled() {
log.Printf("enabling kerberos authentication")
keytab, err := keytab.Load(conf.Kerberos.Keytab) keytab, err := keytab.Load(conf.Kerberos.Keytab)
if err != nil { if err != nil {
log.Fatalf("Cannot load keytab: %s", err) log.Fatalf("Cannot load keytab: %s", err)
} }
http.Handle("/remoteDesktopGateway/", common.EnrichContext( rdp.Headers("Authorization", "Negotiate*").Handler(
spnego.SPNEGOKRB5Authenticate( spnego.SPNEGOKRB5Authenticate(web.TransposeSPNEGOContext(http.HandlerFunc(gw.HandleGatewayProtocol)),
common.FixKerberosContext(http.HandlerFunc(gw.HandleGatewayProtocol)),
keytab, keytab,
service.Logger(log.Default()))), service.Logger(log.Default())))
)
// kdcproxy
k := kdcproxy.InitKdcProxy(conf.Kerberos.Krb5Conf) k := kdcproxy.InitKdcProxy(conf.Kerberos.Krb5Conf)
http.HandleFunc("/KdcProxy", k.Handler) r.HandleFunc(kdcProxyEndPoint, k.Handler).Methods("POST")
} else { auth.Register("Negotiate")
// openid }
oidc := initOIDC(url, store)
http.Handle("/connect", common.EnrichContext(oidc.Authenticated(http.HandlerFunc(h.HandleDownload)))) // allow stacking of authentication
http.Handle("/remoteDesktopGateway/", common.EnrichContext(http.HandlerFunc(gw.HandleGatewayProtocol))) rdp.Use(auth.Route)
http.Handle("/callback", common.EnrichContext(http.HandlerFunc(oidc.HandleCallback)))
// setup server
server := http.Server{
Addr: ":" + strconv.Itoa(conf.Server.Port),
Handler: r,
TLSConfig: cfg,
TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), // disable http2
} }
http.Handle("/metrics", promhttp.Handler())
http.HandleFunc("/tokeninfo", web.TokenInfo)
if conf.Server.Tls == config.TlsDisable { if conf.Server.Tls == config.TlsDisable {
err = server.ListenAndServe() err = server.ListenAndServe()

View File

@@ -3,7 +3,7 @@ package protocol
import ( import (
"context" "context"
"errors" "errors"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
@@ -61,14 +61,14 @@ func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request)
var t *Tunnel var t *Tunnel
ctx := r.Context() ctx := r.Context()
id := common.FromRequestCtx(r) id := identity.FromRequestCtx(r)
connId := r.Header.Get(rdgConnectionIdKey) connId := r.Header.Get(rdgConnectionIdKey)
x, found := c.Get(connId) x, found := c.Get(connId)
if !found { if !found {
t = &Tunnel{ t = &Tunnel{
RDGId: connId, RDGId: connId,
RemoteAddr: id.GetAttribute(common.AttrRemoteAddr).(string), RemoteAddr: id.GetAttribute(identity.AttrRemoteAddr).(string),
User: id, User: id,
} }
} else { } else {
@@ -183,14 +183,14 @@ func (g *Gateway) handleWebsocketProtocol(ctx context.Context, c *websocket.Conn
func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, t *Tunnel) { func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, t *Tunnel) {
log.Printf("Session %s, %t, %t", t.RDGId, t.transportOut != nil, t.transportIn != nil) log.Printf("Session %s, %t, %t", t.RDGId, t.transportOut != nil, t.transportIn != nil)
id := common.FromRequestCtx(r) id := identity.FromRequestCtx(r)
if r.Method == MethodRDGOUT { if r.Method == MethodRDGOUT {
out, err := transport.NewLegacy(w) out, err := transport.NewLegacy(w)
if err != nil { if err != nil {
log.Printf("cannot hijack connection to support RDG OUT data channel: %s", err) log.Printf("cannot hijack connection to support RDG OUT data channel: %s", err)
return return
} }
log.Printf("Opening RDGOUT for client %s", id.GetAttribute(common.AttrClientIp)) log.Printf("Opening RDGOUT for client %s", id.GetAttribute(identity.AttrClientIp))
t.transportOut = out t.transportOut = out
out.SendAccept(true) out.SendAccept(true)
@@ -212,13 +212,13 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, t
t.transportIn = in t.transportIn = in
c.Set(t.RDGId, t, cache.DefaultExpiration) c.Set(t.RDGId, t, cache.DefaultExpiration)
log.Printf("Opening RDGIN for client %s", id.GetAttribute(common.AttrClientIp)) log.Printf("Opening RDGIN for client %s", id.GetAttribute(identity.AttrClientIp))
in.SendAccept(false) in.SendAccept(false)
// read some initial data // read some initial data
in.Drain() in.Drain()
log.Printf("Legacy handshakeRequest done for client %s", id.GetAttribute(common.AttrClientIp)) log.Printf("Legacy handshakeRequest done for client %s", id.GetAttribute(identity.AttrClientIp))
handler := NewProcessor(g, t) handler := NewProcessor(g, t)
RegisterTunnel(t, handler) RegisterTunnel(t, handler)
defer RemoveTunnel(t) defer RemoveTunnel(t)

View File

@@ -6,7 +6,7 @@ import (
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
"io" "io"
"log" "log"
"net" "net"
@@ -51,7 +51,7 @@ func (p *Processor) Process(ctx context.Context) error {
switch pt { switch pt {
case PKT_TYPE_HANDSHAKE_REQUEST: case PKT_TYPE_HANDSHAKE_REQUEST:
log.Printf("Client handshakeRequest from %s", p.tunnel.User.GetAttribute(common.AttrClientIp)) log.Printf("Client handshakeRequest from %s", p.tunnel.User.GetAttribute(identity.AttrClientIp))
if p.state != SERVER_STATE_INITIALIZED { if p.state != SERVER_STATE_INITIALIZED {
log.Printf("Handshake attempted while in wrong state %d != %d", p.state, SERVER_STATE_INITIALIZED) log.Printf("Handshake attempted while in wrong state %d != %d", p.state, SERVER_STATE_INITIALIZED)
msg := p.handshakeResponse(0x0, 0x0, 0, E_PROXY_INTERNALERROR) msg := p.handshakeResponse(0x0, 0x0, 0, E_PROXY_INTERNALERROR)
@@ -81,7 +81,7 @@ func (p *Processor) Process(ctx context.Context) error {
_, cookie := p.tunnelRequest(pkt) _, cookie := p.tunnelRequest(pkt)
if p.gw.CheckPAACookie != nil { if p.gw.CheckPAACookie != nil {
if ok, _ := p.gw.CheckPAACookie(ctx, cookie); !ok { if ok, _ := p.gw.CheckPAACookie(ctx, cookie); !ok {
log.Printf("Invalid PAA cookie received from client %s", p.tunnel.User.GetAttribute(common.AttrClientIp)) log.Printf("Invalid PAA cookie received from client %s", p.tunnel.User.GetAttribute(identity.AttrClientIp))
msg := p.tunnelResponse(E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED) msg := p.tunnelResponse(E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED)
p.tunnel.Write(msg) p.tunnel.Write(msg)
return fmt.Errorf("%x: invalid PAA cookie", E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED) return fmt.Errorf("%x: invalid PAA cookie", E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED)

View File

@@ -1,7 +1,7 @@
package protocol package protocol
import ( import (
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport"
"net" "net"
"time" "time"
@@ -27,7 +27,7 @@ type Tunnel struct {
// The obtained client ip address // The obtained client ip address
RemoteAddr string RemoteAddr string
// User // User
User common.Identity User identity.Identity
// rwc is the underlying connection to the remote desktop server. // rwc is the underlying connection to the remote desktop server.
// It is of the type *net.TCPConn // It is of the type *net.TCPConn

View File

@@ -2,7 +2,7 @@ package security
import ( import (
"context" "context"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/protocol" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/protocol"
"testing" "testing"
) )
@@ -18,7 +18,7 @@ var (
) )
func TestCheckHost(t *testing.T) { func TestCheckHost(t *testing.T) {
info.User = common.NewUser() info.User = identity.NewUser()
info.User.SetUserName("MYNAME") info.User.SetUserName("MYNAME")
ctx := context.WithValue(context.Background(), protocol.CtxTunnel, &info) ctx := context.WithValue(context.Background(), protocol.CtxTunnel, &info)

View File

@@ -4,7 +4,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/protocol" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/protocol"
"github.com/coreos/go-oidc/v3/oidc" "github.com/coreos/go-oidc/v3/oidc"
"github.com/go-jose/go-jose/v3" "github.com/go-jose/go-jose/v3"
@@ -46,10 +46,10 @@ func CheckSession(next protocol.CheckHostFunc) protocol.CheckHostFunc {
} }
// use identity from context rather then set by tunnel // use identity from context rather then set by tunnel
id := common.FromCtx(ctx) id := identity.FromCtx(ctx)
if VerifyClientIP && tunnel.RemoteAddr != id.GetAttribute(common.AttrClientIp) { if VerifyClientIP && tunnel.RemoteAddr != id.GetAttribute(identity.AttrClientIp) {
log.Printf("Current client ip address %s does not match token client ip %s", log.Printf("Current client ip address %s does not match token client ip %s",
id.GetAttribute(common.AttrClientIp), tunnel.RemoteAddr) id.GetAttribute(identity.AttrClientIp), tunnel.RemoteAddr)
return false, nil return false, nil
} }
return next(ctx, host) return next(ctx, host)
@@ -129,11 +129,11 @@ func GeneratePAAToken(ctx context.Context, username string, server string) (stri
Subject: username, Subject: username,
} }
id := common.FromCtx(ctx) id := identity.FromCtx(ctx)
private := customClaims{ private := customClaims{
RemoteServer: server, RemoteServer: server,
ClientIP: id.GetAttribute(common.AttrClientIp).(string), ClientIP: id.GetAttribute(identity.AttrClientIp).(string),
AccessToken: id.GetAttribute(common.AttrAccessToken).(string), AccessToken: id.GetAttribute(identity.AttrAccessToken).(string),
} }
if token, err := jwt.Signed(sig).Claims(standard).Claims(private).CompactSerialize(); err != nil { if token, err := jwt.Signed(sig).Claims(standard).Claims(private).CompactSerialize(); err != nil {

View File

@@ -2,7 +2,7 @@ package web
import ( import (
"context" "context"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
"github.com/bolkedebruin/rdpgw/shared/auth" "github.com/bolkedebruin/rdpgw/shared/auth"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/insecure"
@@ -53,11 +53,11 @@ func (h *BasicAuthHandler) BasicAuth(next http.HandlerFunc) http.HandlerFunc {
log.Printf("User %s is not authenticated for this service", username) log.Printf("User %s is not authenticated for this service", username)
} else { } else {
log.Printf("User %s authenticated", username) log.Printf("User %s authenticated", username)
id := common.FromRequestCtx(r) id := identity.FromRequestCtx(r)
id.SetUserName(username) id.SetUserName(username)
id.SetAuthenticated(true) id.SetAuthenticated(true)
id.SetAuthTime(time.Now()) id.SetAuthTime(time.Now())
next.ServeHTTP(w, common.AddToRequestCtx(id, r)) next.ServeHTTP(w, identity.AddToRequestCtx(id, r))
return return
} }
@@ -66,7 +66,7 @@ func (h *BasicAuthHandler) BasicAuth(next http.HandlerFunc) http.HandlerFunc {
// username or password is wrong, then set a WWW-Authenticate // username or password is wrong, then set a WWW-Authenticate
// header to inform the client that we expect them to use basic // header to inform the client that we expect them to use basic
// authentication and send a 401 Unauthorized response. // authentication and send a 401 Unauthorized response.
w.Header().Set("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`) w.Header().Add("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`)
http.Error(w, "Unauthorized", http.StatusUnauthorized) http.Error(w, "Unauthorized", http.StatusUnauthorized)
} }
} }

View File

@@ -1,7 +1,7 @@
package common package web
import ( import (
"context" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
"github.com/jcmturner/goidentity/v6" "github.com/jcmturner/goidentity/v6"
"log" "log"
"net" "net"
@@ -9,16 +9,22 @@ import (
"strings" "strings"
) )
const (
CtxAccessToken = "github.com/bolkedebruin/rdpgw/oidc/access_token"
)
func EnrichContext(next http.Handler) http.Handler { func EnrichContext(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
id := FromRequestCtx(r) id, err := GetSessionIdentity(r)
if id == nil { if err != nil {
id = NewUser() http.Error(w, err.Error(), http.StatusInternalServerError)
return
} }
if id == nil {
id = identity.NewUser()
if err := SaveSessionIdentity(r, w, id); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}
log.Printf("Identity SessionId: %s, UserName: %s: Authenticated: %t", log.Printf("Identity SessionId: %s, UserName: %s: Authenticated: %t",
id.SessionId(), id.UserName(), id.Authenticated()) id.SessionId(), id.UserName(), id.Authenticated())
@@ -33,39 +39,30 @@ func EnrichContext(next http.Handler) http.Handler {
if len(ips) > 1 { if len(ips) > 1 {
proxies = ips[1:] proxies = ips[1:]
} }
id.SetAttribute(AttrClientIp, clientIp) id.SetAttribute(identity.AttrClientIp, clientIp)
id.SetAttribute(AttrProxies, proxies) id.SetAttribute(identity.AttrProxies, proxies)
} }
id.SetAttribute(AttrRemoteAddr, r.RemoteAddr) id.SetAttribute(identity.AttrRemoteAddr, r.RemoteAddr)
if h == "" { if h == "" {
clientIp, _, _ := net.SplitHostPort(r.RemoteAddr) clientIp, _, _ := net.SplitHostPort(r.RemoteAddr)
id.SetAttribute(AttrClientIp, clientIp) id.SetAttribute(identity.AttrClientIp, clientIp)
} }
next.ServeHTTP(w, AddToRequestCtx(id, r)) next.ServeHTTP(w, identity.AddToRequestCtx(id, r))
}) })
} }
func FixKerberosContext(next http.Handler) http.Handler { func TransposeSPNEGOContext(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gid := goidentity.FromHTTPRequestContext(r) gid := goidentity.FromHTTPRequestContext(r)
if gid != nil { if gid != nil {
id := FromRequestCtx(r) id := identity.FromRequestCtx(r)
id.SetUserName(gid.UserName()) id.SetUserName(gid.UserName())
id.SetAuthenticated(gid.Authenticated()) id.SetAuthenticated(gid.Authenticated())
id.SetDomain(gid.Domain()) id.SetDomain(gid.Domain())
id.SetAuthTime(gid.AuthTime()) id.SetAuthTime(gid.AuthTime())
r = AddToRequestCtx(id, r) r = identity.AddToRequestCtx(id, r)
} }
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
}) })
} }
func GetAccessToken(ctx context.Context) string {
token, ok := ctx.Value(CtxAccessToken).(string)
if !ok {
log.Printf("cannot get access token from context")
return ""
}
return token
}

View File

@@ -3,9 +3,8 @@ package web
import ( import (
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
"github.com/coreos/go-oidc/v3/oidc" "github.com/coreos/go-oidc/v3/oidc"
"github.com/gorilla/sessions"
"github.com/patrickmn/go-cache" "github.com/patrickmn/go-cache"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"math/rand" "math/rand"
@@ -14,23 +13,20 @@ import (
) )
const ( const (
CacheExpiration = time.Minute * 2 CacheExpiration = time.Minute * 2
CleanupInterval = time.Minute * 5 CleanupInterval = time.Minute * 5
sessionKeyAuthenticated = "authenticated" oidcKeyUserName = "preferred_username"
oidcKeyUserName = "preferred_username"
) )
type OIDC struct { type OIDC struct {
oAuth2Config *oauth2.Config oAuth2Config *oauth2.Config
oidcTokenVerifier *oidc.IDTokenVerifier oidcTokenVerifier *oidc.IDTokenVerifier
stateStore *cache.Cache stateStore *cache.Cache
sessionStore sessions.Store
} }
type OIDCConfig struct { type OIDCConfig struct {
OAuth2Config *oauth2.Config OAuth2Config *oauth2.Config
OIDCTokenVerifier *oidc.IDTokenVerifier OIDCTokenVerifier *oidc.IDTokenVerifier
SessionStore sessions.Store
} }
func (c *OIDCConfig) New() *OIDC { func (c *OIDCConfig) New() *OIDC {
@@ -38,7 +34,6 @@ func (c *OIDCConfig) New() *OIDC {
oAuth2Config: c.OAuth2Config, oAuth2Config: c.OAuth2Config,
oidcTokenVerifier: c.OIDCTokenVerifier, oidcTokenVerifier: c.OIDCTokenVerifier,
stateStore: cache.New(CacheExpiration, CleanupInterval), stateStore: cache.New(CacheExpiration, CleanupInterval),
sessionStore: c.SessionStore,
} }
} }
@@ -85,22 +80,13 @@ func (h *OIDC) HandleCallback(w http.ResponseWriter, r *http.Request) {
return return
} }
session, err := h.sessionStore.Get(r, RdpGwSession) id := identity.FromRequestCtx(r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
id := common.FromRequestCtx(r)
id.SetUserName(data[oidcKeyUserName].(string)) id.SetUserName(data[oidcKeyUserName].(string))
id.SetAuthenticated(true) id.SetAuthenticated(true)
id.SetAuthTime(time.Now()) id.SetAuthTime(time.Now())
id.SetAttribute(common.AttrAccessToken, oauth2Token.AccessToken) id.SetAttribute(identity.AttrAccessToken, oauth2Token.AccessToken)
session.Options.MaxAge = MaxAge if err = SaveSessionIdentity(r, w, id); err != nil {
session.Values[common.CTXKey] = id
if err = session.Save(r, w); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
} }
@@ -109,14 +95,9 @@ func (h *OIDC) HandleCallback(w http.ResponseWriter, r *http.Request) {
func (h *OIDC) Authenticated(next http.Handler) http.Handler { func (h *OIDC) Authenticated(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
session, err := h.sessionStore.Get(r, RdpGwSession) id := identity.FromRequestCtx(r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
id := session.Values[common.CTXKey].(common.Identity) if !id.Authenticated() {
if id == nil {
seed := make([]byte, 16) seed := make([]byte, 16)
rand.Read(seed) rand.Read(seed)
state := hex.EncodeToString(seed) state := hex.EncodeToString(seed)
@@ -126,6 +107,6 @@ func (h *OIDC) Authenticated(next http.Handler) http.Handler {
} }
// replace the identity with the one from the sessions // replace the identity with the one from the sessions
next.ServeHTTP(w, common.AddToRequestCtx(id, r)) next.ServeHTTP(w, r)
}) })
} }

31
cmd/rdpgw/web/router.go Normal file
View File

@@ -0,0 +1,31 @@
package web
import (
"net/http"
)
type AuthMux struct {
headers []string
}
func NewAuthMux() *AuthMux {
return &AuthMux{}
}
func (a *AuthMux) Register(s string) {
a.headers = append(a.headers, s)
}
func (a *AuthMux) Route(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
h := r.Header.Get("Authorization")
if h == "" {
for _, s := range a.headers {
w.Header().Add("WWW-Authenticate", s)
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
}
next.ServeHTTP(w, r)
})
}

View File

@@ -1,30 +1,75 @@
package web package web
import ( import (
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
"github.com/gorilla/sessions" "github.com/gorilla/sessions"
"log" "log"
"net/http"
"os" "os"
) )
type SessionManagerConf struct { const (
SessionKey []byte rdpGwSession = "RDPGWSESSION"
SessionEncryptionKey []byte MaxAge = 120
StoreType string identityKey = "RDPGWID"
} )
func (c *SessionManagerConf) Init() sessions.Store { var sessionStore sessions.Store
if len(c.SessionKey) < 32 {
func InitStore(sessionKey []byte, encryptionKey []byte, storeType string) {
if len(sessionKey) < 32 {
log.Fatal("Session key too small") log.Fatal("Session key too small")
} }
if len(c.SessionEncryptionKey) < 32 { if len(encryptionKey) < 32 {
log.Fatal("Session key too small") log.Fatal("Session key too small")
} }
if c.StoreType == "file" { if storeType == "file" {
log.Println("Filesystem is used as session storage") log.Println("Filesystem is used as session storage")
return sessions.NewFilesystemStore(os.TempDir(), c.SessionKey, c.SessionEncryptionKey) sessionStore = sessions.NewFilesystemStore(os.TempDir(), sessionKey, encryptionKey)
} else { } else {
log.Println("Cookies are used as session storage") log.Println("Cookies are used as session storage")
return sessions.NewCookieStore(c.SessionKey, c.SessionEncryptionKey) sessionStore = sessions.NewCookieStore(sessionKey, encryptionKey)
} }
} }
func GetSession(r *http.Request) (*sessions.Session, error) {
session, err := sessionStore.Get(r, rdpGwSession)
if err != nil {
return nil, err
}
return session, nil
}
func GetSessionIdentity(r *http.Request) (identity.Identity, error) {
s, err := GetSession(r)
if err != nil {
return nil, err
}
idData := s.Values[identityKey]
if idData == nil {
return nil, nil
}
id := identity.NewUser()
id.Unmarshal(idData.([]byte))
return id, nil
}
func SaveSessionIdentity(r *http.Request, w http.ResponseWriter, id identity.Identity) error {
session, err := GetSession(r)
if err != nil {
return err
}
session.Options.MaxAge = MaxAge
idData, err := id.Marshal()
if err != nil {
return err
}
session.Values[identityKey] = idData
return sessionStore.Save(r, w, session)
}

View File

@@ -5,7 +5,7 @@ import (
"encoding/hex" "encoding/hex"
"errors" "errors"
"fmt" "fmt"
"github.com/gorilla/sessions" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
"log" "log"
"math/rand" "math/rand"
"net/http" "net/http"
@@ -14,17 +14,11 @@ import (
"time" "time"
) )
const (
RdpGwSession = "RDPGWSESSION"
MaxAge = 120
)
type TokenGeneratorFunc func(context.Context, string, string) (string, error) type TokenGeneratorFunc func(context.Context, string, string) (string, error)
type UserTokenGeneratorFunc func(context.Context, string) (string, error) type UserTokenGeneratorFunc func(context.Context, string) (string, error)
type QueryInfoFunc func(context.Context, string, string) (string, error) type QueryInfoFunc func(context.Context, string, string) (string, error)
type Config struct { type Config struct {
SessionStore sessions.Store
PAATokenGenerator TokenGeneratorFunc PAATokenGenerator TokenGeneratorFunc
UserTokenGenerator UserTokenGeneratorFunc UserTokenGenerator UserTokenGeneratorFunc
QueryInfo QueryInfoFunc QueryInfo QueryInfoFunc
@@ -46,7 +40,6 @@ type RdpOpts struct {
} }
type Handler struct { type Handler struct {
sessionStore sessions.Store
paaTokenGenerator TokenGeneratorFunc paaTokenGenerator TokenGeneratorFunc
enableUserToken bool enableUserToken bool
userTokenGenerator UserTokenGeneratorFunc userTokenGenerator UserTokenGeneratorFunc
@@ -63,7 +56,6 @@ func (c *Config) NewHandler() *Handler {
log.Fatal("Not enough hosts to connect to specified") log.Fatal("Not enough hosts to connect to specified")
} }
return &Handler{ return &Handler{
sessionStore: c.SessionStore,
paaTokenGenerator: c.PAATokenGenerator, paaTokenGenerator: c.PAATokenGenerator,
enableUserToken: c.EnableUserToken, enableUserToken: c.EnableUserToken,
userTokenGenerator: c.UserTokenGenerator, userTokenGenerator: c.UserTokenGenerator,
@@ -132,13 +124,13 @@ func (h *Handler) getHost(ctx context.Context, u *url.URL) (string, error) {
} }
func (h *Handler) HandleDownload(w http.ResponseWriter, r *http.Request) { func (h *Handler) HandleDownload(w http.ResponseWriter, r *http.Request) {
id := identity.FromRequestCtx(r)
ctx := r.Context() ctx := r.Context()
userName, ok := ctx.Value("preferred_username").(string)
opts := h.rdpOpts opts := h.rdpOpts
if !ok { if !id.Authenticated() {
log.Printf("preferred_username not found in context") log.Printf("unauthenticated user %s", id.UserName())
http.Error(w, errors.New("cannot find session or user").Error(), http.StatusInternalServerError) http.Error(w, errors.New("cannot find session or user").Error(), http.StatusInternalServerError)
return return
} }
@@ -149,13 +141,13 @@ func (h *Handler) HandleDownload(w http.ResponseWriter, r *http.Request) {
http.Error(w, err.Error(), http.StatusBadRequest) http.Error(w, err.Error(), http.StatusBadRequest)
return return
} }
host = strings.Replace(host, "{{ preferred_username }}", userName, 1) host = strings.Replace(host, "{{ preferred_username }}", id.UserName(), 1)
// split the username into user and domain // split the username into user and domain
var user = userName var user = id.UserName()
var domain = opts.DefaultDomain var domain = opts.DefaultDomain
if opts.SplitUserDomain { if opts.SplitUserDomain {
creds := strings.SplitN(userName, "@", 2) creds := strings.SplitN(id.UserName(), "@", 2)
user = creds[0] user = creds[0]
if len(creds) > 1 { if len(creds) > 1 {
domain = creds[1] domain = creds[1]
@@ -203,6 +195,8 @@ func (h *Handler) HandleDownload(w http.ResponseWriter, r *http.Request) {
rdp.Connection.GatewayHostname = h.gatewayAddress.Host rdp.Connection.GatewayHostname = h.gatewayAddress.Host
rdp.Connection.GatewayCredentialsSource = SourceCookie rdp.Connection.GatewayCredentialsSource = SourceCookie
rdp.Connection.GatewayAccessToken = token rdp.Connection.GatewayAccessToken = token
rdp.Connection.GatewayCredentialMethod = 1
rdp.Connection.GatewayUsageMethod = 1
rdp.Session.NetworkAutodetect = opts.NetworkAutoDetect != 0 rdp.Session.NetworkAutodetect = opts.NetworkAutoDetect != 0
rdp.Session.BandwidthAutodetect = opts.BandwidthAutoDetect != 0 rdp.Session.BandwidthAutodetect = opts.BandwidthAutoDetect != 0
rdp.Session.ConnectionType = opts.ConnectionType rdp.Session.ConnectionType = opts.ConnectionType

View File

@@ -2,6 +2,7 @@ package web
import ( import (
"context" "context"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/security" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/security"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@@ -113,9 +114,13 @@ func TestHandler_HandleDownload(t *testing.T) {
} }
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
id := identity.NewUser()
id.SetUserName(testuser)
id.SetAuthenticated(true)
req = identity.AddToRequestCtx(id, req)
ctx := req.Context() ctx := req.Context()
ctx = context.WithValue(ctx, "preferred_username", testuser)
req = req.WithContext(ctx)
u, _ := url.Parse(gateway) u, _ := url.Parse(gateway)
c := Config{ c := Config{

1
go.mod
View File

@@ -8,6 +8,7 @@ require (
github.com/fatih/structs v1.1.0 github.com/fatih/structs v1.1.0
github.com/go-jose/go-jose/v3 v3.0.0 github.com/go-jose/go-jose/v3 v3.0.0
github.com/google/uuid v1.1.2 github.com/google/uuid v1.1.2
github.com/gorilla/mux v1.8.0
github.com/gorilla/sessions v1.2.1 github.com/gorilla/sessions v1.2.1
github.com/gorilla/websocket v1.5.0 github.com/gorilla/websocket v1.5.0
github.com/jcmturner/gofork v1.7.6 github.com/jcmturner/gofork v1.7.6