From 728057ef15f5a34e22ea03028846c10529213aac Mon Sep 17 00:00:00 2001 From: crn4 Date: Tue, 12 May 2026 16:54:56 +0200 Subject: [PATCH] missed files for client side and shared files --- .../grpc/components_envelope_response_test.go | 186 ++++++ shared/management/networkmap/decode.go | 586 ++++++++++++++++++ shared/management/networkmap/encode.go | 323 ++++++++++ shared/management/networkmap/envelope.go | 204 ++++++ shared/management/networkmap/envelope_test.go | 173 ++++++ 5 files changed, 1472 insertions(+) create mode 100644 management/internals/shared/grpc/components_envelope_response_test.go create mode 100644 shared/management/networkmap/decode.go create mode 100644 shared/management/networkmap/encode.go create mode 100644 shared/management/networkmap/envelope.go create mode 100644 shared/management/networkmap/envelope_test.go diff --git a/management/internals/shared/grpc/components_envelope_response_test.go b/management/internals/shared/grpc/components_envelope_response_test.go new file mode 100644 index 000000000..dfb6b5734 --- /dev/null +++ b/management/internals/shared/grpc/components_envelope_response_test.go @@ -0,0 +1,186 @@ +package grpc + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/types" +) + +// TestComputeSSHEnabledForPeer covers both Calculate-mirroring branches: +// explicit NetbirdSSH protocol, and the legacy implicit case where a +// TCP/22 (or 22022 / ALL / port-range-covering-22) rule activates SSH when +// the destination peer has SSHEnabled=true locally. Belt-and-suspenders for +// the B1 fix that the prod-DB equivalence test alone wouldn't have caught +// if no account had this combination. +func TestComputeSSHEnabledForPeer(t *testing.T) { + const targetPeerID = "target" + const targetGroupID = "g_dst" + + mkComponents := func(rule *types.PolicyRule, sshEnabled bool) (*types.NetworkMapComponents, *nbpeer.Peer) { + peer := &nbpeer.Peer{ID: targetPeerID, SSHEnabled: sshEnabled} + group := &types.Group{ID: targetGroupID, Name: "dst", Peers: []string{targetPeerID}} + return &types.NetworkMapComponents{ + Peers: map[string]*nbpeer.Peer{targetPeerID: peer}, + Groups: map[string]*types.Group{targetGroupID: group}, + Policies: []*types.Policy{{ + ID: "p", + Enabled: true, + Rules: []*types.PolicyRule{rule}, + }}, + }, peer + } + + cases := []struct { + name string + peerSSH bool + rule types.PolicyRule + wantEnabled bool + }{ + { + name: "explicit-netbird-ssh-activates-regardless-of-peer-ssh", + peerSSH: false, + rule: types.PolicyRule{ + Enabled: true, Protocol: types.PolicyRuleProtocolNetbirdSSH, + Destinations: []string{targetGroupID}, + }, + wantEnabled: true, + }, + { + name: "implicit-tcp-22-with-peer-ssh", + peerSSH: true, + rule: types.PolicyRule{ + Enabled: true, Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"22"}, + Destinations: []string{targetGroupID}, + }, + wantEnabled: true, + }, + { + name: "implicit-tcp-22-without-peer-ssh-disabled", + peerSSH: false, + rule: types.PolicyRule{ + Enabled: true, Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"22"}, + Destinations: []string{targetGroupID}, + }, + wantEnabled: false, + }, + { + name: "implicit-tcp-22022-with-peer-ssh", + peerSSH: true, + rule: types.PolicyRule{ + Enabled: true, Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"22022"}, + Destinations: []string{targetGroupID}, + }, + wantEnabled: true, + }, + { + name: "implicit-all-protocol-with-peer-ssh", + peerSSH: true, + rule: types.PolicyRule{ + Enabled: true, Protocol: types.PolicyRuleProtocolALL, + Destinations: []string{targetGroupID}, + }, + wantEnabled: true, + }, + { + name: "implicit-port-range-covers-22", + peerSSH: true, + rule: types.PolicyRule{ + Enabled: true, + Protocol: types.PolicyRuleProtocolTCP, + PortRanges: []types.RulePortRange{{Start: 20, End: 30}}, + Destinations: []string{targetGroupID}, + }, + wantEnabled: true, + }, + { + name: "tcp-80-no-ssh", + peerSSH: true, + rule: types.PolicyRule{ + Enabled: true, Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"80"}, + Destinations: []string{targetGroupID}, + }, + wantEnabled: false, + }, + { + name: "disabled-rule-skipped", + peerSSH: true, + rule: types.PolicyRule{ + Enabled: false, Protocol: types.PolicyRuleProtocolNetbirdSSH, + Destinations: []string{targetGroupID}, + }, + wantEnabled: false, + }, + { + name: "peer-not-in-destinations", + peerSSH: true, + rule: types.PolicyRule{ + Enabled: true, Protocol: types.PolicyRuleProtocolNetbirdSSH, + Destinations: []string{"g_other"}, // target not in this group + }, + wantEnabled: false, + }, + { + name: "peer-typed-destination-resource-matches", + peerSSH: false, + rule: types.PolicyRule{ + Enabled: true, + Protocol: types.PolicyRuleProtocolNetbirdSSH, + DestinationResource: types.Resource{ID: targetPeerID, Type: types.ResourceTypePeer}, + }, + wantEnabled: true, + }, + { + name: "non-peer-destination-resource-falls-through-to-groups", + peerSSH: false, + rule: types.PolicyRule{ + Enabled: true, + Protocol: types.PolicyRuleProtocolNetbirdSSH, + DestinationResource: types.Resource{ID: targetPeerID, Type: "host"}, // wrong type + Destinations: []string{targetGroupID}, // saved by group fallback + }, + wantEnabled: true, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + c, peer := mkComponents(&tc.rule, tc.peerSSH) + got := computeSSHEnabledForPeer(c, peer) + assert.Equal(t, tc.wantEnabled, got) + }) + } +} + +// TestComputeSSHEnabledForPeer_TargetMissingFromComponents covers the +// belt-and-suspenders presence guard mirroring Calculate's +// getAllPeersFromGroups invariant. +func TestComputeSSHEnabledForPeer_TargetMissingFromComponents(t *testing.T) { + peer := &nbpeer.Peer{ID: "missing", SSHEnabled: true} + c := &types.NetworkMapComponents{ + Peers: map[string]*nbpeer.Peer{}, // target peer NOT present + Groups: map[string]*types.Group{ + "g": {ID: "g", Peers: []string{"missing"}}, + }, + Policies: []*types.Policy{{ + ID: "p", Enabled: true, + Rules: []*types.PolicyRule{{ + Enabled: true, Protocol: types.PolicyRuleProtocolNetbirdSSH, + Destinations: []string{"g"}, + }}, + }}, + } + assert.False(t, computeSSHEnabledForPeer(c, peer), + "missing target peer must short-circuit to false, not consult policies") +} + +// TestComputeSSHEnabledForPeer_NilInputs guards the cheap nil-checks at +// function entry — Calculate doesn't accept nil either, but the helper is +// exported indirectly via ToComponentSyncResponse and may receive nil +// components on graceful-degrade paths. +func TestComputeSSHEnabledForPeer_NilInputs(t *testing.T) { + assert.False(t, computeSSHEnabledForPeer(nil, &nbpeer.Peer{ID: "x"})) + assert.False(t, computeSSHEnabledForPeer(&types.NetworkMapComponents{}, nil)) +} diff --git a/shared/management/networkmap/decode.go b/shared/management/networkmap/decode.go new file mode 100644 index 000000000..84b104986 --- /dev/null +++ b/shared/management/networkmap/decode.go @@ -0,0 +1,586 @@ +package networkmap + +import ( + "encoding/base64" + "fmt" + "net" + "net/netip" + "strconv" + "time" + + nbdns "github.com/netbirdio/netbird/dns" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/types" + nbroute "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/domain" + "github.com/netbirdio/netbird/shared/management/proto" +) + +// DecodeEnvelope converts a NetworkMapEnvelope into a NetworkMapComponents +// the client can run Calculate() over. Every ID-reference on the wire is a +// uint32 (peer index or account_seq_id) — no xid strings travel. The decoder +// synthesises consistent string IDs from the uint32s so the reconstructed +// components struct round-trips through Calculate exactly the way the +// server-side typed components would. +// +// Synthetic ID scheme (underscore-separated, visually distinct from the xid +// format Calculate would put in log lines under the legacy path): +// +// Peers "p_" // envelope.peers is index-addressed +// Groups "g_" +// Policies "pol_" // 1 rule per policy +// Routes "r_" +// Network resources "nres_" +// Posture checks "pc_" +// Networks "net_" +// Nameserver groups "nsg_" +func DecodeEnvelope(env *proto.NetworkMapEnvelope) (*types.NetworkMapComponents, error) { + if env == nil { + return nil, fmt.Errorf("nil envelope") + } + full := env.GetFull() + if full == nil { + return nil, fmt.Errorf("envelope has no Full payload") + } + + c := &types.NetworkMapComponents{ + PeerID: "", // engine fills its own peer id from PeerConfig + Network: decodeAccountNetwork(full.Network), + AccountSettings: decodeAccountSettings(full.AccountSettings), + CustomZoneDomain: full.CustomZoneDomain, + Peers: make(map[string]*nbpeer.Peer, len(full.Peers)), + Groups: make(map[string]*types.Group, len(full.Groups)), + Policies: make([]*types.Policy, 0, len(full.Policies)), + Routes: make([]*nbroute.Route, 0, len(full.Routes)), + NameServerGroups: make([]*nbdns.NameServerGroup, 0, len(full.NameserverGroups)), + AllDNSRecords: decodeSimpleRecords(full.AllDnsRecords), + AccountZones: decodeCustomZones(full.AccountZones), + ResourcePoliciesMap: make(map[string][]*types.Policy), + RoutersMap: make(map[string]map[string]*routerTypes.NetworkRouter), + NetworkResources: make([]*resourceTypes.NetworkResource, 0, len(full.NetworkResources)), + RouterPeers: make(map[string]*nbpeer.Peer), + AllowedUserIDs: stringSliceToSet(full.AllowedUserIds), + PostureFailedPeers: make(map[string]map[string]struct{}, len(full.PostureFailedPeers)), + GroupIDToUserIDs: make(map[string][]string, len(full.GroupIdToUserIds)), + } + + if full.DnsSettings != nil { + c.DNSSettings = &types.DNSSettings{ + DisabledManagementGroups: groupIDsFromSeqs(full.DnsSettings.DisabledManagementGroupIds), + } + } else { + c.DNSSettings = &types.DNSSettings{} + } + + // Phase 1: peers. The envelope's peers slice is index-addressed; we + // build a peerOrder lookup for downstream references. Peer.ID is + // synthesized from the peer's wire index — wire format ships no xid + // for peers (and never has). + peerIDByIndex := make([]string, len(full.Peers)) + for idx, pc := range full.Peers { + peerID := synthPeerID(uint32(idx)) + peer := decodePeerCompact(pc, peerID, full.AgentVersions) + c.Peers[peerID] = peer + peerIDByIndex[idx] = peerID + } + + // Phase 2: groups. AccountSeqID becomes both the synthesized string ID + // and the GroupCompact.id wire value. + for _, gc := range full.Groups { + groupID := synthGroupID(gc.Id) + peerIDs := make([]string, 0, len(gc.PeerIndexes)) + for _, idx := range gc.PeerIndexes { + if int(idx) < len(peerIDByIndex) { + peerIDs = append(peerIDs, peerIDByIndex[idx]) + } + } + c.Groups[groupID] = &types.Group{ + ID: groupID, + AccountSeqID: gc.Id, + Name: gc.Name, + Peers: peerIDs, + } + } + + // Phase 3: policies (PolicyCompact = one rule per entry; current data + // model is 1 rule per policy). Policy.ID is synthesized from the + // per-account seq id; proto.FirewallRule.PolicyID downstream carries + // the same synth string (no xid on the wire). + for _, pc := range full.Policies { + policyID := synthPolicyID(pc.Id) + c.Policies = append(c.Policies, decodePolicyCompact(pc, policyID, peerIDByIndex)) + } + + // Phase 4: routes. + for _, rr := range full.Routes { + c.Routes = append(c.Routes, decodeRouteRaw(rr, peerIDByIndex)) + } + + // Phase 5: NSGs. + for _, nsg := range full.NameserverGroups { + c.NameServerGroups = append(c.NameServerGroups, decodeNameServerGroupRaw(nsg)) + } + + // Phase 6: network resources. + for _, nr := range full.NetworkResources { + c.NetworkResources = append(c.NetworkResources, decodeNetworkResource(nr)) + } + + // Phase 7: routers_map (outer key = network seq id, inner key = peer-id + // reconstructed from peer_index). Synthesized network id is "net_". + for networkSeq, list := range full.RoutersMap { + networkID := synthNetworkID(networkSeq) + inner := make(map[string]*routerTypes.NetworkRouter, len(list.Entries)) + for _, entry := range list.Entries { + if !entry.PeerIndexSet { + continue + } + if int(entry.PeerIndex) >= len(peerIDByIndex) { + continue + } + peerID := peerIDByIndex[entry.PeerIndex] + inner[peerID] = &routerTypes.NetworkRouter{ + ID: "", + NetworkID: networkID, + AccountSeqID: entry.Id, + Peer: peerID, + PeerGroups: groupIDsFromSeqs(entry.PeerGroupIds), + Masquerade: entry.Masquerade, + Metric: int(entry.Metric), + Enabled: entry.Enabled, + } + } + if len(inner) > 0 { + c.RoutersMap[networkID] = inner + } + } + + // Phase 8: resource_policies_map (resource seq id → list of *types.Policy + // pointers from the decoded policies slice). Resource ID is synthesized + // the same way as in decodeNetworkResource. + for resourceSeq, idxs := range full.ResourcePoliciesMap { + if len(idxs.Indexes) == 0 { + continue + } + resourceID := synthNetworkResourceID(resourceSeq) + policies := make([]*types.Policy, 0, len(idxs.Indexes)) + for _, i := range idxs.Indexes { + if int(i) < len(c.Policies) { + policies = append(policies, c.Policies[i]) + } + } + if len(policies) > 0 { + c.ResourcePoliciesMap[resourceID] = policies + } + } + + // Phase 9: group_id_to_user_ids — wire keys are seq ids, synth to strings. + for groupSeq, list := range full.GroupIdToUserIds { + c.GroupIDToUserIDs[synthGroupID(groupSeq)] = append([]string(nil), list.UserIds...) + } + + // Phase 10: posture_failed_peers — wire keys are posture-check seq ids, + // values are peer indexes that need to be turned into peer ids. PolicyRule + // SourcePostureChecks (also synth ids) reference the same key space. + for checkSeq, set := range full.PostureFailedPeers { + checkID := synthPostureCheckID(checkSeq) + failed := make(map[string]struct{}, len(set.PeerIndexes)) + for _, idx := range set.PeerIndexes { + if int(idx) < len(peerIDByIndex) { + failed[peerIDByIndex[idx]] = struct{}{} + } + } + if len(failed) > 0 { + c.PostureFailedPeers[checkID] = failed + } + } + + // Phase 11: router_peer_indexes — peers that act as routers. They're + // already in c.Peers (router peers are appended to the global peers + // list by the encoder); RouterPeers is the subset. + for _, idx := range full.RouterPeerIndexes { + if int(idx) < len(peerIDByIndex) { + peerID := peerIDByIndex[idx] + c.RouterPeers[peerID] = c.Peers[peerID] + } + } + + return c, nil +} + +func decodeAccountNetwork(an *proto.AccountNetwork) *types.Network { + if an == nil { + return nil + } + n := &types.Network{ + Identifier: an.Identifier, + Dns: an.Dns, + Serial: an.Serial, + } + if an.NetCidr != "" { + if _, ipnet, err := net.ParseCIDR(an.NetCidr); err == nil && ipnet != nil { + n.Net = *ipnet + } + } + if an.NetV6Cidr != "" { + if _, ipnet, err := net.ParseCIDR(an.NetV6Cidr); err == nil && ipnet != nil { + n.NetV6 = *ipnet + } + } + return n +} + +func decodeAccountSettings(as *proto.AccountSettingsCompact) *types.AccountSettingsInfo { + if as == nil { + return &types.AccountSettingsInfo{} + } + return &types.AccountSettingsInfo{ + PeerLoginExpirationEnabled: as.PeerLoginExpirationEnabled, + PeerLoginExpiration: time.Duration(as.PeerLoginExpirationNs), + } +} + +func decodePeerCompact(pc *proto.PeerCompact, peerID string, agentVersions []string) *nbpeer.Peer { + var caps []int32 + if pc.SupportsSourcePrefixes { + caps = append(caps, nbpeer.PeerCapabilitySourcePrefixes) + } + if pc.SupportsIpv6 { + caps = append(caps, nbpeer.PeerCapabilityIPv6Overlay) + } + peer := &nbpeer.Peer{ + ID: peerID, + Key: encodeWgKeyBase64(pc.WgPubKey), + SSHKey: string(pc.SshPubKey), + SSHEnabled: pc.SshEnabled, + DNSLabel: pc.DnsLabel, + LoginExpirationEnabled: pc.LoginExpirationEnabled, + Meta: nbpeer.PeerSystemMeta{ + WtVersion: lookupAgentVersion(agentVersions, pc.AgentVersionIdx), + Capabilities: caps, + Flags: nbpeer.Flags{ + ServerSSHAllowed: pc.ServerSshAllowed, + }, + }, + } + if pc.AddedWithSsoLogin { + // Set a non-empty UserID so (*Peer).AddedWithSSOLogin() returns true. + // The original UserID isn't on the wire; the value is intentionally + // visibly synthetic so any future consumer that mistakes UserID for a + // real account user xid won't silently match (or worse, write the + // sentinel into a downstream record). + peer.UserID = "" + } + if pc.LastLoginUnixNano != 0 { + t := time.Unix(0, pc.LastLoginUnixNano) + peer.LastLogin = &t + } + switch len(pc.Ip) { + case 4: + peer.IP = netip.AddrFrom4([4]byte{pc.Ip[0], pc.Ip[1], pc.Ip[2], pc.Ip[3]}) + case 16: + var a [16]byte + copy(a[:], pc.Ip) + peer.IP = netip.AddrFrom16(a) + } + if len(pc.Ipv6) == 16 { + var a [16]byte + copy(a[:], pc.Ipv6) + peer.IPv6 = netip.AddrFrom16(a) + } + return peer +} + +func decodePolicyCompact(pc *proto.PolicyCompact, policyID string, peerIDByIndex []string) *types.Policy { + rule := &types.PolicyRule{ + ID: policyID, // 1 rule per policy → reuse synthesized id + PolicyID: policyID, + Enabled: true, + Action: actionFromProto(pc.Action), + Protocol: protocolFromProto(pc.Protocol), + Bidirectional: pc.Bidirectional, + Ports: uint32SliceToStrings(pc.Ports), + PortRanges: portRangesFromProto(pc.PortRanges), + Sources: groupIDsFromSeqs(pc.SourceGroupIds), + Destinations: groupIDsFromSeqs(pc.DestinationGroupIds), + AuthorizedUser: pc.AuthorizedUser, + AuthorizedGroups: authorizedGroupsFromProto(pc.AuthorizedGroups), + SourceResource: resourceFromProto(pc.SourceResource, peerIDByIndex), + DestinationResource: resourceFromProto(pc.DestinationResource, peerIDByIndex), + } + return &types.Policy{ + ID: policyID, + AccountSeqID: pc.Id, + Enabled: true, + Rules: []*types.PolicyRule{rule}, + SourcePostureChecks: postureCheckIDsFromSeqs(pc.SourcePostureCheckSeqIds), + } +} + +// resourceFromProto rebuilds types.Resource. For peer-typed resources the +// peer reference is reconstructed from the envelope's peer index — wire +// format ships no xid for peers, so we use the synthesized peer id. +func resourceFromProto(r *proto.ResourceCompact, peerIDByIndex []string) types.Resource { + if r == nil { + return types.Resource{} + } + out := types.Resource{Type: types.ResourceType(r.Type)} + if r.PeerIndexSet && int(r.PeerIndex) < len(peerIDByIndex) { + out.ID = peerIDByIndex[r.PeerIndex] + } + return out +} + +// postureCheckIDsFromSeqs synths posture-check ids from per-account seq ids. +// Mirrors groupIDsFromSeqs. +func postureCheckIDsFromSeqs(seqs []uint32) []string { + if len(seqs) == 0 { + return nil + } + out := make([]string, len(seqs)) + for i, s := range seqs { + out[i] = synthPostureCheckID(s) + } + return out +} + +// authorizedGroupsFromProto inverts encodeAuthorizedGroups: the wire form +// keys by group account_seq_id, the typed PolicyRule field keys by group +// xid string. We rebuild using the same synthetic scheme the rest of the +// decoder uses ("g"). +func authorizedGroupsFromProto(m map[uint32]*proto.UserNameList) map[string][]string { + if len(m) == 0 { + return nil + } + out := make(map[string][]string, len(m)) + for seq, list := range m { + if list == nil { + continue + } + out[synthGroupID(seq)] = append([]string(nil), list.Names...) + } + return out +} + +func decodeRouteRaw(rr *proto.RouteRaw, peerIDByIndex []string) *nbroute.Route { + r := &nbroute.Route{ + ID: nbroute.ID(synthRouteID(rr.Id)), + AccountSeqID: rr.Id, + NetID: nbroute.NetID(rr.NetId), + Description: rr.Description, + Domains: domainsFromPunycode(rr.Domains), + KeepRoute: rr.KeepRoute, + NetworkType: nbroute.NetworkType(rr.NetworkType), + Masquerade: rr.Masquerade, + Metric: int(rr.Metric), + Enabled: rr.Enabled, + Groups: groupIDsFromSeqs(rr.GroupIds), + AccessControlGroups: groupIDsFromSeqs(rr.AccessControlGroupIds), + PeerGroups: groupIDsFromSeqs(rr.PeerGroupIds), + SkipAutoApply: rr.SkipAutoApply, + } + if rr.NetworkCidr != "" { + if p, err := netip.ParsePrefix(rr.NetworkCidr); err == nil { + r.Network = p + } + } + if rr.PeerIndexSet && int(rr.PeerIndex) < len(peerIDByIndex) { + r.Peer = peerIDByIndex[rr.PeerIndex] + } + return r +} + +func decodeNameServerGroupRaw(nsg *proto.NameServerGroupRaw) *nbdns.NameServerGroup { + out := &nbdns.NameServerGroup{ + ID: synthNameServerGroupID(nsg.Id), + AccountSeqID: nsg.Id, + Name: nsg.Name, + Description: nsg.Description, + Groups: groupIDsFromSeqs(nsg.GroupIds), + Primary: nsg.Primary, + Domains: nsg.Domains, + Enabled: nsg.Enabled, + SearchDomainsEnabled: nsg.SearchDomainsEnabled, + NameServers: make([]nbdns.NameServer, 0, len(nsg.Nameservers)), + } + for _, ns := range nsg.Nameservers { + if addr, err := netip.ParseAddr(ns.IP); err == nil { + out.NameServers = append(out.NameServers, nbdns.NameServer{ + IP: addr, + NSType: nbdns.NameServerType(ns.NSType), + Port: int(ns.Port), + }) + } + } + return out +} + +func decodeNetworkResource(nr *proto.NetworkResourceRaw) *resourceTypes.NetworkResource { + out := &resourceTypes.NetworkResource{ + ID: synthNetworkResourceID(nr.Id), + AccountSeqID: nr.Id, + NetworkID: synthNetworkID(nr.NetworkSeq), + Name: nr.Name, + Description: nr.Description, + Type: resourceTypes.NetworkResourceType(nr.Type), + Address: nr.Address, + Domain: nr.DomainValue, + Enabled: nr.Enabled, + } + if nr.PrefixCidr != "" { + if p, err := netip.ParsePrefix(nr.PrefixCidr); err == nil { + out.Prefix = p + } + } + return out +} + +func decodeSimpleRecords(records []*proto.SimpleRecord) []nbdns.SimpleRecord { + out := make([]nbdns.SimpleRecord, 0, len(records)) + for _, r := range records { + out = append(out, nbdns.SimpleRecord{ + Name: r.Name, + Type: int(r.Type), + Class: r.Class, + TTL: int(r.TTL), + RData: r.RData, + }) + } + return out +} + +func decodeCustomZones(zones []*proto.CustomZone) []nbdns.CustomZone { + out := make([]nbdns.CustomZone, 0, len(zones)) + for _, z := range zones { + out = append(out, nbdns.CustomZone{ + Domain: z.Domain, + Records: decodeSimpleRecords(z.Records), + SearchDomainDisabled: z.SearchDomainDisabled, + NonAuthoritative: z.NonAuthoritative, + }) + } + return out +} + +// Synthetic ID generators — deterministic given the same wire input. +// Underscore-separated ("p_", "pol_", ...) so they're visually +// distinct in operator logs. fmt.Sprintf would dominate the decode hot path +// on large accounts (a 10k-peer envelope produces ~50k synth calls); the +// strconv.AppendUint builder keeps it allocation-light. +func synthID(prefix string, n uint32) string { + buf := make([]byte, 0, len(prefix)+10) + buf = append(buf, prefix...) + buf = strconv.AppendUint(buf, uint64(n), 10) + return string(buf) +} + +func synthPeerID(idx uint32) string { return synthID("p_", idx) } +func synthGroupID(seq uint32) string { return synthID("g_", seq) } +func synthPolicyID(seq uint32) string { return synthID("pol_", seq) } +func synthRouteID(seq uint32) string { return synthID("r_", seq) } +func synthNetworkResourceID(seq uint32) string { return synthID("nres_", seq) } +func synthPostureCheckID(seq uint32) string { return synthID("pc_", seq) } +func synthNetworkID(seq uint32) string { return synthID("net_", seq) } +func synthNameServerGroupID(seq uint32) string { return synthID("nsg_", seq) } + +func groupIDsFromSeqs(seqs []uint32) []string { + if len(seqs) == 0 { + return nil + } + out := make([]string, len(seqs)) + for i, s := range seqs { + out[i] = synthGroupID(s) + } + return out +} + +func uint32SliceToStrings(ports []uint32) []string { + if len(ports) == 0 { + return nil + } + out := make([]string, len(ports)) + for i, p := range ports { + out[i] = strconv.FormatUint(uint64(p), 10) + } + return out +} + +func portRangesFromProto(ranges []*proto.PortInfo_Range) []types.RulePortRange { + if len(ranges) == 0 { + return nil + } + out := make([]types.RulePortRange, 0, len(ranges)) + for _, r := range ranges { + out = append(out, types.RulePortRange{ + Start: uint16(r.Start), + End: uint16(r.End), + }) + } + return out +} + +func actionFromProto(a proto.RuleAction) types.PolicyTrafficActionType { + if a == proto.RuleAction_DROP { + return types.PolicyTrafficActionDrop + } + return types.PolicyTrafficActionAccept +} + +func protocolFromProto(p proto.RuleProtocol) types.PolicyRuleProtocolType { + switch p { + case proto.RuleProtocol_TCP: + return types.PolicyRuleProtocolTCP + case proto.RuleProtocol_UDP: + return types.PolicyRuleProtocolUDP + case proto.RuleProtocol_ICMP: + return types.PolicyRuleProtocolICMP + case proto.RuleProtocol_ALL: + return types.PolicyRuleProtocolALL + case proto.RuleProtocol_NETBIRD_SSH: + return types.PolicyRuleProtocolNetbirdSSH + default: + return types.PolicyRuleProtocolALL + } +} + +func encodeWgKeyBase64(raw []byte) string { + if len(raw) != 32 { + return "" + } + return base64.StdEncoding.EncodeToString(raw) +} + +func lookupAgentVersion(table []string, idx uint32) string { + if int(idx) < len(table) { + return table[idx] + } + return "" +} + +func stringSliceToSet(s []string) map[string]struct{} { + if len(s) == 0 { + return nil + } + out := make(map[string]struct{}, len(s)) + for _, v := range s { + out[v] = struct{}{} + } + return out +} + +// domainsFromPunycode is a thin wrapper that converts a punycode list back to +// the domain.List type the route.Route struct expects. It accepts the +// punycode strings as-is (no extra decoding) — symmetric with +// route.Domains.ToPunycodeList() used in the encoder. +func domainsFromPunycode(punycoded []string) domain.List { + if len(punycoded) == 0 { + return nil + } + out := make(domain.List, 0, len(punycoded)) + for _, d := range punycoded { + out = append(out, domain.Domain(d)) + } + return out +} diff --git a/shared/management/networkmap/encode.go b/shared/management/networkmap/encode.go new file mode 100644 index 000000000..ebaede64a --- /dev/null +++ b/shared/management/networkmap/encode.go @@ -0,0 +1,323 @@ +// Package networkmap contains the shared NetworkMap helpers that both the +// management server and the client agent need. +// +// The proto-conversion helpers (types.NetworkMap → proto.NetworkMap) live +// here so the client can run the same conversion locally after deriving its +// NetworkMap from a NetworkMapEnvelope, without taking a dependency on the +// server-side conversion package (which pulls in cloud integrations and is +// otherwise an unwanted internal import on the client). +// +// The helpers are pure functions over inputs — no caches, no IO, no logging +// beyond a context-aware error log when an individual user-id hash fails. +package networkmap + +import ( + "context" + + log "github.com/sirupsen/logrus" + goproto "google.golang.org/protobuf/proto" + + nbdns "github.com/netbirdio/netbird/dns" + "net/netip" + + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/types" + nbroute "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/proto" + "github.com/netbirdio/netbird/shared/netiputil" + "github.com/netbirdio/netbird/shared/sshauth" +) + +// ToProtocolRoutes converts a slice of typed routes to their proto form. +func ToProtocolRoutes(routes []*nbroute.Route) []*proto.Route { + protoRoutes := make([]*proto.Route, 0, len(routes)) + for _, r := range routes { + protoRoutes = append(protoRoutes, ToProtocolRoute(r)) + } + return protoRoutes +} + +// ToProtocolRoute converts one typed route to its proto form. +func ToProtocolRoute(route *nbroute.Route) *proto.Route { + return &proto.Route{ + ID: string(route.ID), + NetID: string(route.NetID), + Network: route.Network.String(), + Domains: route.Domains.ToPunycodeList(), + NetworkType: int64(route.NetworkType), + Peer: route.Peer, + Metric: int64(route.Metric), + Masquerade: route.Masquerade, + KeepRoute: route.KeepRoute, + SkipAutoApply: route.SkipAutoApply, + } +} + +// ToProtocolFirewallRules converts the firewall rules to the protocol form. +// When useSourcePrefixes is true, the compact SourcePrefixes field is +// populated alongside the deprecated PeerIP for forward compatibility. +// Wildcard rules ("0.0.0.0") are expanded into separate v4/v6 SourcePrefixes +// when includeIPv6 is true. +func ToProtocolFirewallRules(rules []*types.FirewallRule, includeIPv6, useSourcePrefixes bool) []*proto.FirewallRule { + result := make([]*proto.FirewallRule, 0, len(rules)) + for i := range rules { + rule := rules[i] + + fwRule := &proto.FirewallRule{ + PolicyID: []byte(rule.PolicyID), + PeerIP: rule.PeerIP, //nolint:staticcheck // populated for backward compatibility + Direction: GetProtoDirection(rule.Direction), + Action: GetProtoAction(rule.Action), + Protocol: GetProtoProtocol(rule.Protocol), + Port: rule.Port, + } + + if useSourcePrefixes && rule.PeerIP != "" { + result = append(result, populateSourcePrefixes(fwRule, rule, includeIPv6)...) + } + + if ShouldUsePortRange(fwRule) { + fwRule.PortInfo = rule.PortRange.ToProto() + } + + result = append(result, fwRule) + } + return result +} + +// populateSourcePrefixes sets SourcePrefixes on fwRule and returns any +// additional rules needed (e.g. a v6 wildcard clone when the peer IP is +// unspecified). +func populateSourcePrefixes(fwRule *proto.FirewallRule, rule *types.FirewallRule, includeIPv6 bool) []*proto.FirewallRule { + addr, err := netip.ParseAddr(rule.PeerIP) + if err != nil { + return nil + } + + if !addr.IsUnspecified() { + fwRule.SourcePrefixes = [][]byte{netiputil.EncodeAddr(addr.Unmap())} + return nil + } + + v4Wildcard, _ := netiputil.EncodePrefix(netip.PrefixFrom(netip.IPv4Unspecified(), 0)) + fwRule.SourcePrefixes = [][]byte{v4Wildcard} + + if !includeIPv6 { + return nil + } + + v6Rule := goproto.Clone(fwRule).(*proto.FirewallRule) + v6Rule.PeerIP = "::" //nolint:staticcheck // populated for backward compatibility + v6Wildcard, _ := netiputil.EncodePrefix(netip.PrefixFrom(netip.IPv6Unspecified(), 0)) + v6Rule.SourcePrefixes = [][]byte{v6Wildcard} + if ShouldUsePortRange(v6Rule) { + v6Rule.PortInfo = rule.PortRange.ToProto() + } + return []*proto.FirewallRule{v6Rule} +} + +// GetProtoDirection converts the direction to proto.RuleDirection. +func GetProtoDirection(direction int) proto.RuleDirection { + if direction == types.FirewallRuleDirectionOUT { + return proto.RuleDirection_OUT + } + return proto.RuleDirection_IN +} + +// GetProtoAction converts the action to proto.RuleAction. +func GetProtoAction(action string) proto.RuleAction { + if action == string(types.PolicyTrafficActionDrop) { + return proto.RuleAction_DROP + } + return proto.RuleAction_ACCEPT +} + +// GetProtoProtocol converts the protocol to proto.RuleProtocol. +func GetProtoProtocol(protocol string) proto.RuleProtocol { + switch types.PolicyRuleProtocolType(protocol) { + case types.PolicyRuleProtocolALL: + return proto.RuleProtocol_ALL + case types.PolicyRuleProtocolTCP: + return proto.RuleProtocol_TCP + case types.PolicyRuleProtocolUDP: + return proto.RuleProtocol_UDP + case types.PolicyRuleProtocolICMP: + return proto.RuleProtocol_ICMP + case types.PolicyRuleProtocolNetbirdSSH: + return proto.RuleProtocol_NETBIRD_SSH + default: + return proto.RuleProtocol_UNKNOWN + } +} + +// GetProtoPortInfo converts route-firewall-rule port info to proto.PortInfo. +func GetProtoPortInfo(rule *types.RouteFirewallRule) *proto.PortInfo { + var portInfo proto.PortInfo + if rule.Port != 0 { + portInfo.PortSelection = &proto.PortInfo_Port{Port: uint32(rule.Port)} + } else if portRange := rule.PortRange; portRange.Start != 0 && portRange.End != 0 { + portInfo.PortSelection = &proto.PortInfo_Range_{ + Range: &proto.PortInfo_Range{ + Start: uint32(portRange.Start), + End: uint32(portRange.End), + }, + } + } + return &portInfo +} + +// ShouldUsePortRange reports whether the firewall rule should use a port +// range rather than a single port (TCP/UDP without a single port). +func ShouldUsePortRange(rule *proto.FirewallRule) bool { + return rule.Port == "" && (rule.Protocol == proto.RuleProtocol_UDP || rule.Protocol == proto.RuleProtocol_TCP) +} + +// ToProtocolRoutesFirewallRules converts a slice of typed route-firewall +// rules to proto. +func ToProtocolRoutesFirewallRules(rules []*types.RouteFirewallRule) []*proto.RouteFirewallRule { + result := make([]*proto.RouteFirewallRule, len(rules)) + for i := range rules { + rule := rules[i] + result[i] = &proto.RouteFirewallRule{ + SourceRanges: rule.SourceRanges, + Action: GetProtoAction(rule.Action), + Destination: rule.Destination, + Protocol: GetProtoProtocol(rule.Protocol), + PortInfo: GetProtoPortInfo(rule), + IsDynamic: rule.IsDynamic, + Domains: rule.Domains.ToPunycodeList(), + PolicyID: []byte(rule.PolicyID), + RouteID: string(rule.RouteID), + } + } + return result +} + +// ConvertToProtoCustomZone converts an nbdns.CustomZone to its proto form. +func ConvertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone { + protoZone := &proto.CustomZone{ + Domain: zone.Domain, + Records: make([]*proto.SimpleRecord, 0, len(zone.Records)), + SearchDomainDisabled: zone.SearchDomainDisabled, + NonAuthoritative: zone.NonAuthoritative, + } + for _, record := range zone.Records { + protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{ + Name: record.Name, + Type: int64(record.Type), + Class: record.Class, + TTL: int64(record.TTL), + RData: record.RData, + }) + } + return protoZone +} + +// ConvertToProtoNameServerGroup converts a NameServerGroup to its proto form. +func ConvertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameServerGroup { + protoGroup := &proto.NameServerGroup{ + Primary: nsGroup.Primary, + Domains: nsGroup.Domains, + SearchDomainsEnabled: nsGroup.SearchDomainsEnabled, + NameServers: make([]*proto.NameServer, 0, len(nsGroup.NameServers)), + } + for _, ns := range nsGroup.NameServers { + protoGroup.NameServers = append(protoGroup.NameServers, &proto.NameServer{ + IP: ns.IP.String(), + Port: int64(ns.Port), + NSType: int64(ns.NSType), + }) + } + return protoGroup +} + +// DNSConfigCache is the cache contract for amortising NameServerGroup +// proto-conversion across peers in the same account. Server uses a concrete +// implementation; client passes nil (no cross-peer caching needed when +// rebuilding a single NetworkMap from an envelope). +type DNSConfigCache interface { + GetNameServerGroup(key string) (*proto.NameServerGroup, bool) + SetNameServerGroup(key string, value *proto.NameServerGroup) +} + +// ToProtocolDNSConfig converts nbdns.Config to proto.DNSConfig. If cache is +// non-nil, NameServerGroup proto values are cached by NSG.ID across calls — +// the server amortises this across peers, the client passes nil. +func ToProtocolDNSConfig(update nbdns.Config, cache DNSConfigCache, forwardPort int64) *proto.DNSConfig { + protoUpdate := &proto.DNSConfig{ + ServiceEnable: update.ServiceEnable, + CustomZones: make([]*proto.CustomZone, 0, len(update.CustomZones)), + NameServerGroups: make([]*proto.NameServerGroup, 0, len(update.NameServerGroups)), + ForwarderPort: forwardPort, + } + + for _, zone := range update.CustomZones { + protoUpdate.CustomZones = append(protoUpdate.CustomZones, ConvertToProtoCustomZone(zone)) + } + + for _, nsGroup := range update.NameServerGroups { + if cache != nil { + if cachedGroup, exists := cache.GetNameServerGroup(nsGroup.ID); exists { + protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, cachedGroup) + continue + } + } + protoGroup := ConvertToProtoNameServerGroup(nsGroup) + if cache != nil { + cache.SetNameServerGroup(nsGroup.ID, protoGroup) + } + protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup) + } + + return protoUpdate +} + +// AppendRemotePeerConfig appends typed peers as proto.RemotePeerConfig +// entries to dst and returns the result. +func AppendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string, includeIPv6 bool) []*proto.RemotePeerConfig { + for _, rPeer := range peers { + allowedIPs := []string{rPeer.IP.String() + "/32"} + if includeIPv6 && rPeer.IPv6.IsValid() { + allowedIPs = append(allowedIPs, rPeer.IPv6.String()+"/128") + } + dst = append(dst, &proto.RemotePeerConfig{ + WgPubKey: rPeer.Key, + AllowedIps: allowedIPs, + SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)}, + Fqdn: rPeer.FQDN(dnsName), + AgentVersion: rPeer.Meta.WtVersion, + }) + } + return dst +} + +// BuildAuthorizedUsersProto deduplicates user-IDs into a hashed list and +// builds per-machine-user index maps. Returns (hashedUsers, machineUsers). +// Errors from individual hash failures are logged via the provided context; +// they leave the offending user out of the result but don't abort the build. +func BuildAuthorizedUsersProto(ctx context.Context, authorizedUsers map[string]map[string]struct{}) ([][]byte, map[string]*proto.MachineUserIndexes) { + userIDToIndex := make(map[string]uint32) + var hashedUsers [][]byte + machineUsers := make(map[string]*proto.MachineUserIndexes, len(authorizedUsers)) + + for machineUser, users := range authorizedUsers { + indexes := make([]uint32, 0, len(users)) + for userID := range users { + idx, exists := userIDToIndex[userID] + if !exists { + hash, err := sshauth.HashUserID(userID) + if err != nil { + log.WithContext(ctx).Errorf("failed to hash user id %s: %v", userID, err) + continue + } + idx = uint32(len(hashedUsers)) + userIDToIndex[userID] = idx + hashedUsers = append(hashedUsers, hash[:]) + } + indexes = append(indexes, idx) + } + machineUsers[machineUser] = &proto.MachineUserIndexes{Indexes: indexes} + } + + return hashedUsers, machineUsers +} diff --git a/shared/management/networkmap/envelope.go b/shared/management/networkmap/envelope.go new file mode 100644 index 000000000..295c26b7a --- /dev/null +++ b/shared/management/networkmap/envelope.go @@ -0,0 +1,204 @@ +package networkmap + +import ( + "bytes" + "context" + "encoding/base64" + "fmt" + + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/proto" +) + +// EnvelopeResult is what the client engine consumes after receiving a +// component-format NetworkMap. Both fields are populated: +// +// - NetworkMap is the *proto.NetworkMap shape the engine reads today via +// update.GetNetworkMap() — built from the envelope's components by +// running Calculate() locally + converting back through the shared +// proto helpers + merging the optional ProxyPatch. +// - Components is the *types.NetworkMapComponents the engine retains so +// future incremental delta updates (Step 3) have a base to apply +// changes against. The client keeps it under its sync lock. +type EnvelopeResult struct { + NetworkMap *proto.NetworkMap + Components *types.NetworkMapComponents +} + +// EnvelopeToNetworkMap is the full client-side pipeline: decode the +// component envelope back to a typed NetworkMapComponents, run Calculate() +// locally to produce the typed NetworkMap, convert it to the wire form the +// engine consumes, and fold in any ProxyPatch the server attached. +// +// localPeerKey is the receiving peer's WG pub key (used to derive +// includeIPv6 / useSourcePrefixes from the receiving peer's own record in +// the components struct, mirroring legacy ToSyncResponse behaviour). +// +// dnsName is the account's DNS domain ("netbird.cloud" etc.); used when +// rebuilding the per-peer FQDNs that proto.RemotePeerConfig carries. +func EnvelopeToNetworkMap(ctx context.Context, env *proto.NetworkMapEnvelope, localPeerKey, dnsName string) (*EnvelopeResult, error) { + components, err := DecodeEnvelope(env) + if err != nil { + return nil, fmt.Errorf("decode envelope: %w", err) + } + + // Find the receiving peer in the decoded components by WG key so we can + // derive its capabilities and set components.PeerID for Calculate(). The + // envelope.peers list is index-addressed; we synthesized IDs as "p". + localPeerID, localPeer := findPeerByWgKey(components, localPeerKey) + if localPeer == nil { + return nil, fmt.Errorf("receiving peer (wg_key prefix %q) not found among %d decoded peers — components have no PeerID, Calculate would return empty", trimKey(localPeerKey), len(components.Peers)) + } + components.PeerID = localPeerID + + includeIPv6 := localPeer != nil && localPeer.SupportsIPv6() && localPeer.IPv6.IsValid() + useSourcePrefixes := localPeer != nil && localPeer.SupportsSourcePrefixes() + + typedNM := components.Calculate(ctx) + + full := env.GetFull() + dnsFwdPort := int64(0) + if full != nil { + dnsFwdPort = full.DnsForwarderPort + } + + protoNM := &proto.NetworkMap{ + Serial: typedNM.Network.CurrentSerial(), + } + if full != nil { + protoNM.PeerConfig = full.PeerConfig + } + protoNM.Routes = ToProtocolRoutes(typedNM.Routes) + protoNM.DNSConfig = ToProtocolDNSConfig(typedNM.DNSConfig, nil, dnsFwdPort) + + remotePeers := AppendRemotePeerConfig(nil, typedNM.Peers, dnsName, includeIPv6) + protoNM.RemotePeers = remotePeers + protoNM.RemotePeersIsEmpty = len(remotePeers) == 0 + + protoNM.OfflinePeers = AppendRemotePeerConfig(nil, typedNM.OfflinePeers, dnsName, includeIPv6) + + firewallRules := ToProtocolFirewallRules(typedNM.FirewallRules, includeIPv6, useSourcePrefixes) + protoNM.FirewallRules = firewallRules + protoNM.FirewallRulesIsEmpty = len(firewallRules) == 0 + + routesFirewallRules := ToProtocolRoutesFirewallRules(typedNM.RoutesFirewallRules) + protoNM.RoutesFirewallRules = routesFirewallRules + protoNM.RoutesFirewallRulesIsEmpty = len(routesFirewallRules) == 0 + + if typedNM.AuthorizedUsers != nil { + hashedUsers, machineUsers := BuildAuthorizedUsersProto(ctx, typedNM.AuthorizedUsers) + userIDClaim := "" + if full != nil { + userIDClaim = full.UserIdClaim + } + protoNM.SshAuth = &proto.SSHAuth{ + AuthorizedUsers: hashedUsers, + MachineUsers: machineUsers, + UserIDClaim: userIDClaim, + } + } + + if typedNM.ForwardingRules != nil { + forwardingRules := make([]*proto.ForwardingRule, 0, len(typedNM.ForwardingRules)) + for _, rule := range typedNM.ForwardingRules { + forwardingRules = append(forwardingRules, rule.ToProto()) + } + protoNM.ForwardingRules = forwardingRules + } + + // Merge the proxy patch the server attached. Mirrors the legacy + // NetworkMap.Merge step that the server runs after Calculate(). + if full != nil && full.ProxyPatch != nil { + mergeProxyPatch(protoNM, full.ProxyPatch) + } + + return &EnvelopeResult{ + NetworkMap: protoNM, + Components: components, + }, nil +} + +// mergeProxyPatch folds a ProxyPatch's pre-expanded fragments into the +// proto.NetworkMap that Calculate() produced. Mirrors types.NetworkMap.Merge +// — same six collections, deduplicated where the legacy merge dedupes. +func mergeProxyPatch(nm *proto.NetworkMap, patch *proto.ProxyPatch) { + nm.RemotePeers = appendUniquePeers(nm.RemotePeers, patch.Peers) + nm.OfflinePeers = appendUniquePeers(nm.OfflinePeers, patch.OfflinePeers) + nm.FirewallRules = append(nm.FirewallRules, patch.FirewallRules...) + nm.Routes = append(nm.Routes, patch.Routes...) + nm.RoutesFirewallRules = append(nm.RoutesFirewallRules, patch.RouteFirewallRules...) + nm.ForwardingRules = append(nm.ForwardingRules, patch.ForwardingRules...) + if len(nm.RemotePeers) > 0 { + nm.RemotePeersIsEmpty = false + } + if len(nm.FirewallRules) > 0 { + nm.FirewallRulesIsEmpty = false + } + if len(nm.RoutesFirewallRules) > 0 { + nm.RoutesFirewallRulesIsEmpty = false + } +} + +// appendUniquePeers dedupes by WgPubKey — mirrors legacy +// mergeUniquePeersByID's intent (legacy keyed off Peer.ID; in proto form the +// closest stable identifier is WgPubKey). +func appendUniquePeers(dst, extra []*proto.RemotePeerConfig) []*proto.RemotePeerConfig { + if len(extra) == 0 { + return dst + } + seen := make(map[string]struct{}, len(dst)) + for _, p := range dst { + seen[p.WgPubKey] = struct{}{} + } + for _, p := range extra { + if _, ok := seen[p.WgPubKey]; ok { + continue + } + seen[p.WgPubKey] = struct{}{} + dst = append(dst, p) + } + return dst +} + +func trimKey(s string) string { + if len(s) > 12 { + return s[:12] + } + return s +} + +// findPeerByWgKey locates the receiving peer in the decoded components by +// matching its WireGuard public key. Compares raw 32-byte decode output — +// not the base64 string — because production data has occasional non-canonical +// padding bits that round-trip through the envelope's `bytes wg_pub_key` +// field, canonicalising the encoding (semantically equivalent key, different +// string). Decodes `wgKey` once up front and reuses a stack buffer in the +// loop so an N-peer search is ~zero-alloc. +func findPeerByWgKey(c *types.NetworkMapComponents, wgKey string) (string, *nbpeer.Peer) { + const wgKeyRawLen = 32 + var ( + targetRaw [wgKeyRawLen]byte + haveRaw bool + ) + if n, err := base64.StdEncoding.Decode(targetRaw[:], []byte(wgKey)); err == nil && n == wgKeyRawLen { + haveRaw = true + } + var peerRaw [wgKeyRawLen]byte + for id, p := range c.Peers { + if p == nil { + continue + } + if p.Key == wgKey { + return id, p + } + if !haveRaw { + continue + } + n, err := base64.StdEncoding.Decode(peerRaw[:], []byte(p.Key)) + if err == nil && n == wgKeyRawLen && bytes.Equal(peerRaw[:], targetRaw[:]) { + return id, p + } + } + return "", nil +} diff --git a/shared/management/networkmap/envelope_test.go b/shared/management/networkmap/envelope_test.go new file mode 100644 index 000000000..70e705737 --- /dev/null +++ b/shared/management/networkmap/envelope_test.go @@ -0,0 +1,173 @@ +package networkmap_test + +import ( + "context" + "crypto/rand" + "encoding/base64" + "net" + "net/netip" + "testing" + + "github.com/stretchr/testify/require" + goproto "google.golang.org/protobuf/proto" + + mgmtgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/types" + nbnetworkmap "github.com/netbirdio/netbird/shared/management/networkmap" + "github.com/netbirdio/netbird/shared/management/proto" +) + +// TestEnvelopeToNetworkMap_RoundTrip exercises the full client-side pipeline: +// build a small components struct, encode an envelope, marshal/unmarshal the +// wire bytes, decode back via EnvelopeToNetworkMap, and verify the result is +// non-empty and consistent. Deeper per-field semantic equivalence with the +// legacy server path is covered by the prod-DB equivalence test in +// management/server/store/networkmap_envelope_equivalence_test.go. +func TestEnvelopeToNetworkMap_RoundTrip(t *testing.T) { + c, localPeerKey := buildSmokeComponents(t) + + envelope := mgmtgrpc.EncodeNetworkMapEnvelope(mgmtgrpc.ComponentsEnvelopeInput{ + Components: c, + DNSDomain: "netbird.cloud", + }) + + wire, err := goproto.Marshal(envelope) + require.NoError(t, err, "marshal envelope") + + var decoded proto.NetworkMapEnvelope + require.NoError(t, goproto.Unmarshal(wire, &decoded), "unmarshal envelope") + + result, err := nbnetworkmap.EnvelopeToNetworkMap(context.Background(), &decoded, localPeerKey, "netbird.cloud") + require.NoError(t, err, "EnvelopeToNetworkMap") + require.NotNil(t, result) + require.NotNil(t, result.NetworkMap, "decoded NetworkMap must be non-nil") + require.NotNil(t, result.Components, "Components must be retained for future delta updates") + require.NotNil(t, result.Components.AccountSettings) + require.NotEmpty(t, result.NetworkMap.RemotePeers, "two-peer allow policy should produce one remote peer") + require.NotEmpty(t, result.NetworkMap.FirewallRules, "two-peer allow policy should produce firewall rules") +} + +// TestCalculate_FirewallRuleProtocol_NeverNetbirdSSH guards against the +// scenario where a rule with Protocol=NetbirdSSH leaks the enum value into +// proto.FirewallRule.Protocol. Calculate() must rewrite NetbirdSSH → TCP +// before forming firewall rules (see networkmap_components.go:282 and +// account.go:868). Without that rewrite, agents fall into UNKNOWN-protocol +// handling, which on some platforms downgrades to allow-all — a real +// security regression. +func TestCalculate_FirewallRuleProtocol_NeverNetbirdSSH(t *testing.T) { + c, localPeerKey := buildSmokeComponents(t) + // Replace the smoke policy with a NetbirdSSH-protocol allow. + c.Policies = []*types.Policy{{ + ID: "pol-ssh", AccountSeqID: 2, Enabled: true, + Rules: []*types.PolicyRule{{ + ID: "rule-ssh", + Enabled: true, + Action: types.PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolNetbirdSSH, + Bidirectional: true, + Sources: []string{"group-all"}, + Destinations: []string{"group-all"}, + }}, + }} + + envelope := mgmtgrpc.EncodeNetworkMapEnvelope(mgmtgrpc.ComponentsEnvelopeInput{ + Components: c, + DNSDomain: "netbird.cloud", + }) + wire, err := goproto.Marshal(envelope) + require.NoError(t, err) + var decoded proto.NetworkMapEnvelope + require.NoError(t, goproto.Unmarshal(wire, &decoded)) + + result, err := nbnetworkmap.EnvelopeToNetworkMap(context.Background(), &decoded, localPeerKey, "netbird.cloud") + require.NoError(t, err) + require.NotEmpty(t, result.NetworkMap.FirewallRules, "ssh policy should produce firewall rules") + for i, fr := range result.NetworkMap.FirewallRules { + require.NotEqualf(t, proto.RuleProtocol_NETBIRD_SSH, fr.Protocol, + "FirewallRules[%d].Protocol must be the rewritten TCP, not NETBIRD_SSH", i) + } +} + +func TestEnvelopeToNetworkMap_NilEnvelope(t *testing.T) { + _, err := nbnetworkmap.EnvelopeToNetworkMap(context.Background(), nil, "key", "netbird.cloud") + require.Error(t, err, "nil envelope must produce an error rather than panic") +} + +func TestEnvelopeToNetworkMap_FullPayloadMissing(t *testing.T) { + env := &proto.NetworkMapEnvelope{} + _, err := nbnetworkmap.EnvelopeToNetworkMap(context.Background(), env, "key", "netbird.cloud") + require.Error(t, err, "envelope with no Full payload must produce an error") +} + +// buildSmokeComponents returns a minimal NetworkMapComponents (2 peers, 1 +// group, 1 allow policy) plus the receiving peer's WG public key. Sufficient +// to validate the encode → marshal → decode → Calculate pipeline produces +// non-empty output. +func buildSmokeComponents(t *testing.T) (*types.NetworkMapComponents, string) { + t.Helper() + + peerAKey := randomWgKey(t) + peerBKey := randomWgKey(t) + + peerA := &nbpeer.Peer{ + ID: "peer-A", + Key: peerAKey, + IP: netip.AddrFrom4([4]byte{100, 64, 0, 1}), + DNSLabel: "peerA", + Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"}, + } + peerB := &nbpeer.Peer{ + ID: "peer-B", + Key: peerBKey, + IP: netip.AddrFrom4([4]byte{100, 64, 0, 2}), + DNSLabel: "peerB", + Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"}, + } + + group := &types.Group{ + ID: "group-all", AccountSeqID: 1, Name: "All", + Peers: []string{"peer-A", "peer-B"}, + } + + policy := &types.Policy{ + ID: "pol-allow", AccountSeqID: 1, Enabled: true, + Rules: []*types.PolicyRule{{ + ID: "rule-allow", + Enabled: true, + Action: types.PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolALL, + Bidirectional: true, + Sources: []string{"group-all"}, + Destinations: []string{"group-all"}, + }}, + } + + c := &types.NetworkMapComponents{ + PeerID: "peer-A", + Network: &types.Network{ + Identifier: "net-smoke", + Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(10, 32)}, + Serial: 1, + }, + AccountSettings: &types.AccountSettingsInfo{}, + DNSSettings: &types.DNSSettings{}, + Peers: map[string]*nbpeer.Peer{ + "peer-A": peerA, + "peer-B": peerB, + }, + Groups: map[string]*types.Group{ + "group-all": group, + }, + Policies: []*types.Policy{policy}, + } + return c, peerAKey +} + +func randomWgKey(t *testing.T) string { + t.Helper() + var raw [32]byte + _, err := rand.Read(raw[:]) + require.NoError(t, err) + return base64.StdEncoding.EncodeToString(raw[:]) +}