mirror of
https://github.com/bolkedebruin/rdpgw.git
synced 2026-03-27 22:46:37 +00:00
Support for NTLM authentication added (#109)
* Support for NTLM authentication added To support NTLM authentication, a database is added as an authentication source. Currently, only the configuration file is supported as a database. Database authentication supports Basic and NTLM authentication protcols. ServerConfig.BasicAuthEnabled renamed to LocalEnabled as Basic auth can be used with NTLM or Local.
This commit is contained in:
@@ -4,6 +4,9 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/auth/config"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/auth/database"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/auth/ntlm"
|
||||
"github.com/bolkedebruin/rdpgw/shared/auth"
|
||||
"github.com/msteinert/pam/v2"
|
||||
"github.com/thought-machine/go-flags"
|
||||
@@ -21,16 +24,24 @@ const (
|
||||
var opts struct {
|
||||
ServiceName string `short:"n" long:"name" default:"rdpgw" description:"the PAM service name to use"`
|
||||
SocketAddr string `short:"s" long:"socket" default:"/tmp/rdpgw-auth.sock" description:"the location of the socket"`
|
||||
ConfigFile string `short:"c" long:"conf" default:"rdpgw-auth.yaml" description:"users config file for NTLM (yaml)"`
|
||||
}
|
||||
|
||||
type AuthServiceImpl struct {
|
||||
auth.UnimplementedAuthenticateServer
|
||||
|
||||
serviceName string
|
||||
ntlm *ntlm.NTLMAuth
|
||||
}
|
||||
|
||||
var conf config.Configuration
|
||||
var _ auth.AuthenticateServer = (*AuthServiceImpl)(nil)
|
||||
|
||||
func NewAuthService(serviceName string) auth.AuthenticateServer {
|
||||
s := &AuthServiceImpl{serviceName: serviceName}
|
||||
func NewAuthService(serviceName string, database database.Database) auth.AuthenticateServer {
|
||||
s := &AuthServiceImpl{
|
||||
serviceName: serviceName,
|
||||
ntlm: ntlm.NewNTLMAuth(database),
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
@@ -77,12 +88,35 @@ func (s *AuthServiceImpl) Authenticate(ctx context.Context, message *auth.UserPa
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (s *AuthServiceImpl) NTLM(ctx context.Context, message *auth.NtlmRequest) (*auth.NtlmResponse, error) {
|
||||
r, err := s.ntlm.Authenticate(message)
|
||||
|
||||
if err != nil {
|
||||
log.Printf("[%s] NTLM failed: %s", message.Session, err)
|
||||
} else if r.Authenticated {
|
||||
log.Printf("[%s] User: %s authenticated using NTLM", message.Session, r.Username)
|
||||
} else if r.NtlmMessage != "" {
|
||||
log.Printf("[%s] Sending NTLM challenge", message.Session)
|
||||
}
|
||||
|
||||
return r, err
|
||||
}
|
||||
|
||||
func main() {
|
||||
_, err := flags.Parse(&opts)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
var fErr *flags.Error
|
||||
if errors.As(err, &fErr) {
|
||||
if fErr.Type == flags.ErrHelp {
|
||||
fmt.Printf("Acknowledgements:\n")
|
||||
fmt.Printf(" - This product includes software developed by the Thomson Reuters Global Resources. (go-ntlm - https://github.com/m7913d/go-ntlm - BSD-4 License)\n")
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
conf = config.Load(opts.ConfigFile)
|
||||
|
||||
log.Printf("Starting auth server on %s", opts.SocketAddr)
|
||||
cleanup := func() {
|
||||
if _, err := os.Stat(opts.SocketAddr); err == nil {
|
||||
@@ -100,7 +134,8 @@ func main() {
|
||||
log.Fatal(err)
|
||||
}
|
||||
server := grpc.NewServer()
|
||||
service := NewAuthService(opts.ServiceName)
|
||||
db := database.NewConfig(conf.Users)
|
||||
service := NewAuthService(opts.ServiceName, db)
|
||||
auth.RegisterAuthenticateServer(server, service)
|
||||
server.Serve(listener)
|
||||
}
|
||||
|
||||
42
cmd/auth/config/configuration.go
Normal file
42
cmd/auth/config/configuration.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"github.com/knadh/koanf/parsers/yaml"
|
||||
"github.com/knadh/koanf/providers/confmap"
|
||||
"github.com/knadh/koanf/providers/file"
|
||||
"github.com/knadh/koanf/v2"
|
||||
"log"
|
||||
"os"
|
||||
)
|
||||
|
||||
type Configuration struct {
|
||||
Users []UserConfig `koanf:"users"`
|
||||
}
|
||||
|
||||
type UserConfig struct {
|
||||
Username string `koanf:"username"`
|
||||
Password string `koanf:"password"`
|
||||
}
|
||||
|
||||
var Conf Configuration
|
||||
|
||||
func Load(configFile string) Configuration {
|
||||
|
||||
var k = koanf.New(".")
|
||||
|
||||
k.Load(confmap.Provider(map[string]interface{}{}, "."), nil)
|
||||
|
||||
if _, err := os.Stat(configFile); os.IsNotExist(err) {
|
||||
log.Printf("Config file %s not found, skipping config file", configFile)
|
||||
} else {
|
||||
if err := k.Load(file.Provider(configFile), yaml.Parser()); err != nil {
|
||||
log.Fatalf("Error loading config from file: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
koanfTag := koanf.UnmarshalConf{Tag: "koanf"}
|
||||
k.UnmarshalWithConf("Users", &Conf.Users, koanfTag)
|
||||
|
||||
return Conf
|
||||
|
||||
}
|
||||
25
cmd/auth/database/config.go
Executable file
25
cmd/auth/database/config.go
Executable file
@@ -0,0 +1,25 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"github.com/bolkedebruin/rdpgw/cmd/auth/config"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
users map[string]config.UserConfig
|
||||
}
|
||||
|
||||
func NewConfig(users []config.UserConfig) *Config {
|
||||
usersMap := map[string]config.UserConfig{}
|
||||
|
||||
for _, user := range users {
|
||||
usersMap[user.Username] = user
|
||||
}
|
||||
|
||||
return &Config{
|
||||
users: usersMap,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Config) GetPassword (username string) string {
|
||||
return c.users[username].Password
|
||||
}
|
||||
43
cmd/auth/database/config_test.go
Normal file
43
cmd/auth/database/config_test.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"github.com/bolkedebruin/rdpgw/cmd/auth/config"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func createTestDatabase () (Database) {
|
||||
var users = []config.UserConfig{}
|
||||
|
||||
user1 := config.UserConfig{}
|
||||
user1.Username = "my_username"
|
||||
user1.Password = "my_password"
|
||||
users = append(users, user1)
|
||||
|
||||
user2 := config.UserConfig{}
|
||||
user2.Username = "my_username2"
|
||||
user2.Password = "my_password2"
|
||||
users = append(users, user2)
|
||||
|
||||
config := NewConfig(users)
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
func TestDatabaseConfigValidUsername(t *testing.T) {
|
||||
database := createTestDatabase()
|
||||
|
||||
if database.GetPassword("my_username") != "my_password" {
|
||||
t.Fatalf("Wrong password returned")
|
||||
}
|
||||
if database.GetPassword("my_username2") != "my_password2" {
|
||||
t.Fatalf("Wrong password returned")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatabaseInvalidUsername(t *testing.T) {
|
||||
database := createTestDatabase()
|
||||
|
||||
if database.GetPassword("my_invalid_username") != "" {
|
||||
t.Fatalf("Non empty password returned for invalid username")
|
||||
}
|
||||
}
|
||||
5
cmd/auth/database/database.go
Executable file
5
cmd/auth/database/database.go
Executable file
@@ -0,0 +1,5 @@
|
||||
package database
|
||||
|
||||
type Database interface {
|
||||
GetPassword (username string) string
|
||||
}
|
||||
160
cmd/auth/ntlm/ntlm.go
Normal file
160
cmd/auth/ntlm/ntlm.go
Normal file
@@ -0,0 +1,160 @@
|
||||
package ntlm
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/auth/database"
|
||||
"github.com/bolkedebruin/rdpgw/shared/auth"
|
||||
"github.com/patrickmn/go-cache"
|
||||
"github.com/m7913d/go-ntlm/ntlm"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
cacheExpiration = time.Minute
|
||||
cleanupInterval = time.Minute * 5
|
||||
)
|
||||
|
||||
type NTLMAuth struct {
|
||||
contextCache *cache.Cache
|
||||
|
||||
// Information about the server, returned to the client during authentication
|
||||
ServerName string // e.g. EXAMPLE1
|
||||
DomainName string // e.g. EXAMPLE
|
||||
DnsServerName string // e.g. example1.example.com
|
||||
DnsDomainName string // e.g. example.com
|
||||
DnsTreeName string // e.g. example.com
|
||||
|
||||
Database database.Database
|
||||
}
|
||||
|
||||
func NewNTLMAuth (database database.Database) (*NTLMAuth) {
|
||||
return &NTLMAuth{
|
||||
contextCache: cache.New(cacheExpiration, cleanupInterval),
|
||||
Database: database,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *NTLMAuth) Authenticate(message *auth.NtlmRequest) (*auth.NtlmResponse, error) {
|
||||
r := &auth.NtlmResponse{}
|
||||
r.Authenticated = false
|
||||
|
||||
if message.Session == "" {
|
||||
return r, errors.New("Invalid (empty) session specified")
|
||||
}
|
||||
|
||||
if message.NtlmMessage == "" {
|
||||
return r, errors.New("Empty NTLM message specified")
|
||||
}
|
||||
|
||||
c := h.getContext(message.Session)
|
||||
err := c.Authenticate(message.NtlmMessage, r)
|
||||
|
||||
if err != nil || r.Authenticated {
|
||||
h.removeContext(message.Session)
|
||||
}
|
||||
|
||||
return r, err
|
||||
}
|
||||
|
||||
func (h *NTLMAuth) getContext (session string) (*ntlmContext) {
|
||||
if c_, found := h.contextCache.Get(session); found {
|
||||
if c, ok := c_.(*ntlmContext); ok {
|
||||
return c
|
||||
}
|
||||
}
|
||||
c := new(ntlmContext)
|
||||
c.h = h
|
||||
h.contextCache.Set(session, c, cache.DefaultExpiration)
|
||||
return c
|
||||
}
|
||||
|
||||
func (h *NTLMAuth) removeContext (session string) {
|
||||
h.contextCache.Delete(session)
|
||||
}
|
||||
|
||||
type ntlmContext struct {
|
||||
session ntlm.ServerSession
|
||||
h *NTLMAuth
|
||||
}
|
||||
|
||||
func (c *ntlmContext) Authenticate(authorisationEncoded string, r *auth.NtlmResponse) (error) {
|
||||
authorisation, err := base64.StdEncoding.DecodeString(authorisationEncoded)
|
||||
if err != nil {
|
||||
return errors.New(fmt.Sprintf("Failed to decode NTLM Authorisation header: %s", err))
|
||||
}
|
||||
|
||||
nm, err := ntlm.ParseNegotiateMessage(authorisation)
|
||||
if err == nil {
|
||||
return c.negotiate(nm, r)
|
||||
}
|
||||
if (nm != nil && nm.MessageType == 1) {
|
||||
return errors.New(fmt.Sprintf("Failed to parse NTLM Authorisation header: %s", err))
|
||||
} else if c.session == nil {
|
||||
return errors.New(fmt.Sprintf("New NTLM auth sequence should start with negotioate request"))
|
||||
}
|
||||
|
||||
am, err := ntlm.ParseAuthenticateMessage(authorisation, 2)
|
||||
if err == nil {
|
||||
return c.authenticate(am, r)
|
||||
}
|
||||
|
||||
return errors.New(fmt.Sprintf("Failed to parse NTLM Authorisation header: %s", err))
|
||||
}
|
||||
|
||||
func (c *ntlmContext) negotiate(nm *ntlm.NegotiateMessage, r *auth.NtlmResponse) (error) {
|
||||
session, err := ntlm.CreateServerSession(ntlm.Version2, ntlm.ConnectionOrientedMode)
|
||||
|
||||
if err != nil {
|
||||
c.session = nil;
|
||||
return errors.New(fmt.Sprintf("Failed to create NTLM server session: %s", err))
|
||||
}
|
||||
|
||||
c.session = session
|
||||
c.session.SetRequireNtHash(true)
|
||||
c.session.SetDomainName(c.h.DomainName)
|
||||
c.session.SetComputerName(c.h.ServerName)
|
||||
c.session.SetDnsDomainName(c.h.DnsDomainName)
|
||||
c.session.SetDnsComputerName(c.h.DnsServerName)
|
||||
c.session.SetDnsTreeName(c.h.DnsTreeName)
|
||||
|
||||
err = c.session.ProcessNegotiateMessage(nm)
|
||||
if err != nil {
|
||||
return errors.New(fmt.Sprintf("Failed to process NTLM negotiate message: %s", err))
|
||||
}
|
||||
|
||||
cm, err := c.session.GenerateChallengeMessage()
|
||||
if err != nil {
|
||||
return errors.New(fmt.Sprintf("Failed to generate NTLM challenge message: %s", err))
|
||||
}
|
||||
|
||||
r.NtlmMessage = base64.StdEncoding.EncodeToString(cm.Bytes())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ntlmContext) authenticate(am *ntlm.AuthenticateMessage, r *auth.NtlmResponse) (error) {
|
||||
if c.session == nil {
|
||||
return errors.New(fmt.Sprintf("NTLM Authenticate requires active session: first call negotioate"))
|
||||
}
|
||||
|
||||
username := am.UserName.String()
|
||||
password := c.h.Database.GetPassword (username)
|
||||
if password == "" {
|
||||
log.Printf("NTLM: unknown username specified: %s", username)
|
||||
return nil
|
||||
}
|
||||
|
||||
c.session.SetUserInfo(username,password,"")
|
||||
|
||||
err := c.session.ProcessAuthenticateMessage(am)
|
||||
if err != nil {
|
||||
log.Printf("Failed to process NTLM authenticate message: %s", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
r.Authenticated = true
|
||||
r.Username = username
|
||||
return nil
|
||||
}
|
||||
168
cmd/auth/ntlm/ntlm_test.go
Normal file
168
cmd/auth/ntlm/ntlm_test.go
Normal file
@@ -0,0 +1,168 @@
|
||||
package ntlm
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/auth/config"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/auth/database"
|
||||
"github.com/bolkedebruin/rdpgw/shared/auth"
|
||||
"github.com/m7913d/go-ntlm/ntlm"
|
||||
"testing"
|
||||
"log"
|
||||
)
|
||||
|
||||
func createTestDatabase () (database.Database) {
|
||||
user := config.UserConfig{}
|
||||
user.Username = "my_username"
|
||||
user.Password = "my_password"
|
||||
|
||||
var users = []config.UserConfig{}
|
||||
users = append(users, user)
|
||||
|
||||
config := database.NewConfig(users)
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
func TestNtlmValidCredentials(t *testing.T) {
|
||||
client := ntlm.V2ClientSession{}
|
||||
client.SetUserInfo("my_username", "my_password", "")
|
||||
|
||||
authenticateResponse := authenticate(t, &client)
|
||||
if !authenticateResponse.Authenticated {
|
||||
t.Errorf("Failed to authenticate")
|
||||
return
|
||||
}
|
||||
if authenticateResponse.Username != "my_username" {
|
||||
t.Errorf("Wrong username returned")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func TestNtlmInvalidPassword(t *testing.T) {
|
||||
client := ntlm.V2ClientSession{}
|
||||
client.SetUserInfo("my_username", "my_invalid_password", "")
|
||||
|
||||
authenticateResponse := authenticate(t, &client)
|
||||
if authenticateResponse.Authenticated {
|
||||
t.Errorf("Authenticated with wrong password")
|
||||
return
|
||||
}
|
||||
if authenticateResponse.Username != "" {
|
||||
t.Errorf("If authentication failed, no username should be returned")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func TestNtlmInvalidUsername(t *testing.T) {
|
||||
client := ntlm.V2ClientSession{}
|
||||
client.SetUserInfo("my_invalid_username", "my_password", "")
|
||||
|
||||
authenticateResponse := authenticate(t, &client)
|
||||
if authenticateResponse.Authenticated {
|
||||
t.Errorf("Authenticated with wrong password")
|
||||
return
|
||||
}
|
||||
if authenticateResponse.Username != "" {
|
||||
t.Errorf("If authentication failed, no username should be returned")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func authenticate(t *testing.T, client *ntlm.V2ClientSession) (*auth.NtlmResponse) {
|
||||
session := "X"
|
||||
database := createTestDatabase()
|
||||
|
||||
server := NewNTLMAuth(database)
|
||||
|
||||
negotiate, err := client.GenerateNegotiateMessage()
|
||||
if err != nil {
|
||||
t.Errorf("Could not generate negotiate message: %s", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
negotiateRequest := &auth.NtlmRequest{}
|
||||
negotiateRequest.Session = session
|
||||
negotiateRequest.NtlmMessage = base64.StdEncoding.EncodeToString(negotiate.Bytes())
|
||||
negotiateResponse, err := server.Authenticate(negotiateRequest)
|
||||
if err != nil {
|
||||
t.Errorf("Could not generate challenge message: %s", err)
|
||||
return nil
|
||||
}
|
||||
if negotiateResponse.Authenticated {
|
||||
t.Errorf("User should not be authenticated by after negotiate message")
|
||||
return nil
|
||||
}
|
||||
if negotiateResponse.NtlmMessage == "" {
|
||||
t.Errorf("Could not generate challenge message")
|
||||
return nil
|
||||
}
|
||||
|
||||
decodedChallenge, err := base64.StdEncoding.DecodeString(negotiateResponse.NtlmMessage)
|
||||
if err != nil {
|
||||
t.Errorf("Challenge should be base64 encoded: %s", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
challenge, err := ntlm.ParseChallengeMessage(decodedChallenge)
|
||||
if err != nil {
|
||||
t.Errorf("Invalid challenge message generated: %s", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
client.ProcessChallengeMessage(challenge)
|
||||
authenticate, err := client.GenerateAuthenticateMessage()
|
||||
if err != nil {
|
||||
t.Errorf("Could not generate authenticate message: %s", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
authenticateRequest := &auth.NtlmRequest{}
|
||||
authenticateRequest.Session = session
|
||||
authenticateRequest.NtlmMessage = base64.StdEncoding.EncodeToString(authenticate.Bytes())
|
||||
authenticateResponse, err := server.Authenticate(authenticateRequest)
|
||||
if err != nil {
|
||||
t.Errorf("Could not parse authenticate message: %s", err)
|
||||
return authenticateResponse
|
||||
}
|
||||
if authenticateResponse.NtlmMessage != "" {
|
||||
t.Errorf("Authenticate request should not generate a new NTLM message")
|
||||
return authenticateResponse
|
||||
}
|
||||
return authenticateResponse
|
||||
}
|
||||
|
||||
func TestInvalidBase64 (t *testing.T) {
|
||||
testInvalidDataBase(t, "X", "X") // not valid base64
|
||||
}
|
||||
|
||||
func TestInvalidData (t *testing.T) {
|
||||
testInvalidDataBase(t, "X", "XXXX") // valid base64
|
||||
}
|
||||
|
||||
func TestInvalidDataEmptyMessage (t *testing.T) {
|
||||
testInvalidDataBase(t, "X", "")
|
||||
}
|
||||
|
||||
func TestEmptySession (t *testing.T) {
|
||||
testInvalidDataBase(t, "", "XXXX")
|
||||
}
|
||||
|
||||
func testInvalidDataBase (t *testing.T, session string, data string) {
|
||||
database := createTestDatabase()
|
||||
server := NewNTLMAuth(database)
|
||||
|
||||
request := &auth.NtlmRequest{}
|
||||
request.Session = session
|
||||
request.NtlmMessage = data
|
||||
response, err := server.Authenticate(request)
|
||||
log.Printf("%s",err)
|
||||
if err == nil {
|
||||
t.Errorf("Invalid request should return an error")
|
||||
}
|
||||
if response.Authenticated {
|
||||
t.Errorf("User should not be authenticated using invalid data")
|
||||
}
|
||||
if response.NtlmMessage != "" {
|
||||
t.Errorf("No NTLM message should be generated for invalid data")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user