mirror of
https://github.com/fosrl/olm.git
synced 2026-02-08 05:56:41 +00:00
Unix: handle encrypted messages
This commit is contained in:
100
common.go
100
common.go
@@ -12,9 +12,12 @@ import (
|
|||||||
|
|
||||||
"github.com/fosrl/newt/logger"
|
"github.com/fosrl/newt/logger"
|
||||||
"github.com/fosrl/olm/websocket"
|
"github.com/fosrl/olm/websocket"
|
||||||
|
"golang.org/x/crypto/chacha20poly1305"
|
||||||
|
"golang.org/x/crypto/curve25519"
|
||||||
"golang.org/x/exp/rand"
|
"golang.org/x/exp/rand"
|
||||||
"golang.zx2c4.com/wireguard/conn"
|
"golang.zx2c4.com/wireguard/conn"
|
||||||
"golang.zx2c4.com/wireguard/device"
|
"golang.zx2c4.com/wireguard/device"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
)
|
)
|
||||||
|
|
||||||
type WgData struct {
|
type WgData struct {
|
||||||
@@ -34,9 +37,25 @@ type TargetData struct {
|
|||||||
Targets []string `json:"targets"`
|
Targets []string `json:"targets"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type HolePunchMessage struct {
|
||||||
|
NewtID string `json:"newtId"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type HolePunchData struct {
|
||||||
|
ServerPubKey string `json:"serverPubKey"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type EncryptedHolePunchMessage struct {
|
||||||
|
EphemeralPublicKey string `json:"ephemeralPublicKey"`
|
||||||
|
Nonce []byte `json:"nonce"`
|
||||||
|
Ciphertext []byte `json:"ciphertext"`
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
stopHolepunch chan struct{}
|
stopHolepunch chan struct{}
|
||||||
stopRegister chan struct{}
|
stopRegister chan struct{}
|
||||||
|
olmToken string
|
||||||
|
gerbilServerPubKey string
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -155,8 +174,12 @@ func resolveDomain(domain string) (string, error) {
|
|||||||
return ipAddr, nil
|
return ipAddr, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: we need to send the token with this probably to verify auth
|
|
||||||
func sendUDPHolePunch(serverAddr string, olmID string, sourcePort uint16) error {
|
func sendUDPHolePunch(serverAddr string, olmID string, sourcePort uint16) error {
|
||||||
|
|
||||||
|
if gerbilServerPubKey == "" || olmToken == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Bind to specific local port
|
// Bind to specific local port
|
||||||
localAddr := &net.UDPAddr{
|
localAddr := &net.UDPAddr{
|
||||||
Port: int(sourcePort),
|
Port: int(sourcePort),
|
||||||
@@ -176,16 +199,30 @@ func sendUDPHolePunch(serverAddr string, olmID string, sourcePort uint16) error
|
|||||||
|
|
||||||
payload := struct {
|
payload := struct {
|
||||||
OlmID string `json:"olmId"`
|
OlmID string `json:"olmId"`
|
||||||
|
Token string `json:"token"`
|
||||||
}{
|
}{
|
||||||
OlmID: olmID,
|
OlmID: olmID,
|
||||||
|
Token: olmToken,
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := json.Marshal(payload)
|
// Convert payload to JSON
|
||||||
|
payloadBytes, err := json.Marshal(payload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to marshal payload: %v", err)
|
return fmt.Errorf("failed to marshal payload: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = conn.WriteToUDP(data, remoteAddr)
|
// Encrypt the payload using the server's WireGuard public key
|
||||||
|
encryptedPayload, err := encryptPayload(payloadBytes, gerbilServerPubKey)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to encrypt payload: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(encryptedPayload)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal encrypted payload: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = conn.WriteToUDP(jsonData, remoteAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to send UDP packet: %v", err)
|
return fmt.Errorf("failed to send UDP packet: %v", err)
|
||||||
}
|
}
|
||||||
@@ -193,6 +230,59 @@ func sendUDPHolePunch(serverAddr string, olmID string, sourcePort uint16) error
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func encryptPayload(payload []byte, serverPublicKey string) (interface{}, error) {
|
||||||
|
// Generate an ephemeral keypair for this message
|
||||||
|
ephemeralPrivateKey, err := wgtypes.GeneratePrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate ephemeral private key: %v", err)
|
||||||
|
}
|
||||||
|
ephemeralPublicKey := ephemeralPrivateKey.PublicKey()
|
||||||
|
|
||||||
|
// Parse the server's public key
|
||||||
|
serverPubKey, err := wgtypes.ParseKey(serverPublicKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse server public key: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use X25519 for key exchange (replacing deprecated ScalarMult)
|
||||||
|
var ephPrivKeyFixed [32]byte
|
||||||
|
copy(ephPrivKeyFixed[:], ephemeralPrivateKey[:])
|
||||||
|
|
||||||
|
// Perform X25519 key exchange
|
||||||
|
sharedSecret, err := curve25519.X25519(ephPrivKeyFixed[:], serverPubKey[:])
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to perform X25519 key exchange: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create an AEAD cipher using the shared secret
|
||||||
|
aead, err := chacha20poly1305.New(sharedSecret)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create AEAD cipher: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate a random nonce
|
||||||
|
nonce := make([]byte, aead.NonceSize())
|
||||||
|
if _, err := rand.Read(nonce); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate nonce: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encrypt the payload
|
||||||
|
ciphertext := aead.Seal(nil, nonce, payload, nil)
|
||||||
|
|
||||||
|
// Prepare the final encrypted message
|
||||||
|
encryptedMsg := struct {
|
||||||
|
EphemeralPublicKey string `json:"ephemeralPublicKey"`
|
||||||
|
Nonce []byte `json:"nonce"`
|
||||||
|
Ciphertext []byte `json:"ciphertext"`
|
||||||
|
}{
|
||||||
|
EphemeralPublicKey: ephemeralPublicKey.String(),
|
||||||
|
Nonce: nonce,
|
||||||
|
Ciphertext: ciphertext,
|
||||||
|
}
|
||||||
|
|
||||||
|
return encryptedMsg, nil
|
||||||
|
}
|
||||||
|
|
||||||
func keepSendingUDPHolePunch(endpoint string, olmID string, sourcePort uint16) {
|
func keepSendingUDPHolePunch(endpoint string, olmID string, sourcePort uint16) {
|
||||||
host, err := resolveDomain(endpoint)
|
host, err := resolveDomain(endpoint)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
25
main.go
25
main.go
@@ -214,6 +214,7 @@ func main() {
|
|||||||
// Create TUN device and network stack
|
// Create TUN device and network stack
|
||||||
var dev *device.Device
|
var dev *device.Device
|
||||||
var wgData WgData
|
var wgData WgData
|
||||||
|
var holePunchData HolePunchData
|
||||||
var uapi *os.File
|
var uapi *os.File
|
||||||
var tdev tun.Device
|
var tdev tun.Device
|
||||||
|
|
||||||
@@ -426,6 +427,23 @@ persistent_keepalive_interval=1`, fixKey(privateKey.String()), fixKey(wgData.Pub
|
|||||||
logger.Info("WireGuard device created.")
|
logger.Info("WireGuard device created.")
|
||||||
})
|
})
|
||||||
|
|
||||||
|
olm.RegisterHandler("olm/wg/holepunch", func(msg websocket.WSMessage) {
|
||||||
|
logger.Info("Received message: %v", msg.Data)
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
|
if err != nil {
|
||||||
|
logger.Info("Error marshaling data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(jsonData, &holePunchData); err != nil {
|
||||||
|
logger.Info("Error unmarshaling target data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
gerbilServerPubKey = holePunchData.ServerPubKey
|
||||||
|
})
|
||||||
|
|
||||||
olm.OnConnect(func() error {
|
olm.OnConnect(func() error {
|
||||||
publicKey := privateKey.PublicKey()
|
publicKey := privateKey.PublicKey()
|
||||||
logger.Debug("Public key: %s", publicKey)
|
logger.Debug("Public key: %s", publicKey)
|
||||||
@@ -436,8 +454,9 @@ persistent_keepalive_interval=1`, fixKey(privateKey.String()), fixKey(wgData.Pub
|
|||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
// start sending UDP hole punch
|
olm.OnTokenUpdate(func(token string) {
|
||||||
go keepSendingUDPHolePunch(endpoint, id, sourcePort)
|
olmToken = token
|
||||||
|
})
|
||||||
|
|
||||||
// Connect to the WebSocket server
|
// Connect to the WebSocket server
|
||||||
if err := olm.Connect(); err != nil {
|
if err := olm.Connect(); err != nil {
|
||||||
@@ -445,6 +464,8 @@ persistent_keepalive_interval=1`, fixKey(privateKey.String()), fixKey(wgData.Pub
|
|||||||
}
|
}
|
||||||
defer olm.Close()
|
defer olm.Close()
|
||||||
|
|
||||||
|
go keepSendingUDPHolePunch(endpoint, id, sourcePort)
|
||||||
|
|
||||||
// Wait for interrupt signal
|
// Wait for interrupt signal
|
||||||
sigCh := make(chan os.Signal, 1)
|
sigCh := make(chan os.Signal, 1)
|
||||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
|||||||
@@ -27,7 +27,8 @@ type Client struct {
|
|||||||
isConnected bool
|
isConnected bool
|
||||||
reconnectMux sync.RWMutex
|
reconnectMux sync.RWMutex
|
||||||
|
|
||||||
onConnect func() error
|
onConnect func() error
|
||||||
|
onTokenUpdate func(token string)
|
||||||
}
|
}
|
||||||
|
|
||||||
type ClientOption func(*Client)
|
type ClientOption func(*Client)
|
||||||
@@ -45,6 +46,10 @@ func (c *Client) OnConnect(callback func() error) {
|
|||||||
c.onConnect = callback
|
c.onConnect = callback
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) OnTokenUpdate(callback func(token string)) {
|
||||||
|
c.onTokenUpdate = callback
|
||||||
|
}
|
||||||
|
|
||||||
// NewClient creates a new Olm client
|
// NewClient creates a new Olm client
|
||||||
func NewClient(olmID, secret string, endpoint string, opts ...ClientOption) (*Client, error) {
|
func NewClient(olmID, secret string, endpoint string, opts ...ClientOption) (*Client, error) {
|
||||||
config := &Config{
|
config := &Config{
|
||||||
@@ -266,6 +271,8 @@ func (c *Client) establishConnection() error {
|
|||||||
return fmt.Errorf("failed to get token: %w", err)
|
return fmt.Errorf("failed to get token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.onTokenUpdate(token)
|
||||||
|
|
||||||
// Parse the base URL to determine protocol and hostname
|
// Parse the base URL to determine protocol and hostname
|
||||||
baseURL, err := url.Parse(c.baseURL)
|
baseURL, err := url.Parse(c.baseURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -9,7 +9,8 @@ type Config struct {
|
|||||||
|
|
||||||
type TokenResponse struct {
|
type TokenResponse struct {
|
||||||
Data struct {
|
Data struct {
|
||||||
Token string `json:"token"`
|
Token string `json:"token"`
|
||||||
|
ServerPubKey string `json:"serverPubKey"`
|
||||||
} `json:"data"`
|
} `json:"data"`
|
||||||
Success bool `json:"success"`
|
Success bool `json:"success"`
|
||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
|
|||||||
Reference in New Issue
Block a user