mirror of
https://github.com/bolkedebruin/rdpgw.git
synced 2026-03-28 14:56:36 +00:00
Add server implementation of basic auth
This commit is contained in:
62
cmd/rdpgw/api/basic.go
Normal file
62
cmd/rdpgw/api/basic.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/bolkedebruin/rdpgw/shared/auth"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
protocol = "unix"
|
||||
)
|
||||
|
||||
func (c *Config) BasicAuth(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
username, password, ok := r.BasicAuth()
|
||||
if ok {
|
||||
ctx := r.Context()
|
||||
|
||||
conn, err := grpc.Dial(c.SocketAddress, grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
|
||||
return net.Dial(protocol, addr)
|
||||
}))
|
||||
if err != nil {
|
||||
log.Printf("Cannot reach authentication provider: %s", err)
|
||||
http.Error(w, "Server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
c := auth.NewAuthenticateClient(conn)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
req := &auth.UserPass{Username: username, Password: password}
|
||||
res, err := c.Authenticate(ctx, req)
|
||||
if err != nil {
|
||||
log.Printf("Error talking to authentication provider: %s", err)
|
||||
http.Error(w, "Server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if !res.Authenticated {
|
||||
log.Printf("User %s is not authenticated for this service", username)
|
||||
} else {
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
// If the Authentication header is not present, is invalid, or the
|
||||
// username or password is wrong, then set a WWW-Authenticate
|
||||
// header to inform the client that we expect them to use basic
|
||||
// authentication and send a 401 Unauthorized response.
|
||||
w.Header().Set("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`)
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
}
|
||||
}
|
||||
@@ -51,6 +51,8 @@ type Config struct {
|
||||
ConnectionType int
|
||||
SplitUserDomain bool
|
||||
DefaultDomain string
|
||||
SocketAddress string
|
||||
Authentication string
|
||||
}
|
||||
|
||||
func (c *Config) NewApi() {
|
||||
|
||||
@@ -30,8 +30,10 @@ type ServerConfig struct {
|
||||
SessionEncryptionKey string `koanf:"sessionencryptionkey"`
|
||||
SessionStore string `koanf:"sessionstore"`
|
||||
SendBuf int `koanf:"sendbuf"`
|
||||
ReceiveBuf int `koanf:"recievebuf"`
|
||||
ReceiveBuf int `koanf:"receivebuf"`
|
||||
DisableTLS bool `koanf:"disabletls"`
|
||||
Authentication string `koanf:"authentication"`
|
||||
AuthSocket string `koanf:"authsocket"`
|
||||
}
|
||||
|
||||
type OpenIDConfig struct {
|
||||
@@ -121,6 +123,8 @@ func Load(configFile string) Configuration {
|
||||
"Server.Port": 443,
|
||||
"Server.SessionStore": "cookie",
|
||||
"Server.HostSelection": "roundrobin",
|
||||
"Server.Authentication": "openid",
|
||||
"Server.AuthSocket": "/tmp/rdpgw-auth.sock",
|
||||
"Client.NetworkAutoDetect": 1,
|
||||
"Client.BandwidthAutoDetect": 1,
|
||||
"Security.VerifyClientIp": true,
|
||||
@@ -182,6 +186,9 @@ func Load(configFile string) Configuration {
|
||||
log.Fatalf("host selection is set to `signed` but `querytokensigningkey` is not set")
|
||||
}
|
||||
|
||||
if Conf.Server.Authentication == "local" && Conf.Server.DisableTLS {
|
||||
log.Fatalf("basicauth=local and disabletls are mutually exclusive")
|
||||
}
|
||||
return Conf
|
||||
|
||||
}
|
||||
|
||||
@@ -89,6 +89,8 @@ func main() {
|
||||
ConnectionType: conf.Client.ConnectionType,
|
||||
SplitUserDomain: conf.Client.SplitUserDomain,
|
||||
DefaultDomain: conf.Client.DefaultDomain,
|
||||
SocketAddress: conf.Server.AuthSocket,
|
||||
Authentication: conf.Server.Authentication,
|
||||
}
|
||||
api.NewApi()
|
||||
|
||||
@@ -148,11 +150,16 @@ func main() {
|
||||
ServerConf: &handlerConfig,
|
||||
}
|
||||
|
||||
http.Handle("/remoteDesktopGateway/", common.EnrichContext(http.HandlerFunc(gw.HandleGatewayProtocol)))
|
||||
http.Handle("/connect", common.EnrichContext(api.Authenticated(http.HandlerFunc(api.HandleDownload))))
|
||||
if conf.Server.Authentication == "local" {
|
||||
http.Handle("/connect", common.EnrichContext(api.BasicAuth(api.HandleDownload)))
|
||||
http.Handle("/remoteDesktopGateway/", common.EnrichContext(api.BasicAuth(gw.HandleGatewayProtocol)))
|
||||
} else {
|
||||
// openid
|
||||
http.Handle("/connect", common.EnrichContext(api.Authenticated(http.HandlerFunc(api.HandleDownload))))
|
||||
http.HandleFunc("/callback", api.HandleCallback)
|
||||
}
|
||||
http.Handle("/metrics", promhttp.Handler())
|
||||
http.HandleFunc("/tokeninfo", api.TokenInfo)
|
||||
http.HandleFunc("/callback", api.HandleCallback)
|
||||
|
||||
if conf.Server.DisableTLS {
|
||||
err = server.ListenAndServe()
|
||||
|
||||
Reference in New Issue
Block a user