Add server implementation of basic auth

This commit is contained in:
Bolke de Bruin
2022-08-24 13:47:26 +02:00
parent 390f6acbcd
commit fb58cb299e
8 changed files with 157 additions and 53 deletions

62
cmd/rdpgw/api/basic.go Normal file
View File

@@ -0,0 +1,62 @@
package api
import (
"context"
"github.com/bolkedebruin/rdpgw/shared/auth"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"log"
"net"
"net/http"
"time"
)
const (
protocol = "unix"
)
func (c *Config) BasicAuth(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
username, password, ok := r.BasicAuth()
if ok {
ctx := r.Context()
conn, err := grpc.Dial(c.SocketAddress, grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
return net.Dial(protocol, addr)
}))
if err != nil {
log.Printf("Cannot reach authentication provider: %s", err)
http.Error(w, "Server error", http.StatusInternalServerError)
return
}
defer conn.Close()
c := auth.NewAuthenticateClient(conn)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
req := &auth.UserPass{Username: username, Password: password}
res, err := c.Authenticate(ctx, req)
if err != nil {
log.Printf("Error talking to authentication provider: %s", err)
http.Error(w, "Server error", http.StatusInternalServerError)
return
}
if !res.Authenticated {
log.Printf("User %s is not authenticated for this service", username)
} else {
next.ServeHTTP(w, r.WithContext(ctx))
return
}
}
// If the Authentication header is not present, is invalid, or the
// username or password is wrong, then set a WWW-Authenticate
// header to inform the client that we expect them to use basic
// authentication and send a 401 Unauthorized response.
w.Header().Set("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`)
http.Error(w, "Unauthorized", http.StatusUnauthorized)
}
}

View File

@@ -51,6 +51,8 @@ type Config struct {
ConnectionType int
SplitUserDomain bool
DefaultDomain string
SocketAddress string
Authentication string
}
func (c *Config) NewApi() {

View File

@@ -30,8 +30,10 @@ type ServerConfig struct {
SessionEncryptionKey string `koanf:"sessionencryptionkey"`
SessionStore string `koanf:"sessionstore"`
SendBuf int `koanf:"sendbuf"`
ReceiveBuf int `koanf:"recievebuf"`
ReceiveBuf int `koanf:"receivebuf"`
DisableTLS bool `koanf:"disabletls"`
Authentication string `koanf:"authentication"`
AuthSocket string `koanf:"authsocket"`
}
type OpenIDConfig struct {
@@ -121,6 +123,8 @@ func Load(configFile string) Configuration {
"Server.Port": 443,
"Server.SessionStore": "cookie",
"Server.HostSelection": "roundrobin",
"Server.Authentication": "openid",
"Server.AuthSocket": "/tmp/rdpgw-auth.sock",
"Client.NetworkAutoDetect": 1,
"Client.BandwidthAutoDetect": 1,
"Security.VerifyClientIp": true,
@@ -182,6 +186,9 @@ func Load(configFile string) Configuration {
log.Fatalf("host selection is set to `signed` but `querytokensigningkey` is not set")
}
if Conf.Server.Authentication == "local" && Conf.Server.DisableTLS {
log.Fatalf("basicauth=local and disabletls are mutually exclusive")
}
return Conf
}

View File

@@ -89,6 +89,8 @@ func main() {
ConnectionType: conf.Client.ConnectionType,
SplitUserDomain: conf.Client.SplitUserDomain,
DefaultDomain: conf.Client.DefaultDomain,
SocketAddress: conf.Server.AuthSocket,
Authentication: conf.Server.Authentication,
}
api.NewApi()
@@ -148,11 +150,16 @@ func main() {
ServerConf: &handlerConfig,
}
http.Handle("/remoteDesktopGateway/", common.EnrichContext(http.HandlerFunc(gw.HandleGatewayProtocol)))
http.Handle("/connect", common.EnrichContext(api.Authenticated(http.HandlerFunc(api.HandleDownload))))
if conf.Server.Authentication == "local" {
http.Handle("/connect", common.EnrichContext(api.BasicAuth(api.HandleDownload)))
http.Handle("/remoteDesktopGateway/", common.EnrichContext(api.BasicAuth(gw.HandleGatewayProtocol)))
} else {
// openid
http.Handle("/connect", common.EnrichContext(api.Authenticated(http.HandlerFunc(api.HandleDownload))))
http.HandleFunc("/callback", api.HandleCallback)
}
http.Handle("/metrics", promhttp.Handler())
http.HandleFunc("/tokeninfo", api.TokenInfo)
http.HandleFunc("/callback", api.HandleCallback)
if conf.Server.DisableTLS {
err = server.ListenAndServe()