missed files for client side and shared files

This commit is contained in:
crn4
2026-05-12 16:54:56 +02:00
parent 582cd70086
commit 728057ef15
5 changed files with 1472 additions and 0 deletions

View File

@@ -0,0 +1,186 @@
package grpc
import (
"testing"
"github.com/stretchr/testify/assert"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
)
// TestComputeSSHEnabledForPeer covers both Calculate-mirroring branches:
// explicit NetbirdSSH protocol, and the legacy implicit case where a
// TCP/22 (or 22022 / ALL / port-range-covering-22) rule activates SSH when
// the destination peer has SSHEnabled=true locally. Belt-and-suspenders for
// the B1 fix that the prod-DB equivalence test alone wouldn't have caught
// if no account had this combination.
func TestComputeSSHEnabledForPeer(t *testing.T) {
const targetPeerID = "target"
const targetGroupID = "g_dst"
mkComponents := func(rule *types.PolicyRule, sshEnabled bool) (*types.NetworkMapComponents, *nbpeer.Peer) {
peer := &nbpeer.Peer{ID: targetPeerID, SSHEnabled: sshEnabled}
group := &types.Group{ID: targetGroupID, Name: "dst", Peers: []string{targetPeerID}}
return &types.NetworkMapComponents{
Peers: map[string]*nbpeer.Peer{targetPeerID: peer},
Groups: map[string]*types.Group{targetGroupID: group},
Policies: []*types.Policy{{
ID: "p",
Enabled: true,
Rules: []*types.PolicyRule{rule},
}},
}, peer
}
cases := []struct {
name string
peerSSH bool
rule types.PolicyRule
wantEnabled bool
}{
{
name: "explicit-netbird-ssh-activates-regardless-of-peer-ssh",
peerSSH: false,
rule: types.PolicyRule{
Enabled: true, Protocol: types.PolicyRuleProtocolNetbirdSSH,
Destinations: []string{targetGroupID},
},
wantEnabled: true,
},
{
name: "implicit-tcp-22-with-peer-ssh",
peerSSH: true,
rule: types.PolicyRule{
Enabled: true, Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"22"},
Destinations: []string{targetGroupID},
},
wantEnabled: true,
},
{
name: "implicit-tcp-22-without-peer-ssh-disabled",
peerSSH: false,
rule: types.PolicyRule{
Enabled: true, Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"22"},
Destinations: []string{targetGroupID},
},
wantEnabled: false,
},
{
name: "implicit-tcp-22022-with-peer-ssh",
peerSSH: true,
rule: types.PolicyRule{
Enabled: true, Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"22022"},
Destinations: []string{targetGroupID},
},
wantEnabled: true,
},
{
name: "implicit-all-protocol-with-peer-ssh",
peerSSH: true,
rule: types.PolicyRule{
Enabled: true, Protocol: types.PolicyRuleProtocolALL,
Destinations: []string{targetGroupID},
},
wantEnabled: true,
},
{
name: "implicit-port-range-covers-22",
peerSSH: true,
rule: types.PolicyRule{
Enabled: true,
Protocol: types.PolicyRuleProtocolTCP,
PortRanges: []types.RulePortRange{{Start: 20, End: 30}},
Destinations: []string{targetGroupID},
},
wantEnabled: true,
},
{
name: "tcp-80-no-ssh",
peerSSH: true,
rule: types.PolicyRule{
Enabled: true, Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"80"},
Destinations: []string{targetGroupID},
},
wantEnabled: false,
},
{
name: "disabled-rule-skipped",
peerSSH: true,
rule: types.PolicyRule{
Enabled: false, Protocol: types.PolicyRuleProtocolNetbirdSSH,
Destinations: []string{targetGroupID},
},
wantEnabled: false,
},
{
name: "peer-not-in-destinations",
peerSSH: true,
rule: types.PolicyRule{
Enabled: true, Protocol: types.PolicyRuleProtocolNetbirdSSH,
Destinations: []string{"g_other"}, // target not in this group
},
wantEnabled: false,
},
{
name: "peer-typed-destination-resource-matches",
peerSSH: false,
rule: types.PolicyRule{
Enabled: true,
Protocol: types.PolicyRuleProtocolNetbirdSSH,
DestinationResource: types.Resource{ID: targetPeerID, Type: types.ResourceTypePeer},
},
wantEnabled: true,
},
{
name: "non-peer-destination-resource-falls-through-to-groups",
peerSSH: false,
rule: types.PolicyRule{
Enabled: true,
Protocol: types.PolicyRuleProtocolNetbirdSSH,
DestinationResource: types.Resource{ID: targetPeerID, Type: "host"}, // wrong type
Destinations: []string{targetGroupID}, // saved by group fallback
},
wantEnabled: true,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
c, peer := mkComponents(&tc.rule, tc.peerSSH)
got := computeSSHEnabledForPeer(c, peer)
assert.Equal(t, tc.wantEnabled, got)
})
}
}
// TestComputeSSHEnabledForPeer_TargetMissingFromComponents covers the
// belt-and-suspenders presence guard mirroring Calculate's
// getAllPeersFromGroups invariant.
func TestComputeSSHEnabledForPeer_TargetMissingFromComponents(t *testing.T) {
peer := &nbpeer.Peer{ID: "missing", SSHEnabled: true}
c := &types.NetworkMapComponents{
Peers: map[string]*nbpeer.Peer{}, // target peer NOT present
Groups: map[string]*types.Group{
"g": {ID: "g", Peers: []string{"missing"}},
},
Policies: []*types.Policy{{
ID: "p", Enabled: true,
Rules: []*types.PolicyRule{{
Enabled: true, Protocol: types.PolicyRuleProtocolNetbirdSSH,
Destinations: []string{"g"},
}},
}},
}
assert.False(t, computeSSHEnabledForPeer(c, peer),
"missing target peer must short-circuit to false, not consult policies")
}
// TestComputeSSHEnabledForPeer_NilInputs guards the cheap nil-checks at
// function entry — Calculate doesn't accept nil either, but the helper is
// exported indirectly via ToComponentSyncResponse and may receive nil
// components on graceful-degrade paths.
func TestComputeSSHEnabledForPeer_NilInputs(t *testing.T) {
assert.False(t, computeSSHEnabledForPeer(nil, &nbpeer.Peer{ID: "x"}))
assert.False(t, computeSSHEnabledForPeer(&types.NetworkMapComponents{}, nil))
}

View File

