mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 15:26:40 +00:00
389 lines
10 KiB
Go
389 lines
10 KiB
Go
package inspect
|
|
|
|
import (
|
|
"crypto/aes"
|
|
"crypto/cipher"
|
|
"crypto/sha256"
|
|
"encoding/binary"
|
|
"fmt"
|
|
"io"
|
|
|
|
"golang.org/x/crypto/hkdf"
|
|
|
|
"github.com/netbirdio/netbird/shared/management/domain"
|
|
)
|
|
|
|
// QUIC version constants
|
|
const (
|
|
quicV1Version uint32 = 0x00000001
|
|
quicV2Version uint32 = 0x6b3343cf
|
|
)
|
|
|
|
// quicV1Salt is the initial salt for QUIC v1 (RFC 9001 Section 5.2).
|
|
var quicV1Salt = []byte{
|
|
0x38, 0x76, 0x2c, 0xf7, 0xf5, 0x59, 0x34, 0xb3,
|
|
0x4d, 0x17, 0x9a, 0xe6, 0xa4, 0xc8, 0x0c, 0xad,
|
|
0xcc, 0xbb, 0x7f, 0x0a,
|
|
}
|
|
|
|
// quicV2Salt is the initial salt for QUIC v2 (RFC 9369).
|
|
var quicV2Salt = []byte{
|
|
0x0d, 0xed, 0xe3, 0xde, 0xf7, 0x00, 0xa6, 0xdb,
|
|
0x81, 0x93, 0x81, 0xbe, 0x6e, 0x26, 0x9d, 0xcb,
|
|
0xf9, 0xbd, 0x2e, 0xd9,
|
|
}
|
|
|
|
// ExtractQUICSNI extracts the SNI from a QUIC Initial packet.
|
|
// The Initial packet's encryption uses well-known keys derived from the
|
|
// Destination Connection ID, so any observer can decrypt it (by design).
|
|
func ExtractQUICSNI(data []byte) (domain.Domain, error) {
|
|
if len(data) < 5 {
|
|
return "", fmt.Errorf("packet too short")
|
|
}
|
|
|
|
// Check for QUIC Long Header (form bit set)
|
|
if data[0]&0x80 == 0 {
|
|
return "", fmt.Errorf("not a QUIC long header packet")
|
|
}
|
|
|
|
// Version
|
|
version := binary.BigEndian.Uint32(data[1:5])
|
|
|
|
var salt []byte
|
|
var initialLabel, keyLabel, ivLabel, hpLabel string
|
|
|
|
switch version {
|
|
case quicV1Version:
|
|
salt = quicV1Salt
|
|
initialLabel = "client in"
|
|
keyLabel = "quic key"
|
|
ivLabel = "quic iv"
|
|
hpLabel = "quic hp"
|
|
case quicV2Version:
|
|
salt = quicV2Salt
|
|
initialLabel = "client in"
|
|
keyLabel = "quicv2 key"
|
|
ivLabel = "quicv2 iv"
|
|
hpLabel = "quicv2 hp"
|
|
default:
|
|
return "", fmt.Errorf("unsupported QUIC version: 0x%08x", version)
|
|
}
|
|
|
|
// Parse Long Header
|
|
if len(data) < 6 {
|
|
return "", fmt.Errorf("packet too short for DCID length")
|
|
}
|
|
dcidLen := int(data[5])
|
|
if len(data) < 6+dcidLen+1 {
|
|
return "", fmt.Errorf("packet too short for DCID")
|
|
}
|
|
dcid := data[6 : 6+dcidLen]
|
|
|
|
scidLenOff := 6 + dcidLen
|
|
scidLen := int(data[scidLenOff])
|
|
tokenLenOff := scidLenOff + 1 + scidLen
|
|
|
|
if tokenLenOff >= len(data) {
|
|
return "", fmt.Errorf("packet too short for token length")
|
|
}
|
|
|
|
// Token length is a variable-length integer
|
|
tokenLen, tokenLenSize, err := readVarInt(data[tokenLenOff:])
|
|
if err != nil {
|
|
return "", fmt.Errorf("read token length: %w", err)
|
|
}
|
|
|
|
payloadLenOff := tokenLenOff + tokenLenSize + int(tokenLen)
|
|
if payloadLenOff >= len(data) {
|
|
return "", fmt.Errorf("packet too short for payload length")
|
|
}
|
|
|
|
// Payload length is a variable-length integer
|
|
payloadLen, payloadLenSize, err := readVarInt(data[payloadLenOff:])
|
|
if err != nil {
|
|
return "", fmt.Errorf("read payload length: %w", err)
|
|
}
|
|
|
|
pnOffset := payloadLenOff + payloadLenSize
|
|
if pnOffset+4 > len(data) {
|
|
return "", fmt.Errorf("packet too short for packet number")
|
|
}
|
|
|
|
// Derive initial keys
|
|
clientKey, clientIV, clientHP, err := deriveInitialKeys(dcid, salt, initialLabel, keyLabel, ivLabel, hpLabel)
|
|
if err != nil {
|
|
return "", fmt.Errorf("derive initial keys: %w", err)
|
|
}
|
|
|
|
// Remove header protection
|
|
sampleOffset := pnOffset + 4 // sample starts 4 bytes after pn offset
|
|
if sampleOffset+16 > len(data) {
|
|
return "", fmt.Errorf("packet too short for HP sample")
|
|
}
|
|
sample := data[sampleOffset : sampleOffset+16]
|
|
|
|
hpBlock, err := aes.NewCipher(clientHP)
|
|
if err != nil {
|
|
return "", fmt.Errorf("create HP cipher: %w", err)
|
|
}
|
|
|
|
mask := make([]byte, 16)
|
|
hpBlock.Encrypt(mask, sample)
|
|
|
|
// Unmask header byte
|
|
header := make([]byte, len(data))
|
|
copy(header, data)
|
|
header[0] ^= mask[0] & 0x0f // Long header: low 4 bits
|
|
|
|
// Determine packet number length
|
|
pnLen := int(header[0]&0x03) + 1
|
|
|
|
// Unmask packet number
|
|
for i := 0; i < pnLen; i++ {
|
|
header[pnOffset+i] ^= mask[1+i]
|
|
}
|
|
|
|
// Reconstruct packet number
|
|
var pn uint32
|
|
for i := 0; i < pnLen; i++ {
|
|
pn = (pn << 8) | uint32(header[pnOffset+i])
|
|
}
|
|
|
|
// Build nonce
|
|
nonce := make([]byte, len(clientIV))
|
|
copy(nonce, clientIV)
|
|
for i := 0; i < 4; i++ {
|
|
nonce[len(nonce)-1-i] ^= byte(pn >> (8 * i))
|
|
}
|
|
|
|
// Decrypt payload
|
|
block, err := aes.NewCipher(clientKey)
|
|
if err != nil {
|
|
return "", fmt.Errorf("create AES cipher: %w", err)
|
|
}
|
|
|
|
aead, err := cipher.NewGCM(block)
|
|
if err != nil {
|
|
return "", fmt.Errorf("create AEAD: %w", err)
|
|
}
|
|
|
|
encryptedPayload := header[pnOffset+pnLen : pnOffset+int(payloadLen)]
|
|
aad := header[:pnOffset+pnLen]
|
|
|
|
plaintext, err := aead.Open(nil, nonce, encryptedPayload, aad)
|
|
if err != nil {
|
|
return "", fmt.Errorf("decrypt QUIC payload: %w", err)
|
|
}
|
|
|
|
// Parse CRYPTO frames to extract ClientHello
|
|
clientHello, err := extractCryptoFrames(plaintext)
|
|
if err != nil {
|
|
return "", fmt.Errorf("extract CRYPTO frames: %w", err)
|
|
}
|
|
|
|
info, err := parseHelloBody(clientHello)
|
|
return info.SNI, err
|
|
}
|
|
|
|
// deriveInitialKeys derives the client's initial encryption keys from the DCID.
|
|
func deriveInitialKeys(dcid, salt []byte, initialLabel, keyLabel, ivLabel, hpLabel string) (key, iv, hp []byte, err error) {
|
|
// initial_secret = HKDF-Extract(salt, DCID)
|
|
initialSecret := hkdf.Extract(sha256.New, dcid, salt)
|
|
|
|
// client_initial_secret = HKDF-Expand-Label(initial_secret, initialLabel, "", 32)
|
|
clientSecret, err := hkdfExpandLabel(initialSecret, initialLabel, nil, 32)
|
|
if err != nil {
|
|
return nil, nil, nil, fmt.Errorf("derive client secret: %w", err)
|
|
}
|
|
|
|
// client_key = HKDF-Expand-Label(client_secret, keyLabel, "", 16)
|
|
key, err = hkdfExpandLabel(clientSecret, keyLabel, nil, 16)
|
|
if err != nil {
|
|
return nil, nil, nil, fmt.Errorf("derive key: %w", err)
|
|
}
|
|
|
|
// client_iv = HKDF-Expand-Label(client_secret, ivLabel, "", 12)
|
|
iv, err = hkdfExpandLabel(clientSecret, ivLabel, nil, 12)
|
|
if err != nil {
|
|
return nil, nil, nil, fmt.Errorf("derive IV: %w", err)
|
|
}
|
|
|
|
// client_hp = HKDF-Expand-Label(client_secret, hpLabel, "", 16)
|
|
hp, err = hkdfExpandLabel(clientSecret, hpLabel, nil, 16)
|
|
if err != nil {
|
|
return nil, nil, nil, fmt.Errorf("derive HP key: %w", err)
|
|
}
|
|
|
|
return key, iv, hp, nil
|
|
}
|
|
|
|
// hkdfExpandLabel implements TLS 1.3 HKDF-Expand-Label.
|
|
func hkdfExpandLabel(secret []byte, label string, context []byte, length int) ([]byte, error) {
|
|
// HkdfLabel = struct {
|
|
// uint16 length;
|
|
// opaque label<7..255> = "tls13 " + Label;
|
|
// opaque context<0..255> = Context;
|
|
// }
|
|
fullLabel := "tls13 " + label
|
|
|
|
hkdfLabel := make([]byte, 2+1+len(fullLabel)+1+len(context))
|
|
binary.BigEndian.PutUint16(hkdfLabel[0:2], uint16(length))
|
|
hkdfLabel[2] = byte(len(fullLabel))
|
|
copy(hkdfLabel[3:], fullLabel)
|
|
hkdfLabel[3+len(fullLabel)] = byte(len(context))
|
|
if len(context) > 0 {
|
|
copy(hkdfLabel[4+len(fullLabel):], context)
|
|
}
|
|
|
|
expander := hkdf.Expand(sha256.New, secret, hkdfLabel)
|
|
out := make([]byte, length)
|
|
if _, err := io.ReadFull(expander, out); err != nil {
|
|
return nil, err
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
// maxCryptoFrameSize limits total CRYPTO frame data to prevent memory exhaustion.
|
|
const maxCryptoFrameSize = 64 * 1024
|
|
|
|
// extractCryptoFrames reassembles CRYPTO frame data from QUIC frames.
|
|
func extractCryptoFrames(frames []byte) ([]byte, error) {
|
|
var result []byte
|
|
pos := 0
|
|
|
|
for pos < len(frames) {
|
|
frameType := frames[pos]
|
|
|
|
switch {
|
|
case frameType == 0x00:
|
|
// PADDING frame
|
|
pos++
|
|
|
|
case frameType == 0x06:
|
|
// CRYPTO frame
|
|
pos++
|
|
|
|
offset, n, err := readVarInt(frames[pos:])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("read crypto offset: %w", err)
|
|
}
|
|
pos += n
|
|
_ = offset // We assume ordered, offset 0 for Initial
|
|
|
|
dataLen, n, err := readVarInt(frames[pos:])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("read crypto data length: %w", err)
|
|
}
|
|
pos += n
|
|
|
|
end := pos + int(dataLen)
|
|
if end > len(frames) {
|
|
return nil, fmt.Errorf("CRYPTO frame data truncated")
|
|
}
|
|
|
|
result = append(result, frames[pos:end]...)
|
|
if len(result) > maxCryptoFrameSize {
|
|
return nil, fmt.Errorf("CRYPTO frame data exceeds %d bytes", maxCryptoFrameSize)
|
|
}
|
|
pos = end
|
|
|
|
case frameType == 0x01:
|
|
// PING frame
|
|
pos++
|
|
|
|
case frameType == 0x02 || frameType == 0x03:
|
|
// ACK frame - skip
|
|
pos++
|
|
// Largest Acknowledged
|
|
_, n, err := readVarInt(frames[pos:])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("read ACK: %w", err)
|
|
}
|
|
pos += n
|
|
// ACK Delay
|
|
_, n, err = readVarInt(frames[pos:])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("read ACK delay: %w", err)
|
|
}
|
|
pos += n
|
|
// ACK Range Count
|
|
rangeCount, n, err := readVarInt(frames[pos:])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("read ACK range count: %w", err)
|
|
}
|
|
pos += n
|
|
// First ACK Range
|
|
_, n, err = readVarInt(frames[pos:])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("read first ACK range: %w", err)
|
|
}
|
|
pos += n
|
|
// Additional ranges
|
|
for i := uint64(0); i < rangeCount; i++ {
|
|
_, n, err = readVarInt(frames[pos:])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("read ACK gap: %w", err)
|
|
}
|
|
pos += n
|
|
_, n, err = readVarInt(frames[pos:])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("read ACK range: %w", err)
|
|
}
|
|
pos += n
|
|
}
|
|
// ECN counts for type 0x03
|
|
if frameType == 0x03 {
|
|
for range 3 {
|
|
_, n, err = readVarInt(frames[pos:])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("read ECN count: %w", err)
|
|
}
|
|
pos += n
|
|
}
|
|
}
|
|
|
|
default:
|
|
// Unknown frame type, stop parsing
|
|
if len(result) > 0 {
|
|
return result, nil
|
|
}
|
|
return nil, fmt.Errorf("unknown QUIC frame type: 0x%02x at offset %d", frameType, pos)
|
|
}
|
|
}
|
|
|
|
if len(result) == 0 {
|
|
return nil, fmt.Errorf("no CRYPTO frames found")
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
// readVarInt reads a QUIC variable-length integer.
|
|
// Returns (value, bytes consumed, error).
|
|
func readVarInt(data []byte) (uint64, int, error) {
|
|
if len(data) == 0 {
|
|
return 0, 0, fmt.Errorf("empty data for varint")
|
|
}
|
|
|
|
prefix := data[0] >> 6
|
|
length := 1 << prefix
|
|
|
|
if len(data) < length {
|
|
return 0, 0, fmt.Errorf("varint truncated: need %d, have %d", length, len(data))
|
|
}
|
|
|
|
var val uint64
|
|
switch length {
|
|
case 1:
|
|
val = uint64(data[0] & 0x3f)
|
|
case 2:
|
|
val = uint64(binary.BigEndian.Uint16(data[:2])) & 0x3fff
|
|
case 4:
|
|
val = uint64(binary.BigEndian.Uint32(data[:4])) & 0x3fffffff
|
|
case 8:
|
|
val = binary.BigEndian.Uint64(data[:8]) & 0x3fffffffffffffff
|
|
}
|
|
|
|
return val, length, nil
|
|
}
|