From 3d70ff190f8d54bad64fef265d261e9aa712fa92 Mon Sep 17 00:00:00 2001 From: Owen Date: Sat, 15 Mar 2025 21:46:54 -0400 Subject: [PATCH] Unix: handle encrypted messages --- common.go | 100 +++++++++++++++++++++++++++++++++++++++++--- main.go | 25 ++++++++++- websocket/client.go | 9 +++- websocket/types.go | 3 +- 4 files changed, 128 insertions(+), 9 deletions(-) diff --git a/common.go b/common.go index bd31ea5..f1c4efd 100644 --- a/common.go +++ b/common.go @@ -12,9 +12,12 @@ import ( "github.com/fosrl/newt/logger" "github.com/fosrl/olm/websocket" + "golang.org/x/crypto/chacha20poly1305" + "golang.org/x/crypto/curve25519" "golang.org/x/exp/rand" "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) type WgData struct { @@ -34,9 +37,25 @@ type TargetData struct { Targets []string `json:"targets"` } +type HolePunchMessage struct { + NewtID string `json:"newtId"` +} + +type HolePunchData struct { + ServerPubKey string `json:"serverPubKey"` +} + +type EncryptedHolePunchMessage struct { + EphemeralPublicKey string `json:"ephemeralPublicKey"` + Nonce []byte `json:"nonce"` + Ciphertext []byte `json:"ciphertext"` +} + var ( - stopHolepunch chan struct{} - stopRegister chan struct{} + stopHolepunch chan struct{} + stopRegister chan struct{} + olmToken string + gerbilServerPubKey string ) const ( @@ -155,8 +174,12 @@ func resolveDomain(domain string) (string, error) { return ipAddr, nil } -// TODO: we need to send the token with this probably to verify auth func sendUDPHolePunch(serverAddr string, olmID string, sourcePort uint16) error { + + if gerbilServerPubKey == "" || olmToken == "" { + return nil + } + // Bind to specific local port localAddr := &net.UDPAddr{ Port: int(sourcePort), @@ -176,16 +199,30 @@ func sendUDPHolePunch(serverAddr string, olmID string, sourcePort uint16) error payload := struct { OlmID string `json:"olmId"` + Token string `json:"token"` }{ OlmID: olmID, + Token: olmToken, } - data, err := json.Marshal(payload) + // Convert payload to JSON + payloadBytes, err := json.Marshal(payload) if err != nil { return fmt.Errorf("failed to marshal payload: %v", err) } - _, err = conn.WriteToUDP(data, remoteAddr) + // Encrypt the payload using the server's WireGuard public key + encryptedPayload, err := encryptPayload(payloadBytes, gerbilServerPubKey) + if err != nil { + return fmt.Errorf("failed to encrypt payload: %v", err) + } + + jsonData, err := json.Marshal(encryptedPayload) + if err != nil { + return fmt.Errorf("failed to marshal encrypted payload: %v", err) + } + + _, err = conn.WriteToUDP(jsonData, remoteAddr) if err != nil { return fmt.Errorf("failed to send UDP packet: %v", err) } @@ -193,6 +230,59 @@ func sendUDPHolePunch(serverAddr string, olmID string, sourcePort uint16) error return nil } +func encryptPayload(payload []byte, serverPublicKey string) (interface{}, error) { + // Generate an ephemeral keypair for this message + ephemeralPrivateKey, err := wgtypes.GeneratePrivateKey() + if err != nil { + return nil, fmt.Errorf("failed to generate ephemeral private key: %v", err) + } + ephemeralPublicKey := ephemeralPrivateKey.PublicKey() + + // Parse the server's public key + serverPubKey, err := wgtypes.ParseKey(serverPublicKey) + if err != nil { + return nil, fmt.Errorf("failed to parse server public key: %v", err) + } + + // Use X25519 for key exchange (replacing deprecated ScalarMult) + var ephPrivKeyFixed [32]byte + copy(ephPrivKeyFixed[:], ephemeralPrivateKey[:]) + + // Perform X25519 key exchange + sharedSecret, err := curve25519.X25519(ephPrivKeyFixed[:], serverPubKey[:]) + if err != nil { + return nil, fmt.Errorf("failed to perform X25519 key exchange: %v", err) + } + + // Create an AEAD cipher using the shared secret + aead, err := chacha20poly1305.New(sharedSecret) + if err != nil { + return nil, fmt.Errorf("failed to create AEAD cipher: %v", err) + } + + // Generate a random nonce + nonce := make([]byte, aead.NonceSize()) + if _, err := rand.Read(nonce); err != nil { + return nil, fmt.Errorf("failed to generate nonce: %v", err) + } + + // Encrypt the payload + ciphertext := aead.Seal(nil, nonce, payload, nil) + + // Prepare the final encrypted message + encryptedMsg := struct { + EphemeralPublicKey string `json:"ephemeralPublicKey"` + Nonce []byte `json:"nonce"` + Ciphertext []byte `json:"ciphertext"` + }{ + EphemeralPublicKey: ephemeralPublicKey.String(), + Nonce: nonce, + Ciphertext: ciphertext, + } + + return encryptedMsg, nil +} + func keepSendingUDPHolePunch(endpoint string, olmID string, sourcePort uint16) { host, err := resolveDomain(endpoint) if err != nil { diff --git a/main.go b/main.go index 1789106..930025c 100644 --- a/main.go +++ b/main.go @@ -214,6 +214,7 @@ func main() { // Create TUN device and network stack var dev *device.Device var wgData WgData + var holePunchData HolePunchData var uapi *os.File var tdev tun.Device @@ -426,6 +427,23 @@ persistent_keepalive_interval=1`, fixKey(privateKey.String()), fixKey(wgData.Pub logger.Info("WireGuard device created.") }) + olm.RegisterHandler("olm/wg/holepunch", func(msg websocket.WSMessage) { + logger.Info("Received message: %v", msg.Data) + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Info("Error marshaling data: %v", err) + return + } + + if err := json.Unmarshal(jsonData, &holePunchData); err != nil { + logger.Info("Error unmarshaling target data: %v", err) + return + } + + gerbilServerPubKey = holePunchData.ServerPubKey + }) + olm.OnConnect(func() error { publicKey := privateKey.PublicKey() logger.Debug("Public key: %s", publicKey) @@ -436,8 +454,9 @@ persistent_keepalive_interval=1`, fixKey(privateKey.String()), fixKey(wgData.Pub return nil }) - // start sending UDP hole punch - go keepSendingUDPHolePunch(endpoint, id, sourcePort) + olm.OnTokenUpdate(func(token string) { + olmToken = token + }) // Connect to the WebSocket server if err := olm.Connect(); err != nil { @@ -445,6 +464,8 @@ persistent_keepalive_interval=1`, fixKey(privateKey.String()), fixKey(wgData.Pub } defer olm.Close() + go keepSendingUDPHolePunch(endpoint, id, sourcePort) + // Wait for interrupt signal sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) diff --git a/websocket/client.go b/websocket/client.go index 7e13606..9725c50 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -27,7 +27,8 @@ type Client struct { isConnected bool reconnectMux sync.RWMutex - onConnect func() error + onConnect func() error + onTokenUpdate func(token string) } type ClientOption func(*Client) @@ -45,6 +46,10 @@ func (c *Client) OnConnect(callback func() error) { c.onConnect = callback } +func (c *Client) OnTokenUpdate(callback func(token string)) { + c.onTokenUpdate = callback +} + // NewClient creates a new Olm client func NewClient(olmID, secret string, endpoint string, opts ...ClientOption) (*Client, error) { config := &Config{ @@ -266,6 +271,8 @@ func (c *Client) establishConnection() error { return fmt.Errorf("failed to get token: %w", err) } + c.onTokenUpdate(token) + // Parse the base URL to determine protocol and hostname baseURL, err := url.Parse(c.baseURL) if err != nil { diff --git a/websocket/types.go b/websocket/types.go index 7786745..e93c9f9 100644 --- a/websocket/types.go +++ b/websocket/types.go @@ -9,7 +9,8 @@ type Config struct { type TokenResponse struct { Data struct { - Token string `json:"token"` + Token string `json:"token"` + ServerPubKey string `json:"serverPubKey"` } `json:"data"` Success bool `json:"success"` Message string `json:"message"`