feat/mtls-support

This commit is contained in:
progressive-kiwi
2025-03-31 00:06:40 +02:00
parent 2ff8df9a8d
commit 9b3c82648b
5 changed files with 127 additions and 14 deletions

View File

@@ -2,16 +2,19 @@ package websocket
import (
"bytes"
"crypto/tls"
"crypto/x509"
"encoding/json"
"fmt"
"net/http"
"net/url"
"os"
"software.sslmate.com/src/go-pkcs12"
"strings"
"sync"
"time"
"github.com/fosrl/newt/logger"
"github.com/gorilla/websocket"
)
@@ -22,6 +25,7 @@ type Client struct {
handlers map[string]MessageHandler
done chan struct{}
handlersMux sync.RWMutex
tlsConfig *tls.Config
reconnectInterval time.Duration
isConnected bool
@@ -41,6 +45,12 @@ func WithBaseURL(url string) ClientOption {
}
}
func WithTLSConfig(tlsConfig *tls.Config) ClientOption {
return func(c *Client) {
c.tlsConfig = tlsConfig
}
}
func (c *Client) OnConnect(callback func() error) {
c.onConnect = callback
}
@@ -177,6 +187,12 @@ func (c *Client) getToken() (string, error) {
// Make the request
client := &http.Client{}
if c.tlsConfig != nil {
logger.Info("Adding tls to req")
client.Transport = &http.Transport{
TLSClientConfig: c.tlsConfig,
}
}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("failed to check token validity: %w", err)
@@ -220,6 +236,11 @@ func (c *Client) getToken() (string, error) {
// Make the request
client := &http.Client{}
if c.tlsConfig != nil {
client.Transport = &http.Transport{
TLSClientConfig: c.tlsConfig,
}
}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("failed to request new token: %w", err)
@@ -295,7 +316,11 @@ func (c *Client) establishConnection() error {
u.RawQuery = q.Encode()
// Connect to WebSocket
conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
dialer := websocket.DefaultDialer
if c.tlsConfig != nil {
dialer.TLSClientConfig = c.tlsConfig
}
conn, _, err := dialer.Dial(u.String(), nil)
if err != nil {
return fmt.Errorf("failed to connect to WebSocket: %w", err)
}
@@ -353,3 +378,41 @@ func (c *Client) setConnected(status bool) {
defer c.reconnectMux.Unlock()
c.isConnected = status
}
// LoadClientCertificate Helper method to load client certificates
func LoadClientCertificate(p12Path string) (*tls.Config, error) {
// Read the PKCS12 file
p12Data, err := os.ReadFile(p12Path)
if err != nil {
return nil, fmt.Errorf("failed to read PKCS12 file: %w", err)
}
// Parse PKCS12 with empty password for non-encrypted files
privateKey, certificate, caCerts, err := pkcs12.DecodeChain(p12Data, "")
if err != nil {
return nil, fmt.Errorf("failed to decode PKCS12: %w", err)
}
// Create certificate
cert := tls.Certificate{
Certificate: [][]byte{certificate.Raw},
PrivateKey: privateKey,
}
// Optional: Add CA certificates if present
rootCAs, err := x509.SystemCertPool()
if err != nil {
return nil, fmt.Errorf("failed to load system cert pool: %w", err)
}
if len(caCerts) > 0 {
for _, caCert := range caCerts {
rootCAs.AddCert(caCert)
}
}
// Create TLS configuration
return &tls.Config{
Certificates: []tls.Certificate{cert},
RootCAs: rootCAs,
}, nil
}