@@ -0,0 +1,586 @@
package networkmap
import (
"encoding/base64"
"fmt"
"net"
"net/netip"
"strconv"
"time"
nbdns "github.com/netbirdio/netbird/dns"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
nbroute "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/proto"
)
// DecodeEnvelope converts a NetworkMapEnvelope into a NetworkMapComponents
// the client can run Calculate() over. Every ID-reference on the wire is a
// uint32 (peer index or account_seq_id) — no xid strings travel. The decoder
// synthesises consistent string IDs from the uint32s so the reconstructed
// components struct round-trips through Calculate exactly the way the
// server-side typed components would.
//
// Synthetic ID scheme (underscore-separated, visually distinct from the xid
// format Calculate would put in log lines under the legacy path):
//
// Peers "p_<wire_index>" // envelope.peers is index-addressed
// Groups "g_<account_seq_id>"
// Policies "pol_<account_seq_id>" // 1 rule per policy
// Routes "r_<account_seq_id>"
// Network resources "nres_<account_seq_id>"
// Posture checks "pc_<account_seq_id>"
// Networks "net_<account_seq_id>"
// Nameserver groups "nsg_<account_seq_id>"
func DecodeEnvelope(env *proto.NetworkMapEnvelope) (*types.NetworkMapComponents, error) {
if env == nil {
return nil, fmt.Errorf("nil envelope")
}
full := env.GetFull()
if full == nil {
return nil, fmt.Errorf("envelope has no Full payload")
}
c := &types.NetworkMapComponents{
PeerID: "", // engine fills its own peer id from PeerConfig
Network: decodeAccountNetwork(full.Network),
AccountSettings: decodeAccountSettings(full.AccountSettings),
CustomZoneDomain: full.CustomZoneDomain,
Peers: make(map[string]*nbpeer.Peer, len(full.Peers)),
Groups: make(map[string]*types.Group, len(full.Groups)),
Policies: make([]*types.Policy, 0, len(full.Policies)),
Routes: make([]*nbroute.Route, 0, len(full.Routes)),
NameServerGroups: make([]*nbdns.NameServerGroup, 0, len(full.NameserverGroups)),
AllDNSRecords: decodeSimpleRecords(full.AllDnsRecords),
AccountZones: decodeCustomZones(full.AccountZones),
ResourcePoliciesMap: make(map[string][]*types.Policy),
RoutersMap: make(map[string]map[string]*routerTypes.NetworkRouter),
NetworkResources: make([]*resourceTypes.NetworkResource, 0, len(full.NetworkResources)),
RouterPeers: make(map[string]*nbpeer.Peer),
AllowedUserIDs: stringSliceToSet(full.AllowedUserIds),
PostureFailedPeers: make(map[string]map[string]struct{}, len(full.PostureFailedPeers)),
GroupIDToUserIDs: make(map[string][]string, len(full.GroupIdToUserIds)),
}
if full.DnsSettings != nil {
c.DNSSettings = &types.DNSSettings{
DisabledManagementGroups: groupIDsFromSeqs(full.DnsSettings.DisabledManagementGroupIds),
}
} else {
c.DNSSettings = &types.DNSSettings{}
}
// Phase 1: peers. The envelope's peers slice is index-addressed; we
// build a peerOrder lookup for downstream references. Peer.ID is
// synthesized from the peer's wire index — wire format ships no xid
// for peers (and never has).
peerIDByIndex := make([]string, len(full.Peers))
for idx, pc := range full.Peers {
peerID := synthPeerID(uint32(idx))
peer := decodePeerCompact(pc, peerID, full.AgentVersions)
c.Peers[peerID] = peer
peerIDByIndex[idx] = peerID
}
// Phase 2: groups. AccountSeqID becomes both the synthesized string ID
// and the GroupCompact.id wire value.
for _, gc := range full.Groups {
groupID := synthGroupID(gc.Id)
peerIDs := make([]string, 0, len(gc.PeerIndexes))
for _, idx := range gc.PeerIndexes {
if int(idx) < len(peerIDByIndex) {
peerIDs = append(peerIDs, peerIDByIndex[idx])
}
}
c.Groups[groupID] = &types.Group{
ID: groupID,
AccountSeqID: gc.Id,
Name: gc.Name,
Peers: peerIDs,
}
}
// Phase 3: policies (PolicyCompact = one rule per entry; current data
// model is 1 rule per policy). Policy.ID is synthesized from the
// per-account seq id; proto.FirewallRule.PolicyID downstream carries
// the same synth string (no xid on the wire).
for _, pc := range full.Policies {
policyID := synthPolicyID(pc.Id)
c.Policies = append(c.Policies, decodePolicyCompact(pc, policyID, peerIDByIndex))
}
// Phase 4: routes.
for _, rr := range full.Routes {
c.Routes = append(c.Routes, decodeRouteRaw(rr, peerIDByIndex))
}
// Phase 5: NSGs.
for _, nsg := range full.NameserverGroups {
c.NameServerGroups = append(c.NameServerGroups, decodeNameServerGroupRaw(nsg))
}
// Phase 6: network resources.
for _, nr := range full.NetworkResources {
c.NetworkResources = append(c.NetworkResources, decodeNetworkResource(nr))
}
// Phase 7: routers_map (outer key = network seq id, inner key = peer-id
// reconstructed from peer_index). Synthesized network id is "net_<seq>".
for networkSeq, list := range full.RoutersMap {
networkID := synthNetworkID(networkSeq)
inner := make(map[string]*routerTypes.NetworkRouter, len(list.Entries))
for _, entry := range list.Entries {
if !entry.PeerIndexSet {
continue
}
if int(entry.PeerIndex) >= len(peerIDByIndex) {
continue
}
peerID := peerIDByIndex[entry.PeerIndex]
inner[peerID] = &routerTypes.NetworkRouter{
ID: "",
NetworkID: networkID,
AccountSeqID: entry.Id,
Peer: peerID,
PeerGroups: groupIDsFromSeqs(entry.PeerGroupIds),
Masquerade: entry.Masquerade,
Metric: int(entry.Metric),
Enabled: entry.Enabled,
}
}
if len(inner) > 0 {
c.RoutersMap[networkID] = inner
}
}
// Phase 8: resource_policies_map (resource seq id → list of *types.Policy
// pointers from the decoded policies slice). Resource ID is synthesized
// the same way as in decodeNetworkResource.
for resourceSeq, idxs := range full.ResourcePoliciesMap {
if len(idxs.Indexes) == 0 {
continue
}
resourceID := synthNetworkResourceID(resourceSeq)
policies := make([]*types.Policy, 0, len(idxs.Indexes))
for _, i := range idxs.Indexes {
if int(i) < len(c.Policies) {
policies = append(policies, c.Policies[i])
}
}
if len(policies) > 0 {
c.ResourcePoliciesMap[resourceID] = policies
}
}
// Phase 9: group_id_to_user_ids — wire keys are seq ids, synth to strings.
for groupSeq, list := range full.GroupIdToUserIds {
c.GroupIDToUserIDs[synthGroupID(groupSeq)] = append([]string(nil), list.UserIds...)
}
// Phase 10: posture_failed_peers — wire keys are posture-check seq ids,
// values are peer indexes that need to be turned into peer ids. PolicyRule
// SourcePostureChecks (also synth ids) reference the same key space.
for checkSeq, set := range full.PostureFailedPeers {
checkID := synthPostureCheckID(checkSeq)
failed := make(map[string]struct{}, len(set.PeerIndexes))
for _, idx := range set.PeerIndexes {
if int(idx) < len(peerIDByIndex) {
failed[peerIDByIndex[idx]] = struct{}{}
}
}
if len(failed) > 0 {
c.PostureFailedPeers[checkID] = failed
}
}
// Phase 11: router_peer_indexes — peers that act as routers. They're
// already in c.Peers (router peers are appended to the global peers
// list by the encoder); RouterPeers is the subset.
for _, idx := range full.RouterPeerIndexes {
if int(idx) < len(peerIDByIndex) {
peerID := peerIDByIndex[idx]
c.RouterPeers[peerID] = c.Peers[peerID]
}
}
return c, nil
}
func decodeAccountNetwork(an *proto.AccountNetwork) *types.Network {
if an == nil {
return nil
}
n := &types.Network{
Identifier: an.Identifier,
Dns: an.Dns,
Serial: an.Serial,
}
if an.NetCidr != "" {
if _, ipnet, err := net.ParseCIDR(an.NetCidr); err == nil && ipnet != nil {
n.Net = *ipnet
}
}
if an.NetV6Cidr != "" {
if _, ipnet, err := net.ParseCIDR(an.NetV6Cidr); err == nil && ipnet != nil {
n.NetV6 = *ipnet
}
}
return n
}
func decodeAccountSettings(as *proto.AccountSettingsCompact) *types.AccountSettingsInfo {
if as == nil {
return &types.AccountSettingsInfo{}
}
return &types.AccountSettingsInfo{
PeerLoginExpirationEnabled: as.PeerLoginExpirationEnabled,
PeerLoginExpiration: time.Duration(as.PeerLoginExpirationNs),
}
}
func decodePeerCompact(pc *proto.PeerCompact, peerID string, agentVersions []string) *nbpeer.Peer {
var caps []int32
if pc.SupportsSourcePrefixes {
caps = append(caps, nbpeer.PeerCapabilitySourcePrefixes)
}
if pc.SupportsIpv6 {
caps = append(caps, nbpeer.PeerCapabilityIPv6Overlay)
}
peer := &nbpeer.Peer{
ID: peerID,
Key: encodeWgKeyBase64(pc.WgPubKey),
SSHKey: string(pc.SshPubKey),
SSHEnabled: pc.SshEnabled,
DNSLabel: pc.DnsLabel,
LoginExpirationEnabled: pc.LoginExpirationEnabled,
Meta: nbpeer.PeerSystemMeta{
WtVersion: lookupAgentVersion(agentVersions, pc.AgentVersionIdx),
Capabilities: caps,
Flags: nbpeer.Flags{
ServerSSHAllowed: pc.ServerSshAllowed,
},
},
}
if pc.AddedWithSsoLogin {
// Set a non-empty UserID so (*Peer).AddedWithSSOLogin() returns true.
// The original UserID isn't on the wire; the value is intentionally
// visibly synthetic so any future consumer that mistakes UserID for a
// real account user xid won't silently match (or worse, write the
// sentinel into a downstream record).
peer.UserID = "<env-sso>"
}
if pc.LastLoginUnixNano != 0 {
t := time.Unix(0, pc.LastLoginUnixNano)
peer.LastLogin = &t
}
switch len(pc.Ip) {
case 4:
peer.IP = netip.AddrFrom4([4]byte{pc.Ip[0], pc.Ip[1], pc.Ip[2], pc.Ip[3]})
case 16:
var a [16]byte
copy(a[:], pc.Ip)
peer.IP = netip.AddrFrom16(a)
}
if len(pc.Ipv6) == 16 {
var a [16]byte
copy(a[:], pc.Ipv6)
peer.IPv6 = netip.AddrFrom16(a)
}
return peer
}
func decodePolicyCompact(pc *proto.PolicyCompact, policyID string, peerIDByIndex []string) *types.Policy {
rule := &types.PolicyRule{
ID: policyID, // 1 rule per policy → reuse synthesized id
PolicyID: policyID,
Enabled: true,
Action: actionFromProto(pc.Action),
Protocol: protocolFromProto(pc.Protocol),
Bidirectional: pc.Bidirectional,
Ports: uint32SliceToStrings(pc.Ports),
PortRanges: portRangesFromProto(pc.PortRanges),
Sources: groupIDsFromSeqs(pc.SourceGroupIds),
Destinations: groupIDsFromSeqs(pc.DestinationGroupIds),
AuthorizedUser: pc.AuthorizedUser,
AuthorizedGroups: authorizedGroupsFromProto(pc.AuthorizedGroups),
SourceResource: resourceFromProto(pc.SourceResource, peerIDByIndex),
DestinationResource: resourceFromProto(pc.DestinationResource, peerIDByIndex),
}
return &types.Policy{
ID: policyID,
AccountSeqID: pc.Id,
Enabled: true,
Rules: []*types.PolicyRule{rule},
SourcePostureChecks: postureCheckIDsFromSeqs(pc.SourcePostureCheckSeqIds),
}
}
// resourceFromProto rebuilds types.Resource. For peer-typed resources the
// peer reference is reconstructed from the envelope's peer index — wire
// format ships no xid for peers, so we use the synthesized peer id.
func resourceFromProto(r *proto.ResourceCompact, peerIDByIndex []string) types.Resource {
if r == nil {
return types.Resource{}
}
out := types.Resource{Type: types.ResourceType(r.Type)}
if r.PeerIndexSet && int(r.PeerIndex) < len(peerIDByIndex) {
out.ID = peerIDByIndex[r.PeerIndex]
}
return out
}
// postureCheckIDsFromSeqs synths posture-check ids from per-account seq ids.
// Mirrors groupIDsFromSeqs.
func postureCheckIDsFromSeqs(seqs []uint32) []string {
if len(seqs) == 0 {
return nil
}
out := make([]string, len(seqs))
for i, s := range seqs {
out[i] = synthPostureCheckID(s)
}
return out
}
// authorizedGroupsFromProto inverts encodeAuthorizedGroups: the wire form
// keys by group account_seq_id, the typed PolicyRule field keys by group
// xid string. We rebuild using the same synthetic scheme the rest of the
// decoder uses ("g<seq>").
func authorizedGroupsFromProto(m map[uint32]*proto.UserNameList) map[string][]string {
if len(m) == 0 {
return nil
}
out := make(map[string][]string, len(m))
for seq, list := range m {
if list == nil {
continue
}
out[synthGroupID(seq)] = append([]string(nil), list.Names...)
}
return out
}
func decodeRouteRaw(rr *proto.RouteRaw, peerIDByIndex []string) *nbroute.Route {
r := &nbroute.Route{
ID: nbroute.ID(synthRouteID(rr.Id)),
AccountSeqID: rr.Id,
NetID: nbroute.NetID(rr.NetId),
Description: rr.Description,
Domains: domainsFromPunycode(rr.Domains),
KeepRoute: rr.KeepRoute,
NetworkType: nbroute.NetworkType(rr.NetworkType),
Masquerade: rr.Masquerade,
Metric: int(rr.Metric),
Enabled: rr.Enabled,
Groups: groupIDsFromSeqs(rr.GroupIds),
AccessControlGroups: groupIDsFromSeqs(rr.AccessControlGroupIds),
PeerGroups: groupIDsFromSeqs(rr.PeerGroupIds),
SkipAutoApply: rr.SkipAutoApply,
}
if rr.NetworkCidr != "" {
if p, err := netip.ParsePrefix(rr.NetworkCidr); err == nil {
r.Network = p
}
}
if rr.PeerIndexSet && int(rr.PeerIndex) < len(peerIDByIndex) {
r.Peer = peerIDByIndex[rr.PeerIndex]
}
return r
}
func decodeNameServerGroupRaw(nsg *proto.NameServerGroupRaw) *nbdns.NameServerGroup {
out := &nbdns.NameServerGroup{
ID: synthNameServerGroupID(nsg.Id),
AccountSeqID: nsg.Id,
Name: nsg.Name,
Description: nsg.Description,
Groups: groupIDsFromSeqs(nsg.GroupIds),
Primary: nsg.Primary,
Domains: nsg.Domains,
Enabled: nsg.Enabled,
SearchDomainsEnabled: nsg.SearchDomainsEnabled,
NameServers: make([]nbdns.NameServer, 0, len(nsg.Nameservers)),
}
for _, ns := range nsg.Nameservers {
if addr, err := netip.ParseAddr(ns.IP); err == nil {
out.NameServers = append(out.NameServers, nbdns.NameServer{
IP: addr,
NSType: nbdns.NameServerType(ns.NSType),
Port: int(ns.Port),
})
}
}
return out
}
func decodeNetworkResource(nr *proto.NetworkResourceRaw) *resourceTypes.NetworkResource {
out := &resourceTypes.NetworkResource{
ID: synthNetworkResourceID(nr.Id),
AccountSeqID: nr.Id,
NetworkID: synthNetworkID(nr.NetworkSeq),
Name: nr.Name,
Description: nr.Description,
Type: resourceTypes.NetworkResourceType(nr.Type),
Address: nr.Address,
Domain: nr.DomainValue,
Enabled: nr.Enabled,
}
if nr.PrefixCidr != "" {
if p, err := netip.ParsePrefix(nr.PrefixCidr); err == nil {
out.Prefix = p
}
}
return out
}
func decodeSimpleRecords(records []*proto.SimpleRecord) []nbdns.SimpleRecord {
out := make([]nbdns.SimpleRecord, 0, len(records))
for _, r := range records {
out = append(out, nbdns.SimpleRecord{
Name: r.Name,
Type: int(r.Type),
Class: r.Class,
TTL: int(r.TTL),
RData: r.RData,
})
}
return out
}
func decodeCustomZones(zones []*proto.CustomZone) []nbdns.CustomZone {
out := make([]nbdns.CustomZone, 0, len(zones))
for _, z := range zones {
out = append(out, nbdns.CustomZone{
Domain: z.Domain,
Records: decodeSimpleRecords(z.Records),
SearchDomainDisabled: z.SearchDomainDisabled,
NonAuthoritative: z.NonAuthoritative,
})
}
return out
}
// Synthetic ID generators — deterministic given the same wire input.
// Underscore-separated ("p_<n>", "pol_<n>", ...) so they're visually
// distinct in operator logs. fmt.Sprintf would dominate the decode hot path
// on large accounts (a 10k-peer envelope produces ~50k synth calls); the
// strconv.AppendUint builder keeps it allocation-light.
func synthID(prefix string, n uint32) string {
buf := make([]byte, 0, len(prefix)+10)
buf = append(buf, prefix...)
buf = strconv.AppendUint(buf, uint64(n), 10)
return string(buf)
}
func synthPeerID(idx uint32) string { return synthID("p_", idx) }
func synthGroupID(seq uint32) string { return synthID("g_", seq) }
func synthPolicyID(seq uint32) string { return synthID("pol_", seq) }
func synthRouteID(seq uint32) string { return synthID("r_", seq) }
func synthNetworkResourceID(seq uint32) string { return synthID("nres_", seq) }
func synthPostureCheckID(seq uint32) string { return synthID("pc_", seq) }
func synthNetworkID(seq uint32) string { return synthID("net_", seq) }
func synthNameServerGroupID(seq uint32) string { return synthID("nsg_", seq) }
func groupIDsFromSeqs(seqs []uint32) []string {
if len(seqs) == 0 {
return nil
}
out := make([]string, len(seqs))
for i, s := range seqs {
out[i] = synthGroupID(s)
}
return out
}
func uint32SliceToStrings(ports []uint32) []string {
if len(ports) == 0 {
return nil
}
out := make([]string, len(ports))
for i, p := range ports {
out[i] = strconv.FormatUint(uint64(p), 10)
}
return out
}
func portRangesFromProto(ranges []*proto.PortInfo_Range) []types.RulePortRange {
if len(ranges) == 0 {
return nil
}
out := make([]types.RulePortRange, 0, len(ranges))
for _, r := range ranges {
out = append(out, types.RulePortRange{
Start: uint16(r.Start),
End: uint16(r.End),
})
}
return out
}
func actionFromProto(a proto.RuleAction) types.PolicyTrafficActionType {
if a == proto.RuleAction_DROP {
return types.PolicyTrafficActionDrop
}
return types.PolicyTrafficActionAccept
}
func protocolFromProto(p proto.RuleProtocol) types.PolicyRuleProtocolType {
switch p {
case proto.RuleProtocol_TCP:
return types.PolicyRuleProtocolTCP
case proto.RuleProtocol_UDP:
return types.PolicyRuleProtocolUDP
case proto.RuleProtocol_ICMP:
return types.PolicyRuleProtocolICMP
case proto.RuleProtocol_ALL:
return types.PolicyRuleProtocolALL
case proto.RuleProtocol_NETBIRD_SSH:
return types.PolicyRuleProtocolNetbirdSSH
default:
return types.PolicyRuleProtocolALL
}
}
func encodeWgKeyBase64(raw []byte) string {
if len(raw) != 32 {
return ""
}
return base64.StdEncoding.EncodeToString(raw)
}
func lookupAgentVersion(table []string, idx uint32) string {
if int(idx) < len(table) {
return table[idx]
}
return ""
}
func stringSliceToSet(s []string) map[string]struct{} {
if len(s) == 0 {
return nil
}
out := make(map[string]struct{}, len(s))
for _, v := range s {
out[v] = struct{}{}
}
return out
}
// domainsFromPunycode is a thin wrapper that converts a punycode list back to
// the domain.List type the route.Route struct expects. It accepts the
// punycode strings as-is (no extra decoding) — symmetric with
// route.Domains.ToPunycodeList() used in the encoder.
func domainsFromPunycode(punycoded []string) domain.List {
if len(punycoded) == 0 {
return nil
}
out := make(domain.List, 0, len(punycoded))
for _, d := range punycoded {
out = append(out, domain.Domain(d))
}
return out
}

