diff --git a/cmd/rdpgw/main.go b/cmd/rdpgw/main.go index 07e4e44..73e0d44 100644 --- a/cmd/rdpgw/main.go +++ b/cmd/rdpgw/main.go @@ -239,8 +239,9 @@ func main() { ntlm := web.NTLMAuthHandler{SocketAddress: conf.Server.AuthSocket, Timeout: conf.Server.BasicAuthTimeout} rdp.NewRoute().HeadersRegexp("Authorization", "NTLM").HandlerFunc(ntlm.NTLMAuth(gw.HandleGatewayProtocol)) rdp.NewRoute().HeadersRegexp("Authorization", "Negotiate").HandlerFunc(ntlm.NTLMAuth(gw.HandleGatewayProtocol)) - auth.Register(`NTLM`) - auth.Register(`Negotiate`) + auth.Register([]string{`NTLM`, `Negotiate`}, func(r *http.Request) bool { + return r.Header.Get("Sec-WebSocket-Protocol") != "binary" // rdp client for ios is incompatible with this NTLM method. + }) } // basic auth @@ -248,7 +249,7 @@ func main() { log.Printf("enabling basic authentication") q := web.BasicAuthHandler{SocketAddress: conf.Server.AuthSocket, Timeout: conf.Server.BasicAuthTimeout} rdp.NewRoute().HeadersRegexp("Authorization", "Basic").HandlerFunc(q.BasicAuth(gw.HandleGatewayProtocol)) - auth.Register(`Basic realm="restricted", charset="UTF-8"`) + auth.Register([]string{`Basic realm="restricted", charset="UTF-8"`}, nil) } // spnego / kerberos @@ -266,7 +267,7 @@ func main() { // kdcproxy k := kdcproxy.InitKdcProxy(conf.Kerberos.Krb5Conf) r.HandleFunc(kdcProxyEndPoint, k.Handler).Methods("POST") - auth.Register("Negotiate") + auth.Register([]string{"Negotiate"}, nil) } // setup server diff --git a/cmd/rdpgw/web/mux.go b/cmd/rdpgw/web/mux.go index 02069ae..e06a233 100644 --- a/cmd/rdpgw/web/mux.go +++ b/cmd/rdpgw/web/mux.go @@ -5,21 +5,35 @@ import ( "net/http" ) +type AuthHeader struct { + header string + condition func(*http.Request) bool +} + type AuthMux struct { - headers []string + headers []AuthHeader } func NewAuthMux() *AuthMux { return &AuthMux{} } -func (a *AuthMux) Register(s string) { - a.headers = append(a.headers, s) +// Register adds authentication methods with optional condition function +func (a *AuthMux) Register(headers []string, condition func(*http.Request) bool) { + for _, header := range headers { + a.headers = append(a.headers, AuthHeader{ + header: header, + condition: condition, + }) + } } func (a *AuthMux) SetAuthenticate(w http.ResponseWriter, r *http.Request) { - for _, s := range a.headers { - w.Header().Add("WWW-Authenticate", s) + for _, authHeader := range a.headers { + // If condition is nil or condition returns true, add the header + if authHeader.condition == nil || authHeader.condition(r) { + w.Header().Add("WWW-Authenticate", authHeader.header) + } } http.Error(w, "Unauthorized", http.StatusUnauthorized) }