diff --git a/relay/relay.go b/relay/relay.go index 396c196..8611241 100644 --- a/relay/relay.go +++ b/relay/relay.go @@ -12,17 +12,27 @@ import ( "time" "github.com/fosrl/gerbil/logger" + "golang.org/x/crypto/chacha20poly1305" + "golang.org/x/crypto/curve25519" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) +type EncryptedHolePunchMessage struct { + EphemeralPublicKey string `json:"ephemeralPublicKey"` + Nonce []byte `json:"nonce"` + Ciphertext []byte `json:"ciphertext"` +} + type HolePunchMessage struct { OlmID string `json:"olmId"` NewtID string `json:"newtId"` + Token string `json:"token"` } type ClientEndpoint struct { OlmID string `json:"olmId"` NewtID string `json:"newtId"` + Token string `json:"token"` IP string `json:"ip"` Port int `json:"port"` Timestamp int64 `json:"timestamp"` @@ -58,23 +68,23 @@ var bufferPool = sync.Pool{ }, } -// UDPProxyServer now has a channel for incoming packets. +// UDPProxyServer has a channel for incoming packets. type UDPProxyServer struct { addr string serverURL string conn *net.UDPConn proxyMappings sync.Map // map[string]ProxyMapping where key is "ip:port" connections sync.Map // map[string]*DestinationConn where key is destination "ip:port" - publicKey wgtypes.Key + privateKey wgtypes.Key packetChan chan Packet } // NewUDPProxyServer initializes the server with a buffered packet channel. -func NewUDPProxyServer(addr, serverURL string, publicKey wgtypes.Key) *UDPProxyServer { +func NewUDPProxyServer(addr, serverURL string, privateKey wgtypes.Key) *UDPProxyServer { return &UDPProxyServer{ addr: addr, serverURL: serverURL, - publicKey: publicKey, + privateKey: privateKey, packetChan: make(chan Packet, 1000), } } @@ -137,32 +147,91 @@ func (s *UDPProxyServer) packetWorker() { // Process as a WireGuard packet. s.handleWireGuardPacket(packet.data, packet.remoteAddr) } else { - // Process as a hole punch message. - var msg HolePunchMessage - if err := json.Unmarshal(packet.data, &msg); err != nil { - logger.Error("Error unmarshaling message: %v", err) - } else { - endpoint := ClientEndpoint{ - OlmID: msg.OlmID, - NewtID: msg.NewtID, - IP: packet.remoteAddr.IP.String(), - Port: packet.remoteAddr.Port, - Timestamp: time.Now().Unix(), - } - // You can call notifyServer synchronously here or dispatch further if needed. - s.notifyServer(endpoint) + // Process as an encrypted hole punch message + var encMsg EncryptedHolePunchMessage + if err := json.Unmarshal(packet.data, &encMsg); err != nil { + logger.Error("Error unmarshaling encrypted message: %v", err) + // Return the buffer to the pool for reuse and continue with next packet + bufferPool.Put(packet.data[:1500]) + continue } + + if encMsg.EphemeralPublicKey == "" { + logger.Error("Received malformed message without ephemeral key") + // Return the buffer to the pool for reuse and continue with next packet + bufferPool.Put(packet.data[:1500]) + continue + } + + // This appears to be an encrypted message + decryptedData, err := s.decryptMessage(encMsg) + if err != nil { + logger.Error("Failed to decrypt message: %v", err) + // Return the buffer to the pool for reuse and continue with next packet + bufferPool.Put(packet.data[:1500]) + continue + } + + // Process the decrypted hole punch message + var msg HolePunchMessage + if err := json.Unmarshal(decryptedData, &msg); err != nil { + logger.Error("Error unmarshaling decrypted message: %v", err) + // Return the buffer to the pool for reuse and continue with next packet + bufferPool.Put(packet.data[:1500]) + continue + } + + endpoint := ClientEndpoint{ + NewtID: msg.NewtID, + OlmID: msg.OlmID, + Token: msg.Token, + IP: packet.remoteAddr.IP.String(), + Port: packet.remoteAddr.Port, + Timestamp: time.Now().Unix(), + } + s.notifyServer(endpoint) } // Return the buffer to the pool for reuse. bufferPool.Put(packet.data[:1500]) } } -// --- The remaining methods remain largely the same --- -// For example: fetchInitialMappings, handleWireGuardPacket, getOrCreateConnection, etc. +// decryptMessage decrypts the message using the server's private key +func (s *UDPProxyServer) decryptMessage(encMsg EncryptedHolePunchMessage) ([]byte, error) { + // Parse the ephemeral public key + ephPubKey, err := wgtypes.ParseKey(encMsg.EphemeralPublicKey) + if err != nil { + return nil, fmt.Errorf("failed to parse ephemeral public key: %v", err) + } + + // Use X25519 for key exchange instead of ScalarMult + sharedSecret, err := curve25519.X25519(s.privateKey[:], ephPubKey[:]) + if err != nil { + return nil, fmt.Errorf("failed to perform X25519 key exchange: %v", err) + } + + // Create the AEAD cipher using the shared secret + aead, err := chacha20poly1305.New(sharedSecret) + if err != nil { + return nil, fmt.Errorf("failed to create AEAD cipher: %v", err) + } + + // Verify nonce size + if len(encMsg.Nonce) != aead.NonceSize() { + return nil, fmt.Errorf("invalid nonce size") + } + + // Decrypt the ciphertext + plaintext, err := aead.Open(nil, encMsg.Nonce, encMsg.Ciphertext, nil) + if err != nil { + return nil, fmt.Errorf("failed to decrypt message: %v", err) + } + + return plaintext, nil +} func (s *UDPProxyServer) fetchInitialMappings() error { - body := bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s"}`, s.publicKey.PublicKey().String()))) + body := bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s"}`, s.privateKey.PublicKey().String()))) resp, err := http.Post(s.serverURL+"/gerbil/get-all-relays", "application/json", body) if err != nil { return fmt.Errorf("failed to fetch mappings: %v", err) @@ -305,6 +374,8 @@ func (s *UDPProxyServer) notifyServer(endpoint ClientEndpoint) { return } + logger.Debug("Received proxy mapping: %v", mapping) + // Store the mapping key := fmt.Sprintf("%s:%d", endpoint.IP, endpoint.Port) s.proxyMappings.Store(key, mapping)