fix: improve ios compatibility

This commit is contained in:
Bolke de Bruin
2025-09-05 15:02:57 +02:00
parent c99b4ee58b
commit 85fec5fb2a
2 changed files with 24 additions and 9 deletions

View File

@@ -239,8 +239,9 @@ func main() {
ntlm := web.NTLMAuthHandler{SocketAddress: conf.Server.AuthSocket, Timeout: conf.Server.BasicAuthTimeout} ntlm := web.NTLMAuthHandler{SocketAddress: conf.Server.AuthSocket, Timeout: conf.Server.BasicAuthTimeout}
rdp.NewRoute().HeadersRegexp("Authorization", "NTLM").HandlerFunc(ntlm.NTLMAuth(gw.HandleGatewayProtocol)) rdp.NewRoute().HeadersRegexp("Authorization", "NTLM").HandlerFunc(ntlm.NTLMAuth(gw.HandleGatewayProtocol))
rdp.NewRoute().HeadersRegexp("Authorization", "Negotiate").HandlerFunc(ntlm.NTLMAuth(gw.HandleGatewayProtocol)) rdp.NewRoute().HeadersRegexp("Authorization", "Negotiate").HandlerFunc(ntlm.NTLMAuth(gw.HandleGatewayProtocol))
auth.Register(`NTLM`) auth.Register([]string{`NTLM`, `Negotiate`}, func(r *http.Request) bool {
auth.Register(`Negotiate`) return r.Header.Get("Sec-WebSocket-Protocol") != "binary" // rdp client for ios is incompatible with this NTLM method.
})
} }
// basic auth // basic auth
@@ -248,7 +249,7 @@ func main() {
log.Printf("enabling basic authentication") log.Printf("enabling basic authentication")
q := web.BasicAuthHandler{SocketAddress: conf.Server.AuthSocket, Timeout: conf.Server.BasicAuthTimeout} q := web.BasicAuthHandler{SocketAddress: conf.Server.AuthSocket, Timeout: conf.Server.BasicAuthTimeout}
rdp.NewRoute().HeadersRegexp("Authorization", "Basic").HandlerFunc(q.BasicAuth(gw.HandleGatewayProtocol)) rdp.NewRoute().HeadersRegexp("Authorization", "Basic").HandlerFunc(q.BasicAuth(gw.HandleGatewayProtocol))
auth.Register(`Basic realm="restricted", charset="UTF-8"`) auth.Register([]string{`Basic realm="restricted", charset="UTF-8"`}, nil)
} }
// spnego / kerberos // spnego / kerberos
@@ -266,7 +267,7 @@ func main() {
// kdcproxy // kdcproxy
k := kdcproxy.InitKdcProxy(conf.Kerberos.Krb5Conf) k := kdcproxy.InitKdcProxy(conf.Kerberos.Krb5Conf)
r.HandleFunc(kdcProxyEndPoint, k.Handler).Methods("POST") r.HandleFunc(kdcProxyEndPoint, k.Handler).Methods("POST")
auth.Register("Negotiate") auth.Register([]string{"Negotiate"}, nil)
} }
// setup server // setup server

View File

@@ -5,21 +5,35 @@ import (
"net/http" "net/http"
) )
type AuthHeader struct {
header string
condition func(*http.Request) bool
}
type AuthMux struct { type AuthMux struct {
headers []string headers []AuthHeader
} }
func NewAuthMux() *AuthMux { func NewAuthMux() *AuthMux {
return &AuthMux{} return &AuthMux{}
} }
func (a *AuthMux) Register(s string) { // Register adds authentication methods with optional condition function
a.headers = append(a.headers, s) func (a *AuthMux) Register(headers []string, condition func(*http.Request) bool) {
for _, header := range headers {
a.headers = append(a.headers, AuthHeader{
header: header,
condition: condition,
})
}
} }
func (a *AuthMux) SetAuthenticate(w http.ResponseWriter, r *http.Request) { func (a *AuthMux) SetAuthenticate(w http.ResponseWriter, r *http.Request) {
for _, s := range a.headers { for _, authHeader := range a.headers {
w.Header().Add("WWW-Authenticate", s) // If condition is nil or condition returns true, add the header
if authHeader.condition == nil || authHeader.condition(r) {
w.Header().Add("WWW-Authenticate", authHeader.header)
}
} }
http.Error(w, "Unauthorized", http.StatusUnauthorized) http.Error(w, "Unauthorized", http.StatusUnauthorized)
} }