View File

@@ -0,0 +1,323 @@
// Package networkmap contains the shared NetworkMap helpers that both the
// management server and the client agent need.
//
// The proto-conversion helpers (types.NetworkMap → proto.NetworkMap) live
// here so the client can run the same conversion locally after deriving its
// NetworkMap from a NetworkMapEnvelope, without taking a dependency on the
// server-side conversion package (which pulls in cloud integrations and is
// otherwise an unwanted internal import on the client).
//
// The helpers are pure functions over inputs — no caches, no IO, no logging
// beyond a context-aware error log when an individual user-id hash fails.
package networkmap
import (
"context"
log "github.com/sirupsen/logrus"
goproto "google.golang.org/protobuf/proto"
nbdns "github.com/netbirdio/netbird/dns"
"net/netip"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
nbroute "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/netiputil"
"github.com/netbirdio/netbird/shared/sshauth"
)
// ToProtocolRoutes converts a slice of typed routes to their proto form.
func ToProtocolRoutes(routes []*nbroute.Route) []*proto.Route {
protoRoutes := make([]*proto.Route, 0, len(routes))
for _, r := range routes {
protoRoutes = append(protoRoutes, ToProtocolRoute(r))
}
return protoRoutes
}
// ToProtocolRoute converts one typed route to its proto form.
func ToProtocolRoute(route *nbroute.Route) *proto.Route {
return &proto.Route{
ID: string(route.ID),
NetID: string(route.NetID),
Network: route.Network.String(),
Domains: route.Domains.ToPunycodeList(),
NetworkType: int64(route.NetworkType),
Peer: route.Peer,
Metric: int64(route.Metric),
Masquerade: route.Masquerade,
KeepRoute: route.KeepRoute,
SkipAutoApply: route.SkipAutoApply,
}
}
// ToProtocolFirewallRules converts the firewall rules to the protocol form.
// When useSourcePrefixes is true, the compact SourcePrefixes field is
// populated alongside the deprecated PeerIP for forward compatibility.
// Wildcard rules ("0.0.0.0") are expanded into separate v4/v6 SourcePrefixes
// when includeIPv6 is true.
func ToProtocolFirewallRules(rules []*types.FirewallRule, includeIPv6, useSourcePrefixes bool) []*proto.FirewallRule {
result := make([]*proto.FirewallRule, 0, len(rules))
for i := range rules {
rule := rules[i]
fwRule := &proto.FirewallRule{
PolicyID: []byte(rule.PolicyID),
PeerIP: rule.PeerIP, //nolint:staticcheck // populated for backward compatibility
Direction: GetProtoDirection(rule.Direction),
Action: GetProtoAction(rule.Action),
Protocol: GetProtoProtocol(rule.Protocol),
Port: rule.Port,
}
if useSourcePrefixes && rule.PeerIP != "" {
result = append(result, populateSourcePrefixes(fwRule, rule, includeIPv6)...)
}
if ShouldUsePortRange(fwRule) {
fwRule.PortInfo = rule.PortRange.ToProto()
}
result = append(result, fwRule)
}
return result
}
// populateSourcePrefixes sets SourcePrefixes on fwRule and returns any
// additional rules needed (e.g. a v6 wildcard clone when the peer IP is
// unspecified).
func populateSourcePrefixes(fwRule *proto.FirewallRule, rule *types.FirewallRule, includeIPv6 bool) []*proto.FirewallRule {
addr, err := netip.ParseAddr(rule.PeerIP)
if err != nil {
return nil
}
if !addr.IsUnspecified() {
fwRule.SourcePrefixes = [][]byte{netiputil.EncodeAddr(addr.Unmap())}
return nil
}
v4Wildcard, _ := netiputil.EncodePrefix(netip.PrefixFrom(netip.IPv4Unspecified(), 0))
fwRule.SourcePrefixes = [][]byte{v4Wildcard}
if !includeIPv6 {
return nil
}
v6Rule := goproto.Clone(fwRule).(*proto.FirewallRule)
v6Rule.PeerIP = "::" //nolint:staticcheck // populated for backward compatibility
v6Wildcard, _ := netiputil.EncodePrefix(netip.PrefixFrom(netip.IPv6Unspecified(), 0))
v6Rule.SourcePrefixes = [][]byte{v6Wildcard}
if ShouldUsePortRange(v6Rule) {
v6Rule.PortInfo = rule.PortRange.ToProto()
}
return []*proto.FirewallRule{v6Rule}
}
// GetProtoDirection converts the direction to proto.RuleDirection.
func GetProtoDirection(direction int) proto.RuleDirection {
if direction == types.FirewallRuleDirectionOUT {
return proto.RuleDirection_OUT
}
return proto.RuleDirection_IN
}
// GetProtoAction converts the action to proto.RuleAction.
func GetProtoAction(action string) proto.RuleAction {
if action == string(types.PolicyTrafficActionDrop) {
return proto.RuleAction_DROP
}
return proto.RuleAction_ACCEPT
}
// GetProtoProtocol converts the protocol to proto.RuleProtocol.
func GetProtoProtocol(protocol string) proto.RuleProtocol {
switch types.PolicyRuleProtocolType(protocol) {
case types.PolicyRuleProtocolALL:
return proto.RuleProtocol_ALL
case types.PolicyRuleProtocolTCP:
return proto.RuleProtocol_TCP
case types.PolicyRuleProtocolUDP:
return proto.RuleProtocol_UDP
case types.PolicyRuleProtocolICMP:
return proto.RuleProtocol_ICMP
case types.PolicyRuleProtocolNetbirdSSH:
return proto.RuleProtocol_NETBIRD_SSH
default:
return proto.RuleProtocol_UNKNOWN
}
}
// GetProtoPortInfo converts route-firewall-rule port info to proto.PortInfo.
func GetProtoPortInfo(rule *types.RouteFirewallRule) *proto.PortInfo {
var portInfo proto.PortInfo
if rule.Port != 0 {
portInfo.PortSelection = &proto.PortInfo_Port{Port: uint32(rule.Port)}
} else if portRange := rule.PortRange; portRange.Start != 0 && portRange.End != 0 {
portInfo.PortSelection = &proto.PortInfo_Range_{
Range: &proto.PortInfo_Range{
Start: uint32(portRange.Start),
End: uint32(portRange.End),
},
}
}
return &portInfo
}
// ShouldUsePortRange reports whether the firewall rule should use a port
// range rather than a single port (TCP/UDP without a single port).
func ShouldUsePortRange(rule *proto.FirewallRule) bool {
return rule.Port == "" && (rule.Protocol == proto.RuleProtocol_UDP || rule.Protocol == proto.RuleProtocol_TCP)
}
// ToProtocolRoutesFirewallRules converts a slice of typed route-firewall
// rules to proto.
func ToProtocolRoutesFirewallRules(rules []*types.RouteFirewallRule) []*proto.RouteFirewallRule {
result := make([]*proto.RouteFirewallRule, len(rules))
for i := range rules {
rule := rules[i]
result[i] = &proto.RouteFirewallRule{
SourceRanges: rule.SourceRanges,
Action: GetProtoAction(rule.Action),
Destination: rule.Destination,
Protocol: GetProtoProtocol(rule.Protocol),
PortInfo: GetProtoPortInfo(rule),
IsDynamic: rule.IsDynamic,
Domains: rule.Domains.ToPunycodeList(),
PolicyID: []byte(rule.PolicyID),
RouteID: string(rule.RouteID),
}
}
return result
}
// ConvertToProtoCustomZone converts an nbdns.CustomZone to its proto form.
func ConvertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone {
protoZone := &proto.CustomZone{
Domain: zone.Domain,
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
SearchDomainDisabled: zone.SearchDomainDisabled,
NonAuthoritative: zone.NonAuthoritative,
}
for _, record := range zone.Records {
protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{
Name: record.Name,
Type: int64(record.Type),
Class: record.Class,
TTL: int64(record.TTL),
RData: record.RData,
})
}
return protoZone
}
// ConvertToProtoNameServerGroup converts a NameServerGroup to its proto form.
func ConvertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameServerGroup {
protoGroup := &proto.NameServerGroup{
Primary: nsGroup.Primary,
Domains: nsGroup.Domains,
SearchDomainsEnabled: nsGroup.SearchDomainsEnabled,
NameServers: make([]*proto.NameServer, 0, len(nsGroup.NameServers)),
}
for _, ns := range nsGroup.NameServers {
protoGroup.NameServers = append(protoGroup.NameServers, &proto.NameServer{
IP: ns.IP.String(),
Port: int64(ns.Port),
NSType: int64(ns.NSType),
})
}
return protoGroup
}
// DNSConfigCache is the cache contract for amortising NameServerGroup
// proto-conversion across peers in the same account. Server uses a concrete
// implementation; client passes nil (no cross-peer caching needed when
// rebuilding a single NetworkMap from an envelope).
type DNSConfigCache interface {
GetNameServerGroup(key string) (*proto.NameServerGroup, bool)
SetNameServerGroup(key string, value *proto.NameServerGroup)
}
// ToProtocolDNSConfig converts nbdns.Config to proto.DNSConfig. If cache is
// non-nil, NameServerGroup proto values are cached by NSG.ID across calls —
// the server amortises this across peers, the client passes nil.
func ToProtocolDNSConfig(update nbdns.Config, cache DNSConfigCache, forwardPort int64) *proto.DNSConfig {
protoUpdate := &proto.DNSConfig{
ServiceEnable: update.ServiceEnable,
CustomZones: make([]*proto.CustomZone, 0, len(update.CustomZones)),
NameServerGroups: make([]*proto.NameServerGroup, 0, len(update.NameServerGroups)),
ForwarderPort: forwardPort,
}
for _, zone := range update.CustomZones {
protoUpdate.CustomZones = append(protoUpdate.CustomZones, ConvertToProtoCustomZone(zone))
}
for _, nsGroup := range update.NameServerGroups {
if cache != nil {
if cachedGroup, exists := cache.GetNameServerGroup(nsGroup.ID); exists {
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, cachedGroup)
continue
}
}
protoGroup := ConvertToProtoNameServerGroup(nsGroup)
if cache != nil {
cache.SetNameServerGroup(nsGroup.ID, protoGroup)
}
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup)
}
return protoUpdate
}
// AppendRemotePeerConfig appends typed peers as proto.RemotePeerConfig
// entries to dst and returns the result.
func AppendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string, includeIPv6 bool) []*proto.RemotePeerConfig {
for _, rPeer := range peers {
allowedIPs := []string{rPeer.IP.String() + "/32"}
if includeIPv6 && rPeer.IPv6.IsValid() {
allowedIPs = append(allowedIPs, rPeer.IPv6.String()+"/128")
}
dst = append(dst, &proto.RemotePeerConfig{
WgPubKey: rPeer.Key,
AllowedIps: allowedIPs,
SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
Fqdn: rPeer.FQDN(dnsName),
AgentVersion: rPeer.Meta.WtVersion,
})
}
return dst
}
// BuildAuthorizedUsersProto deduplicates user-IDs into a hashed list and
// builds per-machine-user index maps. Returns (hashedUsers, machineUsers).
// Errors from individual hash failures are logged via the provided context;
// they leave the offending user out of the result but don't abort the build.
func BuildAuthorizedUsersProto(ctx context.Context, authorizedUsers map[string]map[string]struct{}) ([][]byte, map[string]*proto.MachineUserIndexes) {
userIDToIndex := make(map[string]uint32)
var hashedUsers [][]byte
machineUsers := make(map[string]*proto.MachineUserIndexes, len(authorizedUsers))
for machineUser, users := range authorizedUsers {
indexes := make([]uint32, 0, len(users))
for userID := range users {
idx, exists := userIDToIndex[userID]
if !exists {
hash, err := sshauth.HashUserID(userID)
if err != nil {
log.WithContext(ctx).Errorf("failed to hash user id %s: %v", userID, err)
continue
}
idx = uint32(len(hashedUsers))
userIDToIndex[userID] = idx
hashedUsers = append(hashedUsers, hash[:])
}
indexes = append(indexes, idx)
}
machineUsers[machineUser] = &proto.MachineUserIndexes{Indexes: indexes}
}
return hashedUsers, machineUsers
}

View File

@@ -0,0 +1,204 @@
package networkmap
import (
"bytes"
"context"
"encoding/base64"
"fmt"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
// EnvelopeResult is what the client engine consumes after receiving a
// component-format NetworkMap. Both fields are populated:
//
// - NetworkMap is the *proto.NetworkMap shape the engine reads today via
// update.GetNetworkMap() — built from the envelope's components by
// running Calculate() locally + converting back through the shared
// proto helpers + merging the optional ProxyPatch.
// - Components is the *types.NetworkMapComponents the engine retains so
// future incremental delta updates (Step 3) have a base to apply
// changes against. The client keeps it under its sync lock.
type EnvelopeResult struct {
NetworkMap *proto.NetworkMap
Components *types.NetworkMapComponents
}
// EnvelopeToNetworkMap is the full client-side pipeline: decode the
// component envelope back to a typed NetworkMapComponents, run Calculate()
// locally to produce the typed NetworkMap, convert it to the wire form the
// engine consumes, and fold in any ProxyPatch the server attached.
//
// localPeerKey is the receiving peer's WG pub key (used to derive
// includeIPv6 / useSourcePrefixes from the receiving peer's own record in
// the components struct, mirroring legacy ToSyncResponse behaviour).
//
// dnsName is the account's DNS domain ("netbird.cloud" etc.); used when
// rebuilding the per-peer FQDNs that proto.RemotePeerConfig carries.
func EnvelopeToNetworkMap(ctx context.Context, env *proto.NetworkMapEnvelope, localPeerKey, dnsName string) (*EnvelopeResult, error) {
components, err := DecodeEnvelope(env)
if err != nil {
return nil, fmt.Errorf("decode envelope: %w", err)
}
// Find the receiving peer in the decoded components by WG key so we can
// derive its capabilities and set components.PeerID for Calculate(). The
// envelope.peers list is index-addressed; we synthesized IDs as "p<idx>".
localPeerID, localPeer := findPeerByWgKey(components, localPeerKey)
if localPeer == nil {
return nil, fmt.Errorf("receiving peer (wg_key prefix %q) not found among %d decoded peers — components have no PeerID, Calculate would return empty", trimKey(localPeerKey), len(components.Peers))
}
components.PeerID = localPeerID
includeIPv6 := localPeer != nil && localPeer.SupportsIPv6() && localPeer.IPv6.IsValid()
useSourcePrefixes := localPeer != nil && localPeer.SupportsSourcePrefixes()
typedNM := components.Calculate(ctx)
full := env.GetFull()
dnsFwdPort := int64(0)
if full != nil {
dnsFwdPort = full.DnsForwarderPort
}
protoNM := &proto.NetworkMap{
Serial: typedNM.Network.CurrentSerial(),
}
if full != nil {
protoNM.PeerConfig = full.PeerConfig
}
protoNM.Routes = ToProtocolRoutes(typedNM.Routes)
protoNM.DNSConfig = ToProtocolDNSConfig(typedNM.DNSConfig, nil, dnsFwdPort)
remotePeers := AppendRemotePeerConfig(nil, typedNM.Peers, dnsName, includeIPv6)
protoNM.RemotePeers = remotePeers
protoNM.RemotePeersIsEmpty = len(remotePeers) == 0
protoNM.OfflinePeers = AppendRemotePeerConfig(nil, typedNM.OfflinePeers, dnsName, includeIPv6)
firewallRules := ToProtocolFirewallRules(typedNM.FirewallRules, includeIPv6, useSourcePrefixes)
protoNM.FirewallRules = firewallRules
protoNM.FirewallRulesIsEmpty = len(firewallRules) == 0
routesFirewallRules := ToProtocolRoutesFirewallRules(typedNM.RoutesFirewallRules)
protoNM.RoutesFirewallRules = routesFirewallRules
protoNM.RoutesFirewallRulesIsEmpty = len(routesFirewallRules) == 0
if typedNM.AuthorizedUsers != nil {
hashedUsers, machineUsers := BuildAuthorizedUsersProto(ctx, typedNM.AuthorizedUsers)
userIDClaim := ""
if full != nil {
userIDClaim = full.UserIdClaim
}
protoNM.SshAuth = &proto.SSHAuth{
AuthorizedUsers: hashedUsers,
MachineUsers: machineUsers,
UserIDClaim: userIDClaim,
}
}
if typedNM.ForwardingRules != nil {
forwardingRules := make([]*proto.ForwardingRule, 0, len(typedNM.ForwardingRules))
for _, rule := range typedNM.ForwardingRules {
forwardingRules = append(forwardingRules, rule.ToProto())
}
protoNM.ForwardingRules = forwardingRules
}
// Merge the proxy patch the server attached. Mirrors the legacy
// NetworkMap.Merge step that the server runs after Calculate().
if full != nil && full.ProxyPatch != nil {
mergeProxyPatch(protoNM, full.ProxyPatch)
}
return &EnvelopeResult{
NetworkMap: protoNM,
Components: components,
}, nil
}
// mergeProxyPatch folds a ProxyPatch's pre-expanded fragments into the
// proto.NetworkMap that Calculate() produced. Mirrors types.NetworkMap.Merge
// — same six collections, deduplicated where the legacy merge dedupes.
func mergeProxyPatch(nm *proto.NetworkMap, patch *proto.ProxyPatch) {
nm.RemotePeers = appendUniquePeers(nm.RemotePeers, patch.Peers)
nm.OfflinePeers = appendUniquePeers(nm.OfflinePeers, patch.OfflinePeers)
nm.FirewallRules = append(nm.FirewallRules, patch.FirewallRules...)
nm.Routes = append(nm.Routes, patch.Routes...)
nm.RoutesFirewallRules = append(nm.RoutesFirewallRules, patch.RouteFirewallRules...)
nm.ForwardingRules = append(nm.ForwardingRules, patch.ForwardingRules...)
if len(nm.RemotePeers) > 0 {
nm.RemotePeersIsEmpty = false
}
if len(nm.FirewallRules) > 0 {
nm.FirewallRulesIsEmpty = false
}
if len(nm.RoutesFirewallRules) > 0 {
nm.RoutesFirewallRulesIsEmpty = false
}
}
// appendUniquePeers dedupes by WgPubKey — mirrors legacy
// mergeUniquePeersByID's intent (legacy keyed off Peer.ID; in proto form the
// closest stable identifier is WgPubKey).
func appendUniquePeers(dst, extra []*proto.RemotePeerConfig) []*proto.RemotePeerConfig {
if len(extra) == 0 {
return dst
}
seen := make(map[string]struct{}, len(dst))
for _, p := range dst {
seen[p.WgPubKey] = struct{}{}
}
for _, p := range extra {
if _, ok := seen[p.WgPubKey]; ok {
continue
}
seen[p.WgPubKey] = struct{}{}
dst = append(dst, p)
}
return dst
}
func trimKey(s string) string {
if len(s) > 12 {
return s[:12]
}
return s
}
// findPeerByWgKey locates the receiving peer in the decoded components by
// matching its WireGuard public key. Compares raw 32-byte decode output —
// not the base64 string — because production data has occasional non-canonical
// padding bits that round-trip through the envelope's `bytes wg_pub_key`
// field, canonicalising the encoding (semantically equivalent key, different
// string). Decodes `wgKey` once up front and reuses a stack buffer in the
// loop so an N-peer search is ~zero-alloc.
func findPeerByWgKey(c *types.NetworkMapComponents, wgKey string) (string, *nbpeer.Peer) {
const wgKeyRawLen = 32
var (
targetRaw [wgKeyRawLen]byte
haveRaw bool
)
if n, err := base64.StdEncoding.Decode(targetRaw[:], []byte(wgKey)); err == nil && n == wgKeyRawLen {
haveRaw = true
}
var peerRaw [wgKeyRawLen]byte
for id, p := range c.Peers {
if p == nil {
continue
}
if p.Key == wgKey {
return id, p
}
if !haveRaw {
continue
}
n, err := base64.StdEncoding.Decode(peerRaw[:], []byte(p.Key))
if err == nil && n == wgKeyRawLen && bytes.Equal(peerRaw[:], targetRaw[:]) {
return id, p
}
}
return "", nil
}

View File

@@ -0,0 +1,173 @@
package networkmap_test
import (
"context"
"crypto/rand"
"encoding/base64"
"net"
"net/netip"
"testing"
"github.com/stretchr/testify/require"
goproto "google.golang.org/protobuf/proto"
mgmtgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
nbnetworkmap "github.com/netbirdio/netbird/shared/management/networkmap"
"github.com/netbirdio/netbird/shared/management/proto"
)
// TestEnvelopeToNetworkMap_RoundTrip exercises the full client-side pipeline:
// build a small components struct, encode an envelope, marshal/unmarshal the
// wire bytes, decode back via EnvelopeToNetworkMap, and verify the result is
// non-empty and consistent. Deeper per-field semantic equivalence with the
// legacy server path is covered by the prod-DB equivalence test in
// management/server/store/networkmap_envelope_equivalence_test.go.
func TestEnvelopeToNetworkMap_RoundTrip(t *testing.T) {
c, localPeerKey := buildSmokeComponents(t)
envelope := mgmtgrpc.EncodeNetworkMapEnvelope(mgmtgrpc.ComponentsEnvelopeInput{
Components: c,
DNSDomain: "netbird.cloud",
})
wire, err := goproto.Marshal(envelope)
require.NoError(t, err, "marshal envelope")
var decoded proto.NetworkMapEnvelope
require.NoError(t, goproto.Unmarshal(wire, &decoded), "unmarshal envelope")
result, err := nbnetworkmap.EnvelopeToNetworkMap(context.Background(), &decoded, localPeerKey, "netbird.cloud")
require.NoError(t, err, "EnvelopeToNetworkMap")
require.NotNil(t, result)
require.NotNil(t, result.NetworkMap, "decoded NetworkMap must be non-nil")
require.NotNil(t, result.Components, "Components must be retained for future delta updates")
require.NotNil(t, result.Components.AccountSettings)
require.NotEmpty(t, result.NetworkMap.RemotePeers, "two-peer allow policy should produce one remote peer")
require.NotEmpty(t, result.NetworkMap.FirewallRules, "two-peer allow policy should produce firewall rules")
}
// TestCalculate_FirewallRuleProtocol_NeverNetbirdSSH guards against the
// scenario where a rule with Protocol=NetbirdSSH leaks the enum value into
// proto.FirewallRule.Protocol. Calculate() must rewrite NetbirdSSH → TCP
// before forming firewall rules (see networkmap_components.go:282 and
// account.go:868). Without that rewrite, agents fall into UNKNOWN-protocol
// handling, which on some platforms downgrades to allow-all — a real
// security regression.
func TestCalculate_FirewallRuleProtocol_NeverNetbirdSSH(t *testing.T) {
c, localPeerKey := buildSmokeComponents(t)
// Replace the smoke policy with a NetbirdSSH-protocol allow.
c.Policies = []*types.Policy{{
ID: "pol-ssh", AccountSeqID: 2, Enabled: true,
Rules: []*types.PolicyRule{{
ID: "rule-ssh",
Enabled: true,
Action: types.PolicyTrafficActionAccept,
Protocol: types.PolicyRuleProtocolNetbirdSSH,
Bidirectional: true,
Sources: []string{"group-all"},
Destinations: []string{"group-all"},
}},
}}
envelope := mgmtgrpc.EncodeNetworkMapEnvelope(mgmtgrpc.ComponentsEnvelopeInput{
Components: c,
DNSDomain: "netbird.cloud",
})
wire, err := goproto.Marshal(envelope)
require.NoError(t, err)
var decoded proto.NetworkMapEnvelope
require.NoError(t, goproto.Unmarshal(wire, &decoded))
result, err := nbnetworkmap.EnvelopeToNetworkMap(context.Background(), &decoded, localPeerKey, "netbird.cloud")
require.NoError(t, err)
require.NotEmpty(t, result.NetworkMap.FirewallRules, "ssh policy should produce firewall rules")
for i, fr := range result.NetworkMap.FirewallRules {
require.NotEqualf(t, proto.RuleProtocol_NETBIRD_SSH, fr.Protocol,
"FirewallRules[%d].Protocol must be the rewritten TCP, not NETBIRD_SSH", i)
}
}
func TestEnvelopeToNetworkMap_NilEnvelope(t *testing.T) {
_, err := nbnetworkmap.EnvelopeToNetworkMap(context.Background(), nil, "key", "netbird.cloud")
require.Error(t, err, "nil envelope must produce an error rather than panic")
}
func TestEnvelopeToNetworkMap_FullPayloadMissing(t *testing.T) {
env := &proto.NetworkMapEnvelope{}
_, err := nbnetworkmap.EnvelopeToNetworkMap(context.Background(), env, "key", "netbird.cloud")
require.Error(t, err, "envelope with no Full payload must produce an error")
}
// buildSmokeComponents returns a minimal NetworkMapComponents (2 peers, 1
// group, 1 allow policy) plus the receiving peer's WG public key. Sufficient
// to validate the encode → marshal → decode → Calculate pipeline produces
// non-empty output.
func buildSmokeComponents(t *testing.T) (*types.NetworkMapComponents, string) {
t.Helper()
peerAKey := randomWgKey(t)
peerBKey := randomWgKey(t)
peerA := &nbpeer.Peer{
ID: "peer-A",
Key: peerAKey,
IP: netip.AddrFrom4([4]byte{100, 64, 0, 1}),
DNSLabel: "peerA",
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
}
peerB := &nbpeer.Peer{
ID: "peer-B",
Key: peerBKey,
IP: netip.AddrFrom4([4]byte{100, 64, 0, 2}),
DNSLabel: "peerB",
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
}
group := &types.Group{
ID: "group-all", AccountSeqID: 1, Name: "All",
Peers: []string{"peer-A", "peer-B"},
}
policy := &types.Policy{
ID: "pol-allow", AccountSeqID: 1, Enabled: true,
Rules: []*types.PolicyRule{{
ID: "rule-allow",
Enabled: true,
Action: types.PolicyTrafficActionAccept,
Protocol: types.PolicyRuleProtocolALL,
Bidirectional: true,
Sources: []string{"group-all"},
Destinations: []string{"group-all"},
}},
}
c := &types.NetworkMapComponents{
PeerID: "peer-A",
Network: &types.Network{
Identifier: "net-smoke",
Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(10, 32)},
Serial: 1,
},
AccountSettings: &types.AccountSettingsInfo{},
DNSSettings: &types.DNSSettings{},
Peers: map[string]*nbpeer.Peer{
"peer-A": peerA,
"peer-B": peerB,
},
Groups: map[string]*types.Group{
"group-all": group,
},
Policies: []*types.Policy{policy},
}
return c, peerAKey
}
func randomWgKey(t *testing.T) string {
t.Helper()
var raw [32]byte
_, err := rand.Read(raw[:])
require.NoError(t, err)
return base64.StdEncoding.EncodeToString(raw[:])
}