mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-08 09:49:54 +00:00
Compare commits
3 Commits
dns-skip-f
...
ssh-config
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ad8765568b | ||
|
|
6e22e8a6fb | ||
|
|
9db7bec233 |
2
.github/workflows/golangci-lint.yml
vendored
2
.github/workflows/golangci-lint.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
|||||||
- name: codespell
|
- name: codespell
|
||||||
uses: codespell-project/actions-codespell@v2
|
uses: codespell-project/actions-codespell@v2
|
||||||
with:
|
with:
|
||||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA,ede,additionals
|
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA
|
||||||
skip: go.mod,go.sum,**/proxy/web/**
|
skip: go.mod,go.sum,**/proxy/web/**
|
||||||
golangci:
|
golangci:
|
||||||
strategy:
|
strategy:
|
||||||
|
|||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -33,4 +33,3 @@ infrastructure_files/setup-*.env
|
|||||||
vendor/
|
vendor/
|
||||||
/netbird
|
/netbird
|
||||||
client/netbird-electron/
|
client/netbird-electron/
|
||||||
management/server/types/testdata/
|
|
||||||
|
|||||||
@@ -301,11 +301,10 @@ func (c *Client) PeersList() *PeerInfoArray {
|
|||||||
peerInfos := make([]PeerInfo, len(fullStatus.Peers))
|
peerInfos := make([]PeerInfo, len(fullStatus.Peers))
|
||||||
for n, p := range fullStatus.Peers {
|
for n, p := range fullStatus.Peers {
|
||||||
pi := PeerInfo{
|
pi := PeerInfo{
|
||||||
IP: p.IP,
|
p.IP,
|
||||||
IPv6: p.IPv6,
|
p.FQDN,
|
||||||
FQDN: p.FQDN,
|
int(p.ConnStatus),
|
||||||
ConnStatus: int(p.ConnStatus),
|
PeerRoutes{routes: maps.Keys(p.GetRoutes())},
|
||||||
Routes: PeerRoutes{routes: maps.Keys(p.GetRoutes())},
|
|
||||||
}
|
}
|
||||||
peerInfos[n] = pi
|
peerInfos[n] = pi
|
||||||
}
|
}
|
||||||
@@ -337,84 +336,43 @@ func (c *Client) Networks() *NetworkArray {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
routesMap := routeManager.GetClientRoutesWithNetID()
|
|
||||||
v6Merged := route.V6ExitMergeSet(routesMap)
|
|
||||||
resolvedDomains := c.recorder.GetResolvedDomainsStates()
|
|
||||||
|
|
||||||
networkArray := &NetworkArray{
|
networkArray := &NetworkArray{
|
||||||
items: make([]Network, 0),
|
items: make([]Network, 0),
|
||||||
}
|
}
|
||||||
|
|
||||||
for id, routes := range routesMap {
|
resolvedDomains := c.recorder.GetResolvedDomainsStates()
|
||||||
|
|
||||||
|
for id, routes := range routeManager.GetClientRoutesWithNetID() {
|
||||||
if len(routes) == 0 {
|
if len(routes) == 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if _, skip := v6Merged[id]; skip {
|
|
||||||
continue
|
r := routes[0]
|
||||||
|
domains := c.getNetworkDomainsFromRoute(r, resolvedDomains)
|
||||||
|
netStr := r.Network.String()
|
||||||
|
|
||||||
|
if r.IsDynamic() {
|
||||||
|
netStr = r.Domains.SafeString()
|
||||||
}
|
}
|
||||||
|
|
||||||
network := c.buildNetwork(id, routes, routeSelector.IsSelected(id), resolvedDomains, v6Merged)
|
routePeer, err := c.recorder.GetPeer(routes[0].Peer)
|
||||||
if network == nil {
|
if err != nil {
|
||||||
|
log.Errorf("could not get peer info for %s: %v", routes[0].Peer, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
networkArray.Add(*network)
|
network := Network{
|
||||||
|
Name: string(id),
|
||||||
|
Network: netStr,
|
||||||
|
Peer: routePeer.FQDN,
|
||||||
|
Status: routePeer.ConnStatus.String(),
|
||||||
|
IsSelected: routeSelector.IsSelected(id),
|
||||||
|
Domains: domains,
|
||||||
|
}
|
||||||
|
networkArray.Add(network)
|
||||||
}
|
}
|
||||||
return networkArray
|
return networkArray
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) buildNetwork(id route.NetID, routes []*route.Route, selected bool, resolvedDomains map[domain.Domain]peer.ResolvedDomainInfo, v6Merged map[route.NetID]struct{}) *Network {
|
|
||||||
r := routes[0]
|
|
||||||
netStr := r.Network.String()
|
|
||||||
if r.IsDynamic() {
|
|
||||||
netStr = r.Domains.SafeString()
|
|
||||||
}
|
|
||||||
|
|
||||||
routePeer, err := c.findBestRoutePeer(routes)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("could not get peer info for route %s: %v", id, err)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
network := &Network{
|
|
||||||
Name: string(id),
|
|
||||||
Network: netStr,
|
|
||||||
Peer: routePeer.FQDN,
|
|
||||||
Status: routePeer.ConnStatus.String(),
|
|
||||||
IsSelected: selected,
|
|
||||||
Domains: c.getNetworkDomainsFromRoute(r, resolvedDomains),
|
|
||||||
}
|
|
||||||
|
|
||||||
if route.IsV4DefaultRoute(r.Network) && route.HasV6ExitPair(id, v6Merged) {
|
|
||||||
network.Network = "0.0.0.0/0, ::/0"
|
|
||||||
}
|
|
||||||
|
|
||||||
return network
|
|
||||||
}
|
|
||||||
|
|
||||||
// findBestRoutePeer returns the peer actively routing traffic for the given
|
|
||||||
// HA route group. Falls back to the first connected peer, then the first peer.
|
|
||||||
func (c *Client) findBestRoutePeer(routes []*route.Route) (peer.State, error) {
|
|
||||||
netStr := routes[0].Network.String()
|
|
||||||
|
|
||||||
fullStatus := c.recorder.GetFullStatus()
|
|
||||||
for _, p := range fullStatus.Peers {
|
|
||||||
if _, ok := p.GetRoutes()[netStr]; ok {
|
|
||||||
return p, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, r := range routes {
|
|
||||||
p, err := c.recorder.GetPeer(r.Peer)
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if p.ConnStatus == peer.StatusConnected {
|
|
||||||
return p, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return c.recorder.GetPeer(routes[0].Peer)
|
|
||||||
}
|
|
||||||
|
|
||||||
// OnUpdatedHostDNS update the DNS servers addresses for root zones
|
// OnUpdatedHostDNS update the DNS servers addresses for root zones
|
||||||
func (c *Client) OnUpdatedHostDNS(list *DNSList) error {
|
func (c *Client) OnUpdatedHostDNS(list *DNSList) error {
|
||||||
dnsServer, err := dns.GetServerDns()
|
dnsServer, err := dns.GetServerDns()
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ const (
|
|||||||
// PeerInfo describe information about the peers. It designed for the UI usage
|
// PeerInfo describe information about the peers. It designed for the UI usage
|
||||||
type PeerInfo struct {
|
type PeerInfo struct {
|
||||||
IP string
|
IP string
|
||||||
IPv6 string
|
|
||||||
FQDN string
|
FQDN string
|
||||||
ConnStatus int
|
ConnStatus int
|
||||||
Routes PeerRoutes
|
Routes PeerRoutes
|
||||||
|
|||||||
@@ -307,24 +307,6 @@ func (p *Preferences) SetBlockInbound(block bool) {
|
|||||||
p.configInput.BlockInbound = &block
|
p.configInput.BlockInbound = &block
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetDisableIPv6 reads disable IPv6 setting from config file
|
|
||||||
func (p *Preferences) GetDisableIPv6() (bool, error) {
|
|
||||||
if p.configInput.DisableIPv6 != nil {
|
|
||||||
return *p.configInput.DisableIPv6, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
return cfg.DisableIPv6, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetDisableIPv6 stores the given value and waits for commit
|
|
||||||
func (p *Preferences) SetDisableIPv6(disable bool) {
|
|
||||||
p.configInput.DisableIPv6 = &disable
|
|
||||||
}
|
|
||||||
|
|
||||||
// Commit writes out the changes to the config file
|
// Commit writes out the changes to the config file
|
||||||
func (p *Preferences) Commit() error {
|
func (p *Preferences) Commit() error {
|
||||||
_, err := profilemanager.UpdateOrCreateConfig(p.configInput)
|
_, err := profilemanager.UpdateOrCreateConfig(p.configInput)
|
||||||
|
|||||||
@@ -18,12 +18,9 @@ func executeRouteToggle(id string, manager routemanager.Manager,
|
|||||||
netID := route.NetID(id)
|
netID := route.NetID(id)
|
||||||
routes := []route.NetID{netID}
|
routes := []route.NetID{netID}
|
||||||
|
|
||||||
routesMap := manager.GetClientRoutesWithNetID()
|
log.Debugf("%s with id: %s", operationName, id)
|
||||||
routes = route.ExpandV6ExitPairs(routes, routesMap)
|
|
||||||
|
|
||||||
log.Debugf("%s with ids: %v", operationName, routes)
|
if err := routeOperation(routes, maps.Keys(manager.GetClientRoutesWithNetID())); err != nil {
|
||||||
|
|
||||||
if err := routeOperation(routes, maps.Keys(routesMap)); err != nil {
|
|
||||||
log.Debugf("error when %s: %s", operationName, err)
|
log.Debugf("error when %s: %s", operationName, err)
|
||||||
return fmt.Errorf("error %s: %w", operationName, err)
|
return fmt.Errorf("error %s: %w", operationName, err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"regexp"
|
"regexp"
|
||||||
"slices"
|
"slices"
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -27,9 +26,8 @@ type Anonymizer struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func DefaultAddresses() (netip.Addr, netip.Addr) {
|
func DefaultAddresses() (netip.Addr, netip.Addr) {
|
||||||
// 198.51.100.0 (RFC 5737 TEST-NET-2), 2001:db8:ffff:: (RFC 3849 documentation, last /48)
|
// 198.51.100.0, 100::
|
||||||
// The old start 100:: (discard, RFC 6666) is now used for fake IPs on Android.
|
return netip.AddrFrom4([4]byte{198, 51, 100, 0}), netip.AddrFrom16([16]byte{0x01})
|
||||||
return netip.AddrFrom4([4]byte{198, 51, 100, 0}), netip.MustParseAddr("2001:db8:ffff::")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAnonymizer(startIPv4, startIPv6 netip.Addr) *Anonymizer {
|
func NewAnonymizer(startIPv4, startIPv6 netip.Addr) *Anonymizer {
|
||||||
@@ -50,7 +48,7 @@ func (a *Anonymizer) AnonymizeIP(ip netip.Addr) netip.Addr {
|
|||||||
ip.IsLinkLocalUnicast() ||
|
ip.IsLinkLocalUnicast() ||
|
||||||
ip.IsLinkLocalMulticast() ||
|
ip.IsLinkLocalMulticast() ||
|
||||||
ip.IsInterfaceLocalMulticast() ||
|
ip.IsInterfaceLocalMulticast() ||
|
||||||
(ip.Is4() && ip.IsPrivate()) ||
|
ip.IsPrivate() ||
|
||||||
ip.IsUnspecified() ||
|
ip.IsUnspecified() ||
|
||||||
ip.IsMulticast() ||
|
ip.IsMulticast() ||
|
||||||
isWellKnown(ip) ||
|
isWellKnown(ip) ||
|
||||||
@@ -98,11 +96,6 @@ func (a *Anonymizer) isInAnonymizedRange(ip netip.Addr) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Anonymizer) AnonymizeIPString(ip string) string {
|
func (a *Anonymizer) AnonymizeIPString(ip string) string {
|
||||||
// Handle CIDR notation (e.g. "2001:db8::/32")
|
|
||||||
if prefix, err := netip.ParsePrefix(ip); err == nil {
|
|
||||||
return a.AnonymizeIP(prefix.Addr()).String() + "/" + strconv.Itoa(prefix.Bits())
|
|
||||||
}
|
|
||||||
|
|
||||||
addr, err := netip.ParseAddr(ip)
|
addr, err := netip.ParseAddr(ip)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ip
|
return ip
|
||||||
@@ -157,7 +150,7 @@ func (a *Anonymizer) AnonymizeURI(uri string) string {
|
|||||||
if u.Opaque != "" {
|
if u.Opaque != "" {
|
||||||
host, port, err := net.SplitHostPort(u.Opaque)
|
host, port, err := net.SplitHostPort(u.Opaque)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
anonymizedHost = net.JoinHostPort(a.AnonymizeDomain(host), port)
|
anonymizedHost = fmt.Sprintf("%s:%s", a.AnonymizeDomain(host), port)
|
||||||
} else {
|
} else {
|
||||||
anonymizedHost = a.AnonymizeDomain(u.Opaque)
|
anonymizedHost = a.AnonymizeDomain(u.Opaque)
|
||||||
}
|
}
|
||||||
@@ -165,7 +158,7 @@ func (a *Anonymizer) AnonymizeURI(uri string) string {
|
|||||||
} else if u.Host != "" {
|
} else if u.Host != "" {
|
||||||
host, port, err := net.SplitHostPort(u.Host)
|
host, port, err := net.SplitHostPort(u.Host)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
anonymizedHost = net.JoinHostPort(a.AnonymizeDomain(host), port)
|
anonymizedHost = fmt.Sprintf("%s:%s", a.AnonymizeDomain(host), port)
|
||||||
} else {
|
} else {
|
||||||
anonymizedHost = a.AnonymizeDomain(u.Host)
|
anonymizedHost = a.AnonymizeDomain(u.Host)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import (
|
|||||||
|
|
||||||
func TestAnonymizeIP(t *testing.T) {
|
func TestAnonymizeIP(t *testing.T) {
|
||||||
startIPv4 := netip.MustParseAddr("198.51.100.0")
|
startIPv4 := netip.MustParseAddr("198.51.100.0")
|
||||||
startIPv6 := netip.MustParseAddr("2001:db8:ffff::")
|
startIPv6 := netip.MustParseAddr("100::")
|
||||||
anonymizer := anonymize.NewAnonymizer(startIPv4, startIPv6)
|
anonymizer := anonymize.NewAnonymizer(startIPv4, startIPv6)
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
@@ -26,9 +26,9 @@ func TestAnonymizeIP(t *testing.T) {
|
|||||||
{"Second Public IPv4", "4.3.2.1", "198.51.100.1"},
|
{"Second Public IPv4", "4.3.2.1", "198.51.100.1"},
|
||||||
{"Repeated IPv4", "1.2.3.4", "198.51.100.0"},
|
{"Repeated IPv4", "1.2.3.4", "198.51.100.0"},
|
||||||
{"Private IPv4", "192.168.1.1", "192.168.1.1"},
|
{"Private IPv4", "192.168.1.1", "192.168.1.1"},
|
||||||
{"First Public IPv6", "2607:f8b0:4005:805::200e", "2001:db8:ffff::"},
|
{"First Public IPv6", "2607:f8b0:4005:805::200e", "100::"},
|
||||||
{"Second Public IPv6", "a::b", "2001:db8:ffff::1"},
|
{"Second Public IPv6", "a::b", "100::1"},
|
||||||
{"Repeated IPv6", "2607:f8b0:4005:805::200e", "2001:db8:ffff::"},
|
{"Repeated IPv6", "2607:f8b0:4005:805::200e", "100::"},
|
||||||
{"Private IPv6", "fe80::1", "fe80::1"},
|
{"Private IPv6", "fe80::1", "fe80::1"},
|
||||||
{"In Range IPv4", "198.51.100.2", "198.51.100.2"},
|
{"In Range IPv4", "198.51.100.2", "198.51.100.2"},
|
||||||
}
|
}
|
||||||
@@ -274,27 +274,17 @@ func TestAnonymizeString_IPAddresses(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "IPv6 Address",
|
name: "IPv6 Address",
|
||||||
input: "Access attempted from 2001:db8::ff00:42",
|
input: "Access attempted from 2001:db8::ff00:42",
|
||||||
expect: "Access attempted from 2001:db8:ffff::",
|
expect: "Access attempted from 100::",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "IPv6 Address with Port",
|
name: "IPv6 Address with Port",
|
||||||
input: "Access attempted from [2001:db8::ff00:42]:8080",
|
input: "Access attempted from [2001:db8::ff00:42]:8080",
|
||||||
expect: "Access attempted from [2001:db8:ffff::]:8080",
|
expect: "Access attempted from [100::]:8080",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Both IPv4 and IPv6",
|
name: "Both IPv4 and IPv6",
|
||||||
input: "IPv4: 142.108.0.1 and IPv6: 2001:db8::ff00:43",
|
input: "IPv4: 142.108.0.1 and IPv6: 2001:db8::ff00:43",
|
||||||
expect: "IPv4: 198.51.100.1 and IPv6: 2001:db8:ffff::1",
|
expect: "IPv4: 198.51.100.1 and IPv6: 100::1",
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "STUN URI with IPv6",
|
|
||||||
input: "Connecting to stun:[2001:db8::ff00:42]:3478",
|
|
||||||
expect: "Connecting to stun:[2001:db8:ffff::]:3478",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "HTTPS URI with IPv6",
|
|
||||||
input: "Visit https://[2001:db8::ff00:42]:443/path",
|
|
||||||
expect: "Visit https://[2001:db8:ffff::]:443/path",
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -523,7 +523,7 @@ func parseHostnameAndCommand(args []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func runSSH(ctx context.Context, addr string, cmd *cobra.Command) error {
|
func runSSH(ctx context.Context, addr string, cmd *cobra.Command) error {
|
||||||
target := net.JoinHostPort(strings.Trim(addr, "[]"), strconv.Itoa(port))
|
target := fmt.Sprintf("%s:%d", addr, port)
|
||||||
c, err := sshclient.Dial(ctx, target, username, sshclient.DialOptions{
|
c, err := sshclient.Dial(ctx, target, username, sshclient.DialOptions{
|
||||||
KnownHostsFile: knownHostsFile,
|
KnownHostsFile: knownHostsFile,
|
||||||
IdentityFile: identityFile,
|
IdentityFile: identityFile,
|
||||||
@@ -787,10 +787,10 @@ func isUnixSocket(path string) bool {
|
|||||||
return strings.HasPrefix(path, "/") || strings.HasPrefix(path, "./")
|
return strings.HasPrefix(path, "/") || strings.HasPrefix(path, "./")
|
||||||
}
|
}
|
||||||
|
|
||||||
// normalizeLocalHost converts "*" to "" for binding to all interfaces (dual-stack).
|
// normalizeLocalHost converts "*" to "0.0.0.0" for binding to all interfaces.
|
||||||
func normalizeLocalHost(host string) string {
|
func normalizeLocalHost(host string) string {
|
||||||
if host == "*" {
|
if host == "*" {
|
||||||
return ""
|
return "0.0.0.0"
|
||||||
}
|
}
|
||||||
return host
|
return host
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -527,10 +527,10 @@ func TestParsePortForward(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "wildcard bind all interfaces",
|
name: "wildcard bind all interfaces",
|
||||||
spec: "*:8080:localhost:80",
|
spec: "*:8080:localhost:80",
|
||||||
expectedLocal: ":8080",
|
expectedLocal: "0.0.0.0:8080",
|
||||||
expectedRemote: "localhost:80",
|
expectedRemote: "localhost:80",
|
||||||
expectError: false,
|
expectError: false,
|
||||||
description: "Wildcard * should bind to all interfaces (dual-stack)",
|
description: "Wildcard * should bind to all interfaces (0.0.0.0)",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "wildcard for port only",
|
name: "wildcard for port only",
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ import (
|
|||||||
var (
|
var (
|
||||||
detailFlag bool
|
detailFlag bool
|
||||||
ipv4Flag bool
|
ipv4Flag bool
|
||||||
ipv6Flag bool
|
|
||||||
jsonFlag bool
|
jsonFlag bool
|
||||||
yamlFlag bool
|
yamlFlag bool
|
||||||
ipsFilter []string
|
ipsFilter []string
|
||||||
@@ -46,9 +45,8 @@ func init() {
|
|||||||
statusCmd.PersistentFlags().BoolVar(&jsonFlag, "json", false, "display detailed status information in json format")
|
statusCmd.PersistentFlags().BoolVar(&jsonFlag, "json", false, "display detailed status information in json format")
|
||||||
statusCmd.PersistentFlags().BoolVar(&yamlFlag, "yaml", false, "display detailed status information in yaml format")
|
statusCmd.PersistentFlags().BoolVar(&yamlFlag, "yaml", false, "display detailed status information in yaml format")
|
||||||
statusCmd.PersistentFlags().BoolVar(&ipv4Flag, "ipv4", false, "display only NetBird IPv4 of this peer, e.g., --ipv4 will output 100.64.0.33")
|
statusCmd.PersistentFlags().BoolVar(&ipv4Flag, "ipv4", false, "display only NetBird IPv4 of this peer, e.g., --ipv4 will output 100.64.0.33")
|
||||||
statusCmd.PersistentFlags().BoolVar(&ipv6Flag, "ipv6", false, "display only NetBird IPv6 of this peer")
|
statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4")
|
||||||
statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4", "ipv6")
|
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200")
|
||||||
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs (v4 or v6), e.g., --filter-by-ips 100.64.0.100,fd00::1")
|
|
||||||
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
|
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
|
||||||
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
|
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
|
||||||
statusCmd.PersistentFlags().StringVar(&connectionTypeFilter, "filter-by-connection-type", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P")
|
statusCmd.PersistentFlags().StringVar(&connectionTypeFilter, "filter-by-connection-type", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P")
|
||||||
@@ -103,14 +101,6 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if ipv6Flag {
|
|
||||||
ipv6 := resp.GetFullStatus().GetLocalPeerState().GetIpv6()
|
|
||||||
if ipv6 != "" {
|
|
||||||
cmd.Print(parseInterfaceIP(ipv6))
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
pm := profilemanager.NewProfileManager()
|
pm := profilemanager.NewProfileManager()
|
||||||
var profName string
|
var profName string
|
||||||
if activeProf, err := pm.GetActiveProfile(); err == nil {
|
if activeProf, err := pm.GetActiveProfile(); err == nil {
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ const (
|
|||||||
disableFirewallFlag = "disable-firewall"
|
disableFirewallFlag = "disable-firewall"
|
||||||
blockLANAccessFlag = "block-lan-access"
|
blockLANAccessFlag = "block-lan-access"
|
||||||
blockInboundFlag = "block-inbound"
|
blockInboundFlag = "block-inbound"
|
||||||
disableIPv6Flag = "disable-ipv6"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -18,7 +17,6 @@ var (
|
|||||||
disableFirewall bool
|
disableFirewall bool
|
||||||
blockLANAccess bool
|
blockLANAccess bool
|
||||||
blockInbound bool
|
blockInbound bool
|
||||||
disableIPv6 bool
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
@@ -41,7 +39,4 @@ func init() {
|
|||||||
upCmd.PersistentFlags().BoolVar(&blockInbound, blockInboundFlag, false,
|
upCmd.PersistentFlags().BoolVar(&blockInbound, blockInboundFlag, false,
|
||||||
"Block inbound connections. If enabled, the client will not allow any inbound connections to the local machine nor routed networks.\n"+
|
"Block inbound connections. If enabled, the client will not allow any inbound connections to the local machine nor routed networks.\n"+
|
||||||
"This overrides any policies received from the management service.")
|
"This overrides any policies received from the management service.")
|
||||||
|
|
||||||
upCmd.PersistentFlags().BoolVar(&disableIPv6, disableIPv6Flag, false,
|
|
||||||
"Disable IPv6 overlay. If enabled, the client won't request or use an IPv6 overlay address.")
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -435,10 +435,6 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
|
|||||||
req.BlockInbound = &blockInbound
|
req.BlockInbound = &blockInbound
|
||||||
}
|
}
|
||||||
|
|
||||||
if cmd.Flag(disableIPv6Flag).Changed {
|
|
||||||
req.DisableIpv6 = &disableIPv6
|
|
||||||
}
|
|
||||||
|
|
||||||
if cmd.Flag(enableLazyConnectionFlag).Changed {
|
if cmd.Flag(enableLazyConnectionFlag).Changed {
|
||||||
req.LazyConnectionEnabled = &lazyConnEnabled
|
req.LazyConnectionEnabled = &lazyConnEnabled
|
||||||
}
|
}
|
||||||
@@ -556,10 +552,6 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
|
|||||||
ic.BlockInbound = &blockInbound
|
ic.BlockInbound = &blockInbound
|
||||||
}
|
}
|
||||||
|
|
||||||
if cmd.Flag(disableIPv6Flag).Changed {
|
|
||||||
ic.DisableIPv6 = &disableIPv6
|
|
||||||
}
|
|
||||||
|
|
||||||
if cmd.Flag(enableLazyConnectionFlag).Changed {
|
if cmd.Flag(enableLazyConnectionFlag).Changed {
|
||||||
ic.LazyConnectionEnabled = &lazyConnEnabled
|
ic.LazyConnectionEnabled = &lazyConnEnabled
|
||||||
}
|
}
|
||||||
@@ -674,10 +666,6 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
|
|||||||
loginRequest.BlockInbound = &blockInbound
|
loginRequest.BlockInbound = &blockInbound
|
||||||
}
|
}
|
||||||
|
|
||||||
if cmd.Flag(disableIPv6Flag).Changed {
|
|
||||||
loginRequest.DisableIpv6 = &disableIPv6
|
|
||||||
}
|
|
||||||
|
|
||||||
if cmd.Flag(enableLazyConnectionFlag).Changed {
|
if cmd.Flag(enableLazyConnectionFlag).Changed {
|
||||||
loginRequest.LazyConnectionEnabled = &lazyConnEnabled
|
loginRequest.LazyConnectionEnabled = &lazyConnEnabled
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -80,8 +80,6 @@ type Options struct {
|
|||||||
StatePath string
|
StatePath string
|
||||||
// DisableClientRoutes disables the client routes
|
// DisableClientRoutes disables the client routes
|
||||||
DisableClientRoutes bool
|
DisableClientRoutes bool
|
||||||
// DisableIPv6 disables IPv6 overlay addressing
|
|
||||||
DisableIPv6 bool
|
|
||||||
// BlockInbound blocks all inbound connections from peers
|
// BlockInbound blocks all inbound connections from peers
|
||||||
BlockInbound bool
|
BlockInbound bool
|
||||||
// WireguardPort is the port for the tunnel interface. Use 0 for a random port.
|
// WireguardPort is the port for the tunnel interface. Use 0 for a random port.
|
||||||
@@ -173,7 +171,6 @@ func New(opts Options) (*Client, error) {
|
|||||||
PreSharedKey: &opts.PreSharedKey,
|
PreSharedKey: &opts.PreSharedKey,
|
||||||
DisableServerRoutes: &t,
|
DisableServerRoutes: &t,
|
||||||
DisableClientRoutes: &opts.DisableClientRoutes,
|
DisableClientRoutes: &opts.DisableClientRoutes,
|
||||||
DisableIPv6: &opts.DisableIPv6,
|
|
||||||
BlockInbound: &opts.BlockInbound,
|
BlockInbound: &opts.BlockInbound,
|
||||||
WireguardPort: opts.WireguardPort,
|
WireguardPort: opts.WireguardPort,
|
||||||
MTU: opts.MTU,
|
MTU: opts.MTU,
|
||||||
|
|||||||
@@ -40,7 +40,6 @@ type aclManager struct {
|
|||||||
entries aclEntries
|
entries aclEntries
|
||||||
optionalEntries map[string][]entry
|
optionalEntries map[string][]entry
|
||||||
ipsetStore *ipsetStore
|
ipsetStore *ipsetStore
|
||||||
v6 bool
|
|
||||||
|
|
||||||
stateManager *statemanager.Manager
|
stateManager *statemanager.Manager
|
||||||
}
|
}
|
||||||
@@ -52,7 +51,6 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*acl
|
|||||||
entries: make(map[string][][]string),
|
entries: make(map[string][][]string),
|
||||||
optionalEntries: make(map[string][]entry),
|
optionalEntries: make(map[string][]entry),
|
||||||
ipsetStore: newIpsetStore(),
|
ipsetStore: newIpsetStore(),
|
||||||
v6: iptablesClient.Proto() == iptables.ProtocolIPv6,
|
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -87,11 +85,7 @@ func (m *aclManager) AddPeerFiltering(
|
|||||||
chain := chainNameInputRules
|
chain := chainNameInputRules
|
||||||
|
|
||||||
ipsetName = transformIPsetName(ipsetName, sPort, dPort, action)
|
ipsetName = transformIPsetName(ipsetName, sPort, dPort, action)
|
||||||
if m.v6 && ipsetName != "" {
|
specs := filterRuleSpecs(ip, string(protocol), sPort, dPort, action, ipsetName)
|
||||||
ipsetName += "-v6"
|
|
||||||
}
|
|
||||||
proto := protoForFamily(protocol, m.v6)
|
|
||||||
specs := filterRuleSpecs(ip, proto, sPort, dPort, action, ipsetName)
|
|
||||||
|
|
||||||
mangleSpecs := slices.Clone(specs)
|
mangleSpecs := slices.Clone(specs)
|
||||||
mangleSpecs = append(mangleSpecs,
|
mangleSpecs = append(mangleSpecs,
|
||||||
@@ -115,7 +109,6 @@ func (m *aclManager) AddPeerFiltering(
|
|||||||
ip: ip.String(),
|
ip: ip.String(),
|
||||||
chain: chain,
|
chain: chain,
|
||||||
specs: specs,
|
specs: specs,
|
||||||
v6: m.v6,
|
|
||||||
}}, nil
|
}}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -168,7 +161,6 @@ func (m *aclManager) AddPeerFiltering(
|
|||||||
ipsetName: ipsetName,
|
ipsetName: ipsetName,
|
||||||
ip: ip.String(),
|
ip: ip.String(),
|
||||||
chain: chain,
|
chain: chain,
|
||||||
v6: m.v6,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
m.updateState()
|
m.updateState()
|
||||||
@@ -421,13 +413,8 @@ func (m *aclManager) updateState() {
|
|||||||
currentState.Lock()
|
currentState.Lock()
|
||||||
defer currentState.Unlock()
|
defer currentState.Unlock()
|
||||||
|
|
||||||
if m.v6 {
|
currentState.ACLEntries = m.entries
|
||||||
currentState.ACLEntries6 = m.entries
|
currentState.ACLIPsetStore = m.ipsetStore
|
||||||
currentState.ACLIPsetStore6 = m.ipsetStore
|
|
||||||
} else {
|
|
||||||
currentState.ACLEntries = m.entries
|
|
||||||
currentState.ACLIPsetStore = m.ipsetStore
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.stateManager.UpdateState(currentState); err != nil {
|
if err := m.stateManager.UpdateState(currentState); err != nil {
|
||||||
log.Errorf("failed to update state: %v", err)
|
log.Errorf("failed to update state: %v", err)
|
||||||
@@ -435,22 +422,13 @@ func (m *aclManager) updateState() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// filterRuleSpecs returns the specs of a filtering rule
|
// filterRuleSpecs returns the specs of a filtering rule
|
||||||
// protoForFamily translates ICMP to ICMPv6 for ip6tables.
|
|
||||||
// ip6tables requires "ipv6-icmp" (or "icmpv6") instead of "icmp".
|
|
||||||
func protoForFamily(protocol firewall.Protocol, v6 bool) string {
|
|
||||||
if v6 && protocol == firewall.ProtocolICMP {
|
|
||||||
return "ipv6-icmp"
|
|
||||||
}
|
|
||||||
return string(protocol)
|
|
||||||
}
|
|
||||||
|
|
||||||
func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) {
|
func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) {
|
||||||
// don't use IP matching if IP is 0.0.0.0
|
// don't use IP matching if IP is 0.0.0.0
|
||||||
matchByIP := !ip.IsUnspecified()
|
matchByIP := !ip.IsUnspecified()
|
||||||
|
|
||||||
if matchByIP {
|
if matchByIP {
|
||||||
if ipsetName != "" {
|
if ipsetName != "" {
|
||||||
specs = append(specs, "-m", "set", "--match-set", ipsetName, "src")
|
specs = append(specs, "-m", "set", "--set", ipsetName, "src")
|
||||||
} else {
|
} else {
|
||||||
specs = append(specs, "-s", ip.String())
|
specs = append(specs, "-s", ip.String())
|
||||||
}
|
}
|
||||||
@@ -496,9 +474,6 @@ func (m *aclManager) createIPSet(name string) error {
|
|||||||
opts := ipset.CreateOptions{
|
opts := ipset.CreateOptions{
|
||||||
Replace: true,
|
Replace: true,
|
||||||
}
|
}
|
||||||
if m.v6 {
|
|
||||||
opts.Family = ipset.FamilyIPV6
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
|
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
|
||||||
return fmt.Errorf("create ipset %s: %w", name, err)
|
return fmt.Errorf("create ipset %s: %w", name, err)
|
||||||
|
|||||||
@@ -18,10 +18,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
type resetter interface {
|
|
||||||
Reset() error
|
|
||||||
}
|
|
||||||
|
|
||||||
// Manager of iptables firewall
|
// Manager of iptables firewall
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
mutex sync.Mutex
|
mutex sync.Mutex
|
||||||
@@ -32,11 +28,6 @@ type Manager struct {
|
|||||||
aclMgr *aclManager
|
aclMgr *aclManager
|
||||||
router *router
|
router *router
|
||||||
rawSupported bool
|
rawSupported bool
|
||||||
|
|
||||||
// IPv6 counterparts, nil when no v6 overlay
|
|
||||||
ipv6Client *iptables.IPTables
|
|
||||||
aclMgr6 *aclManager
|
|
||||||
router6 *router
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// iFaceMapper defines subset methods of interface required for manager
|
// iFaceMapper defines subset methods of interface required for manager
|
||||||
@@ -67,43 +58,9 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
|
|||||||
return nil, fmt.Errorf("create acl manager: %w", err)
|
return nil, fmt.Errorf("create acl manager: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if wgIface.Address().HasIPv6() {
|
|
||||||
if err := m.createIPv6Components(wgIface, mtu); err != nil {
|
|
||||||
return nil, fmt.Errorf("create IPv6 firewall: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) createIPv6Components(wgIface iFaceMapper, mtu uint16) error {
|
|
||||||
ip6Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("init ip6tables: %w", err)
|
|
||||||
}
|
|
||||||
m.ipv6Client = ip6Client
|
|
||||||
|
|
||||||
m.router6, err = newRouter(ip6Client, wgIface, mtu)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("create v6 router: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Share the same IP forwarding state with the v4 router, since
|
|
||||||
// EnableIPForwarding controls both v4 and v6 sysctls.
|
|
||||||
m.router6.ipFwdState = m.router.ipFwdState
|
|
||||||
|
|
||||||
m.aclMgr6, err = newAclManager(ip6Client, wgIface)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("create v6 acl manager: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) hasIPv6() bool {
|
|
||||||
return m.ipv6Client != nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||||
state := &ShutdownState{
|
state := &ShutdownState{
|
||||||
InterfaceState: &InterfaceState{
|
InterfaceState: &InterfaceState{
|
||||||
@@ -117,8 +74,13 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
|||||||
log.Errorf("failed to update state: %v", err)
|
log.Errorf("failed to update state: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.initChains(stateManager); err != nil {
|
if err := m.router.init(stateManager); err != nil {
|
||||||
return err
|
return fmt.Errorf("router init: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.aclMgr.init(stateManager); err != nil {
|
||||||
|
// TODO: cleanup router
|
||||||
|
return fmt.Errorf("acl manager init: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.initNoTrackChain(); err != nil {
|
if err := m.initNoTrackChain(); err != nil {
|
||||||
@@ -141,41 +103,6 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// initChains initializes router and ACL chains for both address families,
|
|
||||||
// rolling back on failure.
|
|
||||||
func (m *Manager) initChains(stateManager *statemanager.Manager) error {
|
|
||||||
type initStep struct {
|
|
||||||
name string
|
|
||||||
init func(*statemanager.Manager) error
|
|
||||||
mgr resetter
|
|
||||||
}
|
|
||||||
|
|
||||||
steps := []initStep{
|
|
||||||
{"router", m.router.init, m.router},
|
|
||||||
{"acl manager", m.aclMgr.init, m.aclMgr},
|
|
||||||
}
|
|
||||||
if m.hasIPv6() {
|
|
||||||
steps = append(steps,
|
|
||||||
initStep{"v6 router", m.router6.init, m.router6},
|
|
||||||
initStep{"v6 acl manager", m.aclMgr6.init, m.aclMgr6},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
var initialized []initStep
|
|
||||||
for _, s := range steps {
|
|
||||||
if err := s.init(stateManager); err != nil {
|
|
||||||
for i := len(initialized) - 1; i >= 0; i-- {
|
|
||||||
if rerr := initialized[i].mgr.Reset(); rerr != nil {
|
|
||||||
log.Warnf("rollback %s: %v", initialized[i].name, rerr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return fmt.Errorf("%s init: %w", s.name, err)
|
|
||||||
}
|
|
||||||
initialized = append(initialized, s)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddPeerFiltering adds a rule to the firewall
|
// AddPeerFiltering adds a rule to the firewall
|
||||||
//
|
//
|
||||||
// Comment will be ignored because some system this feature is not supported
|
// Comment will be ignored because some system this feature is not supported
|
||||||
@@ -191,13 +118,7 @@ func (m *Manager) AddPeerFiltering(
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
if ip.To4() != nil {
|
return m.aclMgr.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
||||||
return m.aclMgr.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
|
||||||
}
|
|
||||||
if !m.hasIPv6() {
|
|
||||||
return nil, fmt.Errorf("add peer filtering for %s: %w", ip, firewall.ErrIPv6NotInitialized)
|
|
||||||
}
|
|
||||||
return m.aclMgr6.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) AddRouteFiltering(
|
func (m *Manager) AddRouteFiltering(
|
||||||
@@ -211,48 +132,25 @@ func (m *Manager) AddRouteFiltering(
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
if isIPv6RouteRule(sources, destination) {
|
if destination.IsPrefix() && !destination.Prefix.Addr().Is4() {
|
||||||
if !m.hasIPv6() {
|
return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String())
|
||||||
return nil, fmt.Errorf("add route filtering: %w", firewall.ErrIPv6NotInitialized)
|
|
||||||
}
|
|
||||||
return m.router6.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||||
}
|
}
|
||||||
|
|
||||||
func isIPv6RouteRule(sources []netip.Prefix, destination firewall.Network) bool {
|
|
||||||
if destination.IsPrefix() {
|
|
||||||
return destination.Prefix.Addr().Is6()
|
|
||||||
}
|
|
||||||
return len(sources) > 0 && sources[0].Addr().Is6()
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeletePeerRule from the firewall by rule definition
|
// DeletePeerRule from the firewall by rule definition
|
||||||
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
if m.hasIPv6() && isIPv6IptRule(rule) {
|
|
||||||
return m.aclMgr6.DeletePeerRule(rule)
|
|
||||||
}
|
|
||||||
return m.aclMgr.DeletePeerRule(rule)
|
return m.aclMgr.DeletePeerRule(rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
func isIPv6IptRule(rule firewall.Rule) bool {
|
|
||||||
r, ok := rule.(*Rule)
|
|
||||||
return ok && r.v6
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteRouteRule deletes a routing rule.
|
|
||||||
// Route rules are keyed by content hash. Check v4 first, try v6 if not found.
|
|
||||||
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
if m.hasIPv6() && !m.router.hasRule(rule.ID()) {
|
|
||||||
return m.router6.DeleteRouteRule(rule)
|
|
||||||
}
|
|
||||||
return m.router.DeleteRouteRule(rule)
|
return m.router.DeleteRouteRule(rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -268,65 +166,18 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() {
|
return m.router.AddNatRule(pair)
|
||||||
if !m.hasIPv6() {
|
|
||||||
return fmt.Errorf("add NAT rule: %w", firewall.ErrIPv6NotInitialized)
|
|
||||||
}
|
|
||||||
return m.router6.AddNatRule(pair)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.router.AddNatRule(pair); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Dynamic routes need NAT in both tables since resolved IPs can be
|
|
||||||
// either v4 or v6. This covers both DomainSet (modern) and the legacy
|
|
||||||
// wildcard 0.0.0.0/0 destination where the client resolves DNS.
|
|
||||||
if m.hasIPv6() && pair.Dynamic {
|
|
||||||
v6Pair := firewall.ToV6NatPair(pair)
|
|
||||||
if err := m.router6.AddNatRule(v6Pair); err != nil {
|
|
||||||
return fmt.Errorf("add v6 NAT rule: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() {
|
return m.router.RemoveNatRule(pair)
|
||||||
if !m.hasIPv6() {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return m.router6.RemoveNatRule(pair)
|
|
||||||
}
|
|
||||||
|
|
||||||
var merr *multierror.Error
|
|
||||||
|
|
||||||
if err := m.router.RemoveNatRule(pair); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove v4 NAT rule: %w", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.hasIPv6() && pair.Dynamic {
|
|
||||||
v6Pair := firewall.ToV6NatPair(pair)
|
|
||||||
if err := m.router6.RemoveNatRule(v6Pair); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove v6 NAT rule: %w", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||||
if err := firewall.SetLegacyManagement(m.router, isLegacy); err != nil {
|
return firewall.SetLegacyManagement(m.router, isLegacy)
|
||||||
return err
|
|
||||||
}
|
|
||||||
if m.hasIPv6() {
|
|
||||||
return firewall.SetLegacyManagement(m.router6, isLegacy)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reset firewall to the default state
|
// Reset firewall to the default state
|
||||||
@@ -340,15 +191,6 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
|||||||
merr = multierror.Append(merr, fmt.Errorf("cleanup notrack chain: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("cleanup notrack chain: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.hasIPv6() {
|
|
||||||
if err := m.aclMgr6.Reset(); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("reset v6 acl manager: %w", err))
|
|
||||||
}
|
|
||||||
if err := m.router6.Reset(); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("reset v6 router: %w", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.aclMgr.Reset(); err != nil {
|
if err := m.aclMgr.Reset(); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err))
|
||||||
}
|
}
|
||||||
@@ -376,21 +218,24 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
|||||||
// This is called when USPFilter wraps the native firewall, adding blanket accept
|
// This is called when USPFilter wraps the native firewall, adding blanket accept
|
||||||
// rules so that packet filtering is handled in userspace instead of by netfilter.
|
// rules so that packet filtering is handled in userspace instead of by netfilter.
|
||||||
func (m *Manager) AllowNetbird() error {
|
func (m *Manager) AllowNetbird() error {
|
||||||
var merr *multierror.Error
|
_, err := m.AddPeerFiltering(
|
||||||
if _, err := m.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolALL, nil, nil, firewall.ActionAccept, ""); err != nil {
|
nil,
|
||||||
merr = multierror.Append(merr, fmt.Errorf("allow netbird v4 interface traffic: %w", err))
|
net.IP{0, 0, 0, 0},
|
||||||
}
|
firewall.ProtocolALL,
|
||||||
if m.hasIPv6() {
|
nil,
|
||||||
if _, err := m.AddPeerFiltering(nil, net.IPv6zero, firewall.ProtocolALL, nil, nil, firewall.ActionAccept, ""); err != nil {
|
nil,
|
||||||
merr = multierror.Append(merr, fmt.Errorf("allow netbird v6 interface traffic: %w", err))
|
firewall.ActionAccept,
|
||||||
}
|
"",
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("allow netbird interface traffic: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
|
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
|
||||||
log.Warnf("failed to trust interface in firewalld: %v", err)
|
log.Warnf("failed to trust interface in firewalld: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Flush doesn't need to be implemented for this manager
|
// Flush doesn't need to be implemented for this manager
|
||||||
@@ -420,12 +265,6 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error)
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
if rule.TranslatedAddress.Is6() {
|
|
||||||
if !m.hasIPv6() {
|
|
||||||
return nil, fmt.Errorf("add DNAT rule: %w", firewall.ErrIPv6NotInitialized)
|
|
||||||
}
|
|
||||||
return m.router6.AddDNATRule(rule)
|
|
||||||
}
|
|
||||||
return m.router.AddDNATRule(rule)
|
return m.router.AddDNATRule(rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -434,9 +273,6 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
if m.hasIPv6() && !m.router.hasRule(rule.ID()+dnatSuffix) {
|
|
||||||
return m.router6.DeleteDNATRule(rule)
|
|
||||||
}
|
|
||||||
return m.router.DeleteDNATRule(rule)
|
return m.router.DeleteDNATRule(rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -445,82 +281,39 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
var v4Prefixes, v6Prefixes []netip.Prefix
|
return m.router.UpdateSet(set, prefixes)
|
||||||
for _, p := range prefixes {
|
|
||||||
if p.Addr().Is6() {
|
|
||||||
v6Prefixes = append(v6Prefixes, p)
|
|
||||||
} else {
|
|
||||||
v4Prefixes = append(v4Prefixes, p)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.router.UpdateSet(set, v4Prefixes); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.hasIPv6() && len(v6Prefixes) > 0 {
|
|
||||||
if err := m.router6.UpdateSet(set, v6Prefixes); err != nil {
|
|
||||||
return fmt.Errorf("update v6 set: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
||||||
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
if localAddr.Is6() {
|
return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||||
if !m.hasIPv6() {
|
|
||||||
return fmt.Errorf("add inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
|
|
||||||
}
|
|
||||||
return m.router6.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
|
||||||
}
|
|
||||||
return m.router.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveInboundDNAT removes an inbound DNAT rule.
|
// RemoveInboundDNAT removes an inbound DNAT rule.
|
||||||
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
if localAddr.Is6() {
|
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||||
if !m.hasIPv6() {
|
|
||||||
return fmt.Errorf("remove inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
|
|
||||||
}
|
|
||||||
return m.router6.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
|
||||||
}
|
|
||||||
return m.router.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||||
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
if localAddr.Is6() {
|
return m.router.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||||
if !m.hasIPv6() {
|
|
||||||
return fmt.Errorf("add output DNAT: %w", firewall.ErrIPv6NotInitialized)
|
|
||||||
}
|
|
||||||
return m.router6.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
|
||||||
}
|
|
||||||
return m.router.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||||
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
if localAddr.Is6() {
|
return m.router.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||||
if !m.hasIPv6() {
|
|
||||||
return fmt.Errorf("remove output DNAT: %w", firewall.ErrIPv6NotInitialized)
|
|
||||||
}
|
|
||||||
return m.router6.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
|
||||||
}
|
|
||||||
return m.router.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|||||||
@@ -54,10 +54,8 @@ const (
|
|||||||
snatSuffix = "_snat"
|
snatSuffix = "_snat"
|
||||||
fwdSuffix = "_fwd"
|
fwdSuffix = "_fwd"
|
||||||
|
|
||||||
// ipv4TCPHeaderSize is the minimum IPv4 (20) + TCP (20) header size for MSS calculation.
|
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
|
||||||
ipv4TCPHeaderSize = 40
|
ipTCPHeaderMinSize = 40
|
||||||
// ipv6TCPHeaderSize is the minimum IPv6 (40) + TCP (20) header size for MSS calculation.
|
|
||||||
ipv6TCPHeaderSize = 60
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type ruleInfo struct {
|
type ruleInfo struct {
|
||||||
@@ -88,7 +86,6 @@ type router struct {
|
|||||||
wgIface iFaceMapper
|
wgIface iFaceMapper
|
||||||
legacyManagement bool
|
legacyManagement bool
|
||||||
mtu uint16
|
mtu uint16
|
||||||
v6 bool
|
|
||||||
|
|
||||||
stateManager *statemanager.Manager
|
stateManager *statemanager.Manager
|
||||||
ipFwdState *ipfwdstate.IPForwardingState
|
ipFwdState *ipfwdstate.IPForwardingState
|
||||||
@@ -100,7 +97,6 @@ func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint1
|
|||||||
rules: make(map[string][]string),
|
rules: make(map[string][]string),
|
||||||
wgIface: wgIface,
|
wgIface: wgIface,
|
||||||
mtu: mtu,
|
mtu: mtu,
|
||||||
v6: iptablesClient.Proto() == iptables.ProtocolIPv6,
|
|
||||||
ipFwdState: ipfwdstate.NewIPForwardingState(),
|
ipFwdState: ipfwdstate.NewIPForwardingState(),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -190,11 +186,6 @@ func (r *router) AddRouteFiltering(
|
|||||||
return ruleKey, nil
|
return ruleKey, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) hasRule(id string) bool {
|
|
||||||
_, ok := r.rules[id]
|
|
||||||
return ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
||||||
ruleKey := rule.ID()
|
ruleKey := rule.ID()
|
||||||
|
|
||||||
@@ -401,13 +392,9 @@ func (r *router) cleanUpDefaultForwardRules() error {
|
|||||||
|
|
||||||
// Remove jump rules from built-in chains before deleting custom chains,
|
// Remove jump rules from built-in chains before deleting custom chains,
|
||||||
// otherwise the chain deletion fails with "device or resource busy".
|
// otherwise the chain deletion fails with "device or resource busy".
|
||||||
if ok, err := r.iptablesClient.ChainExists(tableNat, chainNATOutput); err != nil {
|
jumpRule := []string{"-j", chainNATOutput}
|
||||||
return fmt.Errorf("check chain %s: %w", chainNATOutput, err)
|
if err := r.iptablesClient.Delete(tableNat, "OUTPUT", jumpRule...); err != nil {
|
||||||
} else if ok {
|
log.Debugf("clean OUTPUT jump rule: %v", err)
|
||||||
jumpRule := []string{"-j", chainNATOutput}
|
|
||||||
if err := r.iptablesClient.Delete(tableNat, "OUTPUT", jumpRule...); err != nil {
|
|
||||||
log.Debugf("clean OUTPUT jump rule: %v", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, chainInfo := range []struct {
|
for _, chainInfo := range []struct {
|
||||||
@@ -447,12 +434,6 @@ func (r *router) createContainers() error {
|
|||||||
{chainRTRDR, tableNat},
|
{chainRTRDR, tableNat},
|
||||||
{chainRTMSSCLAMP, tableMangle},
|
{chainRTMSSCLAMP, tableMangle},
|
||||||
} {
|
} {
|
||||||
// Fallback: clear chains that survived an unclean shutdown.
|
|
||||||
if ok, _ := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain); ok {
|
|
||||||
if err := r.iptablesClient.ClearAndDeleteChain(chainInfo.table, chainInfo.chain); err != nil {
|
|
||||||
log.Warnf("clear stale chain %s in %s: %v", chainInfo.chain, chainInfo.table, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil {
|
if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil {
|
||||||
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
|
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
|
||||||
}
|
}
|
||||||
@@ -559,12 +540,9 @@ func (r *router) addPostroutingRules() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
|
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
|
||||||
|
// TODO: Add IPv6 support
|
||||||
func (r *router) addMSSClampingRules() error {
|
func (r *router) addMSSClampingRules() error {
|
||||||
overhead := uint16(ipv4TCPHeaderSize)
|
mss := r.mtu - ipTCPHeaderMinSize
|
||||||
if r.v6 {
|
|
||||||
overhead = ipv6TCPHeaderSize
|
|
||||||
}
|
|
||||||
mss := r.mtu - overhead
|
|
||||||
|
|
||||||
// Add jump rule from FORWARD chain in mangle table to our custom chain
|
// Add jump rule from FORWARD chain in mangle table to our custom chain
|
||||||
jumpRule := []string{
|
jumpRule := []string{
|
||||||
@@ -749,13 +727,8 @@ func (r *router) updateState() {
|
|||||||
currentState.Lock()
|
currentState.Lock()
|
||||||
defer currentState.Unlock()
|
defer currentState.Unlock()
|
||||||
|
|
||||||
if r.v6 {
|
currentState.RouteRules = r.rules
|
||||||
currentState.RouteRules6 = r.rules
|
currentState.RouteIPsetCounter = r.ipsetCounter
|
||||||
currentState.RouteIPsetCounter6 = r.ipsetCounter
|
|
||||||
} else {
|
|
||||||
currentState.RouteRules = r.rules
|
|
||||||
currentState.RouteIPsetCounter = r.ipsetCounter
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.stateManager.UpdateState(currentState); err != nil {
|
if err := r.stateManager.UpdateState(currentState); err != nil {
|
||||||
log.Errorf("failed to update state: %v", err)
|
log.Errorf("failed to update state: %v", err)
|
||||||
@@ -883,7 +856,7 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if fwdRule, exists := r.rules[ruleKey+fwdSuffix]; exists {
|
if fwdRule, exists := r.rules[ruleKey+fwdSuffix]; exists {
|
||||||
if err := r.iptablesClient.Delete(tableFilter, chainRTFWDOUT, fwdRule...); err != nil {
|
if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, fwdRule...); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("delete forward rule: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("delete forward rule: %w", err))
|
||||||
}
|
}
|
||||||
delete(r.rules, ruleKey+fwdSuffix)
|
delete(r.rules, ruleKey+fwdSuffix)
|
||||||
@@ -910,7 +883,7 @@ func (r *router) genRouteRuleSpec(params routeFilteringRuleParams, sources []net
|
|||||||
rule = append(rule, destExp...)
|
rule = append(rule, destExp...)
|
||||||
|
|
||||||
if params.Proto != firewall.ProtocolALL {
|
if params.Proto != firewall.ProtocolALL {
|
||||||
rule = append(rule, "-p", strings.ToLower(protoForFamily(params.Proto, r.v6)))
|
rule = append(rule, "-p", strings.ToLower(string(params.Proto)))
|
||||||
rule = append(rule, applyPort("--sport", params.SPort)...)
|
rule = append(rule, applyPort("--sport", params.SPort)...)
|
||||||
rule = append(rule, applyPort("--dport", params.DPort)...)
|
rule = append(rule, applyPort("--dport", params.DPort)...)
|
||||||
}
|
}
|
||||||
@@ -927,12 +900,11 @@ func (r *router) applyNetwork(flag string, network firewall.Network, prefixes []
|
|||||||
}
|
}
|
||||||
|
|
||||||
if network.IsSet() {
|
if network.IsSet() {
|
||||||
name := r.ipsetName(network.Set.HashedName())
|
if _, err := r.ipsetCounter.Increment(network.Set.HashedName(), prefixes); err != nil {
|
||||||
if _, err := r.ipsetCounter.Increment(name, prefixes); err != nil {
|
|
||||||
return nil, fmt.Errorf("create or get ipset: %w", err)
|
return nil, fmt.Errorf("create or get ipset: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return []string{"-m", "set", matchSet, name, direction}, nil
|
return []string{"-m", "set", matchSet, network.Set.HashedName(), direction}, nil
|
||||||
}
|
}
|
||||||
if network.IsPrefix() {
|
if network.IsPrefix() {
|
||||||
return []string{flag, network.Prefix.String()}, nil
|
return []string{flag, network.Prefix.String()}, nil
|
||||||
@@ -943,23 +915,27 @@ func (r *router) applyNetwork(flag string, network firewall.Network, prefixes []
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||||
name := r.ipsetName(set.HashedName())
|
|
||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
for _, prefix := range prefixes {
|
for _, prefix := range prefixes {
|
||||||
if err := r.addPrefixToIPSet(name, prefix); err != nil {
|
// TODO: Implement IPv6 support
|
||||||
|
if prefix.Addr().Is6() {
|
||||||
|
log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := r.addPrefixToIPSet(set.HashedName(), prefix); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("add prefix to ipset: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("add prefix to ipset: %w", err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if merr == nil {
|
if merr == nil {
|
||||||
log.Debugf("updated set %s with prefixes %v", name, prefixes)
|
log.Debugf("updated set %s with prefixes %v", set.HashedName(), prefixes)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
||||||
func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||||
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort)
|
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||||
|
|
||||||
if _, exists := r.rules[ruleID]; exists {
|
if _, exists := r.rules[ruleID]; exists {
|
||||||
return nil
|
return nil
|
||||||
@@ -967,12 +943,12 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol
|
|||||||
|
|
||||||
dnatRule := []string{
|
dnatRule := []string{
|
||||||
"-i", r.wgIface.Name(),
|
"-i", r.wgIface.Name(),
|
||||||
"-p", strings.ToLower(protoForFamily(protocol, r.v6)),
|
"-p", strings.ToLower(string(protocol)),
|
||||||
"--dport", strconv.Itoa(int(originalPort)),
|
"--dport", strconv.Itoa(int(sourcePort)),
|
||||||
"-d", localAddr.String(),
|
"-d", localAddr.String(),
|
||||||
"-m", "addrtype", "--dst-type", "LOCAL",
|
"-m", "addrtype", "--dst-type", "LOCAL",
|
||||||
"-j", "DNAT",
|
"-j", "DNAT",
|
||||||
"--to-destination", ":" + strconv.Itoa(int(translatedPort)),
|
"--to-destination", ":" + strconv.Itoa(int(targetPort)),
|
||||||
}
|
}
|
||||||
|
|
||||||
ruleInfo := ruleInfo{
|
ruleInfo := ruleInfo{
|
||||||
@@ -991,8 +967,8 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RemoveInboundDNAT removes an inbound DNAT rule.
|
// RemoveInboundDNAT removes an inbound DNAT rule.
|
||||||
func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||||
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort)
|
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||||
|
|
||||||
if dnatRule, exists := r.rules[ruleID]; exists {
|
if dnatRule, exists := r.rules[ruleID]; exists {
|
||||||
if err := r.iptablesClient.Delete(tableNat, chainRTRDR, dnatRule...); err != nil {
|
if err := r.iptablesClient.Delete(tableNat, chainRTRDR, dnatRule...); err != nil {
|
||||||
@@ -1037,8 +1013,8 @@ func (r *router) ensureNATOutputChain() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||||
func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||||
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort)
|
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||||
|
|
||||||
if _, exists := r.rules[ruleID]; exists {
|
if _, exists := r.rules[ruleID]; exists {
|
||||||
return nil
|
return nil
|
||||||
@@ -1049,11 +1025,11 @@ func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol,
|
|||||||
}
|
}
|
||||||
|
|
||||||
dnatRule := []string{
|
dnatRule := []string{
|
||||||
"-p", strings.ToLower(protoForFamily(protocol, localAddr.Is6())),
|
"-p", strings.ToLower(string(protocol)),
|
||||||
"--dport", strconv.Itoa(int(originalPort)),
|
"--dport", strconv.Itoa(int(sourcePort)),
|
||||||
"-d", localAddr.String(),
|
"-d", localAddr.String(),
|
||||||
"-j", "DNAT",
|
"-j", "DNAT",
|
||||||
"--to-destination", ":" + strconv.Itoa(int(translatedPort)),
|
"--to-destination", ":" + strconv.Itoa(int(targetPort)),
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.iptablesClient.Append(tableNat, chainNATOutput, dnatRule...); err != nil {
|
if err := r.iptablesClient.Append(tableNat, chainNATOutput, dnatRule...); err != nil {
|
||||||
@@ -1066,8 +1042,8 @@ func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||||
func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||||
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort)
|
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||||
|
|
||||||
if dnatRule, exists := r.rules[ruleID]; exists {
|
if dnatRule, exists := r.rules[ruleID]; exists {
|
||||||
if err := r.iptablesClient.Delete(tableNat, chainNATOutput, dnatRule...); err != nil {
|
if err := r.iptablesClient.Delete(tableNat, chainNATOutput, dnatRule...); err != nil {
|
||||||
@@ -1100,22 +1076,10 @@ func applyPort(flag string, port *firewall.Port) []string {
|
|||||||
return []string{flag, strconv.Itoa(int(port.Values[0]))}
|
return []string{flag, strconv.Itoa(int(port.Values[0]))}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ipsetName returns the ipset name, suffixed with "-v6" for the v6 router
|
|
||||||
// to avoid collisions since ipsets are global in the kernel.
|
|
||||||
func (r *router) ipsetName(name string) string {
|
|
||||||
if r.v6 {
|
|
||||||
return name + "-v6"
|
|
||||||
}
|
|
||||||
return name
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *router) createIPSet(name string) error {
|
func (r *router) createIPSet(name string) error {
|
||||||
opts := ipset.CreateOptions{
|
opts := ipset.CreateOptions{
|
||||||
Replace: true,
|
Replace: true,
|
||||||
}
|
}
|
||||||
if r.v6 {
|
|
||||||
opts.Family = ipset.FamilyIPV6
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
|
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
|
||||||
return fmt.Errorf("create ipset %s: %w", name, err)
|
return fmt.Errorf("create ipset %s: %w", name, err)
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ type Rule struct {
|
|||||||
mangleSpecs []string
|
mangleSpecs []string
|
||||||
ip string
|
ip string
|
||||||
chain string
|
chain string
|
||||||
v6 bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRuleID returns the rule id
|
// GetRuleID returns the rule id
|
||||||
|
|||||||
@@ -4,8 +4,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
@@ -34,12 +32,6 @@ type ShutdownState struct {
|
|||||||
|
|
||||||
ACLEntries aclEntries `json:"acl_entries,omitempty"`
|
ACLEntries aclEntries `json:"acl_entries,omitempty"`
|
||||||
ACLIPsetStore *ipsetStore `json:"acl_ipset_store,omitempty"`
|
ACLIPsetStore *ipsetStore `json:"acl_ipset_store,omitempty"`
|
||||||
|
|
||||||
// IPv6 counterparts
|
|
||||||
RouteRules6 routeRules `json:"route_rules_v6,omitempty"`
|
|
||||||
RouteIPsetCounter6 *ipsetCounter `json:"route_ipset_counter_v6,omitempty"`
|
|
||||||
ACLEntries6 aclEntries `json:"acl_entries_v6,omitempty"`
|
|
||||||
ACLIPsetStore6 *ipsetStore `json:"acl_ipset_store_v6,omitempty"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ShutdownState) Name() string {
|
func (s *ShutdownState) Name() string {
|
||||||
@@ -70,28 +62,6 @@ func (s *ShutdownState) Cleanup() error {
|
|||||||
ipt.aclMgr.ipsetStore = s.ACLIPsetStore
|
ipt.aclMgr.ipsetStore = s.ACLIPsetStore
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clean up v6 state even if the current run has no IPv6.
|
|
||||||
// The previous run may have left ip6tables rules behind.
|
|
||||||
if !ipt.hasIPv6() {
|
|
||||||
if err := ipt.createIPv6Components(s.InterfaceState, mtu); err != nil {
|
|
||||||
log.Warnf("failed to create v6 components for cleanup: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if ipt.hasIPv6() {
|
|
||||||
if s.RouteRules6 != nil {
|
|
||||||
ipt.router6.rules = s.RouteRules6
|
|
||||||
}
|
|
||||||
if s.RouteIPsetCounter6 != nil {
|
|
||||||
ipt.router6.ipsetCounter.LoadData(s.RouteIPsetCounter6)
|
|
||||||
}
|
|
||||||
if s.ACLEntries6 != nil {
|
|
||||||
ipt.aclMgr6.entries = s.ACLEntries6
|
|
||||||
}
|
|
||||||
if s.ACLIPsetStore6 != nil {
|
|
||||||
ipt.aclMgr6.ipsetStore = s.ACLIPsetStore6
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := ipt.Close(nil); err != nil {
|
if err := ipt.Close(nil); err != nil {
|
||||||
return fmt.Errorf("reset iptables manager: %w", err)
|
return fmt.Errorf("reset iptables manager: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package manager
|
package manager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
@@ -12,10 +11,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ErrIPv6NotInitialized is returned when an IPv6 address is passed to a firewall
|
|
||||||
// method but the IPv6 firewall components were not initialized.
|
|
||||||
var ErrIPv6NotInitialized = errors.New("IPv6 firewall not initialized")
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
ForwardingFormatPrefix = "netbird-fwd-"
|
ForwardingFormatPrefix = "netbird-fwd-"
|
||||||
ForwardingFormat = "netbird-fwd-%s-%t"
|
ForwardingFormat = "netbird-fwd-%s-%t"
|
||||||
@@ -169,16 +164,18 @@ type Manager interface {
|
|||||||
UpdateSet(hash Set, prefixes []netip.Prefix) error
|
UpdateSet(hash Set, prefixes []netip.Prefix) error
|
||||||
|
|
||||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services
|
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services
|
||||||
AddInboundDNAT(localAddr netip.Addr, protocol Protocol, originalPort, translatedPort uint16) error
|
AddInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
|
||||||
|
|
||||||
// RemoveInboundDNAT removes inbound DNAT rule
|
// RemoveInboundDNAT removes inbound DNAT rule
|
||||||
RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, originalPort, translatedPort uint16) error
|
RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
|
||||||
|
|
||||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||||
AddOutputDNAT(localAddr netip.Addr, protocol Protocol, originalPort, translatedPort uint16) error
|
// localAddr must be IPv4; the underlying iptables/nftables backends are IPv4-only.
|
||||||
|
AddOutputDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
|
||||||
|
|
||||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||||
RemoveOutputDNAT(localAddr netip.Addr, protocol Protocol, originalPort, translatedPort uint16) error
|
// localAddr must be IPv4; the underlying iptables/nftables backends are IPv4-only.
|
||||||
|
RemoveOutputDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
|
||||||
|
|
||||||
// SetupEBPFProxyNoTrack creates static notrack rules for eBPF proxy loopback traffic.
|
// SetupEBPFProxyNoTrack creates static notrack rules for eBPF proxy loopback traffic.
|
||||||
// This prevents conntrack from interfering with WireGuard proxy communication.
|
// This prevents conntrack from interfering with WireGuard proxy communication.
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
package manager
|
package manager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/netip"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -12,10 +10,6 @@ type RouterPair struct {
|
|||||||
Destination Network
|
Destination Network
|
||||||
Masquerade bool
|
Masquerade bool
|
||||||
Inverse bool
|
Inverse bool
|
||||||
// Dynamic indicates the route is domain-based. NAT rules for dynamic
|
|
||||||
// routes are duplicated to the v6 table so that resolved AAAA records
|
|
||||||
// are masqueraded correctly.
|
|
||||||
Dynamic bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetInversePair(pair RouterPair) RouterPair {
|
func GetInversePair(pair RouterPair) RouterPair {
|
||||||
@@ -26,17 +20,5 @@ func GetInversePair(pair RouterPair) RouterPair {
|
|||||||
Destination: pair.Source,
|
Destination: pair.Source,
|
||||||
Masquerade: pair.Masquerade,
|
Masquerade: pair.Masquerade,
|
||||||
Inverse: true,
|
Inverse: true,
|
||||||
Dynamic: pair.Dynamic,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ToV6NatPair creates a v6 counterpart of a v4 NAT pair with `::/0` source
|
|
||||||
// and, for prefix destinations, `::/0` destination.
|
|
||||||
func ToV6NatPair(pair RouterPair) RouterPair {
|
|
||||||
v6 := pair
|
|
||||||
v6.Source = Network{Prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
|
|
||||||
if v6.Destination.IsPrefix() {
|
|
||||||
v6.Destination = Network{Prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
|
|
||||||
}
|
|
||||||
return v6
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -33,12 +33,15 @@ const (
|
|||||||
|
|
||||||
const flushError = "flush: %w"
|
const flushError = "flush: %w"
|
||||||
|
|
||||||
|
var (
|
||||||
|
anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
|
||||||
|
)
|
||||||
|
|
||||||
type AclManager struct {
|
type AclManager struct {
|
||||||
rConn *nftables.Conn
|
rConn *nftables.Conn
|
||||||
sConn *nftables.Conn
|
sConn *nftables.Conn
|
||||||
wgIface iFaceMapper
|
wgIface iFaceMapper
|
||||||
routingFwChainName string
|
routingFwChainName string
|
||||||
af addrFamily
|
|
||||||
|
|
||||||
workTable *nftables.Table
|
workTable *nftables.Table
|
||||||
chainInputRules *nftables.Chain
|
chainInputRules *nftables.Chain
|
||||||
@@ -64,7 +67,6 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainNam
|
|||||||
wgIface: wgIface,
|
wgIface: wgIface,
|
||||||
workTable: table,
|
workTable: table,
|
||||||
routingFwChainName: routingFwChainName,
|
routingFwChainName: routingFwChainName,
|
||||||
af: familyForAddr(table.Family == nftables.TableFamilyIPv4),
|
|
||||||
|
|
||||||
ipsetStore: newIpsetStore(),
|
ipsetStore: newIpsetStore(),
|
||||||
rules: make(map[string]*Rule),
|
rules: make(map[string]*Rule),
|
||||||
@@ -143,7 +145,7 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := ips[r.ip.String()]; ok {
|
if _, ok := ips[r.ip.String()]; ok {
|
||||||
err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: ipToBytes(r.ip, m.af)}})
|
err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: r.ip.To4()}})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("delete elements for set %q: %v", r.nftSet.Name, err)
|
log.Errorf("delete elements for set %q: %v", r.nftSet.Name, err)
|
||||||
}
|
}
|
||||||
@@ -252,11 +254,11 @@ func (m *AclManager) addIOFiltering(
|
|||||||
expressions = append(expressions, &expr.Payload{
|
expressions = append(expressions, &expr.Payload{
|
||||||
DestRegister: 1,
|
DestRegister: 1,
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
Base: expr.PayloadBaseNetworkHeader,
|
||||||
Offset: m.af.protoOffset,
|
Offset: uint32(9),
|
||||||
Len: uint32(1),
|
Len: uint32(1),
|
||||||
})
|
})
|
||||||
|
|
||||||
protoData, err := m.af.protoNum(proto)
|
protoData, err := protoToInt(proto)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("convert protocol to number: %v", err)
|
return nil, fmt.Errorf("convert protocol to number: %v", err)
|
||||||
}
|
}
|
||||||
@@ -268,16 +270,19 @@ func (m *AclManager) addIOFiltering(
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
rawIP := ipToBytes(ip, m.af)
|
rawIP := ip.To4()
|
||||||
// check if rawIP contains zeroed IPv4 0.0.0.0 value
|
// check if rawIP contains zeroed IPv4 0.0.0.0 value
|
||||||
// in that case not add IP match expression into the rule definition
|
// in that case not add IP match expression into the rule definition
|
||||||
if slices.ContainsFunc(rawIP, func(v byte) bool { return v != 0 }) {
|
if !bytes.HasPrefix(anyIP, rawIP) {
|
||||||
|
// source address position
|
||||||
|
addrOffset := uint32(12)
|
||||||
|
|
||||||
expressions = append(expressions,
|
expressions = append(expressions,
|
||||||
&expr.Payload{
|
&expr.Payload{
|
||||||
DestRegister: 1,
|
DestRegister: 1,
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
Base: expr.PayloadBaseNetworkHeader,
|
||||||
Offset: m.af.srcAddrOffset,
|
Offset: addrOffset,
|
||||||
Len: m.af.addrLen,
|
Len: 4,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
// add individual IP for match if no ipset defined
|
// add individual IP for match if no ipset defined
|
||||||
@@ -582,7 +587,7 @@ func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr
|
|||||||
|
|
||||||
func (m *AclManager) addIpToSet(ipsetName string, ip net.IP) (*nftables.Set, error) {
|
func (m *AclManager) addIpToSet(ipsetName string, ip net.IP) (*nftables.Set, error) {
|
||||||
ipset, err := m.rConn.GetSetByName(m.workTable, ipsetName)
|
ipset, err := m.rConn.GetSetByName(m.workTable, ipsetName)
|
||||||
rawIP := ipToBytes(ip, m.af)
|
rawIP := ip.To4()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if ipset, err = m.createSet(m.workTable, ipsetName); err != nil {
|
if ipset, err = m.createSet(m.workTable, ipsetName); err != nil {
|
||||||
return nil, fmt.Errorf("get set name: %v", err)
|
return nil, fmt.Errorf("get set name: %v", err)
|
||||||
@@ -614,7 +619,7 @@ func (m *AclManager) createSet(table *nftables.Table, name string) (*nftables.Se
|
|||||||
Name: name,
|
Name: name,
|
||||||
Table: table,
|
Table: table,
|
||||||
Dynamic: true,
|
Dynamic: true,
|
||||||
KeyType: m.af.setKeyType,
|
KeyType: nftables.TypeIPAddr,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.rConn.AddSet(ipset, nil); err != nil {
|
if err := m.rConn.AddSet(ipset, nil); err != nil {
|
||||||
@@ -702,12 +707,15 @@ func ifname(n string) []byte {
|
|||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func protoToInt(protocol firewall.Protocol) (uint8, error) {
|
||||||
// ipToBytes converts net.IP to the correct byte length for the address family.
|
switch protocol {
|
||||||
func ipToBytes(ip net.IP, af addrFamily) []byte {
|
case firewall.ProtocolTCP:
|
||||||
if af.addrLen == 4 {
|
return unix.IPPROTO_TCP, nil
|
||||||
return ip.To4()
|
case firewall.ProtocolUDP:
|
||||||
|
return unix.IPPROTO_UDP, nil
|
||||||
|
case firewall.ProtocolICMP:
|
||||||
|
return unix.IPPROTO_ICMP, nil
|
||||||
}
|
}
|
||||||
return ip.To16()
|
|
||||||
}
|
|
||||||
|
|
||||||
|
return 0, fmt.Errorf("unsupported protocol: %s", protocol)
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,81 +0,0 @@
|
|||||||
package nftables
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
|
|
||||||
"github.com/google/nftables"
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
// afIPv4 defines IPv4 header layout and nftables types.
|
|
||||||
afIPv4 = addrFamily{
|
|
||||||
protoOffset: 9,
|
|
||||||
srcAddrOffset: 12,
|
|
||||||
dstAddrOffset: 16,
|
|
||||||
addrLen: net.IPv4len,
|
|
||||||
totalBits: 8 * net.IPv4len,
|
|
||||||
setKeyType: nftables.TypeIPAddr,
|
|
||||||
tableFamily: nftables.TableFamilyIPv4,
|
|
||||||
icmpProto: unix.IPPROTO_ICMP,
|
|
||||||
}
|
|
||||||
// afIPv6 defines IPv6 header layout and nftables types.
|
|
||||||
afIPv6 = addrFamily{
|
|
||||||
protoOffset: 6,
|
|
||||||
srcAddrOffset: 8,
|
|
||||||
dstAddrOffset: 24,
|
|
||||||
addrLen: net.IPv6len,
|
|
||||||
totalBits: 8 * net.IPv6len,
|
|
||||||
setKeyType: nftables.TypeIP6Addr,
|
|
||||||
tableFamily: nftables.TableFamilyIPv6,
|
|
||||||
icmpProto: unix.IPPROTO_ICMPV6,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
// addrFamily holds protocol-specific constants for nftables expression building.
|
|
||||||
type addrFamily struct {
|
|
||||||
// protoOffset is the IP header offset for the protocol/next-header field (9 for v4, 6 for v6)
|
|
||||||
protoOffset uint32
|
|
||||||
// srcAddrOffset is the IP header offset for the source address (12 for v4, 8 for v6)
|
|
||||||
srcAddrOffset uint32
|
|
||||||
// dstAddrOffset is the IP header offset for the destination address (16 for v4, 24 for v6)
|
|
||||||
dstAddrOffset uint32
|
|
||||||
// addrLen is the byte length of addresses (4 for v4, 16 for v6)
|
|
||||||
addrLen uint32
|
|
||||||
// totalBits is the address size in bits (32 for v4, 128 for v6)
|
|
||||||
totalBits int
|
|
||||||
// setKeyType is the nftables set data type for addresses
|
|
||||||
setKeyType nftables.SetDatatype
|
|
||||||
// tableFamily is the nftables table family
|
|
||||||
tableFamily nftables.TableFamily
|
|
||||||
// icmpProto is the ICMP protocol number for this family (1 for v4, 58 for v6)
|
|
||||||
icmpProto uint8
|
|
||||||
}
|
|
||||||
|
|
||||||
// familyForAddr returns the address family for the given IP.
|
|
||||||
func familyForAddr(is4 bool) addrFamily {
|
|
||||||
if is4 {
|
|
||||||
return afIPv4
|
|
||||||
}
|
|
||||||
return afIPv6
|
|
||||||
}
|
|
||||||
|
|
||||||
// protoNum converts a firewall protocol to the IP protocol number,
|
|
||||||
// using the correct ICMP variant for the address family.
|
|
||||||
func (af addrFamily) protoNum(protocol firewall.Protocol) (uint8, error) {
|
|
||||||
switch protocol {
|
|
||||||
case firewall.ProtocolTCP:
|
|
||||||
return unix.IPPROTO_TCP, nil
|
|
||||||
case firewall.ProtocolUDP:
|
|
||||||
return unix.IPPROTO_UDP, nil
|
|
||||||
case firewall.ProtocolICMP:
|
|
||||||
return af.icmpProto, nil
|
|
||||||
case firewall.ProtocolALL:
|
|
||||||
return 0, nil
|
|
||||||
default:
|
|
||||||
return 0, fmt.Errorf("unsupported protocol: %s", protocol)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,76 +0,0 @@
|
|||||||
//go:build linux
|
|
||||||
|
|
||||||
package nftables
|
|
||||||
|
|
||||||
import (
|
|
||||||
"os"
|
|
||||||
"sync/atomic"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/google/nftables"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TestExternalChainMonitorRootIntegration verifies that adding a new chain
|
|
||||||
// in an external (non-netbird) filter table triggers the reconciler.
|
|
||||||
// Requires CAP_NET_ADMIN; skip otherwise.
|
|
||||||
func TestExternalChainMonitorRootIntegration(t *testing.T) {
|
|
||||||
if os.Geteuid() != 0 {
|
|
||||||
t.Skip("root required")
|
|
||||||
}
|
|
||||||
|
|
||||||
calls := make(chan struct{}, 8)
|
|
||||||
var count atomic.Int32
|
|
||||||
rec := &countingReconciler{calls: calls, count: &count}
|
|
||||||
|
|
||||||
m := newExternalChainMonitor(rec)
|
|
||||||
m.start()
|
|
||||||
t.Cleanup(m.stop)
|
|
||||||
|
|
||||||
// Give the netlink subscription a moment to register.
|
|
||||||
time.Sleep(200 * time.Millisecond)
|
|
||||||
|
|
||||||
conn := &nftables.Conn{}
|
|
||||||
table := conn.AddTable(&nftables.Table{
|
|
||||||
Name: "nbmon_integration_test",
|
|
||||||
Family: nftables.TableFamilyINet,
|
|
||||||
})
|
|
||||||
t.Cleanup(func() {
|
|
||||||
cleanup := &nftables.Conn{}
|
|
||||||
cleanup.DelTable(table)
|
|
||||||
_ = cleanup.Flush()
|
|
||||||
})
|
|
||||||
|
|
||||||
chain := conn.AddChain(&nftables.Chain{
|
|
||||||
Name: "filter_INPUT",
|
|
||||||
Table: table,
|
|
||||||
Hooknum: nftables.ChainHookInput,
|
|
||||||
Priority: nftables.ChainPriorityFilter,
|
|
||||||
Type: nftables.ChainTypeFilter,
|
|
||||||
})
|
|
||||||
_ = chain
|
|
||||||
require.NoError(t, conn.Flush(), "create external test chain")
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-calls:
|
|
||||||
// success
|
|
||||||
case <-time.After(3 * time.Second):
|
|
||||||
t.Fatalf("reconcile was not invoked after creating an external chain")
|
|
||||||
}
|
|
||||||
require.GreaterOrEqual(t, count.Load(), int32(1))
|
|
||||||
}
|
|
||||||
|
|
||||||
type countingReconciler struct {
|
|
||||||
calls chan struct{}
|
|
||||||
count *atomic.Int32
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *countingReconciler) reconcileExternalChains() error {
|
|
||||||
c.count.Add(1)
|
|
||||||
select {
|
|
||||||
case c.calls <- struct{}{}:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,199 +0,0 @@
|
|||||||
package nftables
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/cenkalti/backoff/v4"
|
|
||||||
"github.com/google/nftables"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
externalMonitorReconcileDelay = 500 * time.Millisecond
|
|
||||||
externalMonitorInitInterval = 5 * time.Second
|
|
||||||
externalMonitorMaxInterval = 5 * time.Minute
|
|
||||||
externalMonitorRandomization = 0.5
|
|
||||||
)
|
|
||||||
|
|
||||||
// externalChainReconciler re-applies passthrough accept rules to external
|
|
||||||
// nftables chains. Implementations must be safe to call from the monitor
|
|
||||||
// goroutine; the Manager locks its mutex internally.
|
|
||||||
type externalChainReconciler interface {
|
|
||||||
reconcileExternalChains() error
|
|
||||||
}
|
|
||||||
|
|
||||||
// externalChainMonitor watches nftables netlink events and triggers a
|
|
||||||
// reconcile when a new table or chain appears (e.g. after
|
|
||||||
// `firewall-cmd --reload`). Netlink errors trigger exponential-backoff
|
|
||||||
// reconnect.
|
|
||||||
type externalChainMonitor struct {
|
|
||||||
reconciler externalChainReconciler
|
|
||||||
|
|
||||||
mu sync.Mutex
|
|
||||||
cancel context.CancelFunc
|
|
||||||
done chan struct{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func newExternalChainMonitor(r externalChainReconciler) *externalChainMonitor {
|
|
||||||
return &externalChainMonitor{reconciler: r}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *externalChainMonitor) start() {
|
|
||||||
m.mu.Lock()
|
|
||||||
defer m.mu.Unlock()
|
|
||||||
|
|
||||||
if m.cancel != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
m.cancel = cancel
|
|
||||||
m.done = make(chan struct{})
|
|
||||||
|
|
||||||
go m.run(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *externalChainMonitor) stop() {
|
|
||||||
m.mu.Lock()
|
|
||||||
cancel := m.cancel
|
|
||||||
done := m.done
|
|
||||||
m.cancel = nil
|
|
||||||
m.done = nil
|
|
||||||
m.mu.Unlock()
|
|
||||||
|
|
||||||
if cancel == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
cancel()
|
|
||||||
<-done
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *externalChainMonitor) run(ctx context.Context) {
|
|
||||||
defer close(m.done)
|
|
||||||
|
|
||||||
bo := &backoff.ExponentialBackOff{
|
|
||||||
InitialInterval: externalMonitorInitInterval,
|
|
||||||
RandomizationFactor: externalMonitorRandomization,
|
|
||||||
Multiplier: backoff.DefaultMultiplier,
|
|
||||||
MaxInterval: externalMonitorMaxInterval,
|
|
||||||
MaxElapsedTime: 0,
|
|
||||||
Clock: backoff.SystemClock,
|
|
||||||
}
|
|
||||||
bo.Reset()
|
|
||||||
|
|
||||||
for ctx.Err() == nil {
|
|
||||||
err := m.watch(ctx)
|
|
||||||
if ctx.Err() != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
delay := bo.NextBackOff()
|
|
||||||
log.Warnf("external chain monitor: %v, reconnecting in %s", err, delay)
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
case <-time.After(delay):
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *externalChainMonitor) watch(ctx context.Context) error {
|
|
||||||
events, closeMon, err := m.subscribe()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer closeMon()
|
|
||||||
|
|
||||||
debounce := time.NewTimer(time.Hour)
|
|
||||||
if !debounce.Stop() {
|
|
||||||
<-debounce.C
|
|
||||||
}
|
|
||||||
defer debounce.Stop()
|
|
||||||
|
|
||||||
pending := false
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return nil
|
|
||||||
case <-debounce.C:
|
|
||||||
pending = false
|
|
||||||
m.reconcile()
|
|
||||||
case ev, ok := <-events:
|
|
||||||
if !ok {
|
|
||||||
return errors.New("monitor channel closed")
|
|
||||||
}
|
|
||||||
if ev.Error != nil {
|
|
||||||
return fmt.Errorf("monitor event: %w", ev.Error)
|
|
||||||
}
|
|
||||||
if !isRelevantMonitorEvent(ev) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
resetDebounce(debounce, pending)
|
|
||||||
pending = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *externalChainMonitor) subscribe() (chan *nftables.MonitorEvent, func(), error) {
|
|
||||||
conn := &nftables.Conn{}
|
|
||||||
mon := nftables.NewMonitor(
|
|
||||||
nftables.WithMonitorAction(nftables.MonitorActionNew),
|
|
||||||
nftables.WithMonitorObject(nftables.MonitorObjectChains|nftables.MonitorObjectTables),
|
|
||||||
)
|
|
||||||
events, err := conn.AddMonitor(mon)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, fmt.Errorf("add netlink monitor: %w", err)
|
|
||||||
}
|
|
||||||
return events, func() { _ = mon.Close() }, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// resetDebounce reschedules a pending debounce timer without leaking a stale
|
|
||||||
// fire on its channel. pending must reflect whether the timer is armed.
|
|
||||||
func resetDebounce(t *time.Timer, pending bool) {
|
|
||||||
if pending && !t.Stop() {
|
|
||||||
select {
|
|
||||||
case <-t.C:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
t.Reset(externalMonitorReconcileDelay)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *externalChainMonitor) reconcile() {
|
|
||||||
if err := m.reconciler.reconcileExternalChains(); err != nil {
|
|
||||||
log.Warnf("reconcile external chain rules: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// isRelevantMonitorEvent returns true for table/chain creation events on
|
|
||||||
// families we care about. The reconciler filters to actual external filter
|
|
||||||
// chains.
|
|
||||||
func isRelevantMonitorEvent(ev *nftables.MonitorEvent) bool {
|
|
||||||
switch ev.Type {
|
|
||||||
case nftables.MonitorEventTypeNewChain:
|
|
||||||
chain, ok := ev.Data.(*nftables.Chain)
|
|
||||||
if !ok || chain == nil || chain.Table == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return isMonitoredFamily(chain.Table.Family)
|
|
||||||
case nftables.MonitorEventTypeNewTable:
|
|
||||||
table, ok := ev.Data.(*nftables.Table)
|
|
||||||
if !ok || table == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return isMonitoredFamily(table.Family)
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func isMonitoredFamily(family nftables.TableFamily) bool {
|
|
||||||
switch family {
|
|
||||||
case nftables.TableFamilyIPv4, nftables.TableFamilyIPv6, nftables.TableFamilyINet:
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
@@ -1,137 +0,0 @@
|
|||||||
package nftables
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/google/nftables"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestIsMonitoredFamily(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
family nftables.TableFamily
|
|
||||||
want bool
|
|
||||||
}{
|
|
||||||
{nftables.TableFamilyIPv4, true},
|
|
||||||
{nftables.TableFamilyIPv6, true},
|
|
||||||
{nftables.TableFamilyINet, true},
|
|
||||||
{nftables.TableFamilyARP, false},
|
|
||||||
{nftables.TableFamilyBridge, false},
|
|
||||||
{nftables.TableFamilyNetdev, false},
|
|
||||||
{nftables.TableFamilyUnspecified, false},
|
|
||||||
}
|
|
||||||
for _, tc := range tests {
|
|
||||||
assert.Equal(t, tc.want, isMonitoredFamily(tc.family), "family=%d", tc.family)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIsRelevantMonitorEvent(t *testing.T) {
|
|
||||||
inetTable := &nftables.Table{Name: "firewalld", Family: nftables.TableFamilyINet}
|
|
||||||
ipTable := &nftables.Table{Name: "filter", Family: nftables.TableFamilyIPv4}
|
|
||||||
arpTable := &nftables.Table{Name: "arp", Family: nftables.TableFamilyARP}
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
ev *nftables.MonitorEvent
|
|
||||||
want bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "new chain in inet firewalld",
|
|
||||||
ev: &nftables.MonitorEvent{
|
|
||||||
Type: nftables.MonitorEventTypeNewChain,
|
|
||||||
Data: &nftables.Chain{Name: "filter_INPUT", Table: inetTable},
|
|
||||||
},
|
|
||||||
want: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "new chain in ip filter",
|
|
||||||
ev: &nftables.MonitorEvent{
|
|
||||||
Type: nftables.MonitorEventTypeNewChain,
|
|
||||||
Data: &nftables.Chain{Name: "INPUT", Table: ipTable},
|
|
||||||
},
|
|
||||||
want: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "new chain in unwatched arp family",
|
|
||||||
ev: &nftables.MonitorEvent{
|
|
||||||
Type: nftables.MonitorEventTypeNewChain,
|
|
||||||
Data: &nftables.Chain{Name: "x", Table: arpTable},
|
|
||||||
},
|
|
||||||
want: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "new table inet",
|
|
||||||
ev: &nftables.MonitorEvent{
|
|
||||||
Type: nftables.MonitorEventTypeNewTable,
|
|
||||||
Data: inetTable,
|
|
||||||
},
|
|
||||||
want: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "del chain (we only act on new)",
|
|
||||||
ev: &nftables.MonitorEvent{
|
|
||||||
Type: nftables.MonitorEventTypeDelChain,
|
|
||||||
Data: &nftables.Chain{Name: "filter_INPUT", Table: inetTable},
|
|
||||||
},
|
|
||||||
want: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "chain with nil table",
|
|
||||||
ev: &nftables.MonitorEvent{
|
|
||||||
Type: nftables.MonitorEventTypeNewChain,
|
|
||||||
Data: &nftables.Chain{Name: "x"},
|
|
||||||
},
|
|
||||||
want: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "nil data",
|
|
||||||
ev: &nftables.MonitorEvent{
|
|
||||||
Type: nftables.MonitorEventTypeNewChain,
|
|
||||||
Data: (*nftables.Chain)(nil),
|
|
||||||
},
|
|
||||||
want: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
assert.Equal(t, tc.want, isRelevantMonitorEvent(tc.ev))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// fakeReconciler records reconcile invocations for debounce tests.
|
|
||||||
type fakeReconciler struct {
|
|
||||||
calls chan struct{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fakeReconciler) reconcileExternalChains() error {
|
|
||||||
f.calls <- struct{}{}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestExternalChainMonitorStopWithoutStart(t *testing.T) {
|
|
||||||
m := newExternalChainMonitor(&fakeReconciler{calls: make(chan struct{}, 1)})
|
|
||||||
// Must not panic or block.
|
|
||||||
m.stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestExternalChainMonitorDoubleStart(t *testing.T) {
|
|
||||||
// start() twice should be a no-op; stop() cleans up once.
|
|
||||||
// We avoid exercising the netlink watch loop here because it needs root.
|
|
||||||
m := newExternalChainMonitor(&fakeReconciler{calls: make(chan struct{}, 1)})
|
|
||||||
|
|
||||||
// Replace run with a stub that just waits for cancel, so start() stays
|
|
||||||
// deterministic without opening a netlink socket.
|
|
||||||
origDone := make(chan struct{})
|
|
||||||
m.done = origDone
|
|
||||||
m.cancel = func() { close(origDone) }
|
|
||||||
|
|
||||||
// Second start should be a no-op (cancel already set).
|
|
||||||
m.start()
|
|
||||||
assert.NotNil(t, m.cancel)
|
|
||||||
|
|
||||||
m.stop()
|
|
||||||
assert.Nil(t, m.cancel)
|
|
||||||
assert.Nil(t, m.done)
|
|
||||||
}
|
|
||||||
@@ -11,11 +11,9 @@ import (
|
|||||||
"github.com/google/nftables"
|
"github.com/google/nftables"
|
||||||
"github.com/google/nftables/binaryutil"
|
"github.com/google/nftables/binaryutil"
|
||||||
"github.com/google/nftables/expr"
|
"github.com/google/nftables/expr"
|
||||||
"github.com/hashicorp/go-multierror"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
@@ -51,17 +49,10 @@ type Manager struct {
|
|||||||
rConn *nftables.Conn
|
rConn *nftables.Conn
|
||||||
wgIface iFaceMapper
|
wgIface iFaceMapper
|
||||||
|
|
||||||
router *router
|
router *router
|
||||||
aclManager *AclManager
|
aclManager *AclManager
|
||||||
|
|
||||||
// IPv6 counterparts, nil when no v6 overlay
|
|
||||||
router6 *router
|
|
||||||
aclManager6 *AclManager
|
|
||||||
|
|
||||||
notrackOutputChain *nftables.Chain
|
notrackOutputChain *nftables.Chain
|
||||||
notrackPreroutingChain *nftables.Chain
|
notrackPreroutingChain *nftables.Chain
|
||||||
|
|
||||||
extMonitor *externalChainMonitor
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create nftables firewall manager
|
// Create nftables firewall manager
|
||||||
@@ -71,8 +62,7 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
|
|||||||
wgIface: wgIface,
|
wgIface: wgIface,
|
||||||
}
|
}
|
||||||
|
|
||||||
tableName := getTableName()
|
workTable := &nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4}
|
||||||
workTable := &nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4}
|
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
m.router, err = newRouter(workTable, wgIface, mtu)
|
m.router, err = newRouter(workTable, wgIface, mtu)
|
||||||
@@ -85,137 +75,35 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
|
|||||||
return nil, fmt.Errorf("create acl manager: %w", err)
|
return nil, fmt.Errorf("create acl manager: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if wgIface.Address().HasIPv6() {
|
|
||||||
if err := m.createIPv6Components(tableName, wgIface, mtu); err != nil {
|
|
||||||
return nil, fmt.Errorf("create IPv6 firewall: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
m.extMonitor = newExternalChainMonitor(m)
|
|
||||||
|
|
||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) createIPv6Components(tableName string, wgIface iFaceMapper, mtu uint16) error {
|
|
||||||
workTable6 := &nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv6}
|
|
||||||
|
|
||||||
var err error
|
|
||||||
m.router6, err = newRouter(workTable6, wgIface, mtu)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("create v6 router: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Share the same IP forwarding state with the v4 router, since
|
|
||||||
// EnableIPForwarding controls both v4 and v6 sysctls.
|
|
||||||
m.router6.ipFwdState = m.router.ipFwdState
|
|
||||||
|
|
||||||
m.aclManager6, err = newAclManager(workTable6, wgIface, chainNameRoutingFw)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("create v6 acl manager: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// hasIPv6 reports whether the manager has IPv6 components initialized.
|
|
||||||
func (m *Manager) hasIPv6() bool {
|
|
||||||
return m.router6 != nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) initIPv6() error {
|
|
||||||
workTable6, err := m.createWorkTableFamily(nftables.TableFamilyIPv6)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("create v6 work table: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.router6.init(workTable6); err != nil {
|
|
||||||
return fmt.Errorf("v6 router init: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.aclManager6.init(workTable6); err != nil {
|
|
||||||
return fmt.Errorf("v6 acl manager init: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Init nftables firewall manager
|
// Init nftables firewall manager
|
||||||
func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||||
if err := m.initFirewall(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
m.persistState(stateManager)
|
|
||||||
|
|
||||||
// Start after initFirewall has installed the baseline external-chain
|
|
||||||
// accept rules. start() is idempotent across Init/Close/Init cycles.
|
|
||||||
m.extMonitor.start()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// reconcileExternalChains re-applies passthrough accept rules to external
|
|
||||||
// filter chains for both IPv4 and IPv6 routers. Called by the monitor when
|
|
||||||
// tables or chains appear (e.g. after firewalld reloads).
|
|
||||||
func (m *Manager) reconcileExternalChains() error {
|
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
|
|
||||||
var merr *multierror.Error
|
|
||||||
if m.router != nil {
|
|
||||||
if err := m.router.acceptExternalChainsRules(); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("v4: %w", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if m.hasIPv6() {
|
|
||||||
if err := m.router6.acceptExternalChainsRules(); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("v6: %w", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) initFirewall() (err error) {
|
|
||||||
workTable, err := m.createWorkTable()
|
workTable, err := m.createWorkTable()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("create work table: %w", err)
|
return fmt.Errorf("create work table: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
defer func() {
|
|
||||||
if err != nil {
|
|
||||||
m.rollbackInit()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
if err := m.router.init(workTable); err != nil {
|
if err := m.router.init(workTable); err != nil {
|
||||||
return fmt.Errorf("router init: %w", err)
|
return fmt.Errorf("router init: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.aclManager.init(workTable); err != nil {
|
if err := m.aclManager.init(workTable); err != nil {
|
||||||
|
// TODO: cleanup router
|
||||||
return fmt.Errorf("acl manager init: %w", err)
|
return fmt.Errorf("acl manager init: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.hasIPv6() {
|
|
||||||
if err := m.initIPv6(); err != nil {
|
|
||||||
// Peer has a v6 address: v6 firewall MUST work or we risk fail-open.
|
|
||||||
return fmt.Errorf("init IPv6 firewall (required because peer has IPv6 address): %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.initNoTrackChains(workTable); err != nil {
|
if err := m.initNoTrackChains(workTable); err != nil {
|
||||||
log.Warnf("raw priority chains not available, notrack rules will be disabled: %v", err)
|
log.Warnf("raw priority chains not available, notrack rules will be disabled: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// persistState saves the current interface state for potential recreation on restart.
|
|
||||||
// Unlike iptables, which requires tracking individual rules, nftables maintains
|
|
||||||
// a known state (our netbird table plus a few static rules). This allows for easy
|
|
||||||
// cleanup using Close() without needing to store specific rules.
|
|
||||||
func (m *Manager) persistState(stateManager *statemanager.Manager) {
|
|
||||||
stateManager.RegisterState(&ShutdownState{})
|
stateManager.RegisterState(&ShutdownState{})
|
||||||
|
|
||||||
|
// We only need to record minimal interface state for potential recreation.
|
||||||
|
// Unlike iptables, which requires tracking individual rules, nftables maintains
|
||||||
|
// a known state (our netbird table plus a few static rules). This allows for easy
|
||||||
|
// cleanup using Close() without needing to store specific rules.
|
||||||
if err := stateManager.UpdateState(&ShutdownState{
|
if err := stateManager.UpdateState(&ShutdownState{
|
||||||
InterfaceState: &InterfaceState{
|
InterfaceState: &InterfaceState{
|
||||||
NameStr: m.wgIface.Name(),
|
NameStr: m.wgIface.Name(),
|
||||||
@@ -226,29 +114,14 @@ func (m *Manager) persistState(stateManager *statemanager.Manager) {
|
|||||||
log.Errorf("failed to update state: %v", err)
|
log.Errorf("failed to update state: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// persist early
|
||||||
go func() {
|
go func() {
|
||||||
if err := stateManager.PersistState(context.Background()); err != nil {
|
if err := stateManager.PersistState(context.Background()); err != nil {
|
||||||
log.Errorf("failed to persist state: %v", err)
|
log.Errorf("failed to persist state: %v", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
|
||||||
|
|
||||||
// rollbackInit performs best-effort cleanup of already-initialized state when Init fails partway through.
|
return nil
|
||||||
func (m *Manager) rollbackInit() {
|
|
||||||
if err := m.router.Reset(); err != nil {
|
|
||||||
log.Warnf("rollback router: %v", err)
|
|
||||||
}
|
|
||||||
if m.hasIPv6() {
|
|
||||||
if err := m.router6.Reset(); err != nil {
|
|
||||||
log.Warnf("rollback v6 router: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err := m.cleanupNetbirdTables(); err != nil {
|
|
||||||
log.Warnf("cleanup tables: %v", err)
|
|
||||||
}
|
|
||||||
if err := m.rConn.Flush(); err != nil {
|
|
||||||
log.Warnf("flush: %v", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddPeerFiltering rule to the firewall
|
// AddPeerFiltering rule to the firewall
|
||||||
@@ -267,14 +140,12 @@ func (m *Manager) AddPeerFiltering(
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
if ip.To4() != nil {
|
rawIP := ip.To4()
|
||||||
return m.aclManager.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
if rawIP == nil {
|
||||||
|
return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
if !m.hasIPv6() {
|
return m.aclManager.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
||||||
return nil, fmt.Errorf("add peer filtering for %s: %w", ip, firewall.ErrIPv6NotInitialized)
|
|
||||||
}
|
|
||||||
return m.aclManager6.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) AddRouteFiltering(
|
func (m *Manager) AddRouteFiltering(
|
||||||
@@ -288,11 +159,8 @@ func (m *Manager) AddRouteFiltering(
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
if isIPv6RouteRule(sources, destination) {
|
if destination.IsPrefix() && !destination.Prefix.Addr().Is4() {
|
||||||
if !m.hasIPv6() {
|
return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String())
|
||||||
return nil, fmt.Errorf("add route filtering: %w", firewall.ErrIPv6NotInitialized)
|
|
||||||
}
|
|
||||||
return m.router6.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||||
@@ -303,66 +171,15 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
if m.hasIPv6() && isIPv6Rule(rule) {
|
|
||||||
return m.aclManager6.DeletePeerRule(rule)
|
|
||||||
}
|
|
||||||
return m.aclManager.DeletePeerRule(rule)
|
return m.aclManager.DeletePeerRule(rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
func isIPv6Rule(rule firewall.Rule) bool {
|
// DeleteRouteRule deletes a routing rule
|
||||||
r, ok := rule.(*Rule)
|
|
||||||
return ok && r.nftRule != nil && r.nftRule.Table != nil && r.nftRule.Table.Family == nftables.TableFamilyIPv6
|
|
||||||
}
|
|
||||||
|
|
||||||
// isIPv6RouteRule determines whether a route rule belongs to the v6 table.
|
|
||||||
// For static routes, the destination prefix determines the family. For dynamic
|
|
||||||
// routes (DomainSet), the sources determine the family since management
|
|
||||||
// duplicates dynamic rules per family.
|
|
||||||
func isIPv6RouteRule(sources []netip.Prefix, destination firewall.Network) bool {
|
|
||||||
if destination.IsPrefix() {
|
|
||||||
return destination.Prefix.Addr().Is6()
|
|
||||||
}
|
|
||||||
return len(sources) > 0 && sources[0].Addr().Is6()
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteRouteRule deletes a routing rule. Route rules live in exactly one
|
|
||||||
// router; the cached maps are normally authoritative, so the kernel is only
|
|
||||||
// consulted when neither map knows about the rule.
|
|
||||||
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
id := rule.ID()
|
return m.router.DeleteRouteRule(rule)
|
||||||
r, err := m.routerForRuleID(id, (*router).hasRule)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return r.DeleteRouteRule(rule)
|
|
||||||
}
|
|
||||||
|
|
||||||
// routerForRuleID picks the router holding the rule with the given id, using
|
|
||||||
// the supplied lookup. If the cached maps disagree (or both miss), it refreshes
|
|
||||||
// from the kernel once and re-checks before falling back to the v4 router.
|
|
||||||
func (m *Manager) routerForRuleID(id string, has func(*router, string) bool) (*router, error) {
|
|
||||||
if has(m.router, id) {
|
|
||||||
return m.router, nil
|
|
||||||
}
|
|
||||||
if m.hasIPv6() && has(m.router6, id) {
|
|
||||||
return m.router6, nil
|
|
||||||
}
|
|
||||||
if !m.hasIPv6() {
|
|
||||||
return m.router, nil
|
|
||||||
}
|
|
||||||
if err := m.router.refreshRulesMap(); err != nil {
|
|
||||||
return nil, fmt.Errorf("refresh v4 rules: %w", err)
|
|
||||||
}
|
|
||||||
if err := m.router6.refreshRulesMap(); err != nil {
|
|
||||||
return nil, fmt.Errorf("refresh v6 rules: %w", err)
|
|
||||||
}
|
|
||||||
if has(m.router6, id) && !has(m.router, id) {
|
|
||||||
return m.router6, nil
|
|
||||||
}
|
|
||||||
return m.router, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) IsServerRouteSupported() bool {
|
func (m *Manager) IsServerRouteSupported() bool {
|
||||||
@@ -377,70 +194,19 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() {
|
return m.router.AddNatRule(pair)
|
||||||
if !m.hasIPv6() {
|
|
||||||
return fmt.Errorf("add NAT rule: %w", firewall.ErrIPv6NotInitialized)
|
|
||||||
}
|
|
||||||
return m.router6.AddNatRule(pair)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.router.AddNatRule(pair); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Dynamic routes need NAT in both tables since resolved IPs can be
|
|
||||||
// either v4 or v6. This covers both DomainSet (modern) and the legacy
|
|
||||||
// wildcard 0.0.0.0/0 destination where the client resolves DNS.
|
|
||||||
// On v6 failure we keep the v4 NAT rule rather than rolling back: half
|
|
||||||
// connectivity is better than none, and RemoveNatRule is content-keyed
|
|
||||||
// so the eventual cleanup still works.
|
|
||||||
if m.hasIPv6() && pair.Dynamic {
|
|
||||||
v6Pair := firewall.ToV6NatPair(pair)
|
|
||||||
if err := m.router6.AddNatRule(v6Pair); err != nil {
|
|
||||||
return fmt.Errorf("add v6 NAT rule: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() {
|
return m.router.RemoveNatRule(pair)
|
||||||
if !m.hasIPv6() {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return m.router6.RemoveNatRule(pair)
|
|
||||||
}
|
|
||||||
|
|
||||||
var merr *multierror.Error
|
|
||||||
|
|
||||||
if err := m.router.RemoveNatRule(pair); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove v4 NAT rule: %w", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.hasIPv6() && pair.Dynamic {
|
|
||||||
v6Pair := firewall.ToV6NatPair(pair)
|
|
||||||
if err := m.router6.RemoveNatRule(v6Pair); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove v6 NAT rule: %w", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// AllowNetbird allows netbird interface traffic.
|
// AllowNetbird allows netbird interface traffic.
|
||||||
// This is called when USPFilter wraps the native firewall, adding blanket accept
|
// This is called when USPFilter wraps the native firewall, adding blanket accept
|
||||||
// rules so that packet filtering is handled in userspace instead of by netfilter.
|
// rules so that packet filtering is handled in userspace instead of by netfilter.
|
||||||
//
|
|
||||||
// TODO: In USP mode this only adds ACCEPT to the netbird table's own chains,
|
|
||||||
// which doesn't override DROP rules in external tables (e.g. firewalld).
|
|
||||||
// Should add passthrough rules to external chains (like the native mode router's
|
|
||||||
// addExternalChainsRules does) for both the netbird table family and inet tables.
|
|
||||||
// The netbird table itself is fine (routing chains already exist there), but
|
|
||||||
// non-netbird tables with INPUT/FORWARD hooks can still DROP our WG traffic.
|
|
||||||
func (m *Manager) AllowNetbird() error {
|
func (m *Manager) AllowNetbird() error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
@@ -448,11 +214,6 @@ func (m *Manager) AllowNetbird() error {
|
|||||||
if err := m.aclManager.createDefaultAllowRules(); err != nil {
|
if err := m.aclManager.createDefaultAllowRules(); err != nil {
|
||||||
return fmt.Errorf("create default allow rules: %w", err)
|
return fmt.Errorf("create default allow rules: %w", err)
|
||||||
}
|
}
|
||||||
if m.hasIPv6() {
|
|
||||||
if err := m.aclManager6.createDefaultAllowRules(); err != nil {
|
|
||||||
return fmt.Errorf("create v6 default allow rules: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err := m.rConn.Flush(); err != nil {
|
if err := m.rConn.Flush(); err != nil {
|
||||||
return fmt.Errorf("flush allow input netbird rules: %w", err)
|
return fmt.Errorf("flush allow input netbird rules: %w", err)
|
||||||
}
|
}
|
||||||
@@ -466,47 +227,31 @@ func (m *Manager) AllowNetbird() error {
|
|||||||
|
|
||||||
// SetLegacyManagement sets the route manager to use legacy management
|
// SetLegacyManagement sets the route manager to use legacy management
|
||||||
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||||
if err := firewall.SetLegacyManagement(m.router, isLegacy); err != nil {
|
return firewall.SetLegacyManagement(m.router, isLegacy)
|
||||||
return err
|
|
||||||
}
|
|
||||||
if m.hasIPv6() {
|
|
||||||
return firewall.SetLegacyManagement(m.router6, isLegacy)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close closes the firewall manager
|
// Close closes the firewall manager
|
||||||
func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||||
m.extMonitor.stop()
|
|
||||||
|
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
var merr *multierror.Error
|
|
||||||
|
|
||||||
if err := m.router.Reset(); err != nil {
|
if err := m.router.Reset(); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("reset router: %v", err))
|
return fmt.Errorf("reset router: %v", err)
|
||||||
}
|
|
||||||
|
|
||||||
if m.hasIPv6() {
|
|
||||||
if err := m.router6.Reset(); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("reset v6 router: %v", err))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.cleanupNetbirdTables(); err != nil {
|
if err := m.cleanupNetbirdTables(); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("cleanup netbird tables: %v", err))
|
return fmt.Errorf("cleanup netbird tables: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.rConn.Flush(); err != nil {
|
if err := m.rConn.Flush(); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
|
return fmt.Errorf(flushError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
|
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("delete state: %v", err))
|
return fmt.Errorf("delete state: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) cleanupNetbirdTables() error {
|
func (m *Manager) cleanupNetbirdTables() error {
|
||||||
@@ -555,12 +300,6 @@ func (m *Manager) Flush() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.hasIPv6() {
|
|
||||||
if err := m.aclManager6.Flush(); err != nil {
|
|
||||||
return fmt.Errorf("flush v6 acl: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.refreshNoTrackChains(); err != nil {
|
if err := m.refreshNoTrackChains(); err != nil {
|
||||||
log.Errorf("failed to refresh notrack chains: %v", err)
|
log.Errorf("failed to refresh notrack chains: %v", err)
|
||||||
}
|
}
|
||||||
@@ -573,12 +312,6 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error)
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
if rule.TranslatedAddress.Is6() {
|
|
||||||
if !m.hasIPv6() {
|
|
||||||
return nil, fmt.Errorf("add DNAT rule: %w", firewall.ErrIPv6NotInitialized)
|
|
||||||
}
|
|
||||||
return m.router6.AddDNATRule(rule)
|
|
||||||
}
|
|
||||||
return m.router.AddDNATRule(rule)
|
return m.router.AddDNATRule(rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -587,11 +320,7 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
r, err := m.routerForRuleID(rule.ID(), (*router).hasDNATRule)
|
return m.router.DeleteDNATRule(rule)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return r.DeleteDNATRule(rule)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateSet updates the set with the given prefixes
|
// UpdateSet updates the set with the given prefixes
|
||||||
@@ -599,82 +328,39 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
var v4Prefixes, v6Prefixes []netip.Prefix
|
return m.router.UpdateSet(set, prefixes)
|
||||||
for _, p := range prefixes {
|
|
||||||
if p.Addr().Is6() {
|
|
||||||
v6Prefixes = append(v6Prefixes, p)
|
|
||||||
} else {
|
|
||||||
v4Prefixes = append(v4Prefixes, p)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.router.UpdateSet(set, v4Prefixes); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.hasIPv6() && len(v6Prefixes) > 0 {
|
|
||||||
if err := m.router6.UpdateSet(set, v6Prefixes); err != nil {
|
|
||||||
return fmt.Errorf("update v6 set: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
||||||
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
if localAddr.Is6() {
|
return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||||
if !m.hasIPv6() {
|
|
||||||
return fmt.Errorf("add inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
|
|
||||||
}
|
|
||||||
return m.router6.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
|
||||||
}
|
|
||||||
return m.router.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveInboundDNAT removes an inbound DNAT rule.
|
// RemoveInboundDNAT removes an inbound DNAT rule.
|
||||||
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
if localAddr.Is6() {
|
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||||
if !m.hasIPv6() {
|
|
||||||
return fmt.Errorf("remove inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
|
|
||||||
}
|
|
||||||
return m.router6.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
|
||||||
}
|
|
||||||
return m.router.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||||
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
if localAddr.Is6() {
|
return m.router.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||||
if !m.hasIPv6() {
|
|
||||||
return fmt.Errorf("add output DNAT: %w", firewall.ErrIPv6NotInitialized)
|
|
||||||
}
|
|
||||||
return m.router6.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
|
||||||
}
|
|
||||||
return m.router.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||||
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
if localAddr.Is6() {
|
return m.router.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||||
if !m.hasIPv6() {
|
|
||||||
return fmt.Errorf("remove output DNAT: %w", firewall.ErrIPv6NotInitialized)
|
|
||||||
}
|
|
||||||
return m.router6.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
|
||||||
}
|
|
||||||
return m.router.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -848,11 +534,7 @@ func (m *Manager) refreshNoTrackChains() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) createWorkTable() (*nftables.Table, error) {
|
func (m *Manager) createWorkTable() (*nftables.Table, error) {
|
||||||
return m.createWorkTableFamily(nftables.TableFamilyIPv4)
|
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) createWorkTableFamily(family nftables.TableFamily) (*nftables.Table, error) {
|
|
||||||
tables, err := m.rConn.ListTablesOfFamily(family)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("list of tables: %w", err)
|
return nil, fmt.Errorf("list of tables: %w", err)
|
||||||
}
|
}
|
||||||
@@ -864,7 +546,7 @@ func (m *Manager) createWorkTableFamily(family nftables.TableFamily) (*nftables.
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: family})
|
table := m.rConn.AddTable(&nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4})
|
||||||
err = m.rConn.Flush()
|
err = m.rConn.Flush()
|
||||||
return table, err
|
return table, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -383,138 +383,10 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
|||||||
err = manager.AddNatRule(pair)
|
err = manager.AddNatRule(pair)
|
||||||
require.NoError(t, err, "failed to add NAT rule")
|
require.NoError(t, err, "failed to add NAT rule")
|
||||||
|
|
||||||
dnatRule, err := manager.AddDNATRule(fw.ForwardRule{
|
|
||||||
Protocol: fw.ProtocolTCP,
|
|
||||||
DestinationPort: fw.Port{Values: []uint16{8080}},
|
|
||||||
TranslatedAddress: netip.MustParseAddr("100.96.0.2"),
|
|
||||||
TranslatedPort: fw.Port{Values: []uint16{80}},
|
|
||||||
})
|
|
||||||
require.NoError(t, err, "failed to add DNAT rule")
|
|
||||||
|
|
||||||
t.Cleanup(func() {
|
|
||||||
require.NoError(t, manager.DeleteDNATRule(dnatRule), "failed to delete DNAT rule")
|
|
||||||
})
|
|
||||||
|
|
||||||
stdout, stderr = runIptablesSave(t)
|
stdout, stderr = runIptablesSave(t)
|
||||||
verifyIptablesOutput(t, stdout, stderr)
|
verifyIptablesOutput(t, stdout, stderr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNftablesManagerIPv6CompatibilityWithIp6tables(t *testing.T) {
|
|
||||||
if check() != NFTABLES {
|
|
||||||
t.Skip("nftables not supported on this system")
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, bin := range []string{"ip6tables", "ip6tables-save", "iptables-save"} {
|
|
||||||
if _, err := exec.LookPath(bin); err != nil {
|
|
||||||
t.Skipf("%s not available on this system: %v", bin, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Seed ip6 tables in the nft backend. Docker may not create them.
|
|
||||||
seedIp6tables(t)
|
|
||||||
|
|
||||||
ifaceMockV6 := &iFaceMock{
|
|
||||||
NameFunc: func() string { return "wt-test" },
|
|
||||||
AddressFunc: func() wgaddr.Address {
|
|
||||||
return wgaddr.Address{
|
|
||||||
IP: netip.MustParseAddr("100.96.0.1"),
|
|
||||||
Network: netip.MustParsePrefix("100.96.0.0/16"),
|
|
||||||
IPv6: netip.MustParseAddr("fd00::1"),
|
|
||||||
IPv6Net: netip.MustParsePrefix("fd00::/64"),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
manager, err := Create(ifaceMockV6, iface.DefaultMTU)
|
|
||||||
require.NoError(t, err, "create manager")
|
|
||||||
require.NoError(t, manager.Init(nil))
|
|
||||||
|
|
||||||
t.Cleanup(func() {
|
|
||||||
require.NoError(t, manager.Close(nil), "close manager")
|
|
||||||
|
|
||||||
stdout, stderr := runIp6tablesSave(t)
|
|
||||||
verifyIp6tablesOutput(t, stdout, stderr)
|
|
||||||
})
|
|
||||||
|
|
||||||
ip := netip.MustParseAddr("fd00::2")
|
|
||||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
|
||||||
require.NoError(t, err, "add v6 peer filtering rule")
|
|
||||||
|
|
||||||
_, err = manager.AddRouteFiltering(
|
|
||||||
nil,
|
|
||||||
[]netip.Prefix{netip.MustParsePrefix("fd00:1::/64")},
|
|
||||||
fw.Network{Prefix: netip.MustParsePrefix("2001:db8::/48")},
|
|
||||||
fw.ProtocolTCP,
|
|
||||||
nil,
|
|
||||||
&fw.Port{Values: []uint16{443}},
|
|
||||||
fw.ActionAccept,
|
|
||||||
)
|
|
||||||
require.NoError(t, err, "add v6 route filtering rule")
|
|
||||||
|
|
||||||
err = manager.AddNatRule(fw.RouterPair{
|
|
||||||
Source: fw.Network{Prefix: netip.MustParsePrefix("fd00::/64")},
|
|
||||||
Destination: fw.Network{Prefix: netip.MustParsePrefix("2001:db8::/48")},
|
|
||||||
Masquerade: true,
|
|
||||||
})
|
|
||||||
require.NoError(t, err, "add v6 NAT rule")
|
|
||||||
|
|
||||||
dnatRule, err := manager.AddDNATRule(fw.ForwardRule{
|
|
||||||
Protocol: fw.ProtocolTCP,
|
|
||||||
DestinationPort: fw.Port{Values: []uint16{8080}},
|
|
||||||
TranslatedAddress: netip.MustParseAddr("fd00::2"),
|
|
||||||
TranslatedPort: fw.Port{Values: []uint16{80}},
|
|
||||||
})
|
|
||||||
require.NoError(t, err, "add v6 DNAT rule")
|
|
||||||
|
|
||||||
t.Cleanup(func() {
|
|
||||||
require.NoError(t, manager.DeleteDNATRule(dnatRule), "delete v6 DNAT rule")
|
|
||||||
})
|
|
||||||
|
|
||||||
stdout, stderr := runIptablesSave(t)
|
|
||||||
verifyIptablesOutput(t, stdout, stderr)
|
|
||||||
|
|
||||||
stdout, stderr = runIp6tablesSave(t)
|
|
||||||
verifyIp6tablesOutput(t, stdout, stderr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func seedIp6tables(t *testing.T) {
|
|
||||||
t.Helper()
|
|
||||||
for _, tc := range []struct{ table, chain string }{
|
|
||||||
{"filter", "FORWARD"},
|
|
||||||
{"nat", "POSTROUTING"},
|
|
||||||
{"mangle", "FORWARD"},
|
|
||||||
} {
|
|
||||||
add := exec.Command("ip6tables", "-t", tc.table, "-A", tc.chain, "-j", "ACCEPT")
|
|
||||||
require.NoError(t, add.Run(), "seed ip6tables -t %s", tc.table)
|
|
||||||
del := exec.Command("ip6tables", "-t", tc.table, "-D", tc.chain, "-j", "ACCEPT")
|
|
||||||
require.NoError(t, del.Run(), "unseed ip6tables -t %s", tc.table)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func runIp6tablesSave(t *testing.T) (string, string) {
|
|
||||||
t.Helper()
|
|
||||||
var stdout, stderr bytes.Buffer
|
|
||||||
cmd := exec.Command("ip6tables-save")
|
|
||||||
cmd.Stdout = &stdout
|
|
||||||
cmd.Stderr = &stderr
|
|
||||||
require.NoError(t, cmd.Run(), "ip6tables-save failed")
|
|
||||||
return stdout.String(), stderr.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func verifyIp6tablesOutput(t *testing.T, stdout, stderr string) {
|
|
||||||
t.Helper()
|
|
||||||
for _, msg := range []string{
|
|
||||||
"Table `nat' is incompatible",
|
|
||||||
"Table `mangle' is incompatible",
|
|
||||||
"Table `filter' is incompatible",
|
|
||||||
} {
|
|
||||||
require.NotContains(t, stdout, msg,
|
|
||||||
"ip6tables-save stdout reports incompatibility: %s", stdout)
|
|
||||||
require.NotContains(t, stderr, msg,
|
|
||||||
"ip6tables-save stderr reports incompatibility: %s", stderr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNftablesManagerCompatibilityWithIptablesFor6kPrefixes(t *testing.T) {
|
func TestNftablesManagerCompatibilityWithIptablesFor6kPrefixes(t *testing.T) {
|
||||||
if check() != NFTABLES {
|
if check() != NFTABLES {
|
||||||
t.Skip("nftables not supported on this system")
|
t.Skip("nftables not supported on this system")
|
||||||
|
|||||||
@@ -50,10 +50,8 @@ const (
|
|||||||
dnatSuffix = "_dnat"
|
dnatSuffix = "_dnat"
|
||||||
snatSuffix = "_snat"
|
snatSuffix = "_snat"
|
||||||
|
|
||||||
// ipv4TCPHeaderSize is the minimum IPv4 (20) + TCP (20) header size for MSS calculation.
|
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
|
||||||
ipv4TCPHeaderSize = 40
|
ipTCPHeaderMinSize = 40
|
||||||
// ipv6TCPHeaderSize is the minimum IPv6 (40) + TCP (20) header size for MSS calculation.
|
|
||||||
ipv6TCPHeaderSize = 60
|
|
||||||
|
|
||||||
// maxPrefixesSet 1638 prefixes start to fail, taking some margin
|
// maxPrefixesSet 1638 prefixes start to fail, taking some margin
|
||||||
maxPrefixesSet = 1500
|
maxPrefixesSet = 1500
|
||||||
@@ -78,7 +76,6 @@ type router struct {
|
|||||||
rules map[string]*nftables.Rule
|
rules map[string]*nftables.Rule
|
||||||
ipsetCounter *refcounter.Counter[string, setInput, *nftables.Set]
|
ipsetCounter *refcounter.Counter[string, setInput, *nftables.Set]
|
||||||
|
|
||||||
af addrFamily
|
|
||||||
wgIface iFaceMapper
|
wgIface iFaceMapper
|
||||||
ipFwdState *ipfwdstate.IPForwardingState
|
ipFwdState *ipfwdstate.IPForwardingState
|
||||||
legacyManagement bool
|
legacyManagement bool
|
||||||
@@ -91,7 +88,6 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*rou
|
|||||||
workTable: workTable,
|
workTable: workTable,
|
||||||
chains: make(map[string]*nftables.Chain),
|
chains: make(map[string]*nftables.Chain),
|
||||||
rules: make(map[string]*nftables.Rule),
|
rules: make(map[string]*nftables.Rule),
|
||||||
af: familyForAddr(workTable.Family == nftables.TableFamilyIPv4),
|
|
||||||
wgIface: wgIface,
|
wgIface: wgIface,
|
||||||
ipFwdState: ipfwdstate.NewIPForwardingState(),
|
ipFwdState: ipfwdstate.NewIPForwardingState(),
|
||||||
mtu: mtu,
|
mtu: mtu,
|
||||||
@@ -154,7 +150,7 @@ func (r *router) Reset() error {
|
|||||||
func (r *router) removeNatPreroutingRules() error {
|
func (r *router) removeNatPreroutingRules() error {
|
||||||
table := &nftables.Table{
|
table := &nftables.Table{
|
||||||
Name: tableNat,
|
Name: tableNat,
|
||||||
Family: r.af.tableFamily,
|
Family: nftables.TableFamilyIPv4,
|
||||||
}
|
}
|
||||||
chain := &nftables.Chain{
|
chain := &nftables.Chain{
|
||||||
Name: chainNameNatPrerouting,
|
Name: chainNameNatPrerouting,
|
||||||
@@ -187,7 +183,7 @@ func (r *router) removeNatPreroutingRules() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) loadFilterTable() (*nftables.Table, error) {
|
func (r *router) loadFilterTable() (*nftables.Table, error) {
|
||||||
tables, err := r.conn.ListTablesOfFamily(r.af.tableFamily)
|
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("list tables: %w", err)
|
return nil, fmt.Errorf("list tables: %w", err)
|
||||||
}
|
}
|
||||||
@@ -423,7 +419,7 @@ func (r *router) AddRouteFiltering(
|
|||||||
|
|
||||||
// Handle protocol
|
// Handle protocol
|
||||||
if proto != firewall.ProtocolALL {
|
if proto != firewall.ProtocolALL {
|
||||||
protoNum, err := r.af.protoNum(proto)
|
protoNum, err := protoToInt(proto)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("convert protocol to number: %w", err)
|
return nil, fmt.Errorf("convert protocol to number: %w", err)
|
||||||
}
|
}
|
||||||
@@ -483,24 +479,7 @@ func (r *router) getIpSet(set firewall.Set, prefixes []netip.Prefix, isSource bo
|
|||||||
return nil, fmt.Errorf("create or get ipset: %w", err)
|
return nil, fmt.Errorf("create or get ipset: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return r.getIpSetExprs(ref, isSource)
|
return getIpSetExprs(ref, isSource)
|
||||||
}
|
|
||||||
|
|
||||||
func (r *router) iptablesProto() iptables.Protocol {
|
|
||||||
if r.af.tableFamily == nftables.TableFamilyIPv6 {
|
|
||||||
return iptables.ProtocolIPv6
|
|
||||||
}
|
|
||||||
return iptables.ProtocolIPv4
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *router) hasRule(id string) bool {
|
|
||||||
_, ok := r.rules[id]
|
|
||||||
return ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *router) hasDNATRule(id string) bool {
|
|
||||||
_, ok := r.rules[id+dnatSuffix]
|
|
||||||
return ok
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
||||||
@@ -549,10 +528,10 @@ func (r *router) createIpSet(setName string, input setInput) (*nftables.Set, err
|
|||||||
Table: r.workTable,
|
Table: r.workTable,
|
||||||
// required for prefixes
|
// required for prefixes
|
||||||
Interval: true,
|
Interval: true,
|
||||||
KeyType: r.af.setKeyType,
|
KeyType: nftables.TypeIPAddr,
|
||||||
}
|
}
|
||||||
|
|
||||||
elements := r.convertPrefixesToSet(prefixes)
|
elements := convertPrefixesToSet(prefixes)
|
||||||
nElements := len(elements)
|
nElements := len(elements)
|
||||||
|
|
||||||
maxElements := maxPrefixesSet * 2
|
maxElements := maxPrefixesSet * 2
|
||||||
@@ -585,17 +564,23 @@ func (r *router) createIpSet(setName string, input setInput) (*nftables.Set, err
|
|||||||
return nfset, nil
|
return nfset, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement {
|
func convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement {
|
||||||
var elements []nftables.SetElement
|
var elements []nftables.SetElement
|
||||||
for _, prefix := range prefixes {
|
for _, prefix := range prefixes {
|
||||||
|
// TODO: Implement IPv6 support
|
||||||
|
if prefix.Addr().Is6() {
|
||||||
|
log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
// nftables needs half-open intervals [firstIP, lastIP) for prefixes
|
// nftables needs half-open intervals [firstIP, lastIP) for prefixes
|
||||||
// e.g. 10.0.0.0/24 becomes [10.0.0.0, 10.0.1.0), 10.1.1.1/32 becomes [10.1.1.1, 10.1.1.2) etc
|
// e.g. 10.0.0.0/24 becomes [10.0.0.0, 10.0.1.0), 10.1.1.1/32 becomes [10.1.1.1, 10.1.1.2) etc
|
||||||
firstIP := prefix.Addr()
|
firstIP := prefix.Addr()
|
||||||
lastIP := calculateLastIP(prefix).Next()
|
lastIP := calculateLastIP(prefix).Next()
|
||||||
|
|
||||||
elements = append(elements,
|
elements = append(elements,
|
||||||
// the nft tool also adds a zero-address IntervalEnd element, see https://github.com/google/nftables/issues/247
|
// the nft tool also adds a line like this, see https://github.com/google/nftables/issues/247
|
||||||
// nftables.SetElement{Key: make([]byte, r.af.addrLen), IntervalEnd: true},
|
// nftables.SetElement{Key: []byte{0, 0, 0, 0}, IntervalEnd: true},
|
||||||
nftables.SetElement{Key: firstIP.AsSlice()},
|
nftables.SetElement{Key: firstIP.AsSlice()},
|
||||||
nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true},
|
nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true},
|
||||||
)
|
)
|
||||||
@@ -605,20 +590,10 @@ func (r *router) convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetEle
|
|||||||
|
|
||||||
// calculateLastIP determines the last IP in a given prefix.
|
// calculateLastIP determines the last IP in a given prefix.
|
||||||
func calculateLastIP(prefix netip.Prefix) netip.Addr {
|
func calculateLastIP(prefix netip.Prefix) netip.Addr {
|
||||||
masked := prefix.Masked()
|
hostMask := ^uint32(0) >> prefix.Masked().Bits()
|
||||||
if masked.Addr().Is4() {
|
lastIP := uint32FromNetipAddr(prefix.Addr()) | hostMask
|
||||||
hostMask := ^uint32(0) >> masked.Bits()
|
|
||||||
lastIP := uint32FromNetipAddr(masked.Addr()) | hostMask
|
|
||||||
return netip.AddrFrom4(uint32ToBytes(lastIP))
|
|
||||||
}
|
|
||||||
|
|
||||||
// IPv6: set host bits to all 1s
|
return netip.AddrFrom4(uint32ToBytes(lastIP))
|
||||||
b := masked.Addr().As16()
|
|
||||||
bits := masked.Bits()
|
|
||||||
for i := bits; i < 128; i++ {
|
|
||||||
b[i/8] |= 1 << (7 - i%8)
|
|
||||||
}
|
|
||||||
return netip.AddrFrom16(b)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Utility function to convert netip.Addr to uint32.
|
// Utility function to convert netip.Addr to uint32.
|
||||||
@@ -870,16 +845,9 @@ func (r *router) addPostroutingRules() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
|
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
|
||||||
|
// TODO: Add IPv6 support
|
||||||
func (r *router) addMSSClampingRules() error {
|
func (r *router) addMSSClampingRules() error {
|
||||||
overhead := uint16(ipv4TCPHeaderSize)
|
mss := r.mtu - ipTCPHeaderMinSize
|
||||||
if r.af.tableFamily == nftables.TableFamilyIPv6 {
|
|
||||||
overhead = ipv6TCPHeaderSize
|
|
||||||
}
|
|
||||||
if r.mtu <= overhead {
|
|
||||||
log.Debugf("MTU %d too small for MSS clamping (overhead %d), skipping", r.mtu, overhead)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
mss := r.mtu - overhead
|
|
||||||
|
|
||||||
exprsOut := []expr.Any{
|
exprsOut := []expr.Any{
|
||||||
&expr.Meta{
|
&expr.Meta{
|
||||||
@@ -1086,22 +1054,17 @@ func (r *router) acceptFilterTableRules() error {
|
|||||||
log.Debugf("Used %s to add accept forward and input rules", fw)
|
log.Debugf("Used %s to add accept forward and input rules", fw)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Try iptables first and fallback to nftables if iptables is not available.
|
// Try iptables first and fallback to nftables if iptables is not available
|
||||||
// Use the correct protocol (iptables vs ip6tables) for the address family.
|
ipt, err := iptables.New()
|
||||||
ipt, err := iptables.NewWithProtocol(r.iptablesProto())
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
// iptables is not available but the filter table exists
|
||||||
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
|
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
|
||||||
|
|
||||||
fw = "nftables"
|
fw = "nftables"
|
||||||
return r.acceptFilterRulesNftables(r.filterTable)
|
return r.acceptFilterRulesNftables(r.filterTable)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.acceptFilterRulesIptables(ipt); err != nil {
|
return r.acceptFilterRulesIptables(ipt)
|
||||||
log.Warnf("iptables failed (table may be incompatible), falling back to nftables: %v", err)
|
|
||||||
fw = "nftables"
|
|
||||||
return r.acceptFilterRulesNftables(r.filterTable)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error {
|
func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error {
|
||||||
@@ -1172,122 +1135,83 @@ func (r *router) acceptExternalChainsRules() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
intf := ifname(r.wgIface.Name())
|
intf := ifname(r.wgIface.Name())
|
||||||
|
|
||||||
for _, chain := range chains {
|
for _, chain := range chains {
|
||||||
r.applyExternalChainAccept(chain, intf)
|
if chain.Hooknum == nil {
|
||||||
|
log.Debugf("skipping external chain %s/%s: hooknum is nil", chain.Table.Name, chain.Name)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("adding accept rules to external %s chain: %s %s/%s",
|
||||||
|
hookName(chain.Hooknum), familyName(chain.Table.Family), chain.Table.Name, chain.Name)
|
||||||
|
|
||||||
|
switch *chain.Hooknum {
|
||||||
|
case *nftables.ChainHookForward:
|
||||||
|
r.insertForwardAcceptRules(chain, intf)
|
||||||
|
case *nftables.ChainHookInput:
|
||||||
|
r.insertInputAcceptRule(chain, intf)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
if err := r.conn.Flush(); err != nil {
|
||||||
return fmt.Errorf("flush external chain rules: %w", err)
|
return fmt.Errorf("flush external chain rules: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) applyExternalChainAccept(chain *nftables.Chain, intf []byte) {
|
|
||||||
if chain.Hooknum == nil {
|
|
||||||
log.Debugf("skipping external chain %s/%s: hooknum is nil", chain.Table.Name, chain.Name)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("adding accept rules to external %s chain: %s %s/%s",
|
|
||||||
hookName(chain.Hooknum), familyName(chain.Table.Family), chain.Table.Name, chain.Name)
|
|
||||||
|
|
||||||
switch *chain.Hooknum {
|
|
||||||
case *nftables.ChainHookForward:
|
|
||||||
r.insertForwardAcceptRules(chain, intf)
|
|
||||||
case *nftables.ChainHookInput:
|
|
||||||
r.insertInputAcceptRule(chain, intf)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *router) insertForwardAcceptRules(chain *nftables.Chain, intf []byte) {
|
func (r *router) insertForwardAcceptRules(chain *nftables.Chain, intf []byte) {
|
||||||
existing, err := r.existingNetbirdRulesInChain(chain)
|
iifRule := &nftables.Rule{
|
||||||
if err != nil {
|
|
||||||
log.Warnf("skip forward accept rules in %s/%s: %v", chain.Table.Name, chain.Name, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
r.insertForwardIifRule(chain, intf, existing)
|
|
||||||
r.insertForwardOifEstablishedRule(chain, intf, existing)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *router) insertForwardIifRule(chain *nftables.Chain, intf []byte, existing map[string]bool) {
|
|
||||||
if existing[userDataAcceptForwardRuleIif] {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
r.conn.InsertRule(&nftables.Rule{
|
|
||||||
Table: chain.Table,
|
Table: chain.Table,
|
||||||
Chain: chain,
|
Chain: chain,
|
||||||
Exprs: []expr.Any{
|
Exprs: []expr.Any{
|
||||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: intf},
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: intf,
|
||||||
|
},
|
||||||
&expr.Counter{},
|
&expr.Counter{},
|
||||||
&expr.Verdict{Kind: expr.VerdictAccept},
|
&expr.Verdict{Kind: expr.VerdictAccept},
|
||||||
},
|
},
|
||||||
UserData: []byte(userDataAcceptForwardRuleIif),
|
UserData: []byte(userDataAcceptForwardRuleIif),
|
||||||
})
|
}
|
||||||
}
|
r.conn.InsertRule(iifRule)
|
||||||
|
|
||||||
func (r *router) insertForwardOifEstablishedRule(chain *nftables.Chain, intf []byte, existing map[string]bool) {
|
oifExprs := []expr.Any{
|
||||||
if existing[userDataAcceptForwardRuleOif] {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
exprs := []expr.Any{
|
|
||||||
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: intf},
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: intf,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
r.conn.InsertRule(&nftables.Rule{
|
oifRule := &nftables.Rule{
|
||||||
Table: chain.Table,
|
Table: chain.Table,
|
||||||
Chain: chain,
|
Chain: chain,
|
||||||
Exprs: append(exprs, getEstablishedExprs(2)...),
|
Exprs: append(oifExprs, getEstablishedExprs(2)...),
|
||||||
UserData: []byte(userDataAcceptForwardRuleOif),
|
UserData: []byte(userDataAcceptForwardRuleOif),
|
||||||
})
|
}
|
||||||
|
r.conn.InsertRule(oifRule)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) insertInputAcceptRule(chain *nftables.Chain, intf []byte) {
|
func (r *router) insertInputAcceptRule(chain *nftables.Chain, intf []byte) {
|
||||||
existing, err := r.existingNetbirdRulesInChain(chain)
|
inputRule := &nftables.Rule{
|
||||||
if err != nil {
|
|
||||||
log.Warnf("skip input accept rule in %s/%s: %v", chain.Table.Name, chain.Name, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if existing[userDataAcceptInputRule] {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
r.conn.InsertRule(&nftables.Rule{
|
|
||||||
Table: chain.Table,
|
Table: chain.Table,
|
||||||
Chain: chain,
|
Chain: chain,
|
||||||
Exprs: []expr.Any{
|
Exprs: []expr.Any{
|
||||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: intf},
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: intf,
|
||||||
|
},
|
||||||
&expr.Counter{},
|
&expr.Counter{},
|
||||||
&expr.Verdict{Kind: expr.VerdictAccept},
|
&expr.Verdict{Kind: expr.VerdictAccept},
|
||||||
},
|
},
|
||||||
UserData: []byte(userDataAcceptInputRule),
|
UserData: []byte(userDataAcceptInputRule),
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// existingNetbirdRulesInChain returns the set of netbird-owned UserData tags present in a chain; callers must bail on error since InsertRule is additive.
|
|
||||||
func (r *router) existingNetbirdRulesInChain(chain *nftables.Chain) (map[string]bool, error) {
|
|
||||||
rules, err := r.conn.GetRules(chain.Table, chain)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("list rules: %w", err)
|
|
||||||
}
|
}
|
||||||
present := map[string]bool{}
|
r.conn.InsertRule(inputRule)
|
||||||
for _, rule := range rules {
|
|
||||||
if !isNetbirdAcceptRuleTag(rule.UserData) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
present[string(rule.UserData)] = true
|
|
||||||
}
|
|
||||||
return present, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func isNetbirdAcceptRuleTag(userData []byte) bool {
|
|
||||||
switch string(userData) {
|
|
||||||
case userDataAcceptForwardRuleIif,
|
|
||||||
userDataAcceptForwardRuleOif,
|
|
||||||
userDataAcceptInputRule:
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) removeAcceptFilterRules() error {
|
func (r *router) removeAcceptFilterRules() error {
|
||||||
@@ -1309,17 +1233,13 @@ func (r *router) removeFilterTableRules() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ipt, err := iptables.NewWithProtocol(r.iptablesProto())
|
ipt, err := iptables.New()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("iptables not available, using nftables to remove filter rules: %v", err)
|
log.Debugf("iptables not available, using nftables to remove filter rules: %v", err)
|
||||||
return r.removeAcceptRulesFromTable(r.filterTable)
|
return r.removeAcceptRulesFromTable(r.filterTable)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.removeAcceptFilterRulesIptables(ipt); err != nil {
|
return r.removeAcceptFilterRulesIptables(ipt)
|
||||||
log.Debugf("iptables removal failed (table may be incompatible), falling back to nftables: %v", err)
|
|
||||||
return r.removeAcceptRulesFromTable(r.filterTable)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) removeAcceptRulesFromTable(table *nftables.Table) error {
|
func (r *router) removeAcceptRulesFromTable(table *nftables.Table) error {
|
||||||
@@ -1386,7 +1306,7 @@ func (r *router) removeExternalChainsRules() error {
|
|||||||
func (r *router) findExternalChains() []*nftables.Chain {
|
func (r *router) findExternalChains() []*nftables.Chain {
|
||||||
var chains []*nftables.Chain
|
var chains []*nftables.Chain
|
||||||
|
|
||||||
families := []nftables.TableFamily{r.af.tableFamily, nftables.TableFamilyINet}
|
families := []nftables.TableFamily{nftables.TableFamilyIPv4, nftables.TableFamilyINet}
|
||||||
|
|
||||||
for _, family := range families {
|
for _, family := range families {
|
||||||
allChains, err := r.conn.ListChainsOfTableFamily(family)
|
allChains, err := r.conn.ListChainsOfTableFamily(family)
|
||||||
@@ -1417,8 +1337,8 @@ func (r *router) isExternalChain(chain *nftables.Chain) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Skip iptables/ip6tables-managed tables (adding nft-native rules breaks iptables-save compat)
|
// Skip all iptables-managed tables in the ip family
|
||||||
if (chain.Table.Family == nftables.TableFamilyIPv4 || chain.Table.Family == nftables.TableFamilyIPv6) && isIptablesTable(chain.Table.Name) {
|
if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1559,7 +1479,7 @@ func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
|||||||
return rule, nil
|
return rule, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
protoNum, err := r.af.protoNum(rule.Protocol)
|
protoNum, err := protoToInt(rule.Protocol)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("convert protocol to number: %w", err)
|
return nil, fmt.Errorf("convert protocol to number: %w", err)
|
||||||
}
|
}
|
||||||
@@ -1622,7 +1542,7 @@ func (r *router) addDnatRedirect(rule firewall.ForwardRule, protoNum uint8, rule
|
|||||||
dnatExprs = append(dnatExprs,
|
dnatExprs = append(dnatExprs,
|
||||||
&expr.NAT{
|
&expr.NAT{
|
||||||
Type: expr.NATTypeDestNAT,
|
Type: expr.NATTypeDestNAT,
|
||||||
Family: uint32(r.af.tableFamily),
|
Family: uint32(nftables.TableFamilyIPv4),
|
||||||
RegAddrMin: 1,
|
RegAddrMin: 1,
|
||||||
RegProtoMin: regProtoMin,
|
RegProtoMin: regProtoMin,
|
||||||
RegProtoMax: regProtoMax,
|
RegProtoMax: regProtoMax,
|
||||||
@@ -1715,15 +1635,14 @@ func (r *router) addXTablesRedirect(dnatExprs []expr.Any, ruleKey string, rule f
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
natTable := &nftables.Table{
|
|
||||||
Name: tableNat,
|
|
||||||
Family: r.af.tableFamily,
|
|
||||||
}
|
|
||||||
dnatRule := &nftables.Rule{
|
dnatRule := &nftables.Rule{
|
||||||
Table: natTable,
|
Table: &nftables.Table{
|
||||||
|
Name: tableNat,
|
||||||
|
Family: nftables.TableFamilyIPv4,
|
||||||
|
},
|
||||||
Chain: &nftables.Chain{
|
Chain: &nftables.Chain{
|
||||||
Name: chainNameNatPrerouting,
|
Name: chainNameNatPrerouting,
|
||||||
Table: natTable,
|
Table: r.filterTable,
|
||||||
Type: nftables.ChainTypeNAT,
|
Type: nftables.ChainTypeNAT,
|
||||||
Hooknum: nftables.ChainHookPrerouting,
|
Hooknum: nftables.ChainHookPrerouting,
|
||||||
Priority: nftables.ChainPriorityNATDest,
|
Priority: nftables.ChainPriorityNATDest,
|
||||||
@@ -1754,8 +1673,8 @@ func (r *router) addDnatMasq(rule firewall.ForwardRule, protoNum uint8, ruleKey
|
|||||||
&expr.Payload{
|
&expr.Payload{
|
||||||
DestRegister: 1,
|
DestRegister: 1,
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
Base: expr.PayloadBaseNetworkHeader,
|
||||||
Offset: r.af.dstAddrOffset,
|
Offset: 16,
|
||||||
Len: r.af.addrLen,
|
Len: 4,
|
||||||
},
|
},
|
||||||
&expr.Cmp{
|
&expr.Cmp{
|
||||||
Op: expr.CmpOpEq,
|
Op: expr.CmpOpEq,
|
||||||
@@ -1833,7 +1752,7 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
|||||||
return fmt.Errorf("get set %s: %w", set.HashedName(), err)
|
return fmt.Errorf("get set %s: %w", set.HashedName(), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
elements := r.convertPrefixesToSet(prefixes)
|
elements := convertPrefixesToSet(prefixes)
|
||||||
if err := r.conn.SetAddElements(nfset, elements); err != nil {
|
if err := r.conn.SetAddElements(nfset, elements); err != nil {
|
||||||
return fmt.Errorf("add elements to set %s: %w", set.HashedName(), err)
|
return fmt.Errorf("add elements to set %s: %w", set.HashedName(), err)
|
||||||
}
|
}
|
||||||
@@ -1848,14 +1767,14 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
||||||
func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||||
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort)
|
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||||
|
|
||||||
if _, exists := r.rules[ruleID]; exists {
|
if _, exists := r.rules[ruleID]; exists {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
protoNum, err := r.af.protoNum(protocol)
|
protoNum, err := protoToInt(protocol)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("convert protocol to number: %w", err)
|
return fmt.Errorf("convert protocol to number: %w", err)
|
||||||
}
|
}
|
||||||
@@ -1882,15 +1801,11 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol
|
|||||||
&expr.Cmp{
|
&expr.Cmp{
|
||||||
Op: expr.CmpOpEq,
|
Op: expr.CmpOpEq,
|
||||||
Register: 3,
|
Register: 3,
|
||||||
Data: binaryutil.BigEndian.PutUint16(originalPort),
|
Data: binaryutil.BigEndian.PutUint16(sourcePort),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
bits := 32
|
exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...)
|
||||||
if localAddr.Is6() {
|
|
||||||
bits = 128
|
|
||||||
}
|
|
||||||
exprs = append(exprs, r.applyPrefix(netip.PrefixFrom(localAddr, bits), false)...)
|
|
||||||
|
|
||||||
exprs = append(exprs,
|
exprs = append(exprs,
|
||||||
&expr.Immediate{
|
&expr.Immediate{
|
||||||
@@ -1899,11 +1814,11 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol
|
|||||||
},
|
},
|
||||||
&expr.Immediate{
|
&expr.Immediate{
|
||||||
Register: 2,
|
Register: 2,
|
||||||
Data: binaryutil.BigEndian.PutUint16(translatedPort),
|
Data: binaryutil.BigEndian.PutUint16(targetPort),
|
||||||
},
|
},
|
||||||
&expr.NAT{
|
&expr.NAT{
|
||||||
Type: expr.NATTypeDestNAT,
|
Type: expr.NATTypeDestNAT,
|
||||||
Family: uint32(r.af.tableFamily),
|
Family: uint32(nftables.TableFamilyIPv4),
|
||||||
RegAddrMin: 1,
|
RegAddrMin: 1,
|
||||||
RegProtoMin: 2,
|
RegProtoMin: 2,
|
||||||
RegProtoMax: 0,
|
RegProtoMax: 0,
|
||||||
@@ -1928,12 +1843,12 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RemoveInboundDNAT removes an inbound DNAT rule.
|
// RemoveInboundDNAT removes an inbound DNAT rule.
|
||||||
func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||||
if err := r.refreshRulesMap(); err != nil {
|
if err := r.refreshRulesMap(); err != nil {
|
||||||
return fmt.Errorf(refreshRulesMapError, err)
|
return fmt.Errorf(refreshRulesMapError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort)
|
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||||
|
|
||||||
rule, exists := r.rules[ruleID]
|
rule, exists := r.rules[ruleID]
|
||||||
if !exists {
|
if !exists {
|
||||||
@@ -1979,8 +1894,8 @@ func (r *router) ensureNATOutputChain() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||||
func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||||
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort)
|
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||||
|
|
||||||
if _, exists := r.rules[ruleID]; exists {
|
if _, exists := r.rules[ruleID]; exists {
|
||||||
return nil
|
return nil
|
||||||
@@ -1990,7 +1905,7 @@ func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol,
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
protoNum, err := r.af.protoNum(protocol)
|
protoNum, err := protoToInt(protocol)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("convert protocol to number: %w", err)
|
return fmt.Errorf("convert protocol to number: %w", err)
|
||||||
}
|
}
|
||||||
@@ -2011,15 +1926,11 @@ func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol,
|
|||||||
&expr.Cmp{
|
&expr.Cmp{
|
||||||
Op: expr.CmpOpEq,
|
Op: expr.CmpOpEq,
|
||||||
Register: 2,
|
Register: 2,
|
||||||
Data: binaryutil.BigEndian.PutUint16(originalPort),
|
Data: binaryutil.BigEndian.PutUint16(sourcePort),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
bits := 32
|
exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...)
|
||||||
if localAddr.Is6() {
|
|
||||||
bits = 128
|
|
||||||
}
|
|
||||||
exprs = append(exprs, r.applyPrefix(netip.PrefixFrom(localAddr, bits), false)...)
|
|
||||||
|
|
||||||
exprs = append(exprs,
|
exprs = append(exprs,
|
||||||
&expr.Immediate{
|
&expr.Immediate{
|
||||||
@@ -2028,11 +1939,11 @@ func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol,
|
|||||||
},
|
},
|
||||||
&expr.Immediate{
|
&expr.Immediate{
|
||||||
Register: 2,
|
Register: 2,
|
||||||
Data: binaryutil.BigEndian.PutUint16(translatedPort),
|
Data: binaryutil.BigEndian.PutUint16(targetPort),
|
||||||
},
|
},
|
||||||
&expr.NAT{
|
&expr.NAT{
|
||||||
Type: expr.NATTypeDestNAT,
|
Type: expr.NATTypeDestNAT,
|
||||||
Family: uint32(r.af.tableFamily),
|
Family: uint32(nftables.TableFamilyIPv4),
|
||||||
RegAddrMin: 1,
|
RegAddrMin: 1,
|
||||||
RegProtoMin: 2,
|
RegProtoMin: 2,
|
||||||
},
|
},
|
||||||
@@ -2056,12 +1967,12 @@ func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||||
func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||||
if err := r.refreshRulesMap(); err != nil {
|
if err := r.refreshRulesMap(); err != nil {
|
||||||
return fmt.Errorf(refreshRulesMapError, err)
|
return fmt.Errorf(refreshRulesMapError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort)
|
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||||
|
|
||||||
rule, exists := r.rules[ruleID]
|
rule, exists := r.rules[ruleID]
|
||||||
if !exists {
|
if !exists {
|
||||||
@@ -2100,44 +2011,45 @@ func (r *router) applyNetwork(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if network.IsPrefix() {
|
if network.IsPrefix() {
|
||||||
return r.applyPrefix(network.Prefix, isSource), nil
|
return applyPrefix(network.Prefix, isSource), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// applyPrefix generates nftables expressions for a CIDR prefix
|
// applyPrefix generates nftables expressions for a CIDR prefix
|
||||||
func (r *router) applyPrefix(prefix netip.Prefix, isSource bool) []expr.Any {
|
func applyPrefix(prefix netip.Prefix, isSource bool) []expr.Any {
|
||||||
// dst offset by default
|
// dst offset
|
||||||
offset := r.af.dstAddrOffset
|
offset := uint32(16)
|
||||||
if isSource {
|
if isSource {
|
||||||
// src offset
|
// src offset
|
||||||
offset = r.af.srcAddrOffset
|
offset = 12
|
||||||
}
|
}
|
||||||
|
|
||||||
ones := prefix.Bits()
|
ones := prefix.Bits()
|
||||||
// unspecified address (/0) doesn't need extra expressions
|
// 0.0.0.0/0 doesn't need extra expressions
|
||||||
if ones == 0 {
|
if ones == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
mask := net.CIDRMask(ones, r.af.totalBits)
|
mask := net.CIDRMask(ones, 32)
|
||||||
xor := make([]byte, r.af.addrLen)
|
|
||||||
|
|
||||||
return []expr.Any{
|
return []expr.Any{
|
||||||
&expr.Payload{
|
&expr.Payload{
|
||||||
DestRegister: 1,
|
DestRegister: 1,
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
Base: expr.PayloadBaseNetworkHeader,
|
||||||
Offset: offset,
|
Offset: offset,
|
||||||
Len: r.af.addrLen,
|
Len: 4,
|
||||||
},
|
},
|
||||||
|
// netmask
|
||||||
&expr.Bitwise{
|
&expr.Bitwise{
|
||||||
DestRegister: 1,
|
DestRegister: 1,
|
||||||
SourceRegister: 1,
|
SourceRegister: 1,
|
||||||
Len: r.af.addrLen,
|
Len: 4,
|
||||||
Mask: mask,
|
Mask: mask,
|
||||||
Xor: xor,
|
Xor: []byte{0, 0, 0, 0},
|
||||||
},
|
},
|
||||||
|
// net address
|
||||||
&expr.Cmp{
|
&expr.Cmp{
|
||||||
Op: expr.CmpOpEq,
|
Op: expr.CmpOpEq,
|
||||||
Register: 1,
|
Register: 1,
|
||||||
@@ -2220,12 +2132,13 @@ func getCtNewExprs() []expr.Any {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any, error) {
|
func getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any, error) {
|
||||||
// dst offset by default
|
|
||||||
offset := r.af.dstAddrOffset
|
// dst offset
|
||||||
|
offset := uint32(16)
|
||||||
if isSource {
|
if isSource {
|
||||||
// src offset
|
// src offset
|
||||||
offset = r.af.srcAddrOffset
|
offset = 12
|
||||||
}
|
}
|
||||||
|
|
||||||
return []expr.Any{
|
return []expr.Any{
|
||||||
@@ -2233,7 +2146,7 @@ func (r *router) getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool)
|
|||||||
DestRegister: 1,
|
DestRegister: 1,
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
Base: expr.PayloadBaseNetworkHeader,
|
||||||
Offset: offset,
|
Offset: offset,
|
||||||
Len: r.af.addrLen,
|
Len: 4,
|
||||||
},
|
},
|
||||||
&expr.Lookup{
|
&expr.Lookup{
|
||||||
SourceRegister: 1,
|
SourceRegister: 1,
|
||||||
|
|||||||
@@ -90,9 +90,8 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Build CIDR matching expressions
|
// Build CIDR matching expressions
|
||||||
testRouter := &router{af: afIPv4}
|
sourceExp := applyPrefix(testCase.InputPair.Source.Prefix, true)
|
||||||
sourceExp := testRouter.applyPrefix(testCase.InputPair.Source.Prefix, true)
|
destExp := applyPrefix(testCase.InputPair.Destination.Prefix, false)
|
||||||
destExp := testRouter.applyPrefix(testCase.InputPair.Destination.Prefix, false)
|
|
||||||
|
|
||||||
// Combine all expressions in the correct order
|
// Combine all expressions in the correct order
|
||||||
// nolint:gocritic
|
// nolint:gocritic
|
||||||
@@ -509,136 +508,6 @@ func TestNftablesCreateIpSet(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNftablesCreateIpSet_IPv6(t *testing.T) {
|
|
||||||
if check() != NFTABLES {
|
|
||||||
t.Skip("nftables not supported on this system")
|
|
||||||
}
|
|
||||||
|
|
||||||
workTable, err := createWorkTableIPv6()
|
|
||||||
require.NoError(t, err, "Failed to create v6 work table")
|
|
||||||
defer deleteWorkTableIPv6()
|
|
||||||
|
|
||||||
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
|
|
||||||
require.NoError(t, err, "Failed to create router")
|
|
||||||
require.NoError(t, r.init(workTable))
|
|
||||||
defer func() {
|
|
||||||
require.NoError(t, r.Reset(), "Failed to reset router")
|
|
||||||
}()
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
sources []netip.Prefix
|
|
||||||
expected []netip.Prefix
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Single IPv6",
|
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("2001:db8::1/128")},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Multiple IPv6 Subnets",
|
|
||||||
sources: []netip.Prefix{
|
|
||||||
netip.MustParsePrefix("fd00::/64"),
|
|
||||||
netip.MustParsePrefix("2001:db8::/48"),
|
|
||||||
netip.MustParsePrefix("fe80::/10"),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Overlapping IPv6",
|
|
||||||
sources: []netip.Prefix{
|
|
||||||
netip.MustParsePrefix("fd00::/48"),
|
|
||||||
netip.MustParsePrefix("fd00::/64"),
|
|
||||||
netip.MustParsePrefix("fd00::1/128"),
|
|
||||||
},
|
|
||||||
expected: []netip.Prefix{
|
|
||||||
netip.MustParsePrefix("fd00::/48"),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Mixed prefix lengths",
|
|
||||||
sources: []netip.Prefix{
|
|
||||||
netip.MustParsePrefix("2001:db8:1::/48"),
|
|
||||||
netip.MustParsePrefix("2001:db8:2::1/128"),
|
|
||||||
netip.MustParsePrefix("fd00:abcd::/32"),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
setName := firewall.NewPrefixSet(tt.sources).HashedName()
|
|
||||||
set, err := r.createIpSet(setName, setInput{prefixes: tt.sources})
|
|
||||||
require.NoError(t, err, "Failed to create IPv6 set")
|
|
||||||
require.NotNil(t, set)
|
|
||||||
|
|
||||||
assert.Equal(t, setName, set.Name)
|
|
||||||
assert.True(t, set.Interval)
|
|
||||||
assert.Equal(t, nftables.TypeIP6Addr, set.KeyType)
|
|
||||||
|
|
||||||
fetchedSet, err := r.conn.GetSetByName(r.workTable, setName)
|
|
||||||
require.NoError(t, err, "Failed to fetch created set")
|
|
||||||
|
|
||||||
elements, err := r.conn.GetSetElements(fetchedSet)
|
|
||||||
require.NoError(t, err, "Failed to get set elements")
|
|
||||||
|
|
||||||
uniquePrefixes := make(map[string]bool)
|
|
||||||
for _, elem := range elements {
|
|
||||||
if !elem.IntervalEnd && len(elem.Key) == 16 {
|
|
||||||
ip := netip.AddrFrom16([16]byte(elem.Key))
|
|
||||||
uniquePrefixes[ip.String()] = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
expectedCount := len(tt.expected)
|
|
||||||
if expectedCount == 0 {
|
|
||||||
expectedCount = len(tt.sources)
|
|
||||||
}
|
|
||||||
assert.Equal(t, expectedCount, len(uniquePrefixes), "unique prefix count mismatch")
|
|
||||||
|
|
||||||
r.conn.DelSet(set)
|
|
||||||
require.NoError(t, r.conn.Flush())
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func createWorkTableIPv6() (*nftables.Table, error) {
|
|
||||||
sConn, err := nftables.New(nftables.AsLasting())
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
tables, err := sConn.ListTablesOfFamily(nftables.TableFamilyIPv6)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
for _, t := range tables {
|
|
||||||
if t.Name == tableNameNetbird {
|
|
||||||
sConn.DelTable(t)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
table := sConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv6})
|
|
||||||
err = sConn.Flush()
|
|
||||||
return table, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func deleteWorkTableIPv6() {
|
|
||||||
sConn, err := nftables.New(nftables.AsLasting())
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
tables, err := sConn.ListTablesOfFamily(nftables.TableFamilyIPv6)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
for _, t := range tables {
|
|
||||||
if t.Name == tableNameNetbird {
|
|
||||||
sConn.DelTable(t)
|
|
||||||
_ = sConn.Flush()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func verifyRule(t *testing.T, rule *nftables.Rule, sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, expectSet bool) {
|
func verifyRule(t *testing.T, rule *nftables.Rule, sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, expectSet bool) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
@@ -758,7 +627,7 @@ func containsPort(exprs []expr.Any, port *firewall.Port, isSource bool) bool {
|
|||||||
|
|
||||||
func containsProtocol(exprs []expr.Any, proto firewall.Protocol) bool {
|
func containsProtocol(exprs []expr.Any, proto firewall.Protocol) bool {
|
||||||
var metaFound, cmpFound bool
|
var metaFound, cmpFound bool
|
||||||
expectedProto, _ := afIPv4.protoNum(proto)
|
expectedProto, _ := protoToInt(proto)
|
||||||
for _, e := range exprs {
|
for _, e := range exprs {
|
||||||
switch ex := e.(type) {
|
switch ex := e.(type) {
|
||||||
case *expr.Meta:
|
case *expr.Meta:
|
||||||
@@ -985,55 +854,3 @@ func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
|
|||||||
}
|
}
|
||||||
assert.Equal(t, 1, found, "NAT rule should exist in kernel")
|
assert.Equal(t, 1, found, "NAT rule should exist in kernel")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCalculateLastIP(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
prefix string
|
|
||||||
want string
|
|
||||||
}{
|
|
||||||
{"10.0.0.0/24", "10.0.0.255"},
|
|
||||||
{"10.0.0.0/32", "10.0.0.0"},
|
|
||||||
{"0.0.0.0/0", "255.255.255.255"},
|
|
||||||
{"192.168.1.0/28", "192.168.1.15"},
|
|
||||||
{"fd00::/64", "fd00::ffff:ffff:ffff:ffff"},
|
|
||||||
{"fd00::/128", "fd00::"},
|
|
||||||
{"2001:db8::/48", "2001:db8:0:ffff:ffff:ffff:ffff:ffff"},
|
|
||||||
{"::/0", "ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.prefix, func(t *testing.T) {
|
|
||||||
prefix := netip.MustParsePrefix(tt.prefix)
|
|
||||||
got := calculateLastIP(prefix)
|
|
||||||
assert.Equal(t, tt.want, got.String())
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConvertPrefixesToSet_IPv6(t *testing.T) {
|
|
||||||
r := &router{af: afIPv6}
|
|
||||||
prefixes := []netip.Prefix{
|
|
||||||
netip.MustParsePrefix("fd00::/64"),
|
|
||||||
netip.MustParsePrefix("2001:db8::1/128"),
|
|
||||||
}
|
|
||||||
|
|
||||||
elements := r.convertPrefixesToSet(prefixes)
|
|
||||||
|
|
||||||
// Each prefix produces 2 elements (start + end)
|
|
||||||
require.Len(t, elements, 4)
|
|
||||||
|
|
||||||
// fd00::/64 start
|
|
||||||
assert.Equal(t, netip.MustParseAddr("fd00::").As16(), [16]byte(elements[0].Key))
|
|
||||||
assert.False(t, elements[0].IntervalEnd)
|
|
||||||
|
|
||||||
// fd00::/64 end (fd00:0:0:1::, one past the last)
|
|
||||||
assert.Equal(t, netip.MustParseAddr("fd00:0:0:1::").As16(), [16]byte(elements[1].Key))
|
|
||||||
assert.True(t, elements[1].IntervalEnd)
|
|
||||||
|
|
||||||
// 2001:db8::1/128 start
|
|
||||||
assert.Equal(t, netip.MustParseAddr("2001:db8::1").As16(), [16]byte(elements[2].Key))
|
|
||||||
assert.False(t, elements[2].IntervalEnd)
|
|
||||||
|
|
||||||
// 2001:db8::1/128 end (2001:db8::2)
|
|
||||||
assert.Equal(t, netip.MustParseAddr("2001:db8::2").As16(), [16]byte(elements[3].Key))
|
|
||||||
assert.True(t, elements[3].IntervalEnd)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -5,10 +5,8 @@ import (
|
|||||||
"os/exec"
|
"os/exec"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
"github.com/hashicorp/go-multierror"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -31,20 +29,15 @@ func (m *Manager) Close(*statemanager.Manager) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var merr *multierror.Error
|
if !isFirewallRuleActive(firewallRuleName) {
|
||||||
if isFirewallRuleActive(firewallRuleName) {
|
return nil
|
||||||
if err := manageFirewallRule(firewallRuleName, deleteRule); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove windows firewall rule: %w", err))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if isFirewallRuleActive(firewallRuleName + "-v6") {
|
if err := manageFirewallRule(firewallRuleName, deleteRule); err != nil {
|
||||||
if err := manageFirewallRule(firewallRuleName+"-v6", deleteRule); err != nil {
|
return fmt.Errorf("couldn't remove windows firewall: %w", err)
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove windows v6 firewall rule: %w", err))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AllowNetbird allows netbird interface traffic
|
// AllowNetbird allows netbird interface traffic
|
||||||
@@ -53,33 +46,17 @@ func (m *Manager) AllowNetbird() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if !isFirewallRuleActive(firewallRuleName) {
|
if isFirewallRuleActive(firewallRuleName) {
|
||||||
if err := manageFirewallRule(firewallRuleName,
|
return nil
|
||||||
addRule,
|
|
||||||
"dir=in",
|
|
||||||
"enable=yes",
|
|
||||||
"action=allow",
|
|
||||||
"profile=any",
|
|
||||||
"localip="+m.wgIface.Address().IP.String(),
|
|
||||||
); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
return manageFirewallRule(firewallRuleName,
|
||||||
if v6 := m.wgIface.Address().IPv6; v6.IsValid() && !isFirewallRuleActive(firewallRuleName+"-v6") {
|
addRule,
|
||||||
if err := manageFirewallRule(firewallRuleName+"-v6",
|
"dir=in",
|
||||||
addRule,
|
"enable=yes",
|
||||||
"dir=in",
|
"action=allow",
|
||||||
"enable=yes",
|
"profile=any",
|
||||||
"action=allow",
|
"localip="+m.wgIface.Address().IP.String(),
|
||||||
"profile=any",
|
)
|
||||||
"localip="+v6.String(),
|
|
||||||
); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func manageFirewallRule(ruleName string, action action, extraArgs ...string) error {
|
func manageFirewallRule(ruleName string, action action, extraArgs ...string) error {
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
package conntrack
|
package conntrack
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"strconv"
|
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -65,7 +64,5 @@ type ConnKey struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c ConnKey) String() string {
|
func (c ConnKey) String() string {
|
||||||
return net.JoinHostPort(c.SrcIP.Unmap().String(), strconv.Itoa(int(c.SrcPort))) +
|
return fmt.Sprintf("%s:%d → %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
|
||||||
" → " +
|
|
||||||
net.JoinHostPort(c.DstIP.Unmap().String(), strconv.Itoa(int(c.DstPort)))
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,54 +13,6 @@ import (
|
|||||||
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
||||||
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
||||||
|
|
||||||
func TestConnKey_String(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
key ConnKey
|
|
||||||
expect string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "IPv4",
|
|
||||||
key: ConnKey{
|
|
||||||
SrcIP: netip.MustParseAddr("192.168.1.1"),
|
|
||||||
DstIP: netip.MustParseAddr("10.0.0.1"),
|
|
||||||
SrcPort: 12345,
|
|
||||||
DstPort: 80,
|
|
||||||
},
|
|
||||||
expect: "192.168.1.1:12345 → 10.0.0.1:80",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "IPv6",
|
|
||||||
key: ConnKey{
|
|
||||||
SrcIP: netip.MustParseAddr("2001:db8::1"),
|
|
||||||
DstIP: netip.MustParseAddr("2001:db8::2"),
|
|
||||||
SrcPort: 54321,
|
|
||||||
DstPort: 443,
|
|
||||||
},
|
|
||||||
expect: "[2001:db8::1]:54321 → [2001:db8::2]:443",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "IPv4-mapped IPv6 unmaps",
|
|
||||||
key: ConnKey{
|
|
||||||
SrcIP: netip.MustParseAddr("::ffff:10.0.0.1"),
|
|
||||||
DstIP: netip.MustParseAddr("::ffff:10.0.0.2"),
|
|
||||||
SrcPort: 1000,
|
|
||||||
DstPort: 2000,
|
|
||||||
},
|
|
||||||
expect: "10.0.0.1:1000 → 10.0.0.2:2000",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
got := tc.key.String()
|
|
||||||
if got != tc.expect {
|
|
||||||
t.Errorf("got %q, want %q", got, tc.expect)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Memory pressure tests
|
// Memory pressure tests
|
||||||
func BenchmarkMemoryPressure(b *testing.B) {
|
func BenchmarkMemoryPressure(b *testing.B) {
|
||||||
b.Run("TCPHighLoad", func(b *testing.B) {
|
b.Run("TCPHighLoad", func(b *testing.B) {
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"strconv"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -22,14 +21,9 @@ const (
|
|||||||
// ICMPCleanupInterval is how often we check for stale ICMP connections
|
// ICMPCleanupInterval is how often we check for stale ICMP connections
|
||||||
ICMPCleanupInterval = 15 * time.Second
|
ICMPCleanupInterval = 15 * time.Second
|
||||||
|
|
||||||
// MaxICMPPayloadLength is the maximum length of ICMP payload we consider for original packet info.
|
// MaxICMPPayloadLength is the maximum length of ICMP payload we consider for original packet info,
|
||||||
// IPv4: 20-byte header + 8-byte transport = 28 bytes.
|
// which includes the IP header (20 bytes) and transport header (8 bytes)
|
||||||
// IPv6: 40-byte header + 8-byte transport = 48 bytes.
|
MaxICMPPayloadLength = 28
|
||||||
MaxICMPPayloadLength = 48
|
|
||||||
// minICMPPayloadIPv4 is the minimum embedded packet length for IPv4 ICMP errors.
|
|
||||||
minICMPPayloadIPv4 = 28
|
|
||||||
// minICMPPayloadIPv6 is the minimum embedded packet length for IPv6 ICMP errors.
|
|
||||||
minICMPPayloadIPv6 = 48
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ICMPConnKey uniquely identifies an ICMP connection
|
// ICMPConnKey uniquely identifies an ICMP connection
|
||||||
@@ -71,7 +65,7 @@ type ICMPInfo struct {
|
|||||||
|
|
||||||
// String implements fmt.Stringer for lazy evaluation in log messages
|
// String implements fmt.Stringer for lazy evaluation in log messages
|
||||||
func (info ICMPInfo) String() string {
|
func (info ICMPInfo) String() string {
|
||||||
if info.isErrorMessage() && info.PayloadLen >= minICMPPayloadIPv4 {
|
if info.isErrorMessage() && info.PayloadLen >= MaxICMPPayloadLength {
|
||||||
if origInfo := info.parseOriginalPacket(); origInfo != "" {
|
if origInfo := info.parseOriginalPacket(); origInfo != "" {
|
||||||
return fmt.Sprintf("%s (original: %s)", info.TypeCode, origInfo)
|
return fmt.Sprintf("%s (original: %s)", info.TypeCode, origInfo)
|
||||||
}
|
}
|
||||||
@@ -80,72 +74,42 @@ func (info ICMPInfo) String() string {
|
|||||||
return info.TypeCode.String()
|
return info.TypeCode.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
// isErrorMessage returns true if this ICMP type carries original packet info.
|
// isErrorMessage returns true if this ICMP type carries original packet info
|
||||||
// Covers both ICMPv4 and ICMPv6 error types. Without a family field we match
|
|
||||||
// both sets; type 3 overlaps (v4 DestUnreachable / v6 TimeExceeded) so it's
|
|
||||||
// kept as a literal.
|
|
||||||
func (info ICMPInfo) isErrorMessage() bool {
|
func (info ICMPInfo) isErrorMessage() bool {
|
||||||
typ := info.TypeCode.Type()
|
typ := info.TypeCode.Type()
|
||||||
// ICMPv4 error types
|
return typ == 3 || // Destination Unreachable
|
||||||
if typ == layers.ICMPv4TypeDestinationUnreachable ||
|
typ == 5 || // Redirect
|
||||||
typ == layers.ICMPv4TypeRedirect ||
|
typ == 11 || // Time Exceeded
|
||||||
typ == layers.ICMPv4TypeTimeExceeded ||
|
typ == 12 // Parameter Problem
|
||||||
typ == layers.ICMPv4TypeParameterProblem {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
// ICMPv6 error types (type 3 already matched above as v4 DestUnreachable)
|
|
||||||
if typ == layers.ICMPv6TypeDestinationUnreachable ||
|
|
||||||
typ == layers.ICMPv6TypePacketTooBig ||
|
|
||||||
typ == layers.ICMPv6TypeParameterProblem {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseOriginalPacket extracts info about the original packet from ICMP payload
|
// parseOriginalPacket extracts info about the original packet from ICMP payload
|
||||||
func (info ICMPInfo) parseOriginalPacket() string {
|
func (info ICMPInfo) parseOriginalPacket() string {
|
||||||
if info.PayloadLen == 0 {
|
if info.PayloadLen < MaxICMPPayloadLength {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
version := (info.PayloadData[0] >> 4) & 0xF
|
// TODO: handle IPv6
|
||||||
|
if version := (info.PayloadData[0] >> 4) & 0xF; version != 4 {
|
||||||
var protocol uint8
|
|
||||||
var srcIP, dstIP net.IP
|
|
||||||
var transportData []byte
|
|
||||||
|
|
||||||
switch version {
|
|
||||||
case 4:
|
|
||||||
if info.PayloadLen < minICMPPayloadIPv4 {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
protocol = info.PayloadData[9]
|
|
||||||
srcIP = net.IP(info.PayloadData[12:16])
|
|
||||||
dstIP = net.IP(info.PayloadData[16:20])
|
|
||||||
transportData = info.PayloadData[20:]
|
|
||||||
case 6:
|
|
||||||
if info.PayloadLen < minICMPPayloadIPv6 {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
// Next Header field in IPv6 header
|
|
||||||
protocol = info.PayloadData[6]
|
|
||||||
srcIP = net.IP(info.PayloadData[8:24])
|
|
||||||
dstIP = net.IP(info.PayloadData[24:40])
|
|
||||||
transportData = info.PayloadData[40:]
|
|
||||||
default:
|
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protocol := info.PayloadData[9]
|
||||||
|
srcIP := net.IP(info.PayloadData[12:16])
|
||||||
|
dstIP := net.IP(info.PayloadData[16:20])
|
||||||
|
|
||||||
|
transportData := info.PayloadData[20:]
|
||||||
|
|
||||||
switch nftypes.Protocol(protocol) {
|
switch nftypes.Protocol(protocol) {
|
||||||
case nftypes.TCP:
|
case nftypes.TCP:
|
||||||
srcPort := uint16(transportData[0])<<8 | uint16(transportData[1])
|
srcPort := uint16(transportData[0])<<8 | uint16(transportData[1])
|
||||||
dstPort := uint16(transportData[2])<<8 | uint16(transportData[3])
|
dstPort := uint16(transportData[2])<<8 | uint16(transportData[3])
|
||||||
return "TCP " + net.JoinHostPort(srcIP.String(), strconv.Itoa(int(srcPort))) + " → " + net.JoinHostPort(dstIP.String(), strconv.Itoa(int(dstPort)))
|
return fmt.Sprintf("TCP %s:%d → %s:%d", srcIP, srcPort, dstIP, dstPort)
|
||||||
|
|
||||||
case nftypes.UDP:
|
case nftypes.UDP:
|
||||||
srcPort := uint16(transportData[0])<<8 | uint16(transportData[1])
|
srcPort := uint16(transportData[0])<<8 | uint16(transportData[1])
|
||||||
dstPort := uint16(transportData[2])<<8 | uint16(transportData[3])
|
dstPort := uint16(transportData[2])<<8 | uint16(transportData[3])
|
||||||
return "UDP " + net.JoinHostPort(srcIP.String(), strconv.Itoa(int(srcPort))) + " → " + net.JoinHostPort(dstIP.String(), strconv.Itoa(int(dstPort)))
|
return fmt.Sprintf("UDP %s:%d → %s:%d", srcIP, srcPort, dstIP, dstPort)
|
||||||
|
|
||||||
case nftypes.ICMP:
|
case nftypes.ICMP:
|
||||||
icmpType := transportData[0]
|
icmpType := transportData[0]
|
||||||
@@ -283,10 +247,9 @@ func (t *ICMPTracker) track(
|
|||||||
t.sendEvent(nftypes.TypeStart, conn, ruleId)
|
t.sendEvent(nftypes.TypeStart, conn, ruleId)
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request.
|
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
|
||||||
// Accepts both ICMPv4 (type 0) and ICMPv6 (type 129) echo replies.
|
|
||||||
func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, icmpType uint8, size int) bool {
|
func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, icmpType uint8, size int) bool {
|
||||||
if icmpType != uint8(layers.ICMPv4TypeEchoReply) && icmpType != uint8(layers.ICMPv6TypeEchoReply) {
|
if icmpType != uint8(layers.ICMPv4TypeEchoReply) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -338,13 +301,6 @@ func (t *ICMPTracker) cleanup() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func icmpProtocolForAddr(ip netip.Addr) nftypes.Protocol {
|
|
||||||
if ip.Is6() {
|
|
||||||
return nftypes.ICMPv6
|
|
||||||
}
|
|
||||||
return nftypes.ICMP
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close stops the cleanup routine and releases resources
|
// Close stops the cleanup routine and releases resources
|
||||||
func (t *ICMPTracker) Close() {
|
func (t *ICMPTracker) Close() {
|
||||||
t.tickerCancel()
|
t.tickerCancel()
|
||||||
@@ -360,7 +316,7 @@ func (t *ICMPTracker) sendEvent(typ nftypes.Type, conn *ICMPConnTrack, ruleID []
|
|||||||
Type: typ,
|
Type: typ,
|
||||||
RuleID: ruleID,
|
RuleID: ruleID,
|
||||||
Direction: conn.Direction,
|
Direction: conn.Direction,
|
||||||
Protocol: icmpProtocolForAddr(conn.SourceIP),
|
Protocol: nftypes.ICMP, // TODO: adjust for IPv6/icmpv6
|
||||||
SourceIP: conn.SourceIP,
|
SourceIP: conn.SourceIP,
|
||||||
DestIP: conn.DestIP,
|
DestIP: conn.DestIP,
|
||||||
ICMPType: conn.ICMPType,
|
ICMPType: conn.ICMPType,
|
||||||
@@ -378,7 +334,7 @@ func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, srcIP netip.Ad
|
|||||||
Type: nftypes.TypeStart,
|
Type: nftypes.TypeStart,
|
||||||
RuleID: ruleID,
|
RuleID: ruleID,
|
||||||
Direction: direction,
|
Direction: direction,
|
||||||
Protocol: icmpProtocolForAddr(srcIP),
|
Protocol: nftypes.ICMP,
|
||||||
SourceIP: srcIP,
|
SourceIP: srcIP,
|
||||||
DestIP: dstIP,
|
DestIP: dstIP,
|
||||||
ICMPType: typ,
|
ICMPType: typ,
|
||||||
|
|||||||
@@ -5,42 +5,6 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestICMPConnKey_String(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
key ICMPConnKey
|
|
||||||
expect string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "IPv4",
|
|
||||||
key: ICMPConnKey{
|
|
||||||
SrcIP: netip.MustParseAddr("192.168.1.1"),
|
|
||||||
DstIP: netip.MustParseAddr("10.0.0.1"),
|
|
||||||
ID: 1234,
|
|
||||||
},
|
|
||||||
expect: "192.168.1.1 → 10.0.0.1 (id 1234)",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "IPv6",
|
|
||||||
key: ICMPConnKey{
|
|
||||||
SrcIP: netip.MustParseAddr("2001:db8::1"),
|
|
||||||
DstIP: netip.MustParseAddr("2001:db8::2"),
|
|
||||||
ID: 5678,
|
|
||||||
},
|
|
||||||
expect: "2001:db8::1 → 2001:db8::2 (id 5678)",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
got := tc.key.String()
|
|
||||||
if got != tc.expect {
|
|
||||||
t.Errorf("got %q, want %q", got, tc.expect)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkICMPTracker(b *testing.B) {
|
func BenchmarkICMPTracker(b *testing.B) {
|
||||||
b.Run("TrackOutbound", func(b *testing.B) {
|
b.Run("TrackOutbound", func(b *testing.B) {
|
||||||
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
|
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
|
||||||
|
|||||||
@@ -18,10 +18,9 @@ import (
|
|||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/hashicorp/go-multierror"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||||
@@ -36,10 +35,8 @@ import (
|
|||||||
const (
|
const (
|
||||||
layerTypeAll = 255
|
layerTypeAll = 255
|
||||||
|
|
||||||
// ipv4TCPHeaderMinSize represents minimum IPv4 (20) + TCP (20) header size for MSS calculation
|
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
|
||||||
ipv4TCPHeaderMinSize = 40
|
ipTCPHeaderMinSize = 40
|
||||||
// ipv6TCPHeaderMinSize represents minimum IPv6 (40) + TCP (20) header size for MSS calculation
|
|
||||||
ipv6TCPHeaderMinSize = 60
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// serviceKey represents a protocol/port combination for netstack service registry
|
// serviceKey represents a protocol/port combination for netstack service registry
|
||||||
@@ -126,7 +123,7 @@ type Manager struct {
|
|||||||
logger *nblog.Logger
|
logger *nblog.Logger
|
||||||
flowLogger nftypes.FlowLogger
|
flowLogger nftypes.FlowLogger
|
||||||
|
|
||||||
blockRules []firewall.Rule
|
blockRule firewall.Rule
|
||||||
|
|
||||||
// Internal 1:1 DNAT
|
// Internal 1:1 DNAT
|
||||||
dnatEnabled atomic.Bool
|
dnatEnabled atomic.Bool
|
||||||
@@ -141,10 +138,9 @@ type Manager struct {
|
|||||||
netstackServices map[serviceKey]struct{}
|
netstackServices map[serviceKey]struct{}
|
||||||
netstackServiceMutex sync.RWMutex
|
netstackServiceMutex sync.RWMutex
|
||||||
|
|
||||||
mtu uint16
|
mtu uint16
|
||||||
mssClampValueIPv4 uint16
|
mssClampValue uint16
|
||||||
mssClampValueIPv6 uint16
|
mssClampEnabled bool
|
||||||
mssClampEnabled bool
|
|
||||||
|
|
||||||
// Only one hook per protocol is supported. Outbound direction only.
|
// Only one hook per protocol is supported. Outbound direction only.
|
||||||
udpHookOut atomic.Pointer[common.PacketHook]
|
udpHookOut atomic.Pointer[common.PacketHook]
|
||||||
@@ -161,28 +157,11 @@ type decoder struct {
|
|||||||
icmp4 layers.ICMPv4
|
icmp4 layers.ICMPv4
|
||||||
icmp6 layers.ICMPv6
|
icmp6 layers.ICMPv6
|
||||||
decoded []gopacket.LayerType
|
decoded []gopacket.LayerType
|
||||||
parser4 *gopacket.DecodingLayerParser
|
parser *gopacket.DecodingLayerParser
|
||||||
parser6 *gopacket.DecodingLayerParser
|
|
||||||
|
|
||||||
dnatOrigPort uint16
|
dnatOrigPort uint16
|
||||||
}
|
}
|
||||||
|
|
||||||
// decodePacket decodes packet data using the appropriate parser based on IP version.
|
|
||||||
func (d *decoder) decodePacket(data []byte) error {
|
|
||||||
if len(data) == 0 {
|
|
||||||
return errors.New("empty packet")
|
|
||||||
}
|
|
||||||
version := data[0] >> 4
|
|
||||||
switch version {
|
|
||||||
case 4:
|
|
||||||
return d.parser4.DecodeLayers(data, &d.decoded)
|
|
||||||
case 6:
|
|
||||||
return d.parser6.DecodeLayers(data, &d.decoded)
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("unknown IP version %d", version)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create userspace firewall manager constructor
|
// Create userspace firewall manager constructor
|
||||||
func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
|
func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
|
||||||
return create(iface, nil, disableServerRoutes, flowLogger, mtu)
|
return create(iface, nil, disableServerRoutes, flowLogger, mtu)
|
||||||
@@ -240,17 +219,11 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
|||||||
d := &decoder{
|
d := &decoder{
|
||||||
decoded: []gopacket.LayerType{},
|
decoded: []gopacket.LayerType{},
|
||||||
}
|
}
|
||||||
d.parser4 = gopacket.NewDecodingLayerParser(
|
d.parser = gopacket.NewDecodingLayerParser(
|
||||||
layers.LayerTypeIPv4,
|
layers.LayerTypeIPv4,
|
||||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||||
)
|
)
|
||||||
d.parser4.IgnoreUnsupported = true
|
d.parser.IgnoreUnsupported = true
|
||||||
|
|
||||||
d.parser6 = gopacket.NewDecodingLayerParser(
|
|
||||||
layers.LayerTypeIPv6,
|
|
||||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
|
||||||
)
|
|
||||||
d.parser6.IgnoreUnsupported = true
|
|
||||||
return d
|
return d
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -276,12 +249,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
|||||||
|
|
||||||
if !disableMSSClamping {
|
if !disableMSSClamping {
|
||||||
m.mssClampEnabled = true
|
m.mssClampEnabled = true
|
||||||
if mtu > ipv4TCPHeaderMinSize {
|
m.mssClampValue = mtu - ipTCPHeaderMinSize
|
||||||
m.mssClampValueIPv4 = mtu - ipv4TCPHeaderMinSize
|
|
||||||
}
|
|
||||||
if mtu > ipv6TCPHeaderMinSize {
|
|
||||||
m.mssClampValueIPv6 = mtu - ipv6TCPHeaderMinSize
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
|
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
|
||||||
return nil, fmt.Errorf("update local IPs: %w", err)
|
return nil, fmt.Errorf("update local IPs: %w", err)
|
||||||
@@ -304,25 +272,13 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
|||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// blockInvalidRouted installs drop rules for traffic to the wg overlay that
|
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) (firewall.Rule, error) {
|
||||||
// arrives via the routing path. v4 and v6 are independent: a v6 install
|
|
||||||
// failure leaves v4 protection in place (and vice versa) so the returned
|
|
||||||
// slice always contains whatever was successfully installed, even on error.
|
|
||||||
// Callers must persist the slice so DisableRouting can clean partial state.
|
|
||||||
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) ([]firewall.Rule, error) {
|
|
||||||
wgPrefix := iface.Address().Network
|
wgPrefix := iface.Address().Network
|
||||||
log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
|
log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
|
||||||
|
|
||||||
sources := []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
|
rule, err := m.addRouteFiltering(
|
||||||
v6Net := iface.Address().IPv6Net
|
|
||||||
if v6Net.IsValid() {
|
|
||||||
sources = append(sources, netip.PrefixFrom(netip.IPv6Unspecified(), 0))
|
|
||||||
}
|
|
||||||
|
|
||||||
var rules []firewall.Rule
|
|
||||||
v4Rule, err := m.addRouteFiltering(
|
|
||||||
nil,
|
nil,
|
||||||
sources,
|
[]netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)},
|
||||||
firewall.Network{Prefix: wgPrefix},
|
firewall.Network{Prefix: wgPrefix},
|
||||||
firewall.ProtocolALL,
|
firewall.ProtocolALL,
|
||||||
nil,
|
nil,
|
||||||
@@ -330,30 +286,12 @@ func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) ([]firewall.Rule,
|
|||||||
firewall.ActionDrop,
|
firewall.ActionDrop,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return rules, fmt.Errorf("block wg v4 net: %w", err)
|
return nil, fmt.Errorf("block wg nte : %w", err)
|
||||||
}
|
|
||||||
rules = append(rules, v4Rule)
|
|
||||||
|
|
||||||
if v6Net.IsValid() {
|
|
||||||
log.Debugf("blocking invalid routed traffic for %s", v6Net)
|
|
||||||
v6Rule, err := m.addRouteFiltering(
|
|
||||||
nil,
|
|
||||||
sources,
|
|
||||||
firewall.Network{Prefix: v6Net},
|
|
||||||
firewall.ProtocolALL,
|
|
||||||
nil,
|
|
||||||
nil,
|
|
||||||
firewall.ActionDrop,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return rules, fmt.Errorf("block wg v6 net: %w", err)
|
|
||||||
}
|
|
||||||
rules = append(rules, v6Rule)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Block networks that we're a client of
|
// TODO: Block networks that we're a client of
|
||||||
|
|
||||||
return rules, nil
|
return rule, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) determineRouting() error {
|
func (m *Manager) determineRouting() error {
|
||||||
@@ -583,7 +521,7 @@ func (m *Manager) addRouteFiltering(
|
|||||||
mgmtId: id,
|
mgmtId: id,
|
||||||
sources: sources,
|
sources: sources,
|
||||||
dstSet: destination.Set,
|
dstSet: destination.Set,
|
||||||
protoLayer: protoToLayer(proto, ipLayerFromPrefix(destination.Prefix)),
|
protoLayer: protoToLayer(proto, layers.LayerTypeIPv4),
|
||||||
srcPort: sPort,
|
srcPort: sPort,
|
||||||
dstPort: dPort,
|
dstPort: dPort,
|
||||||
action: action,
|
action: action,
|
||||||
@@ -674,10 +612,10 @@ func (m *Manager) Flush() error { return nil }
|
|||||||
// resetState clears all firewall rules and closes connection trackers.
|
// resetState clears all firewall rules and closes connection trackers.
|
||||||
// Must be called with m.mutex held.
|
// Must be called with m.mutex held.
|
||||||
func (m *Manager) resetState() {
|
func (m *Manager) resetState() {
|
||||||
clear(m.outgoingRules)
|
maps.Clear(m.outgoingRules)
|
||||||
clear(m.incomingDenyRules)
|
maps.Clear(m.incomingDenyRules)
|
||||||
clear(m.incomingRules)
|
maps.Clear(m.incomingRules)
|
||||||
clear(m.routeRulesMap)
|
maps.Clear(m.routeRulesMap)
|
||||||
m.routeRules = m.routeRules[:0]
|
m.routeRules = m.routeRules[:0]
|
||||||
m.udpHookOut.Store(nil)
|
m.udpHookOut.Store(nil)
|
||||||
m.tcpHookOut.Store(nil)
|
m.tcpHookOut.Store(nil)
|
||||||
@@ -738,7 +676,11 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
destinations := matches[0].destinations
|
destinations := matches[0].destinations
|
||||||
destinations = append(destinations, prefixes...)
|
for _, prefix := range prefixes {
|
||||||
|
if prefix.Addr().Is4() {
|
||||||
|
destinations = append(destinations, prefix)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
slices.SortFunc(destinations, func(a, b netip.Prefix) int {
|
slices.SortFunc(destinations, func(a, b netip.Prefix) int {
|
||||||
cmp := a.Addr().Compare(b.Addr())
|
cmp := a.Addr().Compare(b.Addr())
|
||||||
@@ -777,7 +719,7 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
|
|||||||
d := m.decoders.Get().(*decoder)
|
d := m.decoders.Get().(*decoder)
|
||||||
defer m.decoders.Put(d)
|
defer m.decoders.Put(d)
|
||||||
|
|
||||||
if err := d.decodePacket(packetData); err != nil {
|
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -861,32 +803,12 @@ func (m *Manager) clampTCPMSS(packetData []byte, d *decoder) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
var mssClampValue uint16
|
|
||||||
var ipHeaderSize int
|
|
||||||
switch d.decoded[0] {
|
|
||||||
case layers.LayerTypeIPv4:
|
|
||||||
mssClampValue = m.mssClampValueIPv4
|
|
||||||
ipHeaderSize = int(d.ip4.IHL) * 4
|
|
||||||
if ipHeaderSize < 20 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
case layers.LayerTypeIPv6:
|
|
||||||
mssClampValue = m.mssClampValueIPv6
|
|
||||||
ipHeaderSize = 40
|
|
||||||
default:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if mssClampValue == 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
mssOptionIndex := -1
|
mssOptionIndex := -1
|
||||||
var currentMSS uint16
|
var currentMSS uint16
|
||||||
for i, opt := range d.tcp.Options {
|
for i, opt := range d.tcp.Options {
|
||||||
if opt.OptionType == layers.TCPOptionKindMSS && len(opt.OptionData) == 2 {
|
if opt.OptionType == layers.TCPOptionKindMSS && len(opt.OptionData) == 2 {
|
||||||
currentMSS = binary.BigEndian.Uint16(opt.OptionData)
|
currentMSS = binary.BigEndian.Uint16(opt.OptionData)
|
||||||
if currentMSS > mssClampValue {
|
if currentMSS > m.mssClampValue {
|
||||||
mssOptionIndex = i
|
mssOptionIndex = i
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -897,15 +819,20 @@ func (m *Manager) clampTCPMSS(packetData []byte, d *decoder) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if !m.updateMSSOption(packetData, d, mssOptionIndex, mssClampValue, ipHeaderSize) {
|
ipHeaderSize := int(d.ip4.IHL) * 4
|
||||||
|
if ipHeaderSize < 20 {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, mssClampValue)
|
if !m.updateMSSOption(packetData, d, mssOptionIndex, ipHeaderSize) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, m.mssClampValue)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) updateMSSOption(packetData []byte, d *decoder, mssOptionIndex int, mssClampValue uint16, ipHeaderSize int) bool {
|
func (m *Manager) updateMSSOption(packetData []byte, d *decoder, mssOptionIndex, ipHeaderSize int) bool {
|
||||||
tcpHeaderStart := ipHeaderSize
|
tcpHeaderStart := ipHeaderSize
|
||||||
tcpOptionsStart := tcpHeaderStart + 20
|
tcpOptionsStart := tcpHeaderStart + 20
|
||||||
|
|
||||||
@@ -920,7 +847,7 @@ func (m *Manager) updateMSSOption(packetData []byte, d *decoder, mssOptionIndex
|
|||||||
}
|
}
|
||||||
|
|
||||||
mssValueOffset := optOffset + 2
|
mssValueOffset := optOffset + 2
|
||||||
binary.BigEndian.PutUint16(packetData[mssValueOffset:mssValueOffset+2], mssClampValue)
|
binary.BigEndian.PutUint16(packetData[mssValueOffset:mssValueOffset+2], m.mssClampValue)
|
||||||
|
|
||||||
m.recalculateTCPChecksum(packetData, d, tcpHeaderStart)
|
m.recalculateTCPChecksum(packetData, d, tcpHeaderStart)
|
||||||
return true
|
return true
|
||||||
@@ -930,32 +857,18 @@ func (m *Manager) recalculateTCPChecksum(packetData []byte, d *decoder, tcpHeade
|
|||||||
tcpLayer := packetData[tcpHeaderStart:]
|
tcpLayer := packetData[tcpHeaderStart:]
|
||||||
tcpLength := len(packetData) - tcpHeaderStart
|
tcpLength := len(packetData) - tcpHeaderStart
|
||||||
|
|
||||||
// Zero out existing checksum
|
|
||||||
tcpLayer[16] = 0
|
tcpLayer[16] = 0
|
||||||
tcpLayer[17] = 0
|
tcpLayer[17] = 0
|
||||||
|
|
||||||
// Build pseudo-header checksum based on IP version
|
|
||||||
var pseudoSum uint32
|
var pseudoSum uint32
|
||||||
switch d.decoded[0] {
|
pseudoSum += uint32(d.ip4.SrcIP[0])<<8 | uint32(d.ip4.SrcIP[1])
|
||||||
case layers.LayerTypeIPv4:
|
pseudoSum += uint32(d.ip4.SrcIP[2])<<8 | uint32(d.ip4.SrcIP[3])
|
||||||
pseudoSum += uint32(d.ip4.SrcIP[0])<<8 | uint32(d.ip4.SrcIP[1])
|
pseudoSum += uint32(d.ip4.DstIP[0])<<8 | uint32(d.ip4.DstIP[1])
|
||||||
pseudoSum += uint32(d.ip4.SrcIP[2])<<8 | uint32(d.ip4.SrcIP[3])
|
pseudoSum += uint32(d.ip4.DstIP[2])<<8 | uint32(d.ip4.DstIP[3])
|
||||||
pseudoSum += uint32(d.ip4.DstIP[0])<<8 | uint32(d.ip4.DstIP[1])
|
pseudoSum += uint32(d.ip4.Protocol)
|
||||||
pseudoSum += uint32(d.ip4.DstIP[2])<<8 | uint32(d.ip4.DstIP[3])
|
pseudoSum += uint32(tcpLength)
|
||||||
pseudoSum += uint32(d.ip4.Protocol)
|
|
||||||
pseudoSum += uint32(tcpLength)
|
|
||||||
case layers.LayerTypeIPv6:
|
|
||||||
for i := 0; i < 16; i += 2 {
|
|
||||||
pseudoSum += uint32(d.ip6.SrcIP[i])<<8 | uint32(d.ip6.SrcIP[i+1])
|
|
||||||
}
|
|
||||||
for i := 0; i < 16; i += 2 {
|
|
||||||
pseudoSum += uint32(d.ip6.DstIP[i])<<8 | uint32(d.ip6.DstIP[i+1])
|
|
||||||
}
|
|
||||||
pseudoSum += uint32(tcpLength)
|
|
||||||
pseudoSum += uint32(layers.IPProtocolTCP)
|
|
||||||
}
|
|
||||||
|
|
||||||
sum := pseudoSum
|
var sum = pseudoSum
|
||||||
for i := 0; i < tcpLength-1; i += 2 {
|
for i := 0; i < tcpLength-1; i += 2 {
|
||||||
sum += uint32(tcpLayer[i])<<8 | uint32(tcpLayer[i+1])
|
sum += uint32(tcpLayer[i])<<8 | uint32(tcpLayer[i+1])
|
||||||
}
|
}
|
||||||
@@ -993,9 +906,6 @@ func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, packetData
|
|||||||
}
|
}
|
||||||
case layers.LayerTypeICMPv4:
|
case layers.LayerTypeICMPv4:
|
||||||
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size)
|
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size)
|
||||||
case layers.LayerTypeICMPv6:
|
|
||||||
id, tc := icmpv6EchoFields(d)
|
|
||||||
m.icmpTracker.TrackOutbound(srcIP, dstIP, id, tc, d.icmp6.Payload, size)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1009,9 +919,6 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
|
|||||||
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size, d.dnatOrigPort)
|
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size, d.dnatOrigPort)
|
||||||
case layers.LayerTypeICMPv4:
|
case layers.LayerTypeICMPv4:
|
||||||
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size)
|
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size)
|
||||||
case layers.LayerTypeICMPv6:
|
|
||||||
id, tc := icmpv6EchoFields(d)
|
|
||||||
m.icmpTracker.TrackInbound(srcIP, dstIP, id, tc, ruleID, d.icmp6.Payload, size)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
d.dnatOrigPort = 0
|
d.dnatOrigPort = 0
|
||||||
@@ -1044,19 +951,15 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
|
|||||||
|
|
||||||
// TODO: pass fragments of routed packets to forwarder
|
// TODO: pass fragments of routed packets to forwarder
|
||||||
if fragment {
|
if fragment {
|
||||||
if d.decoded[0] == layers.LayerTypeIPv4 {
|
m.logger.Trace4("packet is a fragment: src=%v dst=%v id=%v flags=%v",
|
||||||
m.logger.Trace4("packet is a fragment: src=%v dst=%v id=%v flags=%v",
|
srcIP, dstIP, d.ip4.Id, d.ip4.Flags)
|
||||||
srcIP, dstIP, d.ip4.Id, d.ip4.Flags)
|
|
||||||
} else {
|
|
||||||
m.logger.Trace2("packet is an IPv6 fragment: src=%v dst=%v", srcIP, dstIP)
|
|
||||||
}
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: optimize port DNAT by caching matched rules in conntrack
|
// TODO: optimize port DNAT by caching matched rules in conntrack
|
||||||
if translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP); translated {
|
if translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP); translated {
|
||||||
// Re-decode after port DNAT translation to update port information
|
// Re-decode after port DNAT translation to update port information
|
||||||
if err := d.decodePacket(packetData); err != nil {
|
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||||
m.logger.Error1("failed to re-decode packet after port DNAT: %v", err)
|
m.logger.Error1("failed to re-decode packet after port DNAT: %v", err)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@@ -1065,7 +968,7 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
|
|||||||
|
|
||||||
if translated := m.translateInboundReverse(packetData, d); translated {
|
if translated := m.translateInboundReverse(packetData, d); translated {
|
||||||
// Re-decode after translation to get original addresses
|
// Re-decode after translation to get original addresses
|
||||||
if err := d.decodePacket(packetData); err != nil {
|
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||||
m.logger.Error1("failed to re-decode packet after reverse DNAT: %v", err)
|
m.logger.Error1("failed to re-decode packet after reverse DNAT: %v", err)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@@ -1197,48 +1100,6 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// icmpv6EchoFields extracts the echo identifier from an ICMPv6 packet and maps
|
|
||||||
// the ICMPv6 type code to an ICMPv4TypeCode so the ICMP conntrack can handle
|
|
||||||
// both families uniformly. The echo ID is in the first two payload bytes.
|
|
||||||
func icmpv6EchoFields(d *decoder) (id uint16, tc layers.ICMPv4TypeCode) {
|
|
||||||
if len(d.icmp6.Payload) >= 2 {
|
|
||||||
id = uint16(d.icmp6.Payload[0])<<8 | uint16(d.icmp6.Payload[1])
|
|
||||||
}
|
|
||||||
// Map ICMPv6 echo types to ICMPv4 equivalents for unified tracking.
|
|
||||||
switch d.icmp6.TypeCode.Type() {
|
|
||||||
case layers.ICMPv6TypeEchoRequest:
|
|
||||||
tc = layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0)
|
|
||||||
case layers.ICMPv6TypeEchoReply:
|
|
||||||
tc = layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoReply, 0)
|
|
||||||
default:
|
|
||||||
tc = layers.CreateICMPv4TypeCode(d.icmp6.TypeCode.Type(), d.icmp6.TypeCode.Code())
|
|
||||||
}
|
|
||||||
return id, tc
|
|
||||||
}
|
|
||||||
|
|
||||||
// protoLayerMatches checks if a packet's protocol layer matches a rule's expected
|
|
||||||
// protocol layer. ICMPv4 and ICMPv6 are treated as equivalent when matching
|
|
||||||
// ICMP rules since management sends a single ICMP rule for both families.
|
|
||||||
func protoLayerMatches(ruleLayer, packetLayer gopacket.LayerType) bool {
|
|
||||||
if ruleLayer == packetLayer {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if ruleLayer == layers.LayerTypeICMPv4 && packetLayer == layers.LayerTypeICMPv6 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if ruleLayer == layers.LayerTypeICMPv6 && packetLayer == layers.LayerTypeICMPv4 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func ipLayerFromPrefix(p netip.Prefix) gopacket.LayerType {
|
|
||||||
if p.Addr().Is6() {
|
|
||||||
return layers.LayerTypeIPv6
|
|
||||||
}
|
|
||||||
return layers.LayerTypeIPv4
|
|
||||||
}
|
|
||||||
|
|
||||||
func protoToLayer(proto firewall.Protocol, ipLayer gopacket.LayerType) gopacket.LayerType {
|
func protoToLayer(proto firewall.Protocol, ipLayer gopacket.LayerType) gopacket.LayerType {
|
||||||
switch proto {
|
switch proto {
|
||||||
case firewall.ProtocolTCP:
|
case firewall.ProtocolTCP:
|
||||||
@@ -1262,10 +1123,8 @@ func getProtocolFromPacket(d *decoder) nftypes.Protocol {
|
|||||||
return nftypes.TCP
|
return nftypes.TCP
|
||||||
case layers.LayerTypeUDP:
|
case layers.LayerTypeUDP:
|
||||||
return nftypes.UDP
|
return nftypes.UDP
|
||||||
case layers.LayerTypeICMPv4:
|
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
||||||
return nftypes.ICMP
|
return nftypes.ICMP
|
||||||
case layers.LayerTypeICMPv6:
|
|
||||||
return nftypes.ICMPv6
|
|
||||||
default:
|
default:
|
||||||
return nftypes.ProtocolUnknown
|
return nftypes.ProtocolUnknown
|
||||||
}
|
}
|
||||||
@@ -1286,7 +1145,7 @@ func getPortsFromPacket(d *decoder) (srcPort, dstPort uint16) {
|
|||||||
// It returns true, false if the packet is valid and not a fragment.
|
// It returns true, false if the packet is valid and not a fragment.
|
||||||
// It returns true, true if the packet is a fragment and valid.
|
// It returns true, true if the packet is a fragment and valid.
|
||||||
func (m *Manager) isValidPacket(d *decoder, packetData []byte) (bool, bool) {
|
func (m *Manager) isValidPacket(d *decoder, packetData []byte) (bool, bool) {
|
||||||
if err := d.decodePacket(packetData); err != nil {
|
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||||
m.logger.Trace1("couldn't decode packet, err: %s", err)
|
m.logger.Trace1("couldn't decode packet, err: %s", err)
|
||||||
return false, false
|
return false, false
|
||||||
}
|
}
|
||||||
@@ -1299,21 +1158,10 @@ func (m *Manager) isValidPacket(d *decoder, packetData []byte) (bool, bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Fragments are also valid
|
// Fragments are also valid
|
||||||
if l == 1 {
|
if l == 1 && d.decoded[0] == layers.LayerTypeIPv4 {
|
||||||
switch d.decoded[0] {
|
ip4 := d.ip4
|
||||||
case layers.LayerTypeIPv4:
|
if ip4.Flags&layers.IPv4MoreFragments != 0 || ip4.FragOffset != 0 {
|
||||||
if d.ip4.Flags&layers.IPv4MoreFragments != 0 || d.ip4.FragOffset != 0 {
|
return true, true
|
||||||
return true, true
|
|
||||||
}
|
|
||||||
case layers.LayerTypeIPv6:
|
|
||||||
// IPv6 uses Fragment extension header (NextHeader=44). If gopacket
|
|
||||||
// only decoded the IPv6 layer, the transport is in a fragment.
|
|
||||||
// TODO: handle non-Fragment extension headers (HopByHop, Routing,
|
|
||||||
// DestOpts) by walking the chain. gopacket's parser does not
|
|
||||||
// support them as DecodingLayers; today we drop such packets.
|
|
||||||
if d.ip6.NextHeader == layers.IPProtocolIPv6Fragment {
|
|
||||||
return true, true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1351,35 +1199,21 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr,
|
|||||||
size,
|
size,
|
||||||
)
|
)
|
||||||
|
|
||||||
case layers.LayerTypeICMPv6:
|
// TODO: ICMPv6
|
||||||
id, _ := icmpv6EchoFields(d)
|
|
||||||
return m.icmpTracker.IsValidInbound(
|
|
||||||
srcIP,
|
|
||||||
dstIP,
|
|
||||||
id,
|
|
||||||
d.icmp6.TypeCode.Type(),
|
|
||||||
size,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// isSpecialICMP returns true if the packet is a special ICMP error packet that should be allowed.
|
// isSpecialICMP returns true if the packet is a special ICMP packet that should be allowed
|
||||||
func (m *Manager) isSpecialICMP(d *decoder) bool {
|
func (m *Manager) isSpecialICMP(d *decoder) bool {
|
||||||
switch d.decoded[1] {
|
if d.decoded[1] != layers.LayerTypeICMPv4 {
|
||||||
case layers.LayerTypeICMPv4:
|
return false
|
||||||
icmpType := d.icmp4.TypeCode.Type()
|
|
||||||
return icmpType == layers.ICMPv4TypeDestinationUnreachable ||
|
|
||||||
icmpType == layers.ICMPv4TypeTimeExceeded
|
|
||||||
case layers.LayerTypeICMPv6:
|
|
||||||
icmpType := d.icmp6.TypeCode.Type()
|
|
||||||
return icmpType == layers.ICMPv6TypeDestinationUnreachable ||
|
|
||||||
icmpType == layers.ICMPv6TypePacketTooBig ||
|
|
||||||
icmpType == layers.ICMPv6TypeTimeExceeded ||
|
|
||||||
icmpType == layers.ICMPv6TypeParameterProblem
|
|
||||||
}
|
}
|
||||||
return false
|
|
||||||
|
icmpType := d.icmp4.TypeCode.Type()
|
||||||
|
return icmpType == layers.ICMPv4TypeDestinationUnreachable ||
|
||||||
|
icmpType == layers.ICMPv4TypeTimeExceeded
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) peerACLsBlock(srcIP netip.Addr, d *decoder, packetData []byte) ([]byte, bool) {
|
func (m *Manager) peerACLsBlock(srcIP netip.Addr, d *decoder, packetData []byte) ([]byte, bool) {
|
||||||
@@ -1436,7 +1270,7 @@ func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d
|
|||||||
return rule.mgmtId, rule.drop, true
|
return rule.mgmtId, rule.drop, true
|
||||||
}
|
}
|
||||||
|
|
||||||
if !protoLayerMatches(rule.protoLayer, payloadLayer) {
|
if payloadLayer != rule.protoLayer {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1471,7 +1305,8 @@ func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, protoLayer gopacket.Lay
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) bool {
|
func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) bool {
|
||||||
if rule.protoLayer != layerTypeAll && !protoLayerMatches(rule.protoLayer, protoLayer) {
|
// TODO: handle ipv6 vs ipv4 icmp rules
|
||||||
|
if rule.protoLayer != layerTypeAll && rule.protoLayer != protoLayer {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1532,14 +1367,13 @@ func (m *Manager) EnableRouting() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
rules, err := m.blockInvalidRouted(m.wgIface)
|
rule, err := m.blockInvalidRouted(m.wgIface)
|
||||||
// Persist whatever was installed even on partial failure, so DisableRouting
|
|
||||||
// can clean it up later.
|
|
||||||
m.blockRules = rules
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("block invalid routed: %w", err)
|
return fmt.Errorf("block invalid routed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.blockRule = rule
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1555,16 +1389,9 @@ func (m *Manager) DisableRouting() error {
|
|||||||
m.routingEnabled.Store(false)
|
m.routingEnabled.Store(false)
|
||||||
m.nativeRouter.Store(false)
|
m.nativeRouter.Store(false)
|
||||||
|
|
||||||
var merr *multierror.Error
|
// don't stop forwarder if in use by netstack
|
||||||
for _, rule := range m.blockRules {
|
|
||||||
if err := m.deleteRouteRule(rule); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("delete block rule: %w", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
m.blockRules = nil
|
|
||||||
|
|
||||||
if m.netstack && m.localForwarding {
|
if m.netstack && m.localForwarding {
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
fwder.Stop()
|
fwder.Stop()
|
||||||
@@ -1572,7 +1399,14 @@ func (m *Manager) DisableRouting() error {
|
|||||||
|
|
||||||
log.Debug("forwarder stopped")
|
log.Debug("forwarder stopped")
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
if m.blockRule != nil {
|
||||||
|
if err := m.deleteRouteRule(m.blockRule); err != nil {
|
||||||
|
return fmt.Errorf("delete block rule: %w", err)
|
||||||
|
}
|
||||||
|
m.blockRule = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegisterNetstackService registers a service as listening on the netstack for the given protocol and port
|
// RegisterNetstackService registers a service as listening on the netstack for the given protocol and port
|
||||||
@@ -1626,8 +1460,7 @@ func (m *Manager) shouldForward(d *decoder, dstIP netip.Addr) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// traffic to our other local interfaces (not NetBird IP) - always forward
|
// traffic to our other local interfaces (not NetBird IP) - always forward
|
||||||
addr := m.wgIface.Address()
|
if dstIP != m.wgIface.Address().IP {
|
||||||
if dstIP != addr.IP && (!addr.IPv6.IsValid() || dstIP != addr.IPv6) {
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1023,8 +1023,7 @@ func BenchmarkMSSClamping(b *testing.B) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
manager.mssClampEnabled = true
|
manager.mssClampEnabled = true
|
||||||
manager.mssClampValueIPv4 = 1240
|
manager.mssClampValue = 1240
|
||||||
manager.mssClampValueIPv6 = 1220
|
|
||||||
|
|
||||||
srcIP := net.ParseIP("100.64.0.2")
|
srcIP := net.ParseIP("100.64.0.2")
|
||||||
dstIP := net.ParseIP("8.8.8.8")
|
dstIP := net.ParseIP("8.8.8.8")
|
||||||
@@ -1089,8 +1088,7 @@ func BenchmarkMSSClampingOverhead(b *testing.B) {
|
|||||||
|
|
||||||
manager.mssClampEnabled = sc.enabled
|
manager.mssClampEnabled = sc.enabled
|
||||||
if sc.enabled {
|
if sc.enabled {
|
||||||
manager.mssClampValueIPv4 = 1240
|
manager.mssClampValue = 1240
|
||||||
manager.mssClampValueIPv6 = 1220
|
|
||||||
}
|
}
|
||||||
|
|
||||||
srcIP := net.ParseIP("100.64.0.2")
|
srcIP := net.ParseIP("100.64.0.2")
|
||||||
@@ -1143,8 +1141,7 @@ func BenchmarkMSSClampingMemory(b *testing.B) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
manager.mssClampEnabled = true
|
manager.mssClampEnabled = true
|
||||||
manager.mssClampValueIPv4 = 1240
|
manager.mssClampValue = 1240
|
||||||
manager.mssClampValueIPv6 = 1220
|
|
||||||
|
|
||||||
srcIP := net.ParseIP("100.64.0.2")
|
srcIP := net.ParseIP("100.64.0.2")
|
||||||
dstIP := net.ParseIP("8.8.8.8")
|
dstIP := net.ParseIP("8.8.8.8")
|
||||||
|
|||||||
@@ -539,236 +539,53 @@ func TestPeerACLFiltering(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPeerACLFilteringIPv6(t *testing.T) {
|
|
||||||
localIP := netip.MustParseAddr("100.10.0.100")
|
|
||||||
localIPv6 := netip.MustParseAddr("fd00::100")
|
|
||||||
wgNet := netip.MustParsePrefix("100.10.0.0/16")
|
|
||||||
wgNetV6 := netip.MustParsePrefix("fd00::/64")
|
|
||||||
|
|
||||||
ifaceMock := &IFaceMock{
|
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
|
||||||
AddressFunc: func() wgaddr.Address {
|
|
||||||
return wgaddr.Address{
|
|
||||||
IP: localIP,
|
|
||||||
Network: wgNet,
|
|
||||||
IPv6: localIPv6,
|
|
||||||
IPv6Net: wgNetV6,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
|
||||||
require.NoError(t, err)
|
|
||||||
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
|
|
||||||
|
|
||||||
err = manager.UpdateLocalIPs()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
testCases := []struct {
|
|
||||||
name string
|
|
||||||
srcIP string
|
|
||||||
dstIP string
|
|
||||||
proto fw.Protocol
|
|
||||||
srcPort uint16
|
|
||||||
dstPort uint16
|
|
||||||
ruleIP string
|
|
||||||
ruleProto fw.Protocol
|
|
||||||
ruleDstPort *fw.Port
|
|
||||||
ruleAction fw.Action
|
|
||||||
shouldBeBlocked bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "IPv6: allow TCP from peer",
|
|
||||||
srcIP: "fd00::1",
|
|
||||||
dstIP: "fd00::100",
|
|
||||||
proto: fw.ProtocolTCP,
|
|
||||||
srcPort: 12345,
|
|
||||||
dstPort: 443,
|
|
||||||
ruleIP: "fd00::1",
|
|
||||||
ruleProto: fw.ProtocolTCP,
|
|
||||||
ruleDstPort: &fw.Port{Values: []uint16{443}},
|
|
||||||
ruleAction: fw.ActionAccept,
|
|
||||||
shouldBeBlocked: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "IPv6: allow UDP from peer",
|
|
||||||
srcIP: "fd00::1",
|
|
||||||
dstIP: "fd00::100",
|
|
||||||
proto: fw.ProtocolUDP,
|
|
||||||
srcPort: 12345,
|
|
||||||
dstPort: 53,
|
|
||||||
ruleIP: "fd00::1",
|
|
||||||
ruleProto: fw.ProtocolUDP,
|
|
||||||
ruleDstPort: &fw.Port{Values: []uint16{53}},
|
|
||||||
ruleAction: fw.ActionAccept,
|
|
||||||
shouldBeBlocked: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "IPv6: allow ICMPv6 from peer",
|
|
||||||
srcIP: "fd00::1",
|
|
||||||
dstIP: "fd00::100",
|
|
||||||
proto: fw.ProtocolICMP,
|
|
||||||
ruleIP: "fd00::1",
|
|
||||||
ruleProto: fw.ProtocolICMP,
|
|
||||||
ruleAction: fw.ActionAccept,
|
|
||||||
shouldBeBlocked: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "IPv6: block TCP without rule",
|
|
||||||
srcIP: "fd00::2",
|
|
||||||
dstIP: "fd00::100",
|
|
||||||
proto: fw.ProtocolTCP,
|
|
||||||
srcPort: 12345,
|
|
||||||
dstPort: 443,
|
|
||||||
ruleIP: "fd00::1",
|
|
||||||
ruleProto: fw.ProtocolTCP,
|
|
||||||
ruleDstPort: &fw.Port{Values: []uint16{443}},
|
|
||||||
ruleAction: fw.ActionAccept,
|
|
||||||
shouldBeBlocked: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "IPv6: drop rule",
|
|
||||||
srcIP: "fd00::1",
|
|
||||||
dstIP: "fd00::100",
|
|
||||||
proto: fw.ProtocolTCP,
|
|
||||||
srcPort: 12345,
|
|
||||||
dstPort: 22,
|
|
||||||
ruleIP: "fd00::1",
|
|
||||||
ruleProto: fw.ProtocolTCP,
|
|
||||||
ruleDstPort: &fw.Port{Values: []uint16{22}},
|
|
||||||
ruleAction: fw.ActionDrop,
|
|
||||||
shouldBeBlocked: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "IPv6: allow all protocols",
|
|
||||||
srcIP: "fd00::1",
|
|
||||||
dstIP: "fd00::100",
|
|
||||||
proto: fw.ProtocolUDP,
|
|
||||||
srcPort: 12345,
|
|
||||||
dstPort: 9999,
|
|
||||||
ruleIP: "fd00::1",
|
|
||||||
ruleProto: fw.ProtocolALL,
|
|
||||||
ruleAction: fw.ActionAccept,
|
|
||||||
shouldBeBlocked: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "IPv6: v4 wildcard ICMP rule matches ICMPv6 via protoLayerMatches",
|
|
||||||
srcIP: "fd00::1",
|
|
||||||
dstIP: "fd00::100",
|
|
||||||
proto: fw.ProtocolICMP,
|
|
||||||
ruleIP: "0.0.0.0",
|
|
||||||
ruleProto: fw.ProtocolICMP,
|
|
||||||
ruleAction: fw.ActionAccept,
|
|
||||||
shouldBeBlocked: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Run("IPv6 implicit DROP (no rules)", func(t *testing.T) {
|
|
||||||
packet := createTestPacket(t, "fd00::1", "fd00::100", fw.ProtocolTCP, 12345, 443)
|
|
||||||
isDropped := manager.FilterInbound(packet, 0)
|
|
||||||
require.True(t, isDropped, "IPv6 packet should be dropped when no rules exist")
|
|
||||||
})
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
if tc.ruleAction == fw.ActionDrop {
|
|
||||||
rules, err := manager.AddPeerFiltering(nil, net.ParseIP(tc.ruleIP), fw.ProtocolALL, nil, nil, fw.ActionAccept, "")
|
|
||||||
require.NoError(t, err)
|
|
||||||
t.Cleanup(func() {
|
|
||||||
for _, rule := range rules {
|
|
||||||
require.NoError(t, manager.DeletePeerRule(rule))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
rules, err := manager.AddPeerFiltering(nil, net.ParseIP(tc.ruleIP), tc.ruleProto, nil, tc.ruleDstPort, tc.ruleAction, "")
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotEmpty(t, rules)
|
|
||||||
t.Cleanup(func() {
|
|
||||||
for _, rule := range rules {
|
|
||||||
require.NoError(t, manager.DeletePeerRule(rule))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
|
||||||
isDropped := manager.FilterInbound(packet, 0)
|
|
||||||
require.Equal(t, tc.shouldBeBlocked, isDropped, "packet filter result mismatch")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func createTestPacket(t *testing.T, srcIP, dstIP string, proto fw.Protocol, srcPort, dstPort uint16) []byte {
|
func createTestPacket(t *testing.T, srcIP, dstIP string, proto fw.Protocol, srcPort, dstPort uint16) []byte {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
src := net.ParseIP(srcIP)
|
|
||||||
dst := net.ParseIP(dstIP)
|
|
||||||
|
|
||||||
buf := gopacket.NewSerializeBuffer()
|
buf := gopacket.NewSerializeBuffer()
|
||||||
opts := gopacket.SerializeOptions{
|
opts := gopacket.SerializeOptions{
|
||||||
ComputeChecksums: true,
|
ComputeChecksums: true,
|
||||||
FixLengths: true,
|
FixLengths: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Detect address family
|
ipLayer := &layers.IPv4{
|
||||||
isV6 := src.To4() == nil
|
Version: 4,
|
||||||
|
TTL: 64,
|
||||||
|
SrcIP: net.ParseIP(srcIP),
|
||||||
|
DstIP: net.ParseIP(dstIP),
|
||||||
|
}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
|
switch proto {
|
||||||
|
case fw.ProtocolTCP:
|
||||||
|
ipLayer.Protocol = layers.IPProtocolTCP
|
||||||
|
tcp := &layers.TCP{
|
||||||
|
SrcPort: layers.TCPPort(srcPort),
|
||||||
|
DstPort: layers.TCPPort(dstPort),
|
||||||
|
}
|
||||||
|
err = tcp.SetNetworkLayerForChecksum(ipLayer)
|
||||||
|
require.NoError(t, err)
|
||||||
|
err = gopacket.SerializeLayers(buf, opts, ipLayer, tcp)
|
||||||
|
|
||||||
if isV6 {
|
case fw.ProtocolUDP:
|
||||||
ip6 := &layers.IPv6{
|
ipLayer.Protocol = layers.IPProtocolUDP
|
||||||
Version: 6,
|
udp := &layers.UDP{
|
||||||
HopLimit: 64,
|
SrcPort: layers.UDPPort(srcPort),
|
||||||
SrcIP: src,
|
DstPort: layers.UDPPort(dstPort),
|
||||||
DstIP: dst,
|
|
||||||
}
|
}
|
||||||
|
err = udp.SetNetworkLayerForChecksum(ipLayer)
|
||||||
|
require.NoError(t, err)
|
||||||
|
err = gopacket.SerializeLayers(buf, opts, ipLayer, udp)
|
||||||
|
|
||||||
switch proto {
|
case fw.ProtocolICMP:
|
||||||
case fw.ProtocolTCP:
|
ipLayer.Protocol = layers.IPProtocolICMPv4
|
||||||
ip6.NextHeader = layers.IPProtocolTCP
|
icmp := &layers.ICMPv4{
|
||||||
tcp := &layers.TCP{SrcPort: layers.TCPPort(srcPort), DstPort: layers.TCPPort(dstPort)}
|
TypeCode: layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0),
|
||||||
_ = tcp.SetNetworkLayerForChecksum(ip6)
|
|
||||||
err = gopacket.SerializeLayers(buf, opts, ip6, tcp)
|
|
||||||
case fw.ProtocolUDP:
|
|
||||||
ip6.NextHeader = layers.IPProtocolUDP
|
|
||||||
udp := &layers.UDP{SrcPort: layers.UDPPort(srcPort), DstPort: layers.UDPPort(dstPort)}
|
|
||||||
_ = udp.SetNetworkLayerForChecksum(ip6)
|
|
||||||
err = gopacket.SerializeLayers(buf, opts, ip6, udp)
|
|
||||||
case fw.ProtocolICMP:
|
|
||||||
ip6.NextHeader = layers.IPProtocolICMPv6
|
|
||||||
icmp := &layers.ICMPv6{
|
|
||||||
TypeCode: layers.CreateICMPv6TypeCode(layers.ICMPv6TypeEchoRequest, 0),
|
|
||||||
}
|
|
||||||
_ = icmp.SetNetworkLayerForChecksum(ip6)
|
|
||||||
err = gopacket.SerializeLayers(buf, opts, ip6, icmp)
|
|
||||||
default:
|
|
||||||
err = gopacket.SerializeLayers(buf, opts, ip6)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
ip4 := &layers.IPv4{
|
|
||||||
Version: 4,
|
|
||||||
TTL: 64,
|
|
||||||
SrcIP: src,
|
|
||||||
DstIP: dst,
|
|
||||||
}
|
}
|
||||||
|
err = gopacket.SerializeLayers(buf, opts, ipLayer, icmp)
|
||||||
|
|
||||||
switch proto {
|
default:
|
||||||
case fw.ProtocolTCP:
|
err = gopacket.SerializeLayers(buf, opts, ipLayer)
|
||||||
ip4.Protocol = layers.IPProtocolTCP
|
|
||||||
tcp := &layers.TCP{SrcPort: layers.TCPPort(srcPort), DstPort: layers.TCPPort(dstPort)}
|
|
||||||
_ = tcp.SetNetworkLayerForChecksum(ip4)
|
|
||||||
err = gopacket.SerializeLayers(buf, opts, ip4, tcp)
|
|
||||||
case fw.ProtocolUDP:
|
|
||||||
ip4.Protocol = layers.IPProtocolUDP
|
|
||||||
udp := &layers.UDP{SrcPort: layers.UDPPort(srcPort), DstPort: layers.UDPPort(dstPort)}
|
|
||||||
_ = udp.SetNetworkLayerForChecksum(ip4)
|
|
||||||
err = gopacket.SerializeLayers(buf, opts, ip4, udp)
|
|
||||||
case fw.ProtocolICMP:
|
|
||||||
ip4.Protocol = layers.IPProtocolICMPv4
|
|
||||||
icmp := &layers.ICMPv4{TypeCode: layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0)}
|
|
||||||
err = gopacket.SerializeLayers(buf, opts, ip4, icmp)
|
|
||||||
default:
|
|
||||||
err = gopacket.SerializeLayers(buf, opts, ip4)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -1681,103 +1498,3 @@ func TestRouteACLSet(t *testing.T) {
|
|||||||
_, isAllowed = manager.routeACLsPass(srcIP, dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
|
_, isAllowed = manager.routeACLsPass(srcIP, dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
|
||||||
require.True(t, isAllowed, "After set update, traffic to the added network should be allowed")
|
require.True(t, isAllowed, "After set update, traffic to the added network should be allowed")
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestRouteACLFilteringIPv6 tests IPv6 route ACL matching directly via routeACLsPass.
|
|
||||||
// Note: full FilterInbound for routed IPv6 traffic drops at the forwarder stage (IPv4-only)
|
|
||||||
// but the ACL decision itself is correct.
|
|
||||||
func TestRouteACLFilteringIPv6(t *testing.T) {
|
|
||||||
manager := setupRoutedManager(t, "10.10.0.100/16")
|
|
||||||
|
|
||||||
v6Dst := netip.MustParsePrefix("fd00:dead:beef::/48")
|
|
||||||
_, err := manager.AddRouteFiltering(
|
|
||||||
nil,
|
|
||||||
[]netip.Prefix{netip.MustParsePrefix("fd00::/16")},
|
|
||||||
fw.Network{Prefix: v6Dst},
|
|
||||||
fw.ProtocolTCP,
|
|
||||||
nil,
|
|
||||||
&fw.Port{Values: []uint16{80}},
|
|
||||||
fw.ActionAccept,
|
|
||||||
)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
_, err = manager.AddRouteFiltering(
|
|
||||||
nil,
|
|
||||||
[]netip.Prefix{netip.MustParsePrefix("fd00::/16")},
|
|
||||||
fw.Network{Prefix: netip.MustParsePrefix("fd00:dead:beef:1::/64")},
|
|
||||||
fw.ProtocolALL,
|
|
||||||
nil,
|
|
||||||
nil,
|
|
||||||
fw.ActionDrop,
|
|
||||||
)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
srcIP netip.Addr
|
|
||||||
dstIP netip.Addr
|
|
||||||
proto gopacket.LayerType
|
|
||||||
srcPort uint16
|
|
||||||
dstPort uint16
|
|
||||||
allowed bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "IPv6 TCP to allowed dest",
|
|
||||||
srcIP: netip.MustParseAddr("fd00::1"),
|
|
||||||
dstIP: netip.MustParseAddr("fd00:dead:beef::80"),
|
|
||||||
proto: layers.LayerTypeTCP,
|
|
||||||
srcPort: 12345,
|
|
||||||
dstPort: 80,
|
|
||||||
allowed: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "IPv6 TCP wrong port",
|
|
||||||
srcIP: netip.MustParseAddr("fd00::1"),
|
|
||||||
dstIP: netip.MustParseAddr("fd00:dead:beef::80"),
|
|
||||||
proto: layers.LayerTypeTCP,
|
|
||||||
srcPort: 12345,
|
|
||||||
dstPort: 443,
|
|
||||||
allowed: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "IPv6 UDP not matched by TCP rule",
|
|
||||||
srcIP: netip.MustParseAddr("fd00::1"),
|
|
||||||
dstIP: netip.MustParseAddr("fd00:dead:beef::80"),
|
|
||||||
proto: layers.LayerTypeUDP,
|
|
||||||
srcPort: 12345,
|
|
||||||
dstPort: 80,
|
|
||||||
allowed: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "IPv6 ICMPv6 matches ICMP rule via protoLayerMatches",
|
|
||||||
srcIP: netip.MustParseAddr("fd00::1"),
|
|
||||||
dstIP: netip.MustParseAddr("fd00:dead:beef::80"),
|
|
||||||
proto: layers.LayerTypeICMPv6,
|
|
||||||
allowed: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "IPv6 to denied subnet",
|
|
||||||
srcIP: netip.MustParseAddr("fd00::1"),
|
|
||||||
dstIP: netip.MustParseAddr("fd00:dead:beef:1::1"),
|
|
||||||
proto: layers.LayerTypeTCP,
|
|
||||||
srcPort: 12345,
|
|
||||||
dstPort: 80,
|
|
||||||
allowed: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "IPv6 source outside allowed range",
|
|
||||||
srcIP: netip.MustParseAddr("fe80::1"),
|
|
||||||
dstIP: netip.MustParseAddr("fd00:dead:beef::80"),
|
|
||||||
proto: layers.LayerTypeTCP,
|
|
||||||
srcPort: 12345,
|
|
||||||
dstPort: 80,
|
|
||||||
allowed: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
_, pass := manager.routeACLsPass(tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
|
||||||
require.Equal(t, tc.allowed, pass, "route ACL result mismatch")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -189,21 +189,21 @@ func TestBlockInvalidRoutedIdempotent(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
// Call blockInvalidRouted directly multiple times
|
// Call blockInvalidRouted directly multiple times
|
||||||
rules1, err := manager.blockInvalidRouted(ifaceMock)
|
rule1, err := manager.blockInvalidRouted(ifaceMock)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotEmpty(t, rules1)
|
require.NotNil(t, rule1)
|
||||||
|
|
||||||
rules2, err := manager.blockInvalidRouted(ifaceMock)
|
rule2, err := manager.blockInvalidRouted(ifaceMock)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotEmpty(t, rules2)
|
require.NotNil(t, rule2)
|
||||||
|
|
||||||
rules3, err := manager.blockInvalidRouted(ifaceMock)
|
rule3, err := manager.blockInvalidRouted(ifaceMock)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotEmpty(t, rules3)
|
require.NotNil(t, rule3)
|
||||||
|
|
||||||
// All calls should return the same v4 block rule (idempotent install).
|
// All should return the same rule
|
||||||
assert.Equal(t, rules1[0].ID(), rules2[0].ID(), "Second call should return same v4 rule")
|
assert.Equal(t, rule1.ID(), rule2.ID(), "Second call should return same rule")
|
||||||
assert.Equal(t, rules2[0].ID(), rules3[0].ID(), "Third call should return same v4 rule")
|
assert.Equal(t, rule2.ID(), rule3.ID(), "Third call should return same rule")
|
||||||
|
|
||||||
// Should have exactly 1 route rule
|
// Should have exactly 1 route rule
|
||||||
manager.mutex.RLock()
|
manager.mutex.RLock()
|
||||||
|
|||||||
@@ -535,16 +535,11 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
|||||||
d := &decoder{
|
d := &decoder{
|
||||||
decoded: []gopacket.LayerType{},
|
decoded: []gopacket.LayerType{},
|
||||||
}
|
}
|
||||||
d.parser4 = gopacket.NewDecodingLayerParser(
|
d.parser = gopacket.NewDecodingLayerParser(
|
||||||
layers.LayerTypeIPv4,
|
layers.LayerTypeIPv4,
|
||||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||||
)
|
)
|
||||||
d.parser4.IgnoreUnsupported = true
|
d.parser.IgnoreUnsupported = true
|
||||||
d.parser6 = gopacket.NewDecodingLayerParser(
|
|
||||||
layers.LayerTypeIPv6,
|
|
||||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
|
||||||
)
|
|
||||||
d.parser6.IgnoreUnsupported = true
|
|
||||||
return d
|
return d
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -643,16 +638,11 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
d := &decoder{
|
d := &decoder{
|
||||||
decoded: []gopacket.LayerType{},
|
decoded: []gopacket.LayerType{},
|
||||||
}
|
}
|
||||||
d.parser4 = gopacket.NewDecodingLayerParser(
|
d.parser = gopacket.NewDecodingLayerParser(
|
||||||
layers.LayerTypeIPv4,
|
layers.LayerTypeIPv4,
|
||||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||||
)
|
)
|
||||||
d.parser4.IgnoreUnsupported = true
|
d.parser.IgnoreUnsupported = true
|
||||||
d.parser6 = gopacket.NewDecodingLayerParser(
|
|
||||||
layers.LayerTypeIPv6,
|
|
||||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
|
||||||
)
|
|
||||||
d.parser6.IgnoreUnsupported = true
|
|
||||||
return d
|
return d
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -1058,8 +1048,8 @@ func TestMSSClamping(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
require.True(t, manager.mssClampEnabled, "MSS clamping should be enabled by default")
|
require.True(t, manager.mssClampEnabled, "MSS clamping should be enabled by default")
|
||||||
require.Equal(t, uint16(1280-ipv4TCPHeaderMinSize), manager.mssClampValueIPv4, "IPv4 MSS clamp value should be MTU - 40")
|
expectedMSSValue := uint16(1280 - ipTCPHeaderMinSize)
|
||||||
require.Equal(t, uint16(1280-ipv6TCPHeaderMinSize), manager.mssClampValueIPv6, "IPv6 MSS clamp value should be MTU - 60")
|
require.Equal(t, expectedMSSValue, manager.mssClampValue, "MSS clamp value should be MTU - 40")
|
||||||
|
|
||||||
err = manager.UpdateLocalIPs()
|
err = manager.UpdateLocalIPs()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -1077,7 +1067,7 @@ func TestMSSClamping(t *testing.T) {
|
|||||||
require.Len(t, d.tcp.Options, 1, "Should have MSS option")
|
require.Len(t, d.tcp.Options, 1, "Should have MSS option")
|
||||||
require.Equal(t, uint8(layers.TCPOptionKindMSS), uint8(d.tcp.Options[0].OptionType))
|
require.Equal(t, uint8(layers.TCPOptionKindMSS), uint8(d.tcp.Options[0].OptionType))
|
||||||
actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData)
|
actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData)
|
||||||
require.Equal(t, manager.mssClampValueIPv4, actualMSS, "MSS should be clamped to MTU - 40")
|
require.Equal(t, expectedMSSValue, actualMSS, "MSS should be clamped to MTU - 40")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("SYN packet with low MSS unchanged", func(t *testing.T) {
|
t.Run("SYN packet with low MSS unchanged", func(t *testing.T) {
|
||||||
@@ -1101,7 +1091,7 @@ func TestMSSClamping(t *testing.T) {
|
|||||||
d := parsePacket(t, packet)
|
d := parsePacket(t, packet)
|
||||||
require.Len(t, d.tcp.Options, 1, "Should have MSS option")
|
require.Len(t, d.tcp.Options, 1, "Should have MSS option")
|
||||||
actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData)
|
actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData)
|
||||||
require.Equal(t, manager.mssClampValueIPv4, actualMSS, "MSS in SYN-ACK should be clamped")
|
require.Equal(t, expectedMSSValue, actualMSS, "MSS in SYN-ACK should be clamped")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("Non-SYN packet unchanged", func(t *testing.T) {
|
t.Run("Non-SYN packet unchanged", func(t *testing.T) {
|
||||||
@@ -1273,18 +1263,13 @@ func TestShouldForward(t *testing.T) {
|
|||||||
d := &decoder{
|
d := &decoder{
|
||||||
decoded: []gopacket.LayerType{},
|
decoded: []gopacket.LayerType{},
|
||||||
}
|
}
|
||||||
d.parser4 = gopacket.NewDecodingLayerParser(
|
d.parser = gopacket.NewDecodingLayerParser(
|
||||||
layers.LayerTypeIPv4,
|
layers.LayerTypeIPv4,
|
||||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||||
)
|
)
|
||||||
d.parser4.IgnoreUnsupported = true
|
d.parser.IgnoreUnsupported = true
|
||||||
d.parser6 = gopacket.NewDecodingLayerParser(
|
|
||||||
layers.LayerTypeIPv6,
|
|
||||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
|
||||||
)
|
|
||||||
d.parser6.IgnoreUnsupported = true
|
|
||||||
|
|
||||||
err = d.decodePacket(buf.Bytes())
|
err = d.parser.DecodeLayers(buf.Bytes(), &d.decoded)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
return d
|
return d
|
||||||
@@ -1344,44 +1329,6 @@ func TestShouldForward(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add IPv6 to the interface and test dual-stack cases
|
|
||||||
wgIPv6 := netip.MustParseAddr("fd00::1")
|
|
||||||
otherIPv6 := netip.MustParseAddr("fd00::2")
|
|
||||||
ifaceMock.AddressFunc = func() wgaddr.Address {
|
|
||||||
return wgaddr.Address{
|
|
||||||
IP: wgIP,
|
|
||||||
Network: netip.PrefixFrom(wgIP, 24),
|
|
||||||
IPv6: wgIPv6,
|
|
||||||
IPv6Net: netip.PrefixFrom(wgIPv6, 64),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Re-create manager to pick up the new address with IPv6
|
|
||||||
require.NoError(t, manager.Close(nil))
|
|
||||||
manager, err = Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
v6Cases := []struct {
|
|
||||||
name string
|
|
||||||
dstIP netip.Addr
|
|
||||||
expected bool
|
|
||||||
description string
|
|
||||||
}{
|
|
||||||
{"v6 traffic to other address", otherIPv6, true, "should forward v6 traffic not destined to our v6 address"},
|
|
||||||
{"v6 traffic to our v6 IP", wgIPv6, false, "should not forward traffic destined to our v6 address"},
|
|
||||||
{"v4 traffic to other with v6 configured", otherIP, true, "should forward v4 traffic when v6 configured"},
|
|
||||||
{"v4 traffic to our v4 IP with v6 configured", wgIP, false, "should not forward traffic to our v4 address"},
|
|
||||||
}
|
|
||||||
for _, tt := range v6Cases {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
manager.localForwarding = true
|
|
||||||
manager.netstack = false
|
|
||||||
decoder := createTCPDecoder(8080)
|
|
||||||
result := manager.shouldForward(decoder, tt.dstIP)
|
|
||||||
require.Equal(t, tt.expected, result, tt.description)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
// Configure manager
|
// Configure manager
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
package forwarder
|
package forwarder
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"fmt"
|
||||||
"strconv"
|
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
@@ -55,23 +54,16 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress {
|
|||||||
func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) {
|
func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) {
|
||||||
var written int
|
var written int
|
||||||
for _, pkt := range pkts.AsSlice() {
|
for _, pkt := range pkts.AsSlice() {
|
||||||
|
netHeader := header.IPv4(pkt.NetworkHeader().View().AsSlice())
|
||||||
|
|
||||||
data := stack.PayloadSince(pkt.NetworkHeader())
|
data := stack.PayloadSince(pkt.NetworkHeader())
|
||||||
if data == nil {
|
if data == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
raw := pkt.NetworkHeader().View().AsSlice()
|
|
||||||
if len(raw) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
var address tcpip.Address
|
|
||||||
if raw[0]>>4 == 6 {
|
|
||||||
address = header.IPv6(raw).DestinationAddress()
|
|
||||||
} else {
|
|
||||||
address = header.IPv4(raw).DestinationAddress()
|
|
||||||
}
|
|
||||||
|
|
||||||
pktBytes := data.AsSlice()
|
pktBytes := data.AsSlice()
|
||||||
|
|
||||||
|
address := netHeader.DestinationAddress()
|
||||||
if err := e.device.CreateOutboundPacket(pktBytes, address.AsSlice()); err != nil {
|
if err := e.device.CreateOutboundPacket(pktBytes, address.AsSlice()); err != nil {
|
||||||
e.logger.Error1("CreateOutboundPacket: %v", err)
|
e.logger.Error1("CreateOutboundPacket: %v", err)
|
||||||
continue
|
continue
|
||||||
@@ -122,7 +114,5 @@ type epID stack.TransportEndpointID
|
|||||||
|
|
||||||
func (i epID) String() string {
|
func (i epID) String() string {
|
||||||
// src and remote is swapped
|
// src and remote is swapped
|
||||||
return net.JoinHostPort(i.RemoteAddress.String(), strconv.Itoa(int(i.RemotePort))) +
|
return fmt.Sprintf("%s:%d → %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort)
|
||||||
" → " +
|
|
||||||
net.JoinHostPort(i.LocalAddress.String(), strconv.Itoa(int(i.LocalPort)))
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ import (
|
|||||||
"gvisor.dev/gvisor/pkg/tcpip"
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
|
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
|
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||||
@@ -37,31 +36,25 @@ type Forwarder struct {
|
|||||||
logger *nblog.Logger
|
logger *nblog.Logger
|
||||||
flowLogger nftypes.FlowLogger
|
flowLogger nftypes.FlowLogger
|
||||||
// ruleIdMap is used to store the rule ID for a given connection
|
// ruleIdMap is used to store the rule ID for a given connection
|
||||||
ruleIdMap sync.Map
|
ruleIdMap sync.Map
|
||||||
stack *stack.Stack
|
stack *stack.Stack
|
||||||
endpoint *endpoint
|
endpoint *endpoint
|
||||||
udpForwarder *udpForwarder
|
udpForwarder *udpForwarder
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
ip tcpip.Address
|
ip tcpip.Address
|
||||||
ipv6 tcpip.Address
|
netstack bool
|
||||||
netstack bool
|
hasRawICMPAccess bool
|
||||||
hasRawICMPAccess bool
|
pingSemaphore chan struct{}
|
||||||
hasRawICMPv6Access bool
|
|
||||||
pingSemaphore chan struct{}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) {
|
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) {
|
||||||
s := stack.New(stack.Options{
|
s := stack.New(stack.Options{
|
||||||
NetworkProtocols: []stack.NetworkProtocolFactory{
|
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
|
||||||
ipv4.NewProtocol,
|
|
||||||
ipv6.NewProtocol,
|
|
||||||
},
|
|
||||||
TransportProtocols: []stack.TransportProtocolFactory{
|
TransportProtocols: []stack.TransportProtocolFactory{
|
||||||
tcp.NewProtocol,
|
tcp.NewProtocol,
|
||||||
udp.NewProtocol,
|
udp.NewProtocol,
|
||||||
icmp.NewProtocol4,
|
icmp.NewProtocol4,
|
||||||
icmp.NewProtocol6,
|
|
||||||
},
|
},
|
||||||
HandleLocal: false,
|
HandleLocal: false,
|
||||||
})
|
})
|
||||||
@@ -80,7 +73,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
|||||||
protoAddr := tcpip.ProtocolAddress{
|
protoAddr := tcpip.ProtocolAddress{
|
||||||
Protocol: ipv4.ProtocolNumber,
|
Protocol: ipv4.ProtocolNumber,
|
||||||
AddressWithPrefix: tcpip.AddressWithPrefix{
|
AddressWithPrefix: tcpip.AddressWithPrefix{
|
||||||
Address: tcpip.AddrFrom4(iface.Address().IP.As4()),
|
Address: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
|
||||||
PrefixLen: iface.Address().Network.Bits(),
|
PrefixLen: iface.Address().Network.Bits(),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -89,19 +82,6 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
|||||||
return nil, fmt.Errorf("failed to add protocol address: %s", err)
|
return nil, fmt.Errorf("failed to add protocol address: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if v6 := iface.Address().IPv6; v6.IsValid() {
|
|
||||||
v6Addr := tcpip.ProtocolAddress{
|
|
||||||
Protocol: ipv6.ProtocolNumber,
|
|
||||||
AddressWithPrefix: tcpip.AddressWithPrefix{
|
|
||||||
Address: tcpip.AddrFrom16(v6.As16()),
|
|
||||||
PrefixLen: iface.Address().IPv6Net.Bits(),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
if err := s.AddProtocolAddress(nicID, v6Addr, stack.AddressProperties{}); err != nil {
|
|
||||||
return nil, fmt.Errorf("add IPv6 protocol address: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
defaultSubnet, err := tcpip.NewSubnet(
|
defaultSubnet, err := tcpip.NewSubnet(
|
||||||
tcpip.AddrFrom4([4]byte{0, 0, 0, 0}),
|
tcpip.AddrFrom4([4]byte{0, 0, 0, 0}),
|
||||||
tcpip.MaskFromBytes([]byte{0, 0, 0, 0}),
|
tcpip.MaskFromBytes([]byte{0, 0, 0, 0}),
|
||||||
@@ -110,14 +90,6 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
|||||||
return nil, fmt.Errorf("creating default subnet: %w", err)
|
return nil, fmt.Errorf("creating default subnet: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
defaultSubnetV6, err := tcpip.NewSubnet(
|
|
||||||
tcpip.AddrFrom16([16]byte{}),
|
|
||||||
tcpip.MaskFromBytes(make([]byte, 16)),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("creating default v6 subnet: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := s.SetPromiscuousMode(nicID, true); err != nil {
|
if err := s.SetPromiscuousMode(nicID, true); err != nil {
|
||||||
return nil, fmt.Errorf("set promiscuous mode: %s", err)
|
return nil, fmt.Errorf("set promiscuous mode: %s", err)
|
||||||
}
|
}
|
||||||
@@ -126,8 +98,10 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
|||||||
}
|
}
|
||||||
|
|
||||||
s.SetRouteTable([]tcpip.Route{
|
s.SetRouteTable([]tcpip.Route{
|
||||||
{Destination: defaultSubnet, NIC: nicID},
|
{
|
||||||
{Destination: defaultSubnetV6, NIC: nicID},
|
Destination: defaultSubnet,
|
||||||
|
NIC: nicID,
|
||||||
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
@@ -140,8 +114,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
|||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
netstack: netstack,
|
netstack: netstack,
|
||||||
ip: tcpip.AddrFrom4(iface.Address().IP.As4()),
|
ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
|
||||||
ipv6: addrFromNetipAddr(iface.Address().IPv6),
|
|
||||||
pingSemaphore: make(chan struct{}, 3),
|
pingSemaphore: make(chan struct{}, 3),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -158,10 +131,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
|||||||
udpForwarder := udp.NewForwarder(s, f.handleUDP)
|
udpForwarder := udp.NewForwarder(s, f.handleUDP)
|
||||||
s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
|
s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
|
||||||
|
|
||||||
// ICMP is handled directly in InjectIncomingPacket, bypassing gVisor's
|
s.SetTransportProtocolHandler(icmp.ProtocolNumber4, f.handleICMP)
|
||||||
// network layer. This avoids duplicate echo replies (v4) and the v6
|
|
||||||
// auto-reply bug where gVisor responds at the network layer before
|
|
||||||
// our transport handler fires.
|
|
||||||
|
|
||||||
f.checkICMPCapability()
|
f.checkICMPCapability()
|
||||||
|
|
||||||
@@ -180,30 +150,8 @@ func (f *Forwarder) SetCapture(pc PacketCapture) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (f *Forwarder) InjectIncomingPacket(payload []byte) error {
|
func (f *Forwarder) InjectIncomingPacket(payload []byte) error {
|
||||||
if len(payload) == 0 {
|
if len(payload) < header.IPv4MinimumSize {
|
||||||
return fmt.Errorf("empty packet")
|
return fmt.Errorf("packet too small: %d bytes", len(payload))
|
||||||
}
|
|
||||||
|
|
||||||
var protoNum tcpip.NetworkProtocolNumber
|
|
||||||
switch payload[0] >> 4 {
|
|
||||||
case 4:
|
|
||||||
if len(payload) < header.IPv4MinimumSize {
|
|
||||||
return fmt.Errorf("IPv4 packet too small: %d bytes", len(payload))
|
|
||||||
}
|
|
||||||
if f.handleICMPDirect(payload) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
protoNum = ipv4.ProtocolNumber
|
|
||||||
case 6:
|
|
||||||
if len(payload) < header.IPv6MinimumSize {
|
|
||||||
return fmt.Errorf("IPv6 packet too small: %d bytes", len(payload))
|
|
||||||
}
|
|
||||||
if f.handleICMPDirect(payload) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
protoNum = ipv6.ProtocolNumber
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("unknown IP version: %d", payload[0]>>4)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||||
@@ -212,160 +160,11 @@ func (f *Forwarder) InjectIncomingPacket(payload []byte) error {
|
|||||||
defer pkt.DecRef()
|
defer pkt.DecRef()
|
||||||
|
|
||||||
if f.endpoint.dispatcher != nil {
|
if f.endpoint.dispatcher != nil {
|
||||||
f.endpoint.dispatcher.DeliverNetworkPacket(protoNum, pkt)
|
f.endpoint.dispatcher.DeliverNetworkPacket(ipv4.ProtocolNumber, pkt)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleICMPDirect intercepts ICMP packets from raw IP payloads before they
|
|
||||||
// enter gVisor. It synthesizes the TransportEndpointID and PacketBuffer that
|
|
||||||
// the existing handlers expect, then dispatches to handleICMP/handleICMPv6.
|
|
||||||
// This bypasses gVisor's network layer which causes duplicate v4 echo replies
|
|
||||||
// and auto-replies to all v6 echo requests in promiscuous mode.
|
|
||||||
//
|
|
||||||
// Unlike gVisor's network layer, this does not validate ICMP checksums or
|
|
||||||
// reassemble IP fragments. Fragmented ICMP packets fall through to gVisor.
|
|
||||||
func parseICMPv4(payload []byte) (ipHdrLen, icmpLen int, src, dst tcpip.Address, ok bool) {
|
|
||||||
if len(payload) < header.IPv4MinimumSize {
|
|
||||||
return 0, 0, src, dst, false
|
|
||||||
}
|
|
||||||
ip := header.IPv4(payload)
|
|
||||||
if ip.Protocol() != uint8(header.ICMPv4ProtocolNumber) {
|
|
||||||
return 0, 0, src, dst, false
|
|
||||||
}
|
|
||||||
if ip.FragmentOffset() != 0 || ip.Flags()&header.IPv4FlagMoreFragments != 0 {
|
|
||||||
return 0, 0, src, dst, false
|
|
||||||
}
|
|
||||||
ipHdrLen = int(ip.HeaderLength())
|
|
||||||
totalLen := int(ip.TotalLength())
|
|
||||||
if ipHdrLen < header.IPv4MinimumSize || ipHdrLen > totalLen || totalLen > len(payload) {
|
|
||||||
return 0, 0, src, dst, false
|
|
||||||
}
|
|
||||||
icmpLen = totalLen - ipHdrLen
|
|
||||||
if icmpLen < header.ICMPv4MinimumSize {
|
|
||||||
return 0, 0, src, dst, false
|
|
||||||
}
|
|
||||||
return ipHdrLen, icmpLen, ip.SourceAddress(), ip.DestinationAddress(), true
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseICMPv6(payload []byte) (ipHdrLen, icmpLen int, src, dst tcpip.Address, ok bool) {
|
|
||||||
if len(payload) < header.IPv6MinimumSize {
|
|
||||||
return 0, 0, src, dst, false
|
|
||||||
}
|
|
||||||
ip := header.IPv6(payload)
|
|
||||||
declaredLen := int(ip.PayloadLength())
|
|
||||||
hdrEnd := header.IPv6MinimumSize + declaredLen
|
|
||||||
if hdrEnd > len(payload) {
|
|
||||||
return 0, 0, src, dst, false
|
|
||||||
}
|
|
||||||
icmpStart, ok := skipIPv6ExtensionsToICMPv6(payload, ip.NextHeader(), hdrEnd)
|
|
||||||
if !ok {
|
|
||||||
return 0, 0, src, dst, false
|
|
||||||
}
|
|
||||||
icmpLen = hdrEnd - icmpStart
|
|
||||||
if icmpLen < header.ICMPv6MinimumSize {
|
|
||||||
return 0, 0, src, dst, false
|
|
||||||
}
|
|
||||||
return icmpStart, icmpLen, ip.SourceAddress(), ip.DestinationAddress(), true
|
|
||||||
}
|
|
||||||
|
|
||||||
// skipIPv6ExtensionsToICMPv6 walks the IPv6 extension-header chain starting
|
|
||||||
// after the fixed header. It advances past Hop-by-Hop, Routing, and
|
|
||||||
// Destination Options headers (which share the NextHeader+ExtLen+6+ExtLen*8
|
|
||||||
// layout) and returns the offset of the ICMPv6 payload. Fragment, ESP, AH,
|
|
||||||
// and unknown identifiers are reported as not handleable so the caller can
|
|
||||||
// defer to gVisor.
|
|
||||||
func skipIPv6ExtensionsToICMPv6(payload []byte, next uint8, hdrEnd int) (int, bool) {
|
|
||||||
off := header.IPv6MinimumSize
|
|
||||||
for {
|
|
||||||
if next == uint8(header.ICMPv6ProtocolNumber) {
|
|
||||||
return off, true
|
|
||||||
}
|
|
||||||
if !isWalkableIPv6ExtHdr(next) {
|
|
||||||
return 0, false
|
|
||||||
}
|
|
||||||
newOff, newNext, ok := advanceIPv6ExtHdr(payload, off, hdrEnd)
|
|
||||||
if !ok {
|
|
||||||
return 0, false
|
|
||||||
}
|
|
||||||
off = newOff
|
|
||||||
next = newNext
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func isWalkableIPv6ExtHdr(id uint8) bool {
|
|
||||||
switch id {
|
|
||||||
case uint8(header.IPv6HopByHopOptionsExtHdrIdentifier),
|
|
||||||
uint8(header.IPv6RoutingExtHdrIdentifier),
|
|
||||||
uint8(header.IPv6DestinationOptionsExtHdrIdentifier):
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func advanceIPv6ExtHdr(payload []byte, off, hdrEnd int) (int, uint8, bool) {
|
|
||||||
if off+8 > hdrEnd {
|
|
||||||
return 0, 0, false
|
|
||||||
}
|
|
||||||
extLen := (int(payload[off+1]) + 1) * 8
|
|
||||||
if off+extLen > hdrEnd {
|
|
||||||
return 0, 0, false
|
|
||||||
}
|
|
||||||
return off + extLen, payload[off], true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *Forwarder) handleICMPDirect(payload []byte) bool {
|
|
||||||
if len(payload) == 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
var (
|
|
||||||
ipHdrLen int
|
|
||||||
icmpLen int
|
|
||||||
srcAddr tcpip.Address
|
|
||||||
dstAddr tcpip.Address
|
|
||||||
ok bool
|
|
||||||
)
|
|
||||||
version := payload[0] >> 4
|
|
||||||
switch version {
|
|
||||||
case 4:
|
|
||||||
ipHdrLen, icmpLen, srcAddr, dstAddr, ok = parseICMPv4(payload)
|
|
||||||
case 6:
|
|
||||||
ipHdrLen, icmpLen, srcAddr, dstAddr, ok = parseICMPv6(payload)
|
|
||||||
}
|
|
||||||
if !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Let gVisor handle ICMP destined for our own addresses natively.
|
|
||||||
// Its network-layer auto-reply is correct and efficient for local traffic.
|
|
||||||
if f.ip.Equal(dstAddr) || f.ipv6.Equal(dstAddr) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
id := stack.TransportEndpointID{
|
|
||||||
LocalAddress: dstAddr,
|
|
||||||
RemoteAddress: srcAddr,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Trim the buffer to the IP-declared length so gVisor doesn't see padding.
|
|
||||||
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
|
||||||
Payload: buffer.MakeWithData(payload[:ipHdrLen+icmpLen]),
|
|
||||||
})
|
|
||||||
defer pkt.DecRef()
|
|
||||||
|
|
||||||
if _, ok := pkt.NetworkHeader().Consume(ipHdrLen); !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if _, ok := pkt.TransportHeader().Consume(icmpLen); !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if version == 6 {
|
|
||||||
return f.handleICMPv6(id, pkt)
|
|
||||||
}
|
|
||||||
return f.handleICMP(id, pkt)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stop gracefully shuts down the forwarder
|
// Stop gracefully shuts down the forwarder
|
||||||
func (f *Forwarder) Stop() {
|
func (f *Forwarder) Stop() {
|
||||||
f.cancel()
|
f.cancel()
|
||||||
@@ -378,14 +177,11 @@ func (f *Forwarder) Stop() {
|
|||||||
f.stack.Wait()
|
f.stack.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Forwarder) determineDialAddr(addr tcpip.Address) netip.Addr {
|
func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
|
||||||
if f.netstack && f.ip.Equal(addr) {
|
if f.netstack && f.ip.Equal(addr) {
|
||||||
return netip.AddrFrom4([4]byte{127, 0, 0, 1})
|
return net.IPv4(127, 0, 0, 1)
|
||||||
}
|
}
|
||||||
if f.netstack && f.ipv6.Equal(addr) {
|
return addr.AsSlice()
|
||||||
return netip.IPv6Loopback()
|
|
||||||
}
|
|
||||||
return addrToNetipAddr(addr)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Forwarder) RegisterRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, ruleID []byte) {
|
func (f *Forwarder) RegisterRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, ruleID []byte) {
|
||||||
@@ -419,50 +215,23 @@ func buildKey(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) conntrack.ConnKe
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// addrFromNetipAddr converts a netip.Addr to a gvisor tcpip.Address without allocating.
|
|
||||||
func addrFromNetipAddr(addr netip.Addr) tcpip.Address {
|
|
||||||
if !addr.IsValid() {
|
|
||||||
return tcpip.Address{}
|
|
||||||
}
|
|
||||||
if addr.Is4() {
|
|
||||||
return tcpip.AddrFrom4(addr.As4())
|
|
||||||
}
|
|
||||||
return tcpip.AddrFrom16(addr.As16())
|
|
||||||
}
|
|
||||||
|
|
||||||
// addrToNetipAddr converts a gvisor tcpip.Address to netip.Addr without allocating.
|
|
||||||
func addrToNetipAddr(addr tcpip.Address) netip.Addr {
|
|
||||||
switch addr.Len() {
|
|
||||||
case 4:
|
|
||||||
return netip.AddrFrom4(addr.As4())
|
|
||||||
case 16:
|
|
||||||
return netip.AddrFrom16(addr.As16())
|
|
||||||
default:
|
|
||||||
return netip.Addr{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// checkICMPCapability tests whether we have raw ICMP socket access at startup.
|
// checkICMPCapability tests whether we have raw ICMP socket access at startup.
|
||||||
func (f *Forwarder) checkICMPCapability() {
|
func (f *Forwarder) checkICMPCapability() {
|
||||||
f.hasRawICMPAccess = probeRawICMP("ip4:icmp", "0.0.0.0", f.logger)
|
|
||||||
f.hasRawICMPv6Access = probeRawICMP("ip6:ipv6-icmp", "::", f.logger)
|
|
||||||
}
|
|
||||||
|
|
||||||
func probeRawICMP(network, addr string, logger *nblog.Logger) bool {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
lc := net.ListenConfig{}
|
lc := net.ListenConfig{}
|
||||||
conn, err := lc.ListenPacket(ctx, network, addr)
|
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Debug1("forwarder: no raw %s socket access, will use ping binary fallback", network)
|
f.hasRawICMPAccess = false
|
||||||
return false
|
f.logger.Debug("forwarder: No raw ICMP socket access, will use ping binary fallback")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := conn.Close(); err != nil {
|
if err := conn.Close(); err != nil {
|
||||||
logger.Debug2("forwarder: failed to close %s capability test socket: %v", network, err)
|
f.logger.Debug1("forwarder: Failed to close ICMP capability test socket: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Debug1("forwarder: raw %s socket access available", network)
|
f.hasRawICMPAccess = true
|
||||||
return true
|
f.logger.Debug("forwarder: Raw ICMP socket access available")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,162 +0,0 @@
|
|||||||
package forwarder
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/netip"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"gvisor.dev/gvisor/pkg/tcpip"
|
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
|
||||||
)
|
|
||||||
|
|
||||||
const echoRequestSize = 8
|
|
||||||
|
|
||||||
func makeIPv6(t *testing.T, src, dst netip.Addr, nextHdr uint8, payload []byte) []byte {
|
|
||||||
t.Helper()
|
|
||||||
buf := make([]byte, header.IPv6MinimumSize+len(payload))
|
|
||||||
ip := header.IPv6(buf)
|
|
||||||
ip.Encode(&header.IPv6Fields{
|
|
||||||
PayloadLength: uint16(len(payload)),
|
|
||||||
TransportProtocol: 0, // overwritten below to allow any value
|
|
||||||
HopLimit: 64,
|
|
||||||
SrcAddr: tcpipAddrFromNetip(src),
|
|
||||||
DstAddr: tcpipAddrFromNetip(dst),
|
|
||||||
})
|
|
||||||
buf[6] = nextHdr
|
|
||||||
copy(buf[header.IPv6MinimumSize:], payload)
|
|
||||||
return buf
|
|
||||||
}
|
|
||||||
|
|
||||||
func tcpipAddrFromNetip(a netip.Addr) tcpip.Address {
|
|
||||||
b := a.As16()
|
|
||||||
return tcpip.AddrFrom16(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
func echoRequest() []byte {
|
|
||||||
icmp := make([]byte, echoRequestSize)
|
|
||||||
icmp[0] = uint8(header.ICMPv6EchoRequest)
|
|
||||||
return icmp
|
|
||||||
}
|
|
||||||
|
|
||||||
// extHdr builds a generic IPv6 extension header (HBH/Routing/DestOpts) of the
|
|
||||||
// given total octet length (must be multiple of 8, >= 8) with the given next
|
|
||||||
// header.
|
|
||||||
func extHdr(t *testing.T, next uint8, totalLen int) []byte {
|
|
||||||
t.Helper()
|
|
||||||
require.GreaterOrEqual(t, totalLen, 8)
|
|
||||||
require.Equal(t, 0, totalLen%8)
|
|
||||||
buf := make([]byte, totalLen)
|
|
||||||
buf[0] = next
|
|
||||||
buf[1] = uint8(totalLen/8 - 1)
|
|
||||||
return buf
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseICMPv6_NoExtensions(t *testing.T) {
|
|
||||||
src := netip.MustParseAddr("fd00::1")
|
|
||||||
dst := netip.MustParseAddr("fd00::2")
|
|
||||||
pkt := makeIPv6(t, src, dst, uint8(header.ICMPv6ProtocolNumber), echoRequest())
|
|
||||||
|
|
||||||
off, icmpLen, _, _, ok := parseICMPv6(pkt)
|
|
||||||
require.True(t, ok)
|
|
||||||
assert.Equal(t, header.IPv6MinimumSize, off)
|
|
||||||
assert.Equal(t, echoRequestSize, icmpLen)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseICMPv6_SingleExtension(t *testing.T) {
|
|
||||||
src := netip.MustParseAddr("fd00::1")
|
|
||||||
dst := netip.MustParseAddr("fd00::2")
|
|
||||||
hbh := extHdr(t, uint8(header.ICMPv6ProtocolNumber), 8)
|
|
||||||
payload := append([]byte{}, hbh...)
|
|
||||||
payload = append(payload, echoRequest()...)
|
|
||||||
pkt := makeIPv6(t, src, dst, uint8(header.IPv6HopByHopOptionsExtHdrIdentifier), payload)
|
|
||||||
|
|
||||||
off, icmpLen, _, _, ok := parseICMPv6(pkt)
|
|
||||||
require.True(t, ok)
|
|
||||||
assert.Equal(t, header.IPv6MinimumSize+8, off)
|
|
||||||
assert.Equal(t, echoRequestSize, icmpLen)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseICMPv6_ChainedExtensions(t *testing.T) {
|
|
||||||
src := netip.MustParseAddr("fd00::1")
|
|
||||||
dst := netip.MustParseAddr("fd00::2")
|
|
||||||
dest := extHdr(t, uint8(header.ICMPv6ProtocolNumber), 16)
|
|
||||||
rt := extHdr(t, uint8(header.IPv6DestinationOptionsExtHdrIdentifier), 8)
|
|
||||||
hbh := extHdr(t, uint8(header.IPv6RoutingExtHdrIdentifier), 8)
|
|
||||||
payload := append(append(append([]byte{}, hbh...), rt...), dest...)
|
|
||||||
payload = append(payload, echoRequest()...)
|
|
||||||
pkt := makeIPv6(t, src, dst, uint8(header.IPv6HopByHopOptionsExtHdrIdentifier), payload)
|
|
||||||
|
|
||||||
off, icmpLen, _, _, ok := parseICMPv6(pkt)
|
|
||||||
require.True(t, ok)
|
|
||||||
assert.Equal(t, header.IPv6MinimumSize+8+8+16, off)
|
|
||||||
assert.Equal(t, echoRequestSize, icmpLen)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseICMPv6_FragmentDefersToGVisor(t *testing.T) {
|
|
||||||
src := netip.MustParseAddr("fd00::1")
|
|
||||||
dst := netip.MustParseAddr("fd00::2")
|
|
||||||
pkt := makeIPv6(t, src, dst, uint8(header.IPv6FragmentExtHdrIdentifier), make([]byte, 8))
|
|
||||||
|
|
||||||
_, _, _, _, ok := parseICMPv6(pkt)
|
|
||||||
assert.False(t, ok)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseICMPv6_TruncatedExtension(t *testing.T) {
|
|
||||||
src := netip.MustParseAddr("fd00::1")
|
|
||||||
dst := netip.MustParseAddr("fd00::2")
|
|
||||||
// Extension claims 16 bytes but only 8 remain after the IP header.
|
|
||||||
hbh := []byte{uint8(header.ICMPv6ProtocolNumber), 1, 0, 0, 0, 0, 0, 0}
|
|
||||||
pkt := makeIPv6(t, src, dst, uint8(header.IPv6HopByHopOptionsExtHdrIdentifier), hbh)
|
|
||||||
|
|
||||||
_, _, _, _, ok := parseICMPv6(pkt)
|
|
||||||
assert.False(t, ok)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseICMPv6_TruncatedICMPPayload(t *testing.T) {
|
|
||||||
src := netip.MustParseAddr("fd00::1")
|
|
||||||
dst := netip.MustParseAddr("fd00::2")
|
|
||||||
// PayloadLength claims 8 bytes of ICMPv6 but the buffer only holds 4.
|
|
||||||
pkt := makeIPv6(t, src, dst, uint8(header.ICMPv6ProtocolNumber), make([]byte, 8))
|
|
||||||
pkt = pkt[:header.IPv6MinimumSize+4]
|
|
||||||
|
|
||||||
_, _, _, _, ok := parseICMPv6(pkt)
|
|
||||||
assert.False(t, ok)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseICMPv4_RejectsShortIHL(t *testing.T) {
|
|
||||||
pkt := make([]byte, 28)
|
|
||||||
pkt[0] = 0x44 // version 4, IHL 4 (16 bytes - below minimum)
|
|
||||||
pkt[9] = uint8(header.ICMPv4ProtocolNumber)
|
|
||||||
header.IPv4(pkt).SetTotalLength(28)
|
|
||||||
|
|
||||||
_, _, _, _, ok := parseICMPv4(pkt)
|
|
||||||
assert.False(t, ok)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseICMPv4_RejectsTotalLenOverBuffer(t *testing.T) {
|
|
||||||
pkt := make([]byte, header.IPv4MinimumSize+header.ICMPv4MinimumSize)
|
|
||||||
ip := header.IPv4(pkt)
|
|
||||||
ip.Encode(&header.IPv4Fields{
|
|
||||||
TotalLength: uint16(len(pkt) + 16),
|
|
||||||
Protocol: uint8(header.ICMPv4ProtocolNumber),
|
|
||||||
TTL: 64,
|
|
||||||
})
|
|
||||||
|
|
||||||
_, _, _, _, ok := parseICMPv4(pkt)
|
|
||||||
assert.False(t, ok)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseICMPv4_RejectsFragment(t *testing.T) {
|
|
||||||
pkt := make([]byte, header.IPv4MinimumSize+header.ICMPv4MinimumSize)
|
|
||||||
ip := header.IPv4(pkt)
|
|
||||||
ip.Encode(&header.IPv4Fields{
|
|
||||||
TotalLength: uint16(len(pkt)),
|
|
||||||
Protocol: uint8(header.ICMPv4ProtocolNumber),
|
|
||||||
TTL: 64,
|
|
||||||
Flags: header.IPv4FlagMoreFragments,
|
|
||||||
})
|
|
||||||
|
|
||||||
_, _, _, _, ok := parseICMPv4(pkt)
|
|
||||||
assert.False(t, ok)
|
|
||||||
}
|
|
||||||
@@ -35,7 +35,7 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt *stack.PacketBu
|
|||||||
}
|
}
|
||||||
|
|
||||||
icmpData := stack.PayloadSince(pkt.TransportHeader()).AsSlice()
|
icmpData := stack.PayloadSince(pkt.TransportHeader()).AsSlice()
|
||||||
conn, err := f.forwardICMPPacket(id, icmpData, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), false, 100*time.Millisecond)
|
conn, err := f.forwardICMPPacket(id, icmpData, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), 100*time.Millisecond)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.logger.Error2("forwarder: Failed to forward ICMP packet for %v: %v", epID(id), err)
|
f.logger.Error2("forwarder: Failed to forward ICMP packet for %v: %v", epID(id), err)
|
||||||
return true
|
return true
|
||||||
@@ -58,7 +58,7 @@ func (f *Forwarder) handleICMPEcho(flowID uuid.UUID, id stack.TransportEndpointI
|
|||||||
defer func() { <-f.pingSemaphore }()
|
defer func() { <-f.pingSemaphore }()
|
||||||
|
|
||||||
if f.hasRawICMPAccess {
|
if f.hasRawICMPAccess {
|
||||||
f.handleICMPViaSocket(flowID, id, icmpType, icmpCode, icmpData, rxBytes, false)
|
f.handleICMPViaSocket(flowID, id, icmpType, icmpCode, icmpData, rxBytes)
|
||||||
} else {
|
} else {
|
||||||
f.handleICMPViaPing(flowID, id, icmpType, icmpCode, icmpData, rxBytes)
|
f.handleICMPViaPing(flowID, id, icmpType, icmpCode, icmpData, rxBytes)
|
||||||
}
|
}
|
||||||
@@ -72,23 +72,18 @@ func (f *Forwarder) handleICMPEcho(flowID uuid.UUID, id stack.TransportEndpointI
|
|||||||
|
|
||||||
// forwardICMPPacket creates a raw ICMP socket and sends the packet, returning the connection.
|
// forwardICMPPacket creates a raw ICMP socket and sends the packet, returning the connection.
|
||||||
// The caller is responsible for closing the returned connection.
|
// The caller is responsible for closing the returned connection.
|
||||||
func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []byte, icmpType, icmpCode uint8, v6 bool, timeout time.Duration) (net.PacketConn, error) {
|
func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []byte, icmpType, icmpCode uint8, timeout time.Duration) (net.PacketConn, error) {
|
||||||
ctx, cancel := context.WithTimeout(f.ctx, timeout)
|
ctx, cancel := context.WithTimeout(f.ctx, timeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
network, listenAddr := "ip4:icmp", "0.0.0.0"
|
|
||||||
if v6 {
|
|
||||||
network, listenAddr = "ip6:ipv6-icmp", "::"
|
|
||||||
}
|
|
||||||
|
|
||||||
lc := net.ListenConfig{}
|
lc := net.ListenConfig{}
|
||||||
conn, err := lc.ListenPacket(ctx, network, listenAddr)
|
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("create ICMP socket: %w", err)
|
return nil, fmt.Errorf("create ICMP socket: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
dstIP := f.determineDialAddr(id.LocalAddress)
|
dstIP := f.determineDialAddr(id.LocalAddress)
|
||||||
dst := &net.IPAddr{IP: dstIP.AsSlice()}
|
dst := &net.IPAddr{IP: dstIP}
|
||||||
|
|
||||||
if _, err = conn.WriteTo(payload, dst); err != nil {
|
if _, err = conn.WriteTo(payload, dst); err != nil {
|
||||||
if closeErr := conn.Close(); closeErr != nil {
|
if closeErr := conn.Close(); closeErr != nil {
|
||||||
@@ -103,11 +98,11 @@ func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []by
|
|||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleICMPViaSocket handles ICMP echo requests using raw sockets for both v4 and v6.
|
// handleICMPViaSocket handles ICMP echo requests using raw sockets.
|
||||||
func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int, v6 bool) {
|
func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int) {
|
||||||
sendTime := time.Now()
|
sendTime := time.Now()
|
||||||
|
|
||||||
conn, err := f.forwardICMPPacket(id, icmpData, icmpType, icmpCode, v6, 5*time.Second)
|
conn, err := f.forwardICMPPacket(id, icmpData, icmpType, icmpCode, 5*time.Second)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.logger.Error2("forwarder: Failed to send ICMP packet for %v: %v", epID(id), err)
|
f.logger.Error2("forwarder: Failed to send ICMP packet for %v: %v", epID(id), err)
|
||||||
return
|
return
|
||||||
@@ -118,20 +113,16 @@ func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndp
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
txBytes := f.handleEchoResponse(conn, id, v6)
|
txBytes := f.handleEchoResponse(conn, id)
|
||||||
rtt := time.Since(sendTime).Round(10 * time.Microsecond)
|
rtt := time.Since(sendTime).Round(10 * time.Microsecond)
|
||||||
|
|
||||||
proto := "ICMP"
|
f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, raw socket)",
|
||||||
if v6 {
|
epID(id), icmpType, icmpCode, rtt)
|
||||||
proto = "ICMPv6"
|
|
||||||
}
|
|
||||||
f.logger.Trace5("forwarder: Forwarded %s echo reply %v type %v code %v (rtt=%v, raw socket)",
|
|
||||||
proto, epID(id), icmpType, icmpCode, rtt)
|
|
||||||
|
|
||||||
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
|
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEndpointID, v6 bool) int {
|
func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEndpointID) int {
|
||||||
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||||
f.logger.Error1("forwarder: Failed to set read deadline for ICMP response: %v", err)
|
f.logger.Error1("forwarder: Failed to set read deadline for ICMP response: %v", err)
|
||||||
return 0
|
return 0
|
||||||
@@ -146,19 +137,6 @@ func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEn
|
|||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
if v6 {
|
|
||||||
// Recompute checksum: the raw socket response has a checksum computed
|
|
||||||
// over the real endpoint addresses, but we inject with overlay addresses.
|
|
||||||
icmpHdr := header.ICMPv6(response[:n])
|
|
||||||
icmpHdr.SetChecksum(0)
|
|
||||||
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
|
|
||||||
Header: icmpHdr,
|
|
||||||
Src: id.LocalAddress,
|
|
||||||
Dst: id.RemoteAddress,
|
|
||||||
}))
|
|
||||||
return f.injectICMPv6Reply(id, response[:n])
|
|
||||||
}
|
|
||||||
|
|
||||||
return f.injectICMPReply(id, response[:n])
|
return f.injectICMPReply(id, response[:n])
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -172,23 +150,19 @@ func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.T
|
|||||||
txPackets = 1
|
txPackets = 1
|
||||||
}
|
}
|
||||||
|
|
||||||
srcIp := addrToNetipAddr(id.RemoteAddress)
|
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
|
||||||
dstIp := addrToNetipAddr(id.LocalAddress)
|
dstIp := netip.AddrFrom4(id.LocalAddress.As4())
|
||||||
|
|
||||||
proto := nftypes.ICMP
|
|
||||||
if srcIp.Is6() {
|
|
||||||
proto = nftypes.ICMPv6
|
|
||||||
}
|
|
||||||
|
|
||||||
fields := nftypes.EventFields{
|
fields := nftypes.EventFields{
|
||||||
FlowID: flowID,
|
FlowID: flowID,
|
||||||
Type: typ,
|
Type: typ,
|
||||||
Direction: nftypes.Ingress,
|
Direction: nftypes.Ingress,
|
||||||
Protocol: proto,
|
Protocol: nftypes.ICMP,
|
||||||
SourceIP: srcIp,
|
// TODO: handle ipv6
|
||||||
DestIP: dstIp,
|
SourceIP: srcIp,
|
||||||
ICMPType: icmpType,
|
DestIP: dstIp,
|
||||||
ICMPCode: icmpCode,
|
ICMPType: icmpType,
|
||||||
|
ICMPCode: icmpCode,
|
||||||
|
|
||||||
RxBytes: rxBytes,
|
RxBytes: rxBytes,
|
||||||
TxBytes: txBytes,
|
TxBytes: txBytes,
|
||||||
@@ -235,164 +209,26 @@ func (f *Forwarder) handleICMPViaPing(flowID uuid.UUID, id stack.TransportEndpoi
|
|||||||
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
|
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleICMPv6 handles ICMPv6 packets from the network stack.
|
|
||||||
func (f *Forwarder) handleICMPv6(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
|
|
||||||
icmpHdr := header.ICMPv6(pkt.TransportHeader().View().AsSlice())
|
|
||||||
|
|
||||||
flowID := uuid.New()
|
|
||||||
f.sendICMPEvent(nftypes.TypeStart, flowID, id, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), 0, 0)
|
|
||||||
|
|
||||||
if icmpHdr.Type() == header.ICMPv6EchoRequest {
|
|
||||||
return f.handleICMPv6Echo(flowID, id, pkt, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()))
|
|
||||||
}
|
|
||||||
|
|
||||||
// For non-echo types (Destination Unreachable, Packet Too Big, etc), forward without waiting
|
|
||||||
if !f.hasRawICMPv6Access {
|
|
||||||
f.logger.Debug2("forwarder: Cannot handle ICMPv6 type %v without raw socket access for %v", icmpHdr.Type(), epID(id))
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
icmpData := stack.PayloadSince(pkt.TransportHeader()).AsSlice()
|
|
||||||
conn, err := f.forwardICMPPacket(id, icmpData, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), true, 100*time.Millisecond)
|
|
||||||
if err != nil {
|
|
||||||
f.logger.Error2("forwarder: Failed to forward ICMPv6 packet for %v: %v", epID(id), err)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if err := conn.Close(); err != nil {
|
|
||||||
f.logger.Debug1("forwarder: Failed to close ICMPv6 socket: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleICMPv6Echo handles ICMPv6 echo requests via raw socket or ping binary fallback.
|
|
||||||
func (f *Forwarder) handleICMPv6Echo(flowID uuid.UUID, id stack.TransportEndpointID, pkt *stack.PacketBuffer, icmpType, icmpCode uint8) bool {
|
|
||||||
select {
|
|
||||||
case f.pingSemaphore <- struct{}{}:
|
|
||||||
icmpData := stack.PayloadSince(pkt.TransportHeader()).ToSlice()
|
|
||||||
rxBytes := pkt.Size()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
defer func() { <-f.pingSemaphore }()
|
|
||||||
|
|
||||||
if f.hasRawICMPv6Access {
|
|
||||||
f.handleICMPViaSocket(flowID, id, icmpType, icmpCode, icmpData, rxBytes, true)
|
|
||||||
} else {
|
|
||||||
f.handleICMPv6ViaPing(flowID, id, icmpType, icmpCode, icmpData, rxBytes)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
default:
|
|
||||||
f.logger.Debug3("forwarder: ICMPv6 rate limit exceeded for %v type %v code %v", epID(id), icmpType, icmpCode)
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleICMPv6ViaPing uses the system ping6 binary for ICMPv6 echo.
|
|
||||||
func (f *Forwarder) handleICMPv6ViaPing(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int) {
|
|
||||||
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
dstIP := f.determineDialAddr(id.LocalAddress)
|
|
||||||
cmd := buildPingCommand(ctx, dstIP, 5*time.Second)
|
|
||||||
|
|
||||||
pingStart := time.Now()
|
|
||||||
if err := cmd.Run(); err != nil {
|
|
||||||
f.logger.Warn4("forwarder: Ping6 failed for %v type %v code %v: %v", epID(id), icmpType, icmpCode, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
rtt := time.Since(pingStart).Round(10 * time.Microsecond)
|
|
||||||
|
|
||||||
f.logger.Trace3("forwarder: Forwarded ICMPv6 echo request %v type %v code %v",
|
|
||||||
epID(id), icmpType, icmpCode)
|
|
||||||
|
|
||||||
txBytes := f.synthesizeICMPv6EchoReply(id, icmpData)
|
|
||||||
|
|
||||||
f.logger.Trace4("forwarder: Forwarded ICMPv6 echo reply %v type %v code %v (rtt=%v, ping binary)",
|
|
||||||
epID(id), icmpType, icmpCode, rtt)
|
|
||||||
|
|
||||||
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
|
|
||||||
}
|
|
||||||
|
|
||||||
// synthesizeICMPv6EchoReply creates an ICMPv6 echo reply and injects it back.
|
|
||||||
func (f *Forwarder) synthesizeICMPv6EchoReply(id stack.TransportEndpointID, icmpData []byte) int {
|
|
||||||
replyICMP := make([]byte, len(icmpData))
|
|
||||||
copy(replyICMP, icmpData)
|
|
||||||
|
|
||||||
replyHdr := header.ICMPv6(replyICMP)
|
|
||||||
replyHdr.SetType(header.ICMPv6EchoReply)
|
|
||||||
replyHdr.SetChecksum(0)
|
|
||||||
// ICMPv6Checksum computes the pseudo-header internally from Src/Dst.
|
|
||||||
// Header contains the full ICMP message, so PayloadCsum/PayloadLen are zero.
|
|
||||||
replyHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
|
|
||||||
Header: replyHdr,
|
|
||||||
Src: id.LocalAddress,
|
|
||||||
Dst: id.RemoteAddress,
|
|
||||||
}))
|
|
||||||
|
|
||||||
return f.injectICMPv6Reply(id, replyICMP)
|
|
||||||
}
|
|
||||||
|
|
||||||
// injectICMPv6Reply wraps an ICMPv6 payload in an IPv6 header and sends to the peer.
|
|
||||||
func (f *Forwarder) injectICMPv6Reply(id stack.TransportEndpointID, icmpPayload []byte) int {
|
|
||||||
ipHdr := make([]byte, header.IPv6MinimumSize)
|
|
||||||
ip := header.IPv6(ipHdr)
|
|
||||||
ip.Encode(&header.IPv6Fields{
|
|
||||||
PayloadLength: uint16(len(icmpPayload)),
|
|
||||||
TransportProtocol: header.ICMPv6ProtocolNumber,
|
|
||||||
HopLimit: 64,
|
|
||||||
SrcAddr: id.LocalAddress,
|
|
||||||
DstAddr: id.RemoteAddress,
|
|
||||||
})
|
|
||||||
|
|
||||||
fullPacket := make([]byte, 0, len(ipHdr)+len(icmpPayload))
|
|
||||||
fullPacket = append(fullPacket, ipHdr...)
|
|
||||||
fullPacket = append(fullPacket, icmpPayload...)
|
|
||||||
|
|
||||||
if err := f.endpoint.device.CreateOutboundPacket(fullPacket, id.RemoteAddress.AsSlice()); err != nil {
|
|
||||||
f.logger.Error1("forwarder: Failed to send ICMPv6 reply to peer: %v", err)
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
return len(fullPacket)
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
pingBin = "ping"
|
|
||||||
ping6Bin = "ping6"
|
|
||||||
)
|
|
||||||
|
|
||||||
// buildPingCommand creates a platform-specific ping command.
|
// buildPingCommand creates a platform-specific ping command.
|
||||||
// Most platforms auto-detect IPv6 from raw addresses. macOS/iOS/OpenBSD require ping6.
|
func buildPingCommand(ctx context.Context, target net.IP, timeout time.Duration) *exec.Cmd {
|
||||||
func buildPingCommand(ctx context.Context, target netip.Addr, timeout time.Duration) *exec.Cmd {
|
|
||||||
timeoutSec := int(timeout.Seconds())
|
timeoutSec := int(timeout.Seconds())
|
||||||
if timeoutSec < 1 {
|
if timeoutSec < 1 {
|
||||||
timeoutSec = 1
|
timeoutSec = 1
|
||||||
}
|
}
|
||||||
|
|
||||||
isV6 := target.Is6()
|
|
||||||
timeoutStr := fmt.Sprintf("%d", timeoutSec)
|
|
||||||
|
|
||||||
switch runtime.GOOS {
|
switch runtime.GOOS {
|
||||||
case "linux", "android":
|
case "linux", "android":
|
||||||
return exec.CommandContext(ctx, pingBin, "-c", "1", "-W", timeoutStr, "-q", target.String())
|
return exec.CommandContext(ctx, "ping", "-c", "1", "-W", fmt.Sprintf("%d", timeoutSec), "-q", target.String())
|
||||||
case "darwin", "ios":
|
case "darwin", "ios":
|
||||||
bin := pingBin
|
return exec.CommandContext(ctx, "ping", "-c", "1", "-t", fmt.Sprintf("%d", timeoutSec), "-q", target.String())
|
||||||
if isV6 {
|
|
||||||
bin = ping6Bin
|
|
||||||
}
|
|
||||||
return exec.CommandContext(ctx, bin, "-c", "1", "-t", timeoutStr, "-q", target.String())
|
|
||||||
case "freebsd":
|
case "freebsd":
|
||||||
return exec.CommandContext(ctx, pingBin, "-c", "1", "-t", timeoutStr, target.String())
|
return exec.CommandContext(ctx, "ping", "-c", "1", "-t", fmt.Sprintf("%d", timeoutSec), target.String())
|
||||||
case "openbsd", "netbsd":
|
case "openbsd", "netbsd":
|
||||||
bin := pingBin
|
return exec.CommandContext(ctx, "ping", "-c", "1", "-w", fmt.Sprintf("%d", timeoutSec), target.String())
|
||||||
if isV6 {
|
|
||||||
bin = ping6Bin
|
|
||||||
}
|
|
||||||
return exec.CommandContext(ctx, bin, "-c", "1", "-w", timeoutStr, target.String())
|
|
||||||
case "windows":
|
case "windows":
|
||||||
return exec.CommandContext(ctx, pingBin, "-n", "1", "-w", fmt.Sprintf("%d", timeoutSec*1000), target.String())
|
return exec.CommandContext(ctx, "ping", "-n", "1", "-w", fmt.Sprintf("%d", timeoutSec*1000), target.String())
|
||||||
default:
|
default:
|
||||||
return exec.CommandContext(ctx, pingBin, "-c", "1", target.String())
|
return exec.CommandContext(ctx, "ping", "-c", "1", target.String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,9 +2,10 @@ package forwarder
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
@@ -32,7 +33,7 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
dialAddr := net.JoinHostPort(f.determineDialAddr(id.LocalAddress).String(), strconv.Itoa(int(id.LocalPort)))
|
dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
||||||
|
|
||||||
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
|
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -132,14 +133,15 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) {
|
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) {
|
||||||
srcIp := addrToNetipAddr(id.RemoteAddress)
|
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
|
||||||
dstIp := addrToNetipAddr(id.LocalAddress)
|
dstIp := netip.AddrFrom4(id.LocalAddress.As4())
|
||||||
|
|
||||||
fields := nftypes.EventFields{
|
fields := nftypes.EventFields{
|
||||||
FlowID: flowID,
|
FlowID: flowID,
|
||||||
Type: typ,
|
Type: typ,
|
||||||
Direction: nftypes.Ingress,
|
Direction: nftypes.Ingress,
|
||||||
Protocol: nftypes.TCP,
|
Protocol: nftypes.TCP,
|
||||||
|
// TODO: handle ipv6
|
||||||
SourceIP: srcIp,
|
SourceIP: srcIp,
|
||||||
DestIP: dstIp,
|
DestIP: dstIp,
|
||||||
SourcePort: id.RemotePort,
|
SourcePort: id.RemotePort,
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
@@ -158,7 +158,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
dstAddr := net.JoinHostPort(f.determineDialAddr(id.LocalAddress).String(), strconv.Itoa(int(id.LocalPort)))
|
dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
||||||
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
|
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.logger.Debug2("forwarder: UDP dial error for %v: %v", epID(id), err)
|
f.logger.Debug2("forwarder: UDP dial error for %v: %v", epID(id), err)
|
||||||
@@ -276,14 +276,15 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
|
|||||||
|
|
||||||
// sendUDPEvent stores flow events for UDP connections
|
// sendUDPEvent stores flow events for UDP connections
|
||||||
func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) {
|
func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) {
|
||||||
srcIp := addrToNetipAddr(id.RemoteAddress)
|
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
|
||||||
dstIp := addrToNetipAddr(id.LocalAddress)
|
dstIp := netip.AddrFrom4(id.LocalAddress.As4())
|
||||||
|
|
||||||
fields := nftypes.EventFields{
|
fields := nftypes.EventFields{
|
||||||
FlowID: flowID,
|
FlowID: flowID,
|
||||||
Type: typ,
|
Type: typ,
|
||||||
Direction: nftypes.Ingress,
|
Direction: nftypes.Ingress,
|
||||||
Protocol: nftypes.UDP,
|
Protocol: nftypes.UDP,
|
||||||
|
// TODO: handle ipv6
|
||||||
SourceIP: srcIp,
|
SourceIP: srcIp,
|
||||||
DestIP: dstIp,
|
DestIP: dstIp,
|
||||||
SourcePort: id.RemotePort,
|
SourcePort: id.RemotePort,
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ const (
|
|||||||
ipv4HeaderMinLen = 20
|
ipv4HeaderMinLen = 20
|
||||||
ipv4ProtoOffset = 9
|
ipv4ProtoOffset = 9
|
||||||
ipv4FlagsOffset = 6
|
ipv4FlagsOffset = 6
|
||||||
|
ipv4DstOffset = 16
|
||||||
ipProtoUDP = 17
|
ipProtoUDP = 17
|
||||||
ipProtoTCP = 6
|
ipProtoTCP = 6
|
||||||
ipv4FragOffMask = 0x1fff
|
ipv4FragOffMask = 0x1fff
|
||||||
|
|||||||
@@ -4,32 +4,89 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync/atomic"
|
"sync"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
// localIPSnapshot is an immutable snapshot of local IP addresses, swapped
|
type localIPManager struct {
|
||||||
// atomically so reads are lock-free.
|
mu sync.RWMutex
|
||||||
type localIPSnapshot struct {
|
|
||||||
ips map[netip.Addr]struct{}
|
// fixed-size high array for upper byte of a IPv4 address
|
||||||
|
ipv4Bitmap [256]*ipv4LowBitmap
|
||||||
}
|
}
|
||||||
|
|
||||||
type localIPManager struct {
|
// ipv4LowBitmap is a map for the low 16 bits of a IPv4 address
|
||||||
snapshot atomic.Pointer[localIPSnapshot]
|
type ipv4LowBitmap struct {
|
||||||
|
bitmap [8192]uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
func newLocalIPManager() *localIPManager {
|
func newLocalIPManager() *localIPManager {
|
||||||
m := &localIPManager{}
|
return &localIPManager{}
|
||||||
m.snapshot.Store(&localIPSnapshot{
|
|
||||||
ips: make(map[netip.Addr]struct{}),
|
|
||||||
})
|
|
||||||
return m
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func processInterface(iface net.Interface, ips map[netip.Addr]struct{}, addresses *[]netip.Addr) {
|
func (m *localIPManager) setBitmapBit(ip net.IP) {
|
||||||
|
ipv4 := ip.To4()
|
||||||
|
if ipv4 == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
high := uint16(ipv4[0])
|
||||||
|
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
|
||||||
|
|
||||||
|
index := low / 32
|
||||||
|
bit := low % 32
|
||||||
|
|
||||||
|
if m.ipv4Bitmap[high] == nil {
|
||||||
|
m.ipv4Bitmap[high] = &ipv4LowBitmap{}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.ipv4Bitmap[high].bitmap[index] |= 1 << bit
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *localIPManager) setBitInBitmap(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
|
||||||
|
if !ip.Is4() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ipv4 := ip.AsSlice()
|
||||||
|
|
||||||
|
high := uint16(ipv4[0])
|
||||||
|
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
|
||||||
|
|
||||||
|
if bitmap[high] == nil {
|
||||||
|
bitmap[high] = &ipv4LowBitmap{}
|
||||||
|
}
|
||||||
|
|
||||||
|
index := low / 32
|
||||||
|
bit := low % 32
|
||||||
|
bitmap[high].bitmap[index] |= 1 << bit
|
||||||
|
|
||||||
|
if _, exists := ipv4Set[ip]; !exists {
|
||||||
|
ipv4Set[ip] = struct{}{}
|
||||||
|
*ipv4Addresses = append(*ipv4Addresses, ip)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *localIPManager) checkBitmapBit(ip []byte) bool {
|
||||||
|
high := uint16(ip[0])
|
||||||
|
low := (uint16(ip[1]) << 8) | (uint16(ip[2]) << 4) | uint16(ip[3])
|
||||||
|
|
||||||
|
if m.ipv4Bitmap[high] == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
index := low / 32
|
||||||
|
bit := low % 32
|
||||||
|
return (m.ipv4Bitmap[high].bitmap[index] & (1 << bit)) != 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *localIPManager) processIP(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) error {
|
||||||
|
m.setBitInBitmap(ip, bitmap, ipv4Set, ipv4Addresses)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
|
||||||
addrs, err := iface.Addrs()
|
addrs, err := iface.Addrs()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
|
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
|
||||||
@@ -47,19 +104,18 @@ func processInterface(iface net.Interface, ips map[netip.Addr]struct{}, addresse
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
parsed, ok := netip.AddrFromSlice(ip)
|
addr, ok := netip.AddrFromSlice(ip)
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Warnf("invalid IP address %s in interface %s", ip.String(), iface.Name)
|
log.Warnf("invalid IP address %s in interface %s", ip.String(), iface.Name)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
parsed = parsed.Unmap()
|
if err := m.processIP(addr.Unmap(), bitmap, ipv4Set, ipv4Addresses); err != nil {
|
||||||
ips[parsed] = struct{}{}
|
log.Debugf("process IP failed: %v", err)
|
||||||
*addresses = append(*addresses, parsed)
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateLocalIPs rebuilds the local IP snapshot and swaps it in atomically.
|
|
||||||
func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
@@ -67,20 +123,20 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
ips := make(map[netip.Addr]struct{})
|
var newIPv4Bitmap [256]*ipv4LowBitmap
|
||||||
var addresses []netip.Addr
|
ipv4Set := make(map[netip.Addr]struct{})
|
||||||
|
var ipv4Addresses []netip.Addr
|
||||||
|
|
||||||
// loopback
|
// 127.0.0.0/8
|
||||||
ips[netip.AddrFrom4([4]byte{127, 0, 0, 1})] = struct{}{}
|
newIPv4Bitmap[127] = &ipv4LowBitmap{}
|
||||||
ips[netip.IPv6Loopback()] = struct{}{}
|
for i := 0; i < 8192; i++ {
|
||||||
|
// #nosec G602 -- bitmap is defined as [8192]uint32, loop range is correct
|
||||||
|
newIPv4Bitmap[127].bitmap[i] = 0xFFFFFFFF
|
||||||
|
}
|
||||||
|
|
||||||
if iface != nil {
|
if iface != nil {
|
||||||
ip := iface.Address().IP
|
if err := m.processIP(iface.Address().IP, &newIPv4Bitmap, ipv4Set, &ipv4Addresses); err != nil {
|
||||||
ips[ip] = struct{}{}
|
return err
|
||||||
addresses = append(addresses, ip)
|
|
||||||
if v6 := iface.Address().IPv6; v6.IsValid() {
|
|
||||||
ips[v6] = struct{}{}
|
|
||||||
addresses = append(addresses, v6)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -91,24 +147,25 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
|||||||
// TODO: filter out down interfaces (net.FlagUp). Also handle the reverse
|
// TODO: filter out down interfaces (net.FlagUp). Also handle the reverse
|
||||||
// case where an interface comes up between refreshes.
|
// case where an interface comes up between refreshes.
|
||||||
for _, intf := range interfaces {
|
for _, intf := range interfaces {
|
||||||
processInterface(intf, ips, &addresses)
|
m.processInterface(intf, &newIPv4Bitmap, ipv4Set, &ipv4Addresses)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
m.snapshot.Store(&localIPSnapshot{ips: ips})
|
m.mu.Lock()
|
||||||
|
m.ipv4Bitmap = newIPv4Bitmap
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
log.Debugf("Local IP addresses: %v", addresses)
|
log.Debugf("Local IPv4 addresses: %v", ipv4Addresses)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsLocalIP checks if the given IP is a local address. Lock-free on the read path.
|
|
||||||
func (m *localIPManager) IsLocalIP(ip netip.Addr) bool {
|
func (m *localIPManager) IsLocalIP(ip netip.Addr) bool {
|
||||||
s := m.snapshot.Load()
|
if !ip.Is4() {
|
||||||
|
return false
|
||||||
if ip.Is4() && ip.As4()[0] == 127 {
|
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
_, found := s.ips[ip]
|
m.mu.RLock()
|
||||||
return found
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
|
return m.checkBitmapBit(ip.AsSlice())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,72 +0,0 @@
|
|||||||
package uspfilter
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/netip"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
|
||||||
)
|
|
||||||
|
|
||||||
func setupManager(b *testing.B) *localIPManager {
|
|
||||||
b.Helper()
|
|
||||||
m := newLocalIPManager()
|
|
||||||
mock := &IFaceMock{
|
|
||||||
AddressFunc: func() wgaddr.Address {
|
|
||||||
return wgaddr.Address{
|
|
||||||
IP: netip.MustParseAddr("100.64.0.1"),
|
|
||||||
Network: netip.MustParsePrefix("100.64.0.0/16"),
|
|
||||||
IPv6: netip.MustParseAddr("fd00::1"),
|
|
||||||
IPv6Net: netip.MustParsePrefix("fd00::/64"),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
if err := m.UpdateLocalIPs(mock); err != nil {
|
|
||||||
b.Fatalf("UpdateLocalIPs: %v", err)
|
|
||||||
}
|
|
||||||
return m
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkIsLocalIP_v4_hit(b *testing.B) {
|
|
||||||
m := setupManager(b)
|
|
||||||
ip := netip.MustParseAddr("100.64.0.1")
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
m.IsLocalIP(ip)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkIsLocalIP_v4_miss(b *testing.B) {
|
|
||||||
m := setupManager(b)
|
|
||||||
ip := netip.MustParseAddr("8.8.8.8")
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
m.IsLocalIP(ip)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkIsLocalIP_v6_hit(b *testing.B) {
|
|
||||||
m := setupManager(b)
|
|
||||||
ip := netip.MustParseAddr("fd00::1")
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
m.IsLocalIP(ip)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkIsLocalIP_v6_miss(b *testing.B) {
|
|
||||||
m := setupManager(b)
|
|
||||||
ip := netip.MustParseAddr("2001:db8::1")
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
m.IsLocalIP(ip)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkIsLocalIP_loopback(b *testing.B) {
|
|
||||||
m := setupManager(b)
|
|
||||||
ip := netip.MustParseAddr("127.0.0.1")
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
m.IsLocalIP(ip)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -72,45 +72,14 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
expected: false,
|
expected: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "IPv6 address matches",
|
name: "IPv6 address",
|
||||||
setupAddr: wgaddr.Address{
|
setupAddr: wgaddr.Address{
|
||||||
IP: netip.MustParseAddr("100.64.0.1"),
|
IP: netip.MustParseAddr("fe80::1"),
|
||||||
Network: netip.MustParsePrefix("100.64.0.0/16"),
|
|
||||||
IPv6: netip.MustParseAddr("fd00::1"),
|
|
||||||
IPv6Net: netip.MustParsePrefix("fd00::/64"),
|
|
||||||
},
|
|
||||||
testIP: netip.MustParseAddr("fd00::1"),
|
|
||||||
expected: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "IPv6 address does not match",
|
|
||||||
setupAddr: wgaddr.Address{
|
|
||||||
IP: netip.MustParseAddr("100.64.0.1"),
|
|
||||||
Network: netip.MustParsePrefix("100.64.0.0/16"),
|
|
||||||
IPv6: netip.MustParseAddr("fd00::1"),
|
|
||||||
IPv6Net: netip.MustParsePrefix("fd00::/64"),
|
|
||||||
},
|
|
||||||
testIP: netip.MustParseAddr("fd00::99"),
|
|
||||||
expected: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "No aliasing between similar IPs",
|
|
||||||
setupAddr: wgaddr.Address{
|
|
||||||
IP: netip.MustParseAddr("192.168.1.1"),
|
|
||||||
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
},
|
},
|
||||||
testIP: netip.MustParseAddr("192.168.0.17"),
|
testIP: netip.MustParseAddr("fe80::1"),
|
||||||
expected: false,
|
expected: false,
|
||||||
},
|
},
|
||||||
{
|
|
||||||
name: "IPv6 loopback",
|
|
||||||
setupAddr: wgaddr.Address{
|
|
||||||
IP: netip.MustParseAddr("100.64.0.1"),
|
|
||||||
Network: netip.MustParsePrefix("100.64.0.0/16"),
|
|
||||||
},
|
|
||||||
testIP: netip.MustParseAddr("::1"),
|
|
||||||
expected: true,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
@@ -202,3 +171,90 @@ func TestLocalIPManager_AllInterfaces(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MapImplementation is a version using map[string]struct{}
|
||||||
|
type MapImplementation struct {
|
||||||
|
localIPs map[string]struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkIPChecks(b *testing.B) {
|
||||||
|
interfaces := make([]net.IP, 16)
|
||||||
|
for i := range interfaces {
|
||||||
|
interfaces[i] = net.IPv4(10, 0, byte(i>>8), byte(i))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup bitmap
|
||||||
|
bitmapManager := newLocalIPManager()
|
||||||
|
for _, ip := range interfaces[:8] { // Add half of IPs
|
||||||
|
bitmapManager.setBitmapBit(ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup map version
|
||||||
|
mapManager := &MapImplementation{
|
||||||
|
localIPs: make(map[string]struct{}),
|
||||||
|
}
|
||||||
|
for _, ip := range interfaces[:8] {
|
||||||
|
mapManager.localIPs[ip.String()] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
b.Run("Bitmap_Hit", func(b *testing.B) {
|
||||||
|
ip := interfaces[4]
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
bitmapManager.checkBitmapBit(ip)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("Bitmap_Miss", func(b *testing.B) {
|
||||||
|
ip := interfaces[12]
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
bitmapManager.checkBitmapBit(ip)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("Map_Hit", func(b *testing.B) {
|
||||||
|
ip := interfaces[4]
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
// nolint:gosimple
|
||||||
|
_ = mapManager.localIPs[ip.String()]
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("Map_Miss", func(b *testing.B) {
|
||||||
|
ip := interfaces[12]
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
// nolint:gosimple
|
||||||
|
_ = mapManager.localIPs[ip.String()]
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkWGPosition(b *testing.B) {
|
||||||
|
wgIP := net.ParseIP("10.10.0.1")
|
||||||
|
|
||||||
|
// Create two managers - one checks WG IP first, other checks it last
|
||||||
|
b.Run("WG_First", func(b *testing.B) {
|
||||||
|
bm := newLocalIPManager()
|
||||||
|
bm.setBitmapBit(wgIP)
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
bm.checkBitmapBit(wgIP)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("WG_Last", func(b *testing.B) {
|
||||||
|
bm := newLocalIPManager()
|
||||||
|
// Fill with other IPs first
|
||||||
|
for i := 0; i < 15; i++ {
|
||||||
|
bm.setBitmapBit(net.IPv4(10, 0, byte(i>>8), byte(i)))
|
||||||
|
}
|
||||||
|
bm.setBitmapBit(wgIP) // Add WG IP last
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
bm.checkBitmapBit(wgIP)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -13,6 +13,8 @@ import (
|
|||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var ErrIPv4Only = errors.New("only IPv4 is supported for DNAT")
|
||||||
|
|
||||||
var (
|
var (
|
||||||
errInvalidIPHeaderLength = errors.New("invalid IP header length")
|
errInvalidIPHeaderLength = errors.New("invalid IP header length")
|
||||||
)
|
)
|
||||||
@@ -23,33 +25,10 @@ const (
|
|||||||
destinationPortOffset = 2
|
destinationPortOffset = 2
|
||||||
|
|
||||||
// IP address offsets in IPv4 header
|
// IP address offsets in IPv4 header
|
||||||
ipv4SrcOffset = 12
|
sourceIPOffset = 12
|
||||||
ipv4DstOffset = 16
|
destinationIPOffset = 16
|
||||||
|
|
||||||
// IP address offsets in IPv6 header
|
|
||||||
ipv6SrcOffset = 8
|
|
||||||
ipv6DstOffset = 24
|
|
||||||
|
|
||||||
// IPv6 fixed header length
|
|
||||||
ipv6HeaderLen = 40
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ipHeaderLen returns the IP header length based on the decoded layer type.
|
|
||||||
func ipHeaderLen(d *decoder) (int, error) {
|
|
||||||
switch d.decoded[0] {
|
|
||||||
case layers.LayerTypeIPv4:
|
|
||||||
n := int(d.ip4.IHL) * 4
|
|
||||||
if n < 20 {
|
|
||||||
return 0, errInvalidIPHeaderLength
|
|
||||||
}
|
|
||||||
return n, nil
|
|
||||||
case layers.LayerTypeIPv6:
|
|
||||||
return ipv6HeaderLen, nil
|
|
||||||
default:
|
|
||||||
return 0, fmt.Errorf("unknown IP layer: %v", d.decoded[0])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ipv4Checksum calculates IPv4 header checksum.
|
// ipv4Checksum calculates IPv4 header checksum.
|
||||||
func ipv4Checksum(header []byte) uint16 {
|
func ipv4Checksum(header []byte) uint16 {
|
||||||
if len(header) < 20 {
|
if len(header) < 20 {
|
||||||
@@ -255,13 +234,14 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
_, dstIP := extractPacketIPs(packetData, d)
|
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
|
||||||
|
|
||||||
translatedIP, exists := m.getDNATTranslation(dstIP)
|
translatedIP, exists := m.getDNATTranslation(dstIP)
|
||||||
if !exists {
|
if !exists {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.rewritePacketIP(packetData, d, translatedIP, false); err != nil {
|
if err := m.rewritePacketIP(packetData, d, translatedIP, destinationIPOffset); err != nil {
|
||||||
m.logger.Error1("failed to rewrite packet destination: %v", err)
|
m.logger.Error1("failed to rewrite packet destination: %v", err)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -276,13 +256,14 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
srcIP, _ := extractPacketIPs(packetData, d)
|
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
|
||||||
|
|
||||||
originalIP, exists := m.findReverseDNATMapping(srcIP)
|
originalIP, exists := m.findReverseDNATMapping(srcIP)
|
||||||
if !exists {
|
if !exists {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.rewritePacketIP(packetData, d, originalIP, true); err != nil {
|
if err := m.rewritePacketIP(packetData, d, originalIP, sourceIPOffset); err != nil {
|
||||||
m.logger.Error1("failed to rewrite packet source: %v", err)
|
m.logger.Error1("failed to rewrite packet source: %v", err)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -291,96 +272,38 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// extractPacketIPs extracts src and dst IP addresses directly from raw packet bytes.
|
// rewritePacketIP replaces an IP address (source or destination) in the packet and updates checksums.
|
||||||
func extractPacketIPs(packetData []byte, d *decoder) (src, dst netip.Addr) {
|
func (m *Manager) rewritePacketIP(packetData []byte, d *decoder, newIP netip.Addr, ipOffset int) error {
|
||||||
switch d.decoded[0] {
|
|
||||||
case layers.LayerTypeIPv4:
|
|
||||||
src = netip.AddrFrom4([4]byte{packetData[ipv4SrcOffset], packetData[ipv4SrcOffset+1], packetData[ipv4SrcOffset+2], packetData[ipv4SrcOffset+3]})
|
|
||||||
dst = netip.AddrFrom4([4]byte{packetData[ipv4DstOffset], packetData[ipv4DstOffset+1], packetData[ipv4DstOffset+2], packetData[ipv4DstOffset+3]})
|
|
||||||
case layers.LayerTypeIPv6:
|
|
||||||
src = netip.AddrFrom16([16]byte(packetData[ipv6SrcOffset : ipv6SrcOffset+16]))
|
|
||||||
dst = netip.AddrFrom16([16]byte(packetData[ipv6DstOffset : ipv6DstOffset+16]))
|
|
||||||
}
|
|
||||||
return src, dst
|
|
||||||
}
|
|
||||||
|
|
||||||
// rewritePacketIP replaces a source (isSource=true) or destination IP address in the packet and updates checksums.
|
|
||||||
func (m *Manager) rewritePacketIP(packetData []byte, d *decoder, newIP netip.Addr, isSource bool) error {
|
|
||||||
hdrLen, err := ipHeaderLen(d)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
switch d.decoded[0] {
|
|
||||||
case layers.LayerTypeIPv4:
|
|
||||||
return m.rewriteIPv4(packetData, d, newIP, hdrLen, isSource)
|
|
||||||
case layers.LayerTypeIPv6:
|
|
||||||
return m.rewriteIPv6(packetData, d, newIP, hdrLen, isSource)
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("unknown IP layer: %v", d.decoded[0])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) rewriteIPv4(packetData []byte, d *decoder, newIP netip.Addr, hdrLen int, isSource bool) error {
|
|
||||||
if !newIP.Is4() {
|
if !newIP.Is4() {
|
||||||
return fmt.Errorf("cannot write IPv6 address into IPv4 packet")
|
return ErrIPv4Only
|
||||||
}
|
|
||||||
|
|
||||||
offset := ipv4DstOffset
|
|
||||||
if isSource {
|
|
||||||
offset = ipv4SrcOffset
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var oldIP [4]byte
|
var oldIP [4]byte
|
||||||
copy(oldIP[:], packetData[offset:offset+4])
|
copy(oldIP[:], packetData[ipOffset:ipOffset+4])
|
||||||
newIPBytes := newIP.As4()
|
newIPBytes := newIP.As4()
|
||||||
copy(packetData[offset:offset+4], newIPBytes[:])
|
|
||||||
|
|
||||||
// Recalculate IPv4 header checksum
|
copy(packetData[ipOffset:ipOffset+4], newIPBytes[:])
|
||||||
|
|
||||||
|
ipHeaderLen := int(d.ip4.IHL) * 4
|
||||||
|
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
|
||||||
|
return errInvalidIPHeaderLength
|
||||||
|
}
|
||||||
|
|
||||||
binary.BigEndian.PutUint16(packetData[10:12], 0)
|
binary.BigEndian.PutUint16(packetData[10:12], 0)
|
||||||
binary.BigEndian.PutUint16(packetData[10:12], ipv4Checksum(packetData[:hdrLen]))
|
ipChecksum := ipv4Checksum(packetData[:ipHeaderLen])
|
||||||
|
binary.BigEndian.PutUint16(packetData[10:12], ipChecksum)
|
||||||
|
|
||||||
// Update transport checksums incrementally
|
|
||||||
if len(d.decoded) > 1 {
|
if len(d.decoded) > 1 {
|
||||||
switch d.decoded[1] {
|
switch d.decoded[1] {
|
||||||
case layers.LayerTypeTCP:
|
case layers.LayerTypeTCP:
|
||||||
m.updateTCPChecksum(packetData, hdrLen, oldIP[:], newIPBytes[:])
|
m.updateTCPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:])
|
||||||
case layers.LayerTypeUDP:
|
case layers.LayerTypeUDP:
|
||||||
m.updateUDPChecksum(packetData, hdrLen, oldIP[:], newIPBytes[:])
|
m.updateUDPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:])
|
||||||
case layers.LayerTypeICMPv4:
|
case layers.LayerTypeICMPv4:
|
||||||
m.updateICMPChecksum(packetData, hdrLen)
|
m.updateICMPChecksum(packetData, ipHeaderLen)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) rewriteIPv6(packetData []byte, d *decoder, newIP netip.Addr, hdrLen int, isSource bool) error {
|
|
||||||
if !newIP.Is6() {
|
|
||||||
return fmt.Errorf("cannot write IPv4 address into IPv6 packet")
|
|
||||||
}
|
|
||||||
|
|
||||||
offset := ipv6DstOffset
|
|
||||||
if isSource {
|
|
||||||
offset = ipv6SrcOffset
|
|
||||||
}
|
|
||||||
|
|
||||||
var oldIP [16]byte
|
|
||||||
copy(oldIP[:], packetData[offset:offset+16])
|
|
||||||
newIPBytes := newIP.As16()
|
|
||||||
copy(packetData[offset:offset+16], newIPBytes[:])
|
|
||||||
|
|
||||||
// IPv6 has no header checksum, only update transport checksums
|
|
||||||
if len(d.decoded) > 1 {
|
|
||||||
switch d.decoded[1] {
|
|
||||||
case layers.LayerTypeTCP:
|
|
||||||
m.updateTCPChecksum(packetData, hdrLen, oldIP[:], newIPBytes[:])
|
|
||||||
case layers.LayerTypeUDP:
|
|
||||||
m.updateUDPChecksum(packetData, hdrLen, oldIP[:], newIPBytes[:])
|
|
||||||
case layers.LayerTypeICMPv6:
|
|
||||||
// ICMPv6 checksum includes pseudo-header with addresses, use incremental update
|
|
||||||
m.updateICMPv6Checksum(packetData, hdrLen, oldIP[:], newIPBytes[:])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -428,20 +351,6 @@ func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) {
|
|||||||
binary.BigEndian.PutUint16(icmpData[2:4], checksum)
|
binary.BigEndian.PutUint16(icmpData[2:4], checksum)
|
||||||
}
|
}
|
||||||
|
|
||||||
// updateICMPv6Checksum updates ICMPv6 checksum after address change.
|
|
||||||
// ICMPv6 uses a pseudo-header (like TCP/UDP), so incremental update applies.
|
|
||||||
func (m *Manager) updateICMPv6Checksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
|
|
||||||
icmpStart := ipHeaderLen
|
|
||||||
if len(packetData) < icmpStart+4 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
checksumOffset := icmpStart + 2
|
|
||||||
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
|
|
||||||
newChecksum := incrementalUpdate(oldChecksum, oldIP, newIP)
|
|
||||||
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
|
|
||||||
}
|
|
||||||
|
|
||||||
// incrementalUpdate performs incremental checksum update per RFC 1624.
|
// incrementalUpdate performs incremental checksum update per RFC 1624.
|
||||||
func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
|
func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
|
||||||
sum := uint32(^oldChecksum)
|
sum := uint32(^oldChecksum)
|
||||||
@@ -494,14 +403,14 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// addPortRedirection adds a port redirection rule.
|
// addPortRedirection adds a port redirection rule.
|
||||||
func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, originalPort, translatedPort uint16) error {
|
func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, sourcePort, targetPort uint16) error {
|
||||||
m.portDNATMutex.Lock()
|
m.portDNATMutex.Lock()
|
||||||
defer m.portDNATMutex.Unlock()
|
defer m.portDNATMutex.Unlock()
|
||||||
|
|
||||||
rule := portDNATRule{
|
rule := portDNATRule{
|
||||||
protocol: protocol,
|
protocol: protocol,
|
||||||
origPort: originalPort,
|
origPort: sourcePort,
|
||||||
targetPort: translatedPort,
|
targetPort: targetPort,
|
||||||
targetIP: targetIP,
|
targetIP: targetIP,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -513,7 +422,7 @@ func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.Laye
|
|||||||
|
|
||||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
||||||
// TODO: also delegate to nativeFirewall when available for kernel WG mode
|
// TODO: also delegate to nativeFirewall when available for kernel WG mode
|
||||||
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||||
var layerType gopacket.LayerType
|
var layerType gopacket.LayerType
|
||||||
switch protocol {
|
switch protocol {
|
||||||
case firewall.ProtocolTCP:
|
case firewall.ProtocolTCP:
|
||||||
@@ -524,16 +433,16 @@ func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protoco
|
|||||||
return fmt.Errorf("unsupported protocol: %s", protocol)
|
return fmt.Errorf("unsupported protocol: %s", protocol)
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.addPortRedirection(localAddr, layerType, originalPort, translatedPort)
|
return m.addPortRedirection(localAddr, layerType, sourcePort, targetPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
// removePortRedirection removes a port redirection rule.
|
// removePortRedirection removes a port redirection rule.
|
||||||
func (m *Manager) removePortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, originalPort, translatedPort uint16) error {
|
func (m *Manager) removePortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, sourcePort, targetPort uint16) error {
|
||||||
m.portDNATMutex.Lock()
|
m.portDNATMutex.Lock()
|
||||||
defer m.portDNATMutex.Unlock()
|
defer m.portDNATMutex.Unlock()
|
||||||
|
|
||||||
m.portDNATRules = slices.DeleteFunc(m.portDNATRules, func(rule portDNATRule) bool {
|
m.portDNATRules = slices.DeleteFunc(m.portDNATRules, func(rule portDNATRule) bool {
|
||||||
return rule.protocol == protocol && rule.origPort == originalPort && rule.targetPort == translatedPort && rule.targetIP.Compare(targetIP) == 0
|
return rule.protocol == protocol && rule.origPort == sourcePort && rule.targetPort == targetPort && rule.targetIP.Compare(targetIP) == 0
|
||||||
})
|
})
|
||||||
|
|
||||||
if len(m.portDNATRules) == 0 {
|
if len(m.portDNATRules) == 0 {
|
||||||
@@ -544,7 +453,7 @@ func (m *Manager) removePortRedirection(targetIP netip.Addr, protocol gopacket.L
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RemoveInboundDNAT removes an inbound DNAT rule.
|
// RemoveInboundDNAT removes an inbound DNAT rule.
|
||||||
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||||
var layerType gopacket.LayerType
|
var layerType gopacket.LayerType
|
||||||
switch protocol {
|
switch protocol {
|
||||||
case firewall.ProtocolTCP:
|
case firewall.ProtocolTCP:
|
||||||
@@ -555,23 +464,23 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
|||||||
return fmt.Errorf("unsupported protocol: %s", protocol)
|
return fmt.Errorf("unsupported protocol: %s", protocol)
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.removePortRedirection(localAddr, layerType, originalPort, translatedPort)
|
return m.removePortRedirection(localAddr, layerType, sourcePort, targetPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddOutputDNAT delegates to the native firewall if available.
|
// AddOutputDNAT delegates to the native firewall if available.
|
||||||
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||||
if m.nativeFirewall == nil {
|
if m.nativeFirewall == nil {
|
||||||
return fmt.Errorf("output DNAT not supported without native firewall")
|
return fmt.Errorf("output DNAT not supported without native firewall")
|
||||||
}
|
}
|
||||||
return m.nativeFirewall.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
return m.nativeFirewall.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveOutputDNAT delegates to the native firewall if available.
|
// RemoveOutputDNAT delegates to the native firewall if available.
|
||||||
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||||
if m.nativeFirewall == nil {
|
if m.nativeFirewall == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return m.nativeFirewall.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
return m.nativeFirewall.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
// translateInboundPortDNAT applies port-specific DNAT translation to inbound packets.
|
// translateInboundPortDNAT applies port-specific DNAT translation to inbound packets.
|
||||||
@@ -623,12 +532,12 @@ func (m *Manager) applyPortRule(packetData []byte, d *decoder, srcIP, dstIP neti
|
|||||||
|
|
||||||
// rewriteTCPPort rewrites a TCP port (source or destination) and updates checksum.
|
// rewriteTCPPort rewrites a TCP port (source or destination) and updates checksum.
|
||||||
func (m *Manager) rewriteTCPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error {
|
func (m *Manager) rewriteTCPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error {
|
||||||
hdrLen, err := ipHeaderLen(d)
|
ipHeaderLen := int(d.ip4.IHL) * 4
|
||||||
if err != nil {
|
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
|
||||||
return err
|
return errInvalidIPHeaderLength
|
||||||
}
|
}
|
||||||
|
|
||||||
tcpStart := hdrLen
|
tcpStart := ipHeaderLen
|
||||||
if len(packetData) < tcpStart+4 {
|
if len(packetData) < tcpStart+4 {
|
||||||
return fmt.Errorf("packet too short for TCP header")
|
return fmt.Errorf("packet too short for TCP header")
|
||||||
}
|
}
|
||||||
@@ -654,12 +563,12 @@ func (m *Manager) rewriteTCPPort(packetData []byte, d *decoder, newPort uint16,
|
|||||||
|
|
||||||
// rewriteUDPPort rewrites a UDP port (source or destination) and updates checksum.
|
// rewriteUDPPort rewrites a UDP port (source or destination) and updates checksum.
|
||||||
func (m *Manager) rewriteUDPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error {
|
func (m *Manager) rewriteUDPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error {
|
||||||
hdrLen, err := ipHeaderLen(d)
|
ipHeaderLen := int(d.ip4.IHL) * 4
|
||||||
if err != nil {
|
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
|
||||||
return err
|
return errInvalidIPHeaderLength
|
||||||
}
|
}
|
||||||
|
|
||||||
udpStart := hdrLen
|
udpStart := ipHeaderLen
|
||||||
if len(packetData) < udpStart+8 {
|
if len(packetData) < udpStart+8 {
|
||||||
return fmt.Errorf("packet too short for UDP header")
|
return fmt.Errorf("packet too short for UDP header")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -342,17 +342,12 @@ func BenchmarkDNATMemoryAllocations(b *testing.B) {
|
|||||||
|
|
||||||
// Parse the packet fresh each time to get a clean decoder
|
// Parse the packet fresh each time to get a clean decoder
|
||||||
d := &decoder{decoded: []gopacket.LayerType{}}
|
d := &decoder{decoded: []gopacket.LayerType{}}
|
||||||
d.parser4 = gopacket.NewDecodingLayerParser(
|
d.parser = gopacket.NewDecodingLayerParser(
|
||||||
layers.LayerTypeIPv4,
|
layers.LayerTypeIPv4,
|
||||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||||
)
|
)
|
||||||
d.parser4.IgnoreUnsupported = true
|
d.parser.IgnoreUnsupported = true
|
||||||
d.parser6 = gopacket.NewDecodingLayerParser(
|
err = d.parser.DecodeLayers(testPacket, &d.decoded)
|
||||||
layers.LayerTypeIPv6,
|
|
||||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
|
||||||
)
|
|
||||||
d.parser6.IgnoreUnsupported = true
|
|
||||||
err = d.decodePacket(testPacket)
|
|
||||||
assert.NoError(b, err)
|
assert.NoError(b, err)
|
||||||
|
|
||||||
manager.translateOutboundDNAT(testPacket, d)
|
manager.translateOutboundDNAT(testPacket, d)
|
||||||
@@ -376,17 +371,12 @@ func BenchmarkDirectIPExtraction(b *testing.B) {
|
|||||||
b.Run("decoder_extraction", func(b *testing.B) {
|
b.Run("decoder_extraction", func(b *testing.B) {
|
||||||
// Create decoder once for comparison
|
// Create decoder once for comparison
|
||||||
d := &decoder{decoded: []gopacket.LayerType{}}
|
d := &decoder{decoded: []gopacket.LayerType{}}
|
||||||
d.parser4 = gopacket.NewDecodingLayerParser(
|
d.parser = gopacket.NewDecodingLayerParser(
|
||||||
layers.LayerTypeIPv4,
|
layers.LayerTypeIPv4,
|
||||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||||
)
|
)
|
||||||
d.parser4.IgnoreUnsupported = true
|
d.parser.IgnoreUnsupported = true
|
||||||
d.parser6 = gopacket.NewDecodingLayerParser(
|
err := d.parser.DecodeLayers(packet, &d.decoded)
|
||||||
layers.LayerTypeIPv6,
|
|
||||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
|
||||||
)
|
|
||||||
d.parser6.IgnoreUnsupported = true
|
|
||||||
err := d.decodePacket(packet)
|
|
||||||
assert.NoError(b, err)
|
assert.NoError(b, err)
|
||||||
|
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
|
|||||||
@@ -86,18 +86,13 @@ func parsePacket(t testing.TB, packetData []byte) *decoder {
|
|||||||
d := &decoder{
|
d := &decoder{
|
||||||
decoded: []gopacket.LayerType{},
|
decoded: []gopacket.LayerType{},
|
||||||
}
|
}
|
||||||
d.parser4 = gopacket.NewDecodingLayerParser(
|
d.parser = gopacket.NewDecodingLayerParser(
|
||||||
layers.LayerTypeIPv4,
|
layers.LayerTypeIPv4,
|
||||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||||
)
|
)
|
||||||
d.parser4.IgnoreUnsupported = true
|
d.parser.IgnoreUnsupported = true
|
||||||
d.parser6 = gopacket.NewDecodingLayerParser(
|
|
||||||
layers.LayerTypeIPv6,
|
|
||||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
|
||||||
)
|
|
||||||
d.parser6.IgnoreUnsupported = true
|
|
||||||
|
|
||||||
err := d.decodePacket(packetData)
|
err := d.parser.DecodeLayers(packetData, &d.decoded)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
return d
|
return d
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,9 +2,7 @@ package uspfilter
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"strconv"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
@@ -114,13 +112,10 @@ func (t *PacketTrace) AddResultWithForwarder(stage PacketStage, message string,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *PacketBuilder) Build() ([]byte, error) {
|
func (p *PacketBuilder) Build() ([]byte, error) {
|
||||||
ipLayer, err := p.buildIPLayer()
|
ip := p.buildIPLayer()
|
||||||
if err != nil {
|
pktLayers := []gopacket.SerializableLayer{ip}
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
pktLayers := []gopacket.SerializableLayer{ipLayer}
|
|
||||||
|
|
||||||
transportLayer, err := p.buildTransportLayer(ipLayer)
|
transportLayer, err := p.buildTransportLayer(ip)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -134,43 +129,30 @@ func (p *PacketBuilder) Build() ([]byte, error) {
|
|||||||
return serializePacket(pktLayers)
|
return serializePacket(pktLayers)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PacketBuilder) buildIPLayer() (gopacket.SerializableLayer, error) {
|
func (p *PacketBuilder) buildIPLayer() *layers.IPv4 {
|
||||||
if p.SrcIP.Is4() != p.DstIP.Is4() {
|
|
||||||
return nil, fmt.Errorf("mixed address families: src=%s dst=%s", p.SrcIP, p.DstIP)
|
|
||||||
}
|
|
||||||
proto := getIPProtocolNumber(p.Protocol, p.SrcIP.Is6())
|
|
||||||
if p.SrcIP.Is6() {
|
|
||||||
return &layers.IPv6{
|
|
||||||
Version: 6,
|
|
||||||
HopLimit: 64,
|
|
||||||
NextHeader: proto,
|
|
||||||
SrcIP: p.SrcIP.AsSlice(),
|
|
||||||
DstIP: p.DstIP.AsSlice(),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
return &layers.IPv4{
|
return &layers.IPv4{
|
||||||
Version: 4,
|
Version: 4,
|
||||||
TTL: 64,
|
TTL: 64,
|
||||||
Protocol: proto,
|
Protocol: layers.IPProtocol(getIPProtocolNumber(p.Protocol)),
|
||||||
SrcIP: p.SrcIP.AsSlice(),
|
SrcIP: p.SrcIP.AsSlice(),
|
||||||
DstIP: p.DstIP.AsSlice(),
|
DstIP: p.DstIP.AsSlice(),
|
||||||
}, nil
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PacketBuilder) buildTransportLayer(ipLayer gopacket.SerializableLayer) ([]gopacket.SerializableLayer, error) {
|
func (p *PacketBuilder) buildTransportLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) {
|
||||||
switch p.Protocol {
|
switch p.Protocol {
|
||||||
case "tcp":
|
case "tcp":
|
||||||
return p.buildTCPLayer(ipLayer)
|
return p.buildTCPLayer(ip)
|
||||||
case "udp":
|
case "udp":
|
||||||
return p.buildUDPLayer(ipLayer)
|
return p.buildUDPLayer(ip)
|
||||||
case "icmp":
|
case "icmp":
|
||||||
return p.buildICMPLayer(ipLayer)
|
return p.buildICMPLayer()
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unsupported protocol: %s", p.Protocol)
|
return nil, fmt.Errorf("unsupported protocol: %s", p.Protocol)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PacketBuilder) buildTCPLayer(ipLayer gopacket.SerializableLayer) ([]gopacket.SerializableLayer, error) {
|
func (p *PacketBuilder) buildTCPLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) {
|
||||||
tcp := &layers.TCP{
|
tcp := &layers.TCP{
|
||||||
SrcPort: layers.TCPPort(p.SrcPort),
|
SrcPort: layers.TCPPort(p.SrcPort),
|
||||||
DstPort: layers.TCPPort(p.DstPort),
|
DstPort: layers.TCPPort(p.DstPort),
|
||||||
@@ -182,44 +164,24 @@ func (p *PacketBuilder) buildTCPLayer(ipLayer gopacket.SerializableLayer) ([]gop
|
|||||||
PSH: p.TCPState != nil && p.TCPState.PSH,
|
PSH: p.TCPState != nil && p.TCPState.PSH,
|
||||||
URG: p.TCPState != nil && p.TCPState.URG,
|
URG: p.TCPState != nil && p.TCPState.URG,
|
||||||
}
|
}
|
||||||
if nl, ok := ipLayer.(gopacket.NetworkLayer); ok {
|
if err := tcp.SetNetworkLayerForChecksum(ip); err != nil {
|
||||||
if err := tcp.SetNetworkLayerForChecksum(nl); err != nil {
|
return nil, fmt.Errorf("set network layer for TCP checksum: %w", err)
|
||||||
return nil, fmt.Errorf("set network layer for TCP checksum: %w", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return []gopacket.SerializableLayer{tcp}, nil
|
return []gopacket.SerializableLayer{tcp}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PacketBuilder) buildUDPLayer(ipLayer gopacket.SerializableLayer) ([]gopacket.SerializableLayer, error) {
|
func (p *PacketBuilder) buildUDPLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) {
|
||||||
udp := &layers.UDP{
|
udp := &layers.UDP{
|
||||||
SrcPort: layers.UDPPort(p.SrcPort),
|
SrcPort: layers.UDPPort(p.SrcPort),
|
||||||
DstPort: layers.UDPPort(p.DstPort),
|
DstPort: layers.UDPPort(p.DstPort),
|
||||||
}
|
}
|
||||||
if nl, ok := ipLayer.(gopacket.NetworkLayer); ok {
|
if err := udp.SetNetworkLayerForChecksum(ip); err != nil {
|
||||||
if err := udp.SetNetworkLayerForChecksum(nl); err != nil {
|
return nil, fmt.Errorf("set network layer for UDP checksum: %w", err)
|
||||||
return nil, fmt.Errorf("set network layer for UDP checksum: %w", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return []gopacket.SerializableLayer{udp}, nil
|
return []gopacket.SerializableLayer{udp}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PacketBuilder) buildICMPLayer(ipLayer gopacket.SerializableLayer) ([]gopacket.SerializableLayer, error) {
|
func (p *PacketBuilder) buildICMPLayer() ([]gopacket.SerializableLayer, error) {
|
||||||
if p.SrcIP.Is6() || p.DstIP.Is6() {
|
|
||||||
icmp := &layers.ICMPv6{
|
|
||||||
TypeCode: layers.CreateICMPv6TypeCode(p.ICMPType, p.ICMPCode),
|
|
||||||
}
|
|
||||||
if nl, ok := ipLayer.(gopacket.NetworkLayer); ok {
|
|
||||||
_ = icmp.SetNetworkLayerForChecksum(nl)
|
|
||||||
}
|
|
||||||
if p.ICMPType == layers.ICMPv6TypeEchoRequest || p.ICMPType == layers.ICMPv6TypeEchoReply {
|
|
||||||
echo := &layers.ICMPv6Echo{
|
|
||||||
Identifier: 1,
|
|
||||||
SeqNumber: 1,
|
|
||||||
}
|
|
||||||
return []gopacket.SerializableLayer{icmp, echo}, nil
|
|
||||||
}
|
|
||||||
return []gopacket.SerializableLayer{icmp}, nil
|
|
||||||
}
|
|
||||||
icmp := &layers.ICMPv4{
|
icmp := &layers.ICMPv4{
|
||||||
TypeCode: layers.CreateICMPv4TypeCode(p.ICMPType, p.ICMPCode),
|
TypeCode: layers.CreateICMPv4TypeCode(p.ICMPType, p.ICMPCode),
|
||||||
}
|
}
|
||||||
@@ -242,17 +204,14 @@ func serializePacket(layers []gopacket.SerializableLayer) ([]byte, error) {
|
|||||||
return buf.Bytes(), nil
|
return buf.Bytes(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getIPProtocolNumber(protocol fw.Protocol, isV6 bool) layers.IPProtocol {
|
func getIPProtocolNumber(protocol fw.Protocol) int {
|
||||||
switch protocol {
|
switch protocol {
|
||||||
case fw.ProtocolTCP:
|
case fw.ProtocolTCP:
|
||||||
return layers.IPProtocolTCP
|
return int(layers.IPProtocolTCP)
|
||||||
case fw.ProtocolUDP:
|
case fw.ProtocolUDP:
|
||||||
return layers.IPProtocolUDP
|
return int(layers.IPProtocolUDP)
|
||||||
case fw.ProtocolICMP:
|
case fw.ProtocolICMP:
|
||||||
if isV6 {
|
return int(layers.IPProtocolICMPv4)
|
||||||
return layers.IPProtocolICMPv6
|
|
||||||
}
|
|
||||||
return layers.IPProtocolICMPv4
|
|
||||||
default:
|
default:
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
@@ -275,7 +234,7 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa
|
|||||||
trace := &PacketTrace{Direction: direction}
|
trace := &PacketTrace{Direction: direction}
|
||||||
|
|
||||||
// Initial packet decoding
|
// Initial packet decoding
|
||||||
if err := d.decodePacket(packetData); err != nil {
|
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||||
trace.AddResult(StageReceived, fmt.Sprintf("Failed to decode packet: %v", err), false)
|
trace.AddResult(StageReceived, fmt.Sprintf("Failed to decode packet: %v", err), false)
|
||||||
return trace
|
return trace
|
||||||
}
|
}
|
||||||
@@ -297,8 +256,6 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa
|
|||||||
trace.DestinationPort = uint16(d.udp.DstPort)
|
trace.DestinationPort = uint16(d.udp.DstPort)
|
||||||
case layers.LayerTypeICMPv4:
|
case layers.LayerTypeICMPv4:
|
||||||
trace.Protocol = "ICMP"
|
trace.Protocol = "ICMP"
|
||||||
case layers.LayerTypeICMPv6:
|
|
||||||
trace.Protocol = "ICMPv6"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
trace.AddResult(StageReceived, fmt.Sprintf("Received %s packet: %s:%d -> %s:%d",
|
trace.AddResult(StageReceived, fmt.Sprintf("Received %s packet: %s:%d -> %s:%d",
|
||||||
@@ -362,13 +319,6 @@ func (m *Manager) buildConntrackStateMessage(d *decoder) string {
|
|||||||
flags&conntrack.TCPFin != 0)
|
flags&conntrack.TCPFin != 0)
|
||||||
case layers.LayerTypeICMPv4:
|
case layers.LayerTypeICMPv4:
|
||||||
msg += fmt.Sprintf(" (ICMP ID=%d, Seq=%d)", d.icmp4.Id, d.icmp4.Seq)
|
msg += fmt.Sprintf(" (ICMP ID=%d, Seq=%d)", d.icmp4.Id, d.icmp4.Seq)
|
||||||
case layers.LayerTypeICMPv6:
|
|
||||||
var id, seq uint16
|
|
||||||
if len(d.icmp6.Payload) >= 4 {
|
|
||||||
id = uint16(d.icmp6.Payload[0])<<8 | uint16(d.icmp6.Payload[1])
|
|
||||||
seq = uint16(d.icmp6.Payload[2])<<8 | uint16(d.icmp6.Payload[3])
|
|
||||||
}
|
|
||||||
msg += fmt.Sprintf(" (ICMPv6 ID=%d, Seq=%d)", id, seq)
|
|
||||||
}
|
}
|
||||||
return msg
|
return msg
|
||||||
}
|
}
|
||||||
@@ -445,7 +395,7 @@ func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP n
|
|||||||
trace.AddResult(StageRouteACL, msg, allowed)
|
trace.AddResult(StageRouteACL, msg, allowed)
|
||||||
|
|
||||||
if allowed && m.forwarder.Load() != nil {
|
if allowed && m.forwarder.Load() != nil {
|
||||||
m.addForwardingResult(trace, "proxy-remote", net.JoinHostPort(dstIP.String(), strconv.Itoa(int(dstPort))), true)
|
m.addForwardingResult(trace, "proxy-remote", fmt.Sprintf("%s:%d", dstIP, dstPort), true)
|
||||||
}
|
}
|
||||||
|
|
||||||
trace.AddResult(StageCompleted, msgProcessingCompleted, allowed)
|
trace.AddResult(StageCompleted, msgProcessingCompleted, allowed)
|
||||||
@@ -465,7 +415,7 @@ func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTr
|
|||||||
d := m.decoders.Get().(*decoder)
|
d := m.decoders.Get().(*decoder)
|
||||||
defer m.decoders.Put(d)
|
defer m.decoders.Put(d)
|
||||||
|
|
||||||
if err := d.decodePacket(packetData); err != nil {
|
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||||
trace.AddResult(StageCompleted, "Packet dropped - decode error", false)
|
trace.AddResult(StageCompleted, "Packet dropped - decode error", false)
|
||||||
return trace
|
return trace
|
||||||
}
|
}
|
||||||
@@ -484,7 +434,7 @@ func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTr
|
|||||||
func (m *Manager) handleInboundDNAT(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP *netip.Addr) bool {
|
func (m *Manager) handleInboundDNAT(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP *netip.Addr) bool {
|
||||||
portDNATApplied := m.traceInboundPortDNAT(trace, packetData, d)
|
portDNATApplied := m.traceInboundPortDNAT(trace, packetData, d)
|
||||||
if portDNATApplied {
|
if portDNATApplied {
|
||||||
if err := d.decodePacket(packetData); err != nil {
|
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||||
trace.AddResult(StageInboundPortDNAT, "Failed to re-decode after port DNAT", false)
|
trace.AddResult(StageInboundPortDNAT, "Failed to re-decode after port DNAT", false)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@@ -494,7 +444,7 @@ func (m *Manager) handleInboundDNAT(trace *PacketTrace, packetData []byte, d *de
|
|||||||
|
|
||||||
nat1to1Applied := m.traceInbound1to1NAT(trace, packetData, d)
|
nat1to1Applied := m.traceInbound1to1NAT(trace, packetData, d)
|
||||||
if nat1to1Applied {
|
if nat1to1Applied {
|
||||||
if err := d.decodePacket(packetData); err != nil {
|
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||||
trace.AddResult(StageInbound1to1NAT, "Failed to re-decode after 1:1 NAT", false)
|
trace.AddResult(StageInbound1to1NAT, "Failed to re-decode after 1:1 NAT", false)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@@ -559,7 +509,7 @@ func (m *Manager) traceInbound1to1NAT(trace *PacketTrace, packetData []byte, d *
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
srcIP, _ := extractPacketIPs(packetData, d)
|
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
|
||||||
|
|
||||||
translated := m.translateInboundReverse(packetData, d)
|
translated := m.translateInboundReverse(packetData, d)
|
||||||
if translated {
|
if translated {
|
||||||
@@ -589,7 +539,7 @@ func (m *Manager) traceOutbound1to1NAT(trace *PacketTrace, packetData []byte, d
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
_, dstIP := extractPacketIPs(packetData, d)
|
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
|
||||||
|
|
||||||
translated := m.translateOutboundDNAT(packetData, d)
|
translated := m.translateOutboundDNAT(packetData, d)
|
||||||
if translated {
|
if translated {
|
||||||
|
|||||||
@@ -119,7 +119,7 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to parse endpoint address: %w", err)
|
return fmt.Errorf("failed to parse endpoint address: %w", err)
|
||||||
}
|
}
|
||||||
addrPort := netip.AddrPortFrom(addr.Unmap(), uint16(endpoint.Port))
|
addrPort := netip.AddrPortFrom(addr, uint16(endpoint.Port))
|
||||||
c.activityRecorder.UpsertAddress(peerKey, addrPort)
|
c.activityRecorder.UpsertAddress(peerKey, addrPort)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ package device
|
|||||||
|
|
||||||
// TunAdapter is an interface for create tun device from external service
|
// TunAdapter is an interface for create tun device from external service
|
||||||
type TunAdapter interface {
|
type TunAdapter interface {
|
||||||
ConfigureInterface(address string, addressV6 string, mtu int, dns string, searchDomains string, routes string) (int, error)
|
ConfigureInterface(address string, mtu int, dns string, searchDomains string, routes string) (int, error)
|
||||||
UpdateAddr(address string) error
|
UpdateAddr(address string) error
|
||||||
ProtectSocket(fd int32) bool
|
ProtectSocket(fd int32) bool
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
|
|||||||
searchDomainsToString = ""
|
searchDomainsToString = ""
|
||||||
}
|
}
|
||||||
|
|
||||||
fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.address.IPv6String(), int(t.mtu), dns, searchDomainsToString, routesString)
|
fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), int(t.mtu), dns, searchDomainsToString, routesString)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to create Android interface: %s", err)
|
log.Errorf("failed to create Android interface: %s", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -131,32 +131,23 @@ func (t *TunDevice) Device() *device.Device {
|
|||||||
|
|
||||||
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided
|
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided
|
||||||
func (t *TunDevice) assignAddr() error {
|
func (t *TunDevice) assignAddr() error {
|
||||||
if out, err := exec.Command("ifconfig", t.name, "inet", t.address.IP.String(), t.address.IP.String()).CombinedOutput(); err != nil {
|
cmd := exec.Command("ifconfig", t.name, "inet", t.address.IP.String(), t.address.IP.String())
|
||||||
return fmt.Errorf("add v4 address: %s: %w", string(out), err)
|
if out, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
log.Errorf("adding address command '%v' failed with output: %s", cmd.String(), out)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Assign a dummy link-local so macOS enables IPv6 on the tun device.
|
// dummy ipv6 so routing works
|
||||||
// When a real overlay v6 is present, use that instead.
|
cmd = exec.Command("ifconfig", t.name, "inet6", "fe80::/64")
|
||||||
v6Addr := "fe80::/64"
|
if out, err := cmd.CombinedOutput(); err != nil {
|
||||||
if t.address.HasIPv6() {
|
log.Debugf("adding address command '%v' failed with output: %s", cmd.String(), out)
|
||||||
v6Addr = t.address.IPv6String()
|
|
||||||
}
|
|
||||||
if out, err := exec.Command("ifconfig", t.name, "inet6", v6Addr).CombinedOutput(); err != nil {
|
|
||||||
log.Warnf("failed to assign IPv6 address %s, continuing v4-only: %s: %v", v6Addr, string(out), err)
|
|
||||||
t.address.ClearIPv6()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if out, err := exec.Command("route", "add", "-net", t.address.Network.String(), "-interface", t.name).CombinedOutput(); err != nil {
|
routeCmd := exec.Command("route", "add", "-net", t.address.Network.String(), "-interface", t.name)
|
||||||
return fmt.Errorf("add route %s via %s: %s: %w", t.address.Network, t.name, string(out), err)
|
if out, err := routeCmd.CombinedOutput(); err != nil {
|
||||||
|
log.Errorf("adding route command '%v' failed with output: %s", routeCmd.String(), out)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if t.address.HasIPv6() {
|
|
||||||
if out, err := exec.Command("route", "add", "-inet6", "-net", t.address.IPv6Net.String(), "-interface", t.name).CombinedOutput(); err != nil {
|
|
||||||
log.Warnf("failed to add route %s via %s, continuing v4-only: %s: %v", t.address.IPv6Net, t.name, string(out), err)
|
|
||||||
t.address.ClearIPv6()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -151,11 +151,8 @@ func (t *TunDevice) MTU() uint16 {
|
|||||||
return t.mtu
|
return t.mtu
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateAddr updates the device address. On iOS the tunnel is managed by the
|
func (t *TunDevice) UpdateAddr(_ wgaddr.Address) error {
|
||||||
// NetworkExtension, so we only store the new value. The extension picks up the
|
// todo implement
|
||||||
// change on the next tunnel reconfiguration.
|
|
||||||
func (t *TunDevice) UpdateAddr(addr wgaddr.Address) error {
|
|
||||||
t.address = addr
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -173,7 +173,7 @@ func (t *TunKernelDevice) FilteredDevice() *FilteredDevice {
|
|||||||
|
|
||||||
// assignAddr Adds IP address to the tunnel interface
|
// assignAddr Adds IP address to the tunnel interface
|
||||||
func (t *TunKernelDevice) assignAddr() error {
|
func (t *TunKernelDevice) assignAddr() error {
|
||||||
return t.link.assignAddr(&t.address)
|
return t.link.assignAddr(t.address)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TunKernelDevice) GetNet() *netstack.Net {
|
func (t *TunKernelDevice) GetNet() *netstack.Net {
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package device
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.zx2c4.com/wireguard/conn"
|
"golang.zx2c4.com/wireguard/conn"
|
||||||
@@ -64,12 +63,8 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) {
|
|||||||
return nil, fmt.Errorf("last ip: %w", err)
|
return nil, fmt.Errorf("last ip: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
addresses := []netip.Addr{t.address.IP}
|
log.Debugf("netstack using address: %s", t.address.IP)
|
||||||
if t.address.HasIPv6() {
|
t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, int(t.mtu))
|
||||||
addresses = append(addresses, t.address.IPv6)
|
|
||||||
}
|
|
||||||
log.Debugf("netstack using addresses: %v", addresses)
|
|
||||||
t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, addresses, dnsAddr, int(t.mtu))
|
|
||||||
log.Debugf("netstack using dns address: %s", dnsAddr)
|
log.Debugf("netstack using dns address: %s", dnsAddr)
|
||||||
tunIface, net, err := t.nsTun.Create()
|
tunIface, net, err := t.nsTun.Create()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TunDevice struct {
|
type USPDevice struct {
|
||||||
name string
|
name string
|
||||||
address wgaddr.Address
|
address wgaddr.Address
|
||||||
port int
|
port int
|
||||||
@@ -30,10 +30,10 @@ type TunDevice struct {
|
|||||||
configurer WGConfigurer
|
configurer WGConfigurer
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind) *TunDevice {
|
func NewUSPDevice(name string, address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind) *USPDevice {
|
||||||
log.Infof("using userspace bind mode")
|
log.Infof("using userspace bind mode")
|
||||||
|
|
||||||
return &TunDevice{
|
return &USPDevice{
|
||||||
name: name,
|
name: name,
|
||||||
address: address,
|
address: address,
|
||||||
port: port,
|
port: port,
|
||||||
@@ -43,7 +43,7 @@ func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TunDevice) Create() (WGConfigurer, error) {
|
func (t *USPDevice) Create() (WGConfigurer, error) {
|
||||||
log.Info("create tun interface")
|
log.Info("create tun interface")
|
||||||
tunIface, err := tun.CreateTUN(t.name, int(t.mtu))
|
tunIface, err := tun.CreateTUN(t.name, int(t.mtu))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -75,7 +75,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
|
|||||||
return t.configurer, nil
|
return t.configurer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
func (t *USPDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||||
if t.device == nil {
|
if t.device == nil {
|
||||||
return nil, fmt.Errorf("device is not ready yet")
|
return nil, fmt.Errorf("device is not ready yet")
|
||||||
}
|
}
|
||||||
@@ -95,12 +95,12 @@ func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
|||||||
return udpMux, nil
|
return udpMux, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TunDevice) UpdateAddr(address wgaddr.Address) error {
|
func (t *USPDevice) UpdateAddr(address wgaddr.Address) error {
|
||||||
t.address = address
|
t.address = address
|
||||||
return t.assignAddr()
|
return t.assignAddr()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TunDevice) Close() error {
|
func (t *USPDevice) Close() error {
|
||||||
if t.configurer != nil {
|
if t.configurer != nil {
|
||||||
t.configurer.Close()
|
t.configurer.Close()
|
||||||
}
|
}
|
||||||
@@ -115,39 +115,39 @@ func (t *TunDevice) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TunDevice) WgAddress() wgaddr.Address {
|
func (t *USPDevice) WgAddress() wgaddr.Address {
|
||||||
return t.address
|
return t.address
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TunDevice) MTU() uint16 {
|
func (t *USPDevice) MTU() uint16 {
|
||||||
return t.mtu
|
return t.mtu
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TunDevice) DeviceName() string {
|
func (t *USPDevice) DeviceName() string {
|
||||||
return t.name
|
return t.name
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TunDevice) FilteredDevice() *FilteredDevice {
|
func (t *USPDevice) FilteredDevice() *FilteredDevice {
|
||||||
return t.filteredDevice
|
return t.filteredDevice
|
||||||
}
|
}
|
||||||
|
|
||||||
// Device returns the wireguard device
|
// Device returns the wireguard device
|
||||||
func (t *TunDevice) Device() *device.Device {
|
func (t *USPDevice) Device() *device.Device {
|
||||||
return t.device
|
return t.device
|
||||||
}
|
}
|
||||||
|
|
||||||
// assignAddr Adds IP address to the tunnel interface
|
// assignAddr Adds IP address to the tunnel interface
|
||||||
func (t *TunDevice) assignAddr() error {
|
func (t *USPDevice) assignAddr() error {
|
||||||
link := newWGLink(t.name)
|
link := newWGLink(t.name)
|
||||||
|
|
||||||
return link.assignAddr(&t.address)
|
return link.assignAddr(t.address)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TunDevice) GetNet() *netstack.Net {
|
func (t *USPDevice) GetNet() *netstack.Net {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetICEBind returns the ICEBind instance
|
// GetICEBind returns the ICEBind instance
|
||||||
func (t *TunDevice) GetICEBind() EndpointManager {
|
func (t *USPDevice) GetICEBind() EndpointManager {
|
||||||
return t.iceBind
|
return t.iceBind
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -87,21 +87,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
|
|||||||
err = nbiface.Set()
|
err = nbiface.Set()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.device.Close()
|
t.device.Close()
|
||||||
return nil, fmt.Errorf("set IPv4 interface MTU: %s", err)
|
return nil, fmt.Errorf("got error when getting setting the interface mtu: %s", err)
|
||||||
}
|
|
||||||
|
|
||||||
if t.address.HasIPv6() {
|
|
||||||
nbiface6, err := luid.IPInterface(windows.AF_INET6)
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed to get IPv6 interface for MTU, continuing v4-only: %v", err)
|
|
||||||
t.address.ClearIPv6()
|
|
||||||
} else {
|
|
||||||
nbiface6.NLMTU = uint32(t.mtu)
|
|
||||||
if err := nbiface6.Set(); err != nil {
|
|
||||||
log.Warnf("failed to set IPv6 interface MTU, continuing v4-only: %v", err)
|
|
||||||
t.address.ClearIPv6()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
err = t.assignAddr()
|
err = t.assignAddr()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -192,21 +178,8 @@ func (t *TunDevice) GetInterfaceGUIDString() (string, error) {
|
|||||||
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided
|
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided
|
||||||
func (t *TunDevice) assignAddr() error {
|
func (t *TunDevice) assignAddr() error {
|
||||||
luid := winipcfg.LUID(t.nativeTunDevice.LUID())
|
luid := winipcfg.LUID(t.nativeTunDevice.LUID())
|
||||||
|
log.Debugf("adding address %s to interface: %s", t.address.IP, t.name)
|
||||||
v4Prefix := t.address.Prefix()
|
return luid.SetIPAddresses([]netip.Prefix{netip.MustParsePrefix(t.address.String())})
|
||||||
if t.address.HasIPv6() {
|
|
||||||
v6Prefix := t.address.IPv6Prefix()
|
|
||||||
log.Debugf("adding addresses %s, %s to interface: %s", v4Prefix, v6Prefix, t.name)
|
|
||||||
if err := luid.SetIPAddresses([]netip.Prefix{v4Prefix, v6Prefix}); err != nil {
|
|
||||||
log.Warnf("failed to assign dual-stack addresses, retrying v4-only: %v", err)
|
|
||||||
t.address.ClearIPv6()
|
|
||||||
return luid.SetIPAddresses([]netip.Prefix{v4Prefix})
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("adding address %s to interface: %s", v4Prefix, t.name)
|
|
||||||
return luid.SetIPAddresses([]netip.Prefix{v4Prefix})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TunDevice) GetNet() *netstack.Net {
|
func (t *TunDevice) GetNet() *netstack.Net {
|
||||||
|
|||||||
8
client/iface/device/kernel_module.go
Normal file
8
client/iface/device/kernel_module.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
//go:build (!linux && !freebsd) || android
|
||||||
|
|
||||||
|
package device
|
||||||
|
|
||||||
|
// WireGuardModuleIsLoaded check if we can load WireGuard mod (linux only)
|
||||||
|
func WireGuardModuleIsLoaded() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
18
client/iface/device/kernel_module_freebsd.go
Normal file
18
client/iface/device/kernel_module_freebsd.go
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
package device
|
||||||
|
|
||||||
|
// WireGuardModuleIsLoaded check if kernel support wireguard
|
||||||
|
func WireGuardModuleIsLoaded() bool {
|
||||||
|
// Despite the fact FreeBSD natively support Wireguard (https://github.com/WireGuard/wireguard-freebsd)
|
||||||
|
// we are currently do not use it, since it is required to add wireguard kernel support to
|
||||||
|
// - https://github.com/netbirdio/netbird/tree/main/sharedsock
|
||||||
|
// - https://github.com/mdlayher/socket
|
||||||
|
// TODO: implement kernel space
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModuleTunIsLoaded check if tun module exist, if is not attempt to load it
|
||||||
|
func ModuleTunIsLoaded() bool {
|
||||||
|
// Assume tun supported by freebsd kernel by default
|
||||||
|
// TODO: implement check for module loaded in kernel or build-it
|
||||||
|
return true
|
||||||
|
}
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
//go:build !linux || android
|
|
||||||
|
|
||||||
package device
|
|
||||||
|
|
||||||
// WireGuardModuleIsLoaded reports whether the kernel WireGuard module is available.
|
|
||||||
func WireGuardModuleIsLoaded() bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// ModuleTunIsLoaded reports whether the tun device is available.
|
|
||||||
func ModuleTunIsLoaded() bool {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
@@ -2,7 +2,6 @@ package device
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"os/exec"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
@@ -58,32 +57,32 @@ func (l *wgLink) up() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *wgLink) assignAddr(address *wgaddr.Address) error {
|
func (l *wgLink) assignAddr(address wgaddr.Address) error {
|
||||||
link, err := freebsd.LinkByName(l.name)
|
link, err := freebsd.LinkByName(l.name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("link by name: %w", err)
|
return fmt.Errorf("link by name: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ip := address.IP.String()
|
||||||
|
|
||||||
|
// Convert prefix length to hex netmask
|
||||||
prefixLen := address.Network.Bits()
|
prefixLen := address.Network.Bits()
|
||||||
|
if !address.IP.Is4() {
|
||||||
|
return fmt.Errorf("IPv6 not supported for interface assignment")
|
||||||
|
}
|
||||||
|
|
||||||
maskBits := uint32(0xffffffff) << (32 - prefixLen)
|
maskBits := uint32(0xffffffff) << (32 - prefixLen)
|
||||||
mask := fmt.Sprintf("0x%08x", maskBits)
|
mask := fmt.Sprintf("0x%08x", maskBits)
|
||||||
|
|
||||||
log.Infof("assign addr %s mask %s to %s interface", address.IP, mask, l.name)
|
log.Infof("assign addr %s mask %s to %s interface", ip, mask, l.name)
|
||||||
|
|
||||||
if err := link.AssignAddr(address.IP.String(), mask); err != nil {
|
err = link.AssignAddr(ip, mask)
|
||||||
|
if err != nil {
|
||||||
return fmt.Errorf("assign addr: %w", err)
|
return fmt.Errorf("assign addr: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if address.HasIPv6() {
|
err = link.Up()
|
||||||
log.Infof("assign IPv6 addr %s to %s interface", address.IPv6String(), l.name)
|
if err != nil {
|
||||||
cmd := exec.Command("ifconfig", l.name, "inet6", address.IPv6String())
|
|
||||||
if out, err := cmd.CombinedOutput(); err != nil {
|
|
||||||
log.Warnf("failed to assign IPv6 address %s to %s, continuing v4-only: %s: %v", address.IPv6String(), l.name, string(out), err)
|
|
||||||
address.ClearIPv6()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := link.Up(); err != nil {
|
|
||||||
return fmt.Errorf("up: %w", err)
|
return fmt.Errorf("up: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,8 +4,6 @@ package device
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -94,7 +92,7 @@ func (l *wgLink) up() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *wgLink) assignAddr(address *wgaddr.Address) error {
|
func (l *wgLink) assignAddr(address wgaddr.Address) error {
|
||||||
//delete existing addresses
|
//delete existing addresses
|
||||||
list, err := netlink.AddrList(l, 0)
|
list, err := netlink.AddrList(l, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -112,16 +110,20 @@ func (l *wgLink) assignAddr(address *wgaddr.Address) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
name := l.attrs.Name
|
name := l.attrs.Name
|
||||||
|
addrStr := address.String()
|
||||||
|
|
||||||
if err := l.addAddr(name, address.Prefix()); err != nil {
|
log.Debugf("adding address %s to interface: %s", addrStr, name)
|
||||||
return err
|
|
||||||
|
addr, err := netlink.ParseAddr(addrStr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parse addr: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if address.HasIPv6() {
|
err = netlink.AddrAdd(l, addr)
|
||||||
if err := l.addAddr(name, address.IPv6Prefix()); err != nil {
|
if os.IsExist(err) {
|
||||||
log.Warnf("failed to assign IPv6 address %s to %s, continuing v4-only: %v", address.IPv6Prefix(), name, err)
|
log.Infof("interface %s already has the address: %s", name, addrStr)
|
||||||
address.ClearIPv6()
|
} else if err != nil {
|
||||||
}
|
return fmt.Errorf("add addr: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// On linux, the link must be brought up
|
// On linux, the link must be brought up
|
||||||
@@ -131,22 +133,3 @@ func (l *wgLink) assignAddr(address *wgaddr.Address) error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *wgLink) addAddr(ifaceName string, prefix netip.Prefix) error {
|
|
||||||
log.Debugf("adding address %s to interface: %s", prefix, ifaceName)
|
|
||||||
|
|
||||||
addr := &netlink.Addr{
|
|
||||||
IPNet: &net.IPNet{
|
|
||||||
IP: prefix.Addr().AsSlice(),
|
|
||||||
Mask: net.CIDRMask(prefix.Bits(), prefix.Addr().BitLen()),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := netlink.AddrAdd(l, addr); os.IsExist(err) {
|
|
||||||
log.Infof("interface %s already has the address: %s", ifaceName, prefix)
|
|
||||||
} else if err != nil {
|
|
||||||
return fmt.Errorf("add addr %s: %w", prefix, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ type wgProxyFactory interface {
|
|||||||
|
|
||||||
type WGIFaceOpts struct {
|
type WGIFaceOpts struct {
|
||||||
IFaceName string
|
IFaceName string
|
||||||
Address wgaddr.Address
|
Address string
|
||||||
WGPort int
|
WGPort int
|
||||||
WGPrivKey string
|
WGPrivKey string
|
||||||
MTU uint16
|
MTU uint16
|
||||||
@@ -141,11 +141,16 @@ func (w *WGIface) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// UpdateAddr updates address of the interface
|
// UpdateAddr updates address of the interface
|
||||||
func (w *WGIface) UpdateAddr(newAddr wgaddr.Address) error {
|
func (w *WGIface) UpdateAddr(newAddr string) error {
|
||||||
w.mu.Lock()
|
w.mu.Lock()
|
||||||
defer w.mu.Unlock()
|
defer w.mu.Unlock()
|
||||||
|
|
||||||
return w.tun.UpdateAddr(newAddr)
|
addr, err := wgaddr.ParseWGAddress(newAddr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return w.tun.UpdateAddr(addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdatePeer updates existing Wireguard Peer or creates a new one if doesn't exist
|
// UpdatePeer updates existing Wireguard Peer or creates a new one if doesn't exist
|
||||||
|
|||||||
@@ -4,17 +4,23 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewWGIFace Creates a new WireGuard interface instance
|
// NewWGIFace Creates a new WireGuard interface instance
|
||||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
|
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
|
||||||
|
|
||||||
if netstack.IsEnabled() {
|
if netstack.IsEnabled() {
|
||||||
wgIFace := &WGIface{
|
wgIFace := &WGIface{
|
||||||
userspaceBind: true,
|
userspaceBind: true,
|
||||||
tun: device.NewNetstackDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()),
|
tun: device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()),
|
||||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
|
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
|
||||||
}
|
}
|
||||||
return wgIFace, nil
|
return wgIFace, nil
|
||||||
@@ -22,7 +28,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
|||||||
|
|
||||||
wgIFace := &WGIface{
|
wgIFace := &WGIface{
|
||||||
userspaceBind: true,
|
userspaceBind: true,
|
||||||
tun: device.NewTunDevice(opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter, opts.DisableDNS),
|
tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter, opts.DisableDNS),
|
||||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
|
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
|
||||||
}
|
}
|
||||||
return wgIFace, nil
|
return wgIFace, nil
|
||||||
|
|||||||
35
client/iface/iface_new_darwin.go
Normal file
35
client/iface/iface_new_darwin.go
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
//go:build !ios
|
||||||
|
|
||||||
|
package iface
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewWGIFace Creates a new WireGuard interface instance
|
||||||
|
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||||
|
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
|
||||||
|
|
||||||
|
var tun WGTunDevice
|
||||||
|
if netstack.IsEnabled() {
|
||||||
|
tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
|
||||||
|
} else {
|
||||||
|
tun = device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
|
||||||
|
}
|
||||||
|
|
||||||
|
wgIFace := &WGIface{
|
||||||
|
userspaceBind: true,
|
||||||
|
tun: tun,
|
||||||
|
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
|
||||||
|
}
|
||||||
|
return wgIFace, nil
|
||||||
|
}
|
||||||
41
client/iface/iface_new_freebsd.go
Normal file
41
client/iface/iface_new_freebsd.go
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
//go:build freebsd
|
||||||
|
|
||||||
|
package iface
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewWGIFace Creates a new WireGuard interface instance
|
||||||
|
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||||
|
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
wgIFace := &WGIface{}
|
||||||
|
|
||||||
|
if netstack.IsEnabled() {
|
||||||
|
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
|
||||||
|
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
|
||||||
|
wgIFace.userspaceBind = true
|
||||||
|
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
|
||||||
|
return wgIFace, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if device.ModuleTunIsLoaded() {
|
||||||
|
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
|
||||||
|
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
|
||||||
|
wgIFace.userspaceBind = true
|
||||||
|
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
|
||||||
|
return wgIFace, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("couldn't check or load tun module")
|
||||||
|
}
|
||||||
@@ -5,15 +5,21 @@ package iface
|
|||||||
import (
|
import (
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewWGIFace Creates a new WireGuard interface instance
|
// NewWGIFace Creates a new WireGuard interface instance
|
||||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
|
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
|
||||||
|
|
||||||
wgIFace := &WGIface{
|
wgIFace := &WGIface{
|
||||||
tun: device.NewTunDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunFd),
|
tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunFd),
|
||||||
userspaceBind: true,
|
userspaceBind: true,
|
||||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
|
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,15 +4,21 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewWGIFace creates a new WireGuard interface for WASM (always uses netstack mode)
|
// NewWGIFace creates a new WireGuard interface for WASM (always uses netstack mode)
|
||||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||||
|
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
relayBind := bind.NewRelayBindJS()
|
relayBind := bind.NewRelayBindJS()
|
||||||
|
|
||||||
wgIface := &WGIface{
|
wgIface := &WGIface{
|
||||||
tun: device.NewNetstackDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, relayBind, netstack.ListenAddr()),
|
tun: device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, relayBind, netstack.ListenAddr()),
|
||||||
userspaceBind: true,
|
userspaceBind: true,
|
||||||
wgProxyFactory: wgproxy.NewUSPFactory(relayBind, opts.MTU),
|
wgProxyFactory: wgproxy.NewUSPFactory(relayBind, opts.MTU),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,40 +3,44 @@
|
|||||||
package iface
|
package iface
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"fmt"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewWGIFace Creates a new WireGuard interface instance
|
// NewWGIFace Creates a new WireGuard interface instance
|
||||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||||
|
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
wgIFace := &WGIface{}
|
||||||
|
|
||||||
if netstack.IsEnabled() {
|
if netstack.IsEnabled() {
|
||||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
|
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
|
||||||
return &WGIface{
|
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
|
||||||
tun: device.NewNetstackDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()),
|
wgIFace.userspaceBind = true
|
||||||
userspaceBind: true,
|
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
|
||||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
|
return wgIFace, nil
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if device.WireGuardModuleIsLoaded() {
|
if device.WireGuardModuleIsLoaded() {
|
||||||
return &WGIface{
|
wgIFace.tun = device.NewKernelDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, opts.TransportNet)
|
||||||
tun: device.NewKernelDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, opts.TransportNet),
|
wgIFace.wgProxyFactory = wgproxy.NewKernelFactory(opts.WGPort, opts.MTU)
|
||||||
wgProxyFactory: wgproxy.NewKernelFactory(opts.WGPort, opts.MTU),
|
return wgIFace, nil
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if device.ModuleTunIsLoaded() {
|
if device.ModuleTunIsLoaded() {
|
||||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
|
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
|
||||||
return &WGIface{
|
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
|
||||||
tun: device.NewTunDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind),
|
wgIFace.userspaceBind = true
|
||||||
userspaceBind: true,
|
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
|
||||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
|
return wgIFace, nil
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, errors.New("tun module not available")
|
return nil, fmt.Errorf("couldn't check or load tun module")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,28 +1,33 @@
|
|||||||
//go:build !linux && !ios && !android && !js
|
|
||||||
|
|
||||||
package iface
|
package iface
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
|
wgaddr "github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewWGIFace Creates a new WireGuard interface instance
|
// NewWGIFace Creates a new WireGuard interface instance
|
||||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
|
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
|
||||||
|
|
||||||
var tun WGTunDevice
|
var tun WGTunDevice
|
||||||
if netstack.IsEnabled() {
|
if netstack.IsEnabled() {
|
||||||
tun = device.NewNetstackDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
|
tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
|
||||||
} else {
|
} else {
|
||||||
tun = device.NewTunDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
|
tun = device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &WGIface{
|
wgIFace := &WGIface{
|
||||||
userspaceBind: true,
|
userspaceBind: true,
|
||||||
tun: tun,
|
tun: tun,
|
||||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
|
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
|
||||||
}, nil
|
}
|
||||||
|
return wgIFace, nil
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -16,7 +16,6 @@ import (
|
|||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -49,7 +48,7 @@ func TestWGIface_UpdateAddr(t *testing.T) {
|
|||||||
|
|
||||||
opts := WGIFaceOpts{
|
opts := WGIFaceOpts{
|
||||||
IFaceName: ifaceName,
|
IFaceName: ifaceName,
|
||||||
Address: wgaddr.MustParseWGAddress(addr),
|
Address: addr,
|
||||||
WGPort: wgPort,
|
WGPort: wgPort,
|
||||||
WGPrivKey: key,
|
WGPrivKey: key,
|
||||||
MTU: DefaultMTU,
|
MTU: DefaultMTU,
|
||||||
@@ -85,7 +84,7 @@ func TestWGIface_UpdateAddr(t *testing.T) {
|
|||||||
|
|
||||||
//update WireGuard address
|
//update WireGuard address
|
||||||
addr = "100.64.0.2/8"
|
addr = "100.64.0.2/8"
|
||||||
err = iface.UpdateAddr(wgaddr.MustParseWGAddress(addr))
|
err = iface.UpdateAddr(addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -131,7 +130,7 @@ func Test_CreateInterface(t *testing.T) {
|
|||||||
}
|
}
|
||||||
opts := WGIFaceOpts{
|
opts := WGIFaceOpts{
|
||||||
IFaceName: ifaceName,
|
IFaceName: ifaceName,
|
||||||
Address: wgaddr.MustParseWGAddress(wgIP),
|
Address: wgIP,
|
||||||
WGPort: 33100,
|
WGPort: 33100,
|
||||||
WGPrivKey: key,
|
WGPrivKey: key,
|
||||||
MTU: DefaultMTU,
|
MTU: DefaultMTU,
|
||||||
@@ -175,7 +174,7 @@ func Test_Close(t *testing.T) {
|
|||||||
|
|
||||||
opts := WGIFaceOpts{
|
opts := WGIFaceOpts{
|
||||||
IFaceName: ifaceName,
|
IFaceName: ifaceName,
|
||||||
Address: wgaddr.MustParseWGAddress(wgIP),
|
Address: wgIP,
|
||||||
WGPort: wgPort,
|
WGPort: wgPort,
|
||||||
WGPrivKey: key,
|
WGPrivKey: key,
|
||||||
MTU: DefaultMTU,
|
MTU: DefaultMTU,
|
||||||
@@ -220,7 +219,7 @@ func TestRecreation(t *testing.T) {
|
|||||||
|
|
||||||
opts := WGIFaceOpts{
|
opts := WGIFaceOpts{
|
||||||
IFaceName: ifaceName,
|
IFaceName: ifaceName,
|
||||||
Address: wgaddr.MustParseWGAddress(wgIP),
|
Address: wgIP,
|
||||||
WGPort: wgPort,
|
WGPort: wgPort,
|
||||||
WGPrivKey: key,
|
WGPrivKey: key,
|
||||||
MTU: DefaultMTU,
|
MTU: DefaultMTU,
|
||||||
@@ -292,7 +291,7 @@ func Test_ConfigureInterface(t *testing.T) {
|
|||||||
}
|
}
|
||||||
opts := WGIFaceOpts{
|
opts := WGIFaceOpts{
|
||||||
IFaceName: ifaceName,
|
IFaceName: ifaceName,
|
||||||
Address: wgaddr.MustParseWGAddress(wgIP),
|
Address: wgIP,
|
||||||
WGPort: wgPort,
|
WGPort: wgPort,
|
||||||
WGPrivKey: key,
|
WGPrivKey: key,
|
||||||
MTU: DefaultMTU,
|
MTU: DefaultMTU,
|
||||||
@@ -348,7 +347,7 @@ func Test_UpdatePeer(t *testing.T) {
|
|||||||
|
|
||||||
opts := WGIFaceOpts{
|
opts := WGIFaceOpts{
|
||||||
IFaceName: ifaceName,
|
IFaceName: ifaceName,
|
||||||
Address: wgaddr.MustParseWGAddress(wgIP),
|
Address: wgIP,
|
||||||
WGPort: 33100,
|
WGPort: 33100,
|
||||||
WGPrivKey: key,
|
WGPrivKey: key,
|
||||||
MTU: DefaultMTU,
|
MTU: DefaultMTU,
|
||||||
@@ -418,7 +417,7 @@ func Test_RemovePeer(t *testing.T) {
|
|||||||
|
|
||||||
opts := WGIFaceOpts{
|
opts := WGIFaceOpts{
|
||||||
IFaceName: ifaceName,
|
IFaceName: ifaceName,
|
||||||
Address: wgaddr.MustParseWGAddress(wgIP),
|
Address: wgIP,
|
||||||
WGPort: 33100,
|
WGPort: 33100,
|
||||||
WGPrivKey: key,
|
WGPrivKey: key,
|
||||||
MTU: DefaultMTU,
|
MTU: DefaultMTU,
|
||||||
@@ -483,7 +482,7 @@ func Test_ConnectPeers(t *testing.T) {
|
|||||||
|
|
||||||
optsPeer1 := WGIFaceOpts{
|
optsPeer1 := WGIFaceOpts{
|
||||||
IFaceName: peer1ifaceName,
|
IFaceName: peer1ifaceName,
|
||||||
Address: wgaddr.MustParseWGAddress(peer1wgIP.String()),
|
Address: peer1wgIP.String(),
|
||||||
WGPort: peer1wgPort,
|
WGPort: peer1wgPort,
|
||||||
WGPrivKey: peer1Key.String(),
|
WGPrivKey: peer1Key.String(),
|
||||||
MTU: DefaultMTU,
|
MTU: DefaultMTU,
|
||||||
@@ -523,7 +522,7 @@ func Test_ConnectPeers(t *testing.T) {
|
|||||||
|
|
||||||
optsPeer2 := WGIFaceOpts{
|
optsPeer2 := WGIFaceOpts{
|
||||||
IFaceName: peer2ifaceName,
|
IFaceName: peer2ifaceName,
|
||||||
Address: wgaddr.MustParseWGAddress(peer2wgIP.String()),
|
Address: peer2wgIP.String(),
|
||||||
WGPort: peer2wgPort,
|
WGPort: peer2wgPort,
|
||||||
WGPrivKey: peer2Key.String(),
|
WGPrivKey: peer2Key.String(),
|
||||||
MTU: DefaultMTU,
|
MTU: DefaultMTU,
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import (
|
|||||||
const EnvSkipProxy = "NB_NETSTACK_SKIP_PROXY"
|
const EnvSkipProxy = "NB_NETSTACK_SKIP_PROXY"
|
||||||
|
|
||||||
type NetStackTun struct { //nolint:revive
|
type NetStackTun struct { //nolint:revive
|
||||||
addresses []netip.Addr
|
address netip.Addr
|
||||||
dnsAddress netip.Addr
|
dnsAddress netip.Addr
|
||||||
mtu int
|
mtu int
|
||||||
listenAddress string
|
listenAddress string
|
||||||
@@ -22,9 +22,9 @@ type NetStackTun struct { //nolint:revive
|
|||||||
tundev tun.Device
|
tundev tun.Device
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewNetStackTun(listenAddress string, addresses []netip.Addr, dnsAddress netip.Addr, mtu int) *NetStackTun {
|
func NewNetStackTun(listenAddress string, address netip.Addr, dnsAddress netip.Addr, mtu int) *NetStackTun {
|
||||||
return &NetStackTun{
|
return &NetStackTun{
|
||||||
addresses: addresses,
|
address: address,
|
||||||
dnsAddress: dnsAddress,
|
dnsAddress: dnsAddress,
|
||||||
mtu: mtu,
|
mtu: mtu,
|
||||||
listenAddress: listenAddress,
|
listenAddress: listenAddress,
|
||||||
@@ -33,7 +33,7 @@ func NewNetStackTun(listenAddress string, addresses []netip.Addr, dnsAddress net
|
|||||||
|
|
||||||
func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
|
func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
|
||||||
nsTunDev, tunNet, err := netstack.CreateNetTUN(
|
nsTunDev, tunNet, err := netstack.CreateNetTUN(
|
||||||
t.addresses,
|
[]netip.Addr{t.address},
|
||||||
[]netip.Addr{t.dnsAddress},
|
[]netip.Addr{t.dnsAddress},
|
||||||
t.mtu)
|
t.mtu)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -3,18 +3,12 @@ package wgaddr
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/shared/netiputil"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Address WireGuard parsed address
|
// Address WireGuard parsed address
|
||||||
type Address struct {
|
type Address struct {
|
||||||
IP netip.Addr
|
IP netip.Addr
|
||||||
Network netip.Prefix
|
Network netip.Prefix
|
||||||
|
|
||||||
// IPv6 overlay address, if assigned.
|
|
||||||
IPv6 netip.Addr
|
|
||||||
IPv6Net netip.Prefix
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address
|
// ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address
|
||||||
@@ -29,60 +23,6 @@ func ParseWGAddress(address string) (Address, error) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// HasIPv6 reports whether a v6 overlay address is assigned.
|
|
||||||
func (addr Address) HasIPv6() bool {
|
|
||||||
return addr.IPv6.IsValid()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (addr Address) String() string {
|
func (addr Address) String() string {
|
||||||
return addr.Prefix().String()
|
return fmt.Sprintf("%s/%d", addr.IP.String(), addr.Network.Bits())
|
||||||
}
|
|
||||||
|
|
||||||
// IPv6String returns the v6 address in CIDR notation, or empty string if none.
|
|
||||||
func (addr Address) IPv6String() string {
|
|
||||||
if !addr.HasIPv6() {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return addr.IPv6Prefix().String()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Prefix returns the v4 host address with its network prefix length (e.g. 100.64.0.1/16).
|
|
||||||
func (addr Address) Prefix() netip.Prefix {
|
|
||||||
return netip.PrefixFrom(addr.IP, addr.Network.Bits())
|
|
||||||
}
|
|
||||||
|
|
||||||
// IPv6Prefix returns the v6 host address with its network prefix length, or a zero prefix if none.
|
|
||||||
func (addr Address) IPv6Prefix() netip.Prefix {
|
|
||||||
if !addr.HasIPv6() {
|
|
||||||
return netip.Prefix{}
|
|
||||||
}
|
|
||||||
return netip.PrefixFrom(addr.IPv6, addr.IPv6Net.Bits())
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetIPv6FromCompact decodes a compact prefix (5 or 17 bytes) and sets the IPv6 fields.
|
|
||||||
// Returns an error if the bytes are invalid. A nil or empty input is a no-op.
|
|
||||||
//
|
|
||||||
//nolint:recvcheck
|
|
||||||
func (addr *Address) SetIPv6FromCompact(raw []byte) error {
|
|
||||||
if len(raw) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
prefix, err := netiputil.DecodePrefix(raw)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("decode v6 overlay address: %w", err)
|
|
||||||
}
|
|
||||||
if !prefix.Addr().Is6() {
|
|
||||||
return fmt.Errorf("expected IPv6 address, got %s", prefix.Addr())
|
|
||||||
}
|
|
||||||
addr.IPv6 = prefix.Addr()
|
|
||||||
addr.IPv6Net = prefix.Masked()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClearIPv6 removes the IPv6 overlay address, leaving only v4.
|
|
||||||
//
|
|
||||||
//nolint:recvcheck
|
|
||||||
func (addr *Address) ClearIPv6() {
|
|
||||||
addr.IPv6 = netip.Addr{}
|
|
||||||
addr.IPv6Net = netip.Prefix{}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,10 +0,0 @@
|
|||||||
package wgaddr
|
|
||||||
|
|
||||||
// MustParseWGAddress parses and returns a WG Address, panicking on error.
|
|
||||||
func MustParseWGAddress(address string) Address {
|
|
||||||
a, err := ParseWGAddress(address)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
return a
|
|
||||||
}
|
|
||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -196,25 +196,18 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// fakeAddress returns a fake address that is used as an identifier for the peer.
|
// fakeAddress returns a fake address that is used to as an identifier for the peer.
|
||||||
// The fake address is in the format of 127.1.x.x where x.x is derived from the
|
// The fake address is in the format of 127.1.x.x where x.x is the last two octets of the peer address.
|
||||||
// last two bytes of the peer address (works for both IPv4 and IPv6).
|
|
||||||
func fakeAddress(peerAddress *net.UDPAddr) (*netip.AddrPort, error) {
|
func fakeAddress(peerAddress *net.UDPAddr) (*netip.AddrPort, error) {
|
||||||
if peerAddress == nil {
|
octets := strings.Split(peerAddress.IP.String(), ".")
|
||||||
return nil, fmt.Errorf("nil peer address")
|
if len(octets) != 4 {
|
||||||
}
|
|
||||||
if peerAddress.Port < 0 || peerAddress.Port > 65535 {
|
|
||||||
return nil, fmt.Errorf("invalid UDP port: %d", peerAddress.Port)
|
|
||||||
}
|
|
||||||
|
|
||||||
addr, ok := netip.AddrFromSlice(peerAddress.IP)
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("invalid IP format")
|
return nil, fmt.Errorf("invalid IP format")
|
||||||
}
|
}
|
||||||
addr = addr.Unmap()
|
|
||||||
|
|
||||||
raw := addr.As16()
|
fakeIP, err := netip.ParseAddr(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3]))
|
||||||
fakeIP := netip.AddrFrom4([4]byte{127, 1, raw[14], raw[15]})
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parse new IP: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
netipAddr := netip.AddrPortFrom(fakeIP, uint16(peerAddress.Port))
|
netipAddr := netip.AddrPortFrom(fakeIP, uint16(peerAddress.Port))
|
||||||
return &netipAddr, nil
|
return &netipAddr, nil
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -18,7 +19,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
"github.com/netbirdio/netbird/shared/netiputil"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrSourceRangesEmpty = errors.New("sources range is empty")
|
var ErrSourceRangesEmpty = errors.New("sources range is empty")
|
||||||
@@ -105,10 +105,6 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
|||||||
newRulePairs := make(map[id.RuleID][]firewall.Rule)
|
newRulePairs := make(map[id.RuleID][]firewall.Rule)
|
||||||
ipsetByRuleSelectors := make(map[string]string)
|
ipsetByRuleSelectors := make(map[string]string)
|
||||||
|
|
||||||
// TODO: deny rules should be fatal: if a deny rule fails to apply, we must
|
|
||||||
// roll back all allow rules to avoid a fail-open where allowed traffic bypasses
|
|
||||||
// the missing deny. Currently we accumulate errors and continue.
|
|
||||||
var merr *multierror.Error
|
|
||||||
for _, r := range rules {
|
for _, r := range rules {
|
||||||
// if this rule is member of rule selection with more than DefaultIPsCountForSet
|
// if this rule is member of rule selection with more than DefaultIPsCountForSet
|
||||||
// it's IP address can be used in the ipset for firewall manager which supports it
|
// it's IP address can be used in the ipset for firewall manager which supports it
|
||||||
@@ -121,8 +117,9 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
|||||||
}
|
}
|
||||||
pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName)
|
pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("apply firewall rule: %w", err))
|
log.Errorf("failed to apply firewall rule: %+v, %v", r, err)
|
||||||
continue
|
d.rollBack(newRulePairs)
|
||||||
|
break
|
||||||
}
|
}
|
||||||
if len(rulePair) > 0 {
|
if len(rulePair) > 0 {
|
||||||
d.peerRulesPairs[pairID] = rulePair
|
d.peerRulesPairs[pairID] = rulePair
|
||||||
@@ -130,10 +127,6 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if merr != nil {
|
|
||||||
log.Errorf("failed to apply %d peer ACL rule(s): %v", merr.Len(), nberrors.FormatErrorOrNil(merr))
|
|
||||||
}
|
|
||||||
|
|
||||||
for pairID, rules := range d.peerRulesPairs {
|
for pairID, rules := range d.peerRulesPairs {
|
||||||
if _, ok := newRulePairs[pairID]; !ok {
|
if _, ok := newRulePairs[pairID]; !ok {
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
@@ -223,9 +216,9 @@ func (d *DefaultManager) protoRuleToFirewallRule(
|
|||||||
r *mgmProto.FirewallRule,
|
r *mgmProto.FirewallRule,
|
||||||
ipsetName string,
|
ipsetName string,
|
||||||
) (id.RuleID, []firewall.Rule, error) {
|
) (id.RuleID, []firewall.Rule, error) {
|
||||||
ip, err := extractRuleIP(r)
|
ip := net.ParseIP(r.PeerIP)
|
||||||
if err != nil {
|
if ip == nil {
|
||||||
return "", nil, err
|
return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule")
|
||||||
}
|
}
|
||||||
|
|
||||||
protocol, err := convertToFirewallProtocol(r.Protocol)
|
protocol, err := convertToFirewallProtocol(r.Protocol)
|
||||||
@@ -296,13 +289,13 @@ func portInfoEmpty(portInfo *mgmProto.PortInfo) bool {
|
|||||||
|
|
||||||
func (d *DefaultManager) addInRules(
|
func (d *DefaultManager) addInRules(
|
||||||
id []byte,
|
id []byte,
|
||||||
ip netip.Addr,
|
ip net.IP,
|
||||||
protocol firewall.Protocol,
|
protocol firewall.Protocol,
|
||||||
port *firewall.Port,
|
port *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
ipsetName string,
|
ipsetName string,
|
||||||
) ([]firewall.Rule, error) {
|
) ([]firewall.Rule, error) {
|
||||||
rule, err := d.firewall.AddPeerFiltering(id, ip.AsSlice(), protocol, nil, port, action, ipsetName)
|
rule, err := d.firewall.AddPeerFiltering(id, ip, protocol, nil, port, action, ipsetName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("add firewall rule: %w", err)
|
return nil, fmt.Errorf("add firewall rule: %w", err)
|
||||||
}
|
}
|
||||||
@@ -312,7 +305,7 @@ func (d *DefaultManager) addInRules(
|
|||||||
|
|
||||||
func (d *DefaultManager) addOutRules(
|
func (d *DefaultManager) addOutRules(
|
||||||
id []byte,
|
id []byte,
|
||||||
ip netip.Addr,
|
ip net.IP,
|
||||||
protocol firewall.Protocol,
|
protocol firewall.Protocol,
|
||||||
port *firewall.Port,
|
port *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
@@ -322,7 +315,7 @@ func (d *DefaultManager) addOutRules(
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
rule, err := d.firewall.AddPeerFiltering(id, ip.AsSlice(), protocol, port, nil, action, ipsetName)
|
rule, err := d.firewall.AddPeerFiltering(id, ip, protocol, port, nil, action, ipsetName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("add firewall rule: %w", err)
|
return nil, fmt.Errorf("add firewall rule: %w", err)
|
||||||
}
|
}
|
||||||
@@ -330,9 +323,9 @@ func (d *DefaultManager) addOutRules(
|
|||||||
return rule, nil
|
return rule, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getPeerRuleID returns unique ID for the rule based on its parameters.
|
// getPeerRuleID() returns unique ID for the rule based on its parameters.
|
||||||
func (d *DefaultManager) getPeerRuleID(
|
func (d *DefaultManager) getPeerRuleID(
|
||||||
ip netip.Addr,
|
ip net.IP,
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
direction int,
|
direction int,
|
||||||
port *firewall.Port,
|
port *firewall.Port,
|
||||||
@@ -351,25 +344,15 @@ func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) st
|
|||||||
return fmt.Sprintf("%v:%v:%v:%s:%v", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port, rule.PortInfo)
|
return fmt.Sprintf("%v:%v:%v:%s:%v", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port, rule.PortInfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *DefaultManager) rollBack(newRulePairs map[id.RuleID][]firewall.Rule) {
|
||||||
// extractRuleIP extracts the peer IP from a firewall rule.
|
log.Debugf("rollback ACL to previous state")
|
||||||
// If sourcePrefixes is populated (new management), decode the first entry and use its address.
|
for _, rules := range newRulePairs {
|
||||||
// Otherwise fall back to the deprecated PeerIP string field (old management).
|
for _, rule := range rules {
|
||||||
func extractRuleIP(r *mgmProto.FirewallRule) (netip.Addr, error) {
|
if err := d.firewall.DeletePeerRule(rule); err != nil {
|
||||||
if len(r.SourcePrefixes) > 0 {
|
log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.ID(), err)
|
||||||
addr, err := netiputil.DecodeAddr(r.SourcePrefixes[0])
|
}
|
||||||
if err != nil {
|
|
||||||
return netip.Addr{}, fmt.Errorf("decode source prefix: %w", err)
|
|
||||||
}
|
}
|
||||||
return addr.Unmap(), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//nolint:staticcheck // PeerIP used for backward compatibility with old management
|
|
||||||
addr, err := netip.ParseAddr(r.PeerIP)
|
|
||||||
if err != nil {
|
|
||||||
return netip.Addr{}, fmt.Errorf("invalid IP address, skipping firewall rule")
|
|
||||||
}
|
|
||||||
return addr.Unmap(), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewall.Protocol, error) {
|
func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewall.Protocol, error) {
|
||||||
|
|||||||
@@ -321,7 +321,6 @@ func (a *Auth) setSystemInfoFlags(info *system.Info) {
|
|||||||
a.config.DisableFirewall,
|
a.config.DisableFirewall,
|
||||||
a.config.BlockLANAccess,
|
a.config.BlockLANAccess,
|
||||||
a.config.BlockInbound,
|
a.config.BlockInbound,
|
||||||
a.config.DisableIPv6,
|
|
||||||
a.config.LazyConnectionEnabled,
|
a.config.LazyConnectionEnabled,
|
||||||
a.config.EnableSSHRoot,
|
a.config.EnableSSHRoot,
|
||||||
a.config.EnableSSHSFTP,
|
a.config.EnableSSHSFTP,
|
||||||
|
|||||||
@@ -14,13 +14,10 @@ import (
|
|||||||
|
|
||||||
"github.com/cenkalti/backoff/v4"
|
"github.com/cenkalti/backoff/v4"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
gstatus "google.golang.org/grpc/status"
|
gstatus "google.golang.org/grpc/status"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
@@ -539,20 +536,9 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
|
|||||||
if config.NetworkMonitor != nil {
|
if config.NetworkMonitor != nil {
|
||||||
nm = *config.NetworkMonitor
|
nm = *config.NetworkMonitor
|
||||||
}
|
}
|
||||||
wgAddr, err := wgaddr.ParseWGAddress(peerConfig.Address)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("parse overlay address %q: %w", peerConfig.Address, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !config.DisableIPv6 {
|
|
||||||
if err := wgAddr.SetIPv6FromCompact(peerConfig.GetAddressV6()); err != nil {
|
|
||||||
log.Warn(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
engineConf := &EngineConfig{
|
engineConf := &EngineConfig{
|
||||||
WgIfaceName: config.WgIface,
|
WgIfaceName: config.WgIface,
|
||||||
WgAddr: wgAddr,
|
WgAddr: peerConfig.Address,
|
||||||
IFaceBlackList: config.IFaceBlackList,
|
IFaceBlackList: config.IFaceBlackList,
|
||||||
DisableIPv6Discovery: config.DisableIPv6Discovery,
|
DisableIPv6Discovery: config.DisableIPv6Discovery,
|
||||||
WgPrivateKey: key,
|
WgPrivateKey: key,
|
||||||
@@ -577,7 +563,6 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
|
|||||||
DisableFirewall: config.DisableFirewall,
|
DisableFirewall: config.DisableFirewall,
|
||||||
BlockLANAccess: config.BlockLANAccess,
|
BlockLANAccess: config.BlockLANAccess,
|
||||||
BlockInbound: config.BlockInbound,
|
BlockInbound: config.BlockInbound,
|
||||||
DisableIPv6: config.DisableIPv6,
|
|
||||||
|
|
||||||
LazyConnectionEnabled: config.LazyConnectionEnabled,
|
LazyConnectionEnabled: config.LazyConnectionEnabled,
|
||||||
|
|
||||||
@@ -652,7 +637,6 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
|
|||||||
config.DisableFirewall,
|
config.DisableFirewall,
|
||||||
config.BlockLANAccess,
|
config.BlockLANAccess,
|
||||||
config.BlockInbound,
|
config.BlockInbound,
|
||||||
config.DisableIPv6,
|
|
||||||
config.LazyConnectionEnabled,
|
config.LazyConnectionEnabled,
|
||||||
config.EnableSSHRoot,
|
config.EnableSSHRoot,
|
||||||
config.EnableSSHSFTP,
|
config.EnableSSHSFTP,
|
||||||
|
|||||||
@@ -40,10 +40,6 @@ func (noopNetworkChangeListener) SetInterfaceIP(string) {
|
|||||||
// network stack, not by OS-level interface configuration.
|
// network stack, not by OS-level interface configuration.
|
||||||
}
|
}
|
||||||
|
|
||||||
func (noopNetworkChangeListener) SetInterfaceIPv6(string) {
|
|
||||||
// No-op: same as SetInterfaceIP, IPv6 overlay is managed by userspace stack.
|
|
||||||
}
|
|
||||||
|
|
||||||
// noopDnsReadyListener is a stub for embed.Client on Android.
|
// noopDnsReadyListener is a stub for embed.Client on Android.
|
||||||
// DNS readiness notifications are not needed in netstack/embed mode
|
// DNS readiness notifications are not needed in netstack/embed mode
|
||||||
// since system DNS is disabled and DNS resolution happens externally.
|
// since system DNS is disabled and DNS resolution happens externally.
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
||||||
"google.golang.org/protobuf/encoding/protojson"
|
"google.golang.org/protobuf/encoding/protojson"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/anonymize"
|
"github.com/netbirdio/netbird/client/anonymize"
|
||||||
@@ -31,7 +30,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/updater/installer"
|
"github.com/netbirdio/netbird/client/internal/updater/installer"
|
||||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
"github.com/netbirdio/netbird/shared/netiputil"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const readmeContent = `Netbird debug bundle
|
const readmeContent = `Netbird debug bundle
|
||||||
@@ -585,9 +583,6 @@ func isSensitiveEnvVar(key string) bool {
|
|||||||
func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder) {
|
func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder) {
|
||||||
configContent.WriteString("NetBird Client Configuration:\n\n")
|
configContent.WriteString("NetBird Client Configuration:\n\n")
|
||||||
|
|
||||||
if key, err := wgtypes.ParseKey(g.internalConfig.PrivateKey); err == nil {
|
|
||||||
configContent.WriteString(fmt.Sprintf("PublicKey: %s\n", key.PublicKey().String()))
|
|
||||||
}
|
|
||||||
configContent.WriteString(fmt.Sprintf("WgIface: %s\n", g.internalConfig.WgIface))
|
configContent.WriteString(fmt.Sprintf("WgIface: %s\n", g.internalConfig.WgIface))
|
||||||
configContent.WriteString(fmt.Sprintf("WgPort: %d\n", g.internalConfig.WgPort))
|
configContent.WriteString(fmt.Sprintf("WgPort: %d\n", g.internalConfig.WgPort))
|
||||||
if g.internalConfig.NetworkMonitor != nil {
|
if g.internalConfig.NetworkMonitor != nil {
|
||||||
@@ -612,12 +607,6 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
|
|||||||
if g.internalConfig.EnableSSHRemotePortForwarding != nil {
|
if g.internalConfig.EnableSSHRemotePortForwarding != nil {
|
||||||
configContent.WriteString(fmt.Sprintf("EnableSSHRemotePortForwarding: %v\n", *g.internalConfig.EnableSSHRemotePortForwarding))
|
configContent.WriteString(fmt.Sprintf("EnableSSHRemotePortForwarding: %v\n", *g.internalConfig.EnableSSHRemotePortForwarding))
|
||||||
}
|
}
|
||||||
if g.internalConfig.DisableSSHAuth != nil {
|
|
||||||
configContent.WriteString(fmt.Sprintf("DisableSSHAuth: %v\n", *g.internalConfig.DisableSSHAuth))
|
|
||||||
}
|
|
||||||
if g.internalConfig.SSHJWTCacheTTL != nil {
|
|
||||||
configContent.WriteString(fmt.Sprintf("SSHJWTCacheTTL: %d\n", *g.internalConfig.SSHJWTCacheTTL))
|
|
||||||
}
|
|
||||||
|
|
||||||
configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes))
|
configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes))
|
||||||
configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes))
|
configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes))
|
||||||
@@ -625,7 +614,6 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
|
|||||||
configContent.WriteString(fmt.Sprintf("DisableFirewall: %v\n", g.internalConfig.DisableFirewall))
|
configContent.WriteString(fmt.Sprintf("DisableFirewall: %v\n", g.internalConfig.DisableFirewall))
|
||||||
configContent.WriteString(fmt.Sprintf("BlockLANAccess: %v\n", g.internalConfig.BlockLANAccess))
|
configContent.WriteString(fmt.Sprintf("BlockLANAccess: %v\n", g.internalConfig.BlockLANAccess))
|
||||||
configContent.WriteString(fmt.Sprintf("BlockInbound: %v\n", g.internalConfig.BlockInbound))
|
configContent.WriteString(fmt.Sprintf("BlockInbound: %v\n", g.internalConfig.BlockInbound))
|
||||||
configContent.WriteString(fmt.Sprintf("DisableIPv6: %v\n", g.internalConfig.DisableIPv6))
|
|
||||||
|
|
||||||
if g.internalConfig.DisableNotifications != nil {
|
if g.internalConfig.DisableNotifications != nil {
|
||||||
configContent.WriteString(fmt.Sprintf("DisableNotifications: %v\n", *g.internalConfig.DisableNotifications))
|
configContent.WriteString(fmt.Sprintf("DisableNotifications: %v\n", *g.internalConfig.DisableNotifications))
|
||||||
@@ -645,7 +633,6 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
|
|||||||
}
|
}
|
||||||
|
|
||||||
configContent.WriteString(fmt.Sprintf("LazyConnectionEnabled: %v\n", g.internalConfig.LazyConnectionEnabled))
|
configContent.WriteString(fmt.Sprintf("LazyConnectionEnabled: %v\n", g.internalConfig.LazyConnectionEnabled))
|
||||||
configContent.WriteString(fmt.Sprintf("MTU: %d\n", g.internalConfig.MTU))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *BundleGenerator) addProf() (err error) {
|
func (g *BundleGenerator) addProf() (err error) {
|
||||||
@@ -1296,21 +1283,6 @@ func anonymizePeerConfig(config *mgmProto.PeerConfig, anonymizer *anonymize.Anon
|
|||||||
config.Address = anonymizer.AnonymizeIP(addr).String()
|
config.Address = anonymizer.AnonymizeIP(addr).String()
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(config.GetAddressV6()) > 0 {
|
|
||||||
v6Prefix, err := netiputil.DecodePrefix(config.GetAddressV6())
|
|
||||||
if err != nil {
|
|
||||||
config.AddressV6 = nil
|
|
||||||
} else {
|
|
||||||
anonV6 := anonymizer.AnonymizeIP(v6Prefix.Addr())
|
|
||||||
b, err := netiputil.EncodePrefix(netip.PrefixFrom(anonV6, v6Prefix.Bits()))
|
|
||||||
if err != nil {
|
|
||||||
config.AddressV6 = nil
|
|
||||||
} else {
|
|
||||||
config.AddressV6 = b
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
anonymizeSSHConfig(config.SshConfig)
|
anonymizeSSHConfig(config.SshConfig)
|
||||||
|
|
||||||
config.Dns = anonymizer.AnonymizeString(config.Dns)
|
config.Dns = anonymizer.AnonymizeString(config.Dns)
|
||||||
@@ -1413,20 +1385,8 @@ func anonymizeFirewallRule(rule *mgmProto.FirewallRule, anonymizer *anonymize.An
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
//nolint:staticcheck // PeerIP used for backward compatibility
|
|
||||||
if addr, err := netip.ParseAddr(rule.PeerIP); err == nil {
|
if addr, err := netip.ParseAddr(rule.PeerIP); err == nil {
|
||||||
rule.PeerIP = anonymizer.AnonymizeIP(addr).String() //nolint:staticcheck
|
rule.PeerIP = anonymizer.AnonymizeIP(addr).String()
|
||||||
}
|
|
||||||
|
|
||||||
for i, raw := range rule.GetSourcePrefixes() {
|
|
||||||
p, err := netiputil.DecodePrefix(raw)
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
anonAddr := anonymizer.AnonymizeIP(p.Addr())
|
|
||||||
if b, err := netiputil.EncodePrefix(netip.PrefixFrom(anonAddr, p.Bits())); err == nil {
|
|
||||||
rule.SourcePrefixes[i] = b
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -5,33 +5,19 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
|
||||||
"net/url"
|
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"reflect"
|
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/anonymize"
|
"github.com/netbirdio/netbird/client/anonymize"
|
||||||
"github.com/netbirdio/netbird/client/configs"
|
"github.com/netbirdio/netbird/client/configs"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
|
||||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
"github.com/netbirdio/netbird/shared/netiputil"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func mustEncodePrefix(t *testing.T, p netip.Prefix) []byte {
|
|
||||||
t.Helper()
|
|
||||||
b, err := netiputil.EncodePrefix(p)
|
|
||||||
require.NoError(t, err)
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAnonymizeStateFile(t *testing.T) {
|
func TestAnonymizeStateFile(t *testing.T) {
|
||||||
testState := map[string]json.RawMessage{
|
testState := map[string]json.RawMessage{
|
||||||
"null_state": json.RawMessage("null"),
|
"null_state": json.RawMessage("null"),
|
||||||
@@ -182,7 +168,7 @@ func TestAnonymizeStateFile(t *testing.T) {
|
|||||||
assert.Equal(t, "100.64.0.1", state["protected_ip"]) // Protected IP unchanged
|
assert.Equal(t, "100.64.0.1", state["protected_ip"]) // Protected IP unchanged
|
||||||
assert.Equal(t, "8.8.8.8", state["well_known_ip"]) // Well-known IP unchanged
|
assert.Equal(t, "8.8.8.8", state["well_known_ip"]) // Well-known IP unchanged
|
||||||
assert.NotEqual(t, "2001:db8::1", state["ipv6_addr"])
|
assert.NotEqual(t, "2001:db8::1", state["ipv6_addr"])
|
||||||
assert.NotEqual(t, "fd00::1", state["private_ipv6"]) // ULA IPv6 anonymized (global ID is a fingerprint)
|
assert.Equal(t, "fd00::1", state["private_ipv6"]) // Private IPv6 unchanged
|
||||||
assert.NotEqual(t, "test.example.com", state["domain"])
|
assert.NotEqual(t, "test.example.com", state["domain"])
|
||||||
assert.True(t, strings.HasSuffix(state["domain"].(string), ".domain"))
|
assert.True(t, strings.HasSuffix(state["domain"].(string), ".domain"))
|
||||||
assert.Equal(t, "device.netbird.cloud", state["netbird_domain"]) // Netbird domain unchanged
|
assert.Equal(t, "device.netbird.cloud", state["netbird_domain"]) // Netbird domain unchanged
|
||||||
@@ -286,13 +272,11 @@ func mustMarshal(v any) json.RawMessage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAnonymizeNetworkMap(t *testing.T) {
|
func TestAnonymizeNetworkMap(t *testing.T) {
|
||||||
origV6Prefix := netip.MustParsePrefix("2001:db8:abcd::5/64")
|
|
||||||
networkMap := &mgmProto.NetworkMap{
|
networkMap := &mgmProto.NetworkMap{
|
||||||
PeerConfig: &mgmProto.PeerConfig{
|
PeerConfig: &mgmProto.PeerConfig{
|
||||||
Address: "203.0.113.5",
|
Address: "203.0.113.5",
|
||||||
AddressV6: mustEncodePrefix(t, origV6Prefix),
|
Dns: "1.2.3.4",
|
||||||
Dns: "1.2.3.4",
|
Fqdn: "peer1.corp.example.com",
|
||||||
Fqdn: "peer1.corp.example.com",
|
|
||||||
SshConfig: &mgmProto.SSHConfig{
|
SshConfig: &mgmProto.SSHConfig{
|
||||||
SshPubKey: []byte("ssh-rsa AAAAB3NzaC1..."),
|
SshPubKey: []byte("ssh-rsa AAAAB3NzaC1..."),
|
||||||
},
|
},
|
||||||
@@ -366,12 +350,6 @@ func TestAnonymizeNetworkMap(t *testing.T) {
|
|||||||
require.NotEqual(t, "peer1.corp.example.com", peerCfg.Fqdn)
|
require.NotEqual(t, "peer1.corp.example.com", peerCfg.Fqdn)
|
||||||
require.True(t, strings.HasSuffix(peerCfg.Fqdn, ".domain"))
|
require.True(t, strings.HasSuffix(peerCfg.Fqdn, ".domain"))
|
||||||
|
|
||||||
// Verify AddressV6 is anonymized but preserves prefix length
|
|
||||||
anonV6Prefix, err := netiputil.DecodePrefix(peerCfg.AddressV6)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, origV6Prefix.Bits(), anonV6Prefix.Bits(), "prefix length must be preserved")
|
|
||||||
assert.NotEqual(t, origV6Prefix.Addr(), anonV6Prefix.Addr(), "IPv6 address must be anonymized")
|
|
||||||
|
|
||||||
// Verify SSH key is replaced
|
// Verify SSH key is replaced
|
||||||
require.Equal(t, []byte("ssh-placeholder-key"), peerCfg.SshConfig.SshPubKey)
|
require.Equal(t, []byte("ssh-placeholder-key"), peerCfg.SshConfig.SshPubKey)
|
||||||
|
|
||||||
@@ -493,8 +471,8 @@ func TestSanitizeServiceEnvVars(t *testing.T) {
|
|||||||
anonymize: false,
|
anonymize: false,
|
||||||
input: map[string]any{
|
input: map[string]any{
|
||||||
jsonKeyServiceEnv: map[string]any{
|
jsonKeyServiceEnv: map[string]any{
|
||||||
"HOME": "/root",
|
"HOME": "/root",
|
||||||
"PATH": "/usr/bin",
|
"PATH": "/usr/bin",
|
||||||
"NB_LOG_LEVEL": "debug",
|
"NB_LOG_LEVEL": "debug",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -511,9 +489,9 @@ func TestSanitizeServiceEnvVars(t *testing.T) {
|
|||||||
anonymize: false,
|
anonymize: false,
|
||||||
input: map[string]any{
|
input: map[string]any{
|
||||||
jsonKeyServiceEnv: map[string]any{
|
jsonKeyServiceEnv: map[string]any{
|
||||||
"NB_SETUP_KEY": "abc123",
|
"NB_SETUP_KEY": "abc123",
|
||||||
"NB_API_TOKEN": "tok_xyz",
|
"NB_API_TOKEN": "tok_xyz",
|
||||||
"NB_LOG_LEVEL": "info",
|
"NB_LOG_LEVEL": "info",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
check: func(t *testing.T, params map[string]any) {
|
check: func(t *testing.T, params map[string]any) {
|
||||||
@@ -677,6 +655,8 @@ func isInCGNATRange(ip net.IP) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAnonymizeFirewallRules(t *testing.T) {
|
func TestAnonymizeFirewallRules(t *testing.T) {
|
||||||
|
// TODO: Add ipv6
|
||||||
|
|
||||||
// Example iptables-save output
|
// Example iptables-save output
|
||||||
iptablesSave := `# Generated by iptables-save v1.8.7 on Thu Dec 19 10:00:00 2024
|
iptablesSave := `# Generated by iptables-save v1.8.7 on Thu Dec 19 10:00:00 2024
|
||||||
*filter
|
*filter
|
||||||
@@ -712,31 +692,17 @@ Chain FORWARD (policy ACCEPT 0 packets, 0 bytes)
|
|||||||
Chain OUTPUT (policy ACCEPT 0 packets, 0 bytes)
|
Chain OUTPUT (policy ACCEPT 0 packets, 0 bytes)
|
||||||
pkts bytes target prot opt in out source destination`
|
pkts bytes target prot opt in out source destination`
|
||||||
|
|
||||||
// Example ip6tables-save output
|
// Example nftables output
|
||||||
ip6tablesSave := `# Generated by ip6tables-save v1.8.7 on Thu Dec 19 10:00:00 2024
|
|
||||||
*filter
|
|
||||||
:INPUT ACCEPT [0:0]
|
|
||||||
:FORWARD ACCEPT [0:0]
|
|
||||||
:OUTPUT ACCEPT [0:0]
|
|
||||||
-A INPUT -s fd00:1234::1/128 -j ACCEPT
|
|
||||||
-A INPUT -s 2607:f8b0:4005::1/128 -j DROP
|
|
||||||
-A FORWARD -s 2001:db8::/32 -d 2607:f8b0:4005::200e/128 -j ACCEPT
|
|
||||||
COMMIT`
|
|
||||||
|
|
||||||
// Example nftables output with IPv6
|
|
||||||
nftablesRules := `table inet filter {
|
nftablesRules := `table inet filter {
|
||||||
chain input {
|
chain input {
|
||||||
type filter hook input priority filter; policy accept;
|
type filter hook input priority filter; policy accept;
|
||||||
ip saddr 192.168.1.1 accept
|
ip saddr 192.168.1.1 accept
|
||||||
ip saddr 44.192.140.1 drop
|
ip saddr 44.192.140.1 drop
|
||||||
ip6 saddr 2607:f8b0:4005::1 drop
|
|
||||||
ip6 saddr fd00:1234::1 accept
|
|
||||||
}
|
}
|
||||||
chain forward {
|
chain forward {
|
||||||
type filter hook forward priority filter; policy accept;
|
type filter hook forward priority filter; policy accept;
|
||||||
ip saddr 10.0.0.0/8 drop
|
ip saddr 10.0.0.0/8 drop
|
||||||
ip saddr 44.192.140.0/24 ip daddr 52.84.12.34/24 accept
|
ip saddr 44.192.140.0/24 ip daddr 52.84.12.34/24 accept
|
||||||
ip6 saddr 2001:db8::/32 ip6 daddr 2607:f8b0:4005::200e/128 accept
|
|
||||||
}
|
}
|
||||||
}`
|
}`
|
||||||
|
|
||||||
@@ -799,159 +765,4 @@ COMMIT`
|
|||||||
assert.Contains(t, anonNftables, "table inet filter {")
|
assert.Contains(t, anonNftables, "table inet filter {")
|
||||||
assert.Contains(t, anonNftables, "chain input {")
|
assert.Contains(t, anonNftables, "chain input {")
|
||||||
assert.Contains(t, anonNftables, "type filter hook input priority filter; policy accept;")
|
assert.Contains(t, anonNftables, "type filter hook input priority filter; policy accept;")
|
||||||
|
|
||||||
// IPv6 public addresses in nftables should be anonymized
|
|
||||||
assert.NotContains(t, anonNftables, "2607:f8b0:4005::1")
|
|
||||||
assert.NotContains(t, anonNftables, "2607:f8b0:4005::200e")
|
|
||||||
assert.NotContains(t, anonNftables, "2001:db8::")
|
|
||||||
assert.Contains(t, anonNftables, "2001:db8:ffff::") // Default anonymous v6 range
|
|
||||||
|
|
||||||
// ULA addresses in nftables should be anonymized (global ID is a fingerprint)
|
|
||||||
assert.NotContains(t, anonNftables, "fd00:1234::1")
|
|
||||||
|
|
||||||
// IPv6 nftables structure preserved
|
|
||||||
assert.Contains(t, anonNftables, "ip6 saddr")
|
|
||||||
assert.Contains(t, anonNftables, "ip6 daddr")
|
|
||||||
|
|
||||||
// Test ip6tables-save anonymization
|
|
||||||
anonIp6tablesSave := anonymizer.AnonymizeString(ip6tablesSave)
|
|
||||||
|
|
||||||
// ULA IPv6 should be anonymized (global ID is a fingerprint)
|
|
||||||
assert.NotContains(t, anonIp6tablesSave, "fd00:1234::1/128")
|
|
||||||
|
|
||||||
// Public IPv6 addresses should be anonymized
|
|
||||||
assert.NotContains(t, anonIp6tablesSave, "2607:f8b0:4005::1")
|
|
||||||
assert.NotContains(t, anonIp6tablesSave, "2607:f8b0:4005::200e")
|
|
||||||
assert.NotContains(t, anonIp6tablesSave, "2001:db8::")
|
|
||||||
assert.Contains(t, anonIp6tablesSave, "2001:db8:ffff::") // Default anonymous v6 range
|
|
||||||
|
|
||||||
// Structure should be preserved
|
|
||||||
assert.Contains(t, anonIp6tablesSave, "*filter")
|
|
||||||
assert.Contains(t, anonIp6tablesSave, "COMMIT")
|
|
||||||
assert.Contains(t, anonIp6tablesSave, "-j DROP")
|
|
||||||
assert.Contains(t, anonIp6tablesSave, "-j ACCEPT")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestAddConfig_AllFieldsCovered uses reflection to ensure every field in
|
|
||||||
// profilemanager.Config is either rendered in the debug bundle or explicitly
|
|
||||||
// excluded. When a new field is added to Config, this test fails until the
|
|
||||||
// developer either dumps it in addConfig/addCommonConfigFields or adds it to
|
|
||||||
// the excluded set with a justification.
|
|
||||||
func TestAddConfig_AllFieldsCovered(t *testing.T) {
|
|
||||||
excluded := map[string]string{
|
|
||||||
"PrivateKey": "sensitive: WireGuard private key",
|
|
||||||
"PreSharedKey": "sensitive: WireGuard pre-shared key",
|
|
||||||
"SSHKey": "sensitive: SSH private key",
|
|
||||||
"ClientCertKeyPair": "non-config: parsed cert pair, not serialized",
|
|
||||||
}
|
|
||||||
|
|
||||||
mURL, _ := url.Parse("https://api.example.com:443")
|
|
||||||
aURL, _ := url.Parse("https://admin.example.com:443")
|
|
||||||
bTrue := true
|
|
||||||
iVal := 42
|
|
||||||
cfg := &profilemanager.Config{
|
|
||||||
PrivateKey: "priv",
|
|
||||||
PreSharedKey: "psk",
|
|
||||||
ManagementURL: mURL,
|
|
||||||
AdminURL: aURL,
|
|
||||||
WgIface: "wt0",
|
|
||||||
WgPort: 51820,
|
|
||||||
NetworkMonitor: &bTrue,
|
|
||||||
IFaceBlackList: []string{"eth0"},
|
|
||||||
DisableIPv6Discovery: true,
|
|
||||||
RosenpassEnabled: true,
|
|
||||||
RosenpassPermissive: true,
|
|
||||||
ServerSSHAllowed: &bTrue,
|
|
||||||
EnableSSHRoot: &bTrue,
|
|
||||||
EnableSSHSFTP: &bTrue,
|
|
||||||
EnableSSHLocalPortForwarding: &bTrue,
|
|
||||||
EnableSSHRemotePortForwarding: &bTrue,
|
|
||||||
DisableSSHAuth: &bTrue,
|
|
||||||
SSHJWTCacheTTL: &iVal,
|
|
||||||
DisableClientRoutes: true,
|
|
||||||
DisableServerRoutes: true,
|
|
||||||
DisableDNS: true,
|
|
||||||
DisableFirewall: true,
|
|
||||||
BlockLANAccess: true,
|
|
||||||
BlockInbound: true,
|
|
||||||
DisableNotifications: &bTrue,
|
|
||||||
DNSLabels: domain.List{},
|
|
||||||
SSHKey: "sshkey",
|
|
||||||
NATExternalIPs: []string{"1.2.3.4"},
|
|
||||||
CustomDNSAddress: "1.1.1.1:53",
|
|
||||||
DisableAutoConnect: true,
|
|
||||||
DNSRouteInterval: 5 * time.Second,
|
|
||||||
ClientCertPath: "/tmp/cert",
|
|
||||||
ClientCertKeyPath: "/tmp/key",
|
|
||||||
LazyConnectionEnabled: true,
|
|
||||||
MTU: 1280,
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, anonymize := range []bool{false, true} {
|
|
||||||
t.Run("anonymize="+map[bool]string{true: "true", false: "false"}[anonymize], func(t *testing.T) {
|
|
||||||
g := &BundleGenerator{
|
|
||||||
anonymizer: newAnonymizerForTest(),
|
|
||||||
internalConfig: cfg,
|
|
||||||
anonymize: anonymize,
|
|
||||||
}
|
|
||||||
|
|
||||||
var sb strings.Builder
|
|
||||||
g.addCommonConfigFields(&sb)
|
|
||||||
rendered := sb.String() + renderAddConfigSpecific(g)
|
|
||||||
|
|
||||||
val := reflect.ValueOf(cfg).Elem()
|
|
||||||
typ := val.Type()
|
|
||||||
var missing []string
|
|
||||||
for i := 0; i < typ.NumField(); i++ {
|
|
||||||
name := typ.Field(i).Name
|
|
||||||
if _, ok := excluded[name]; ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if !strings.Contains(rendered, name+":") {
|
|
||||||
missing = append(missing, name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(missing) > 0 {
|
|
||||||
t.Fatalf("Config field(s) not present in debug bundle output: %v\n"+
|
|
||||||
"Either render the field in addCommonConfigFields/addConfig, "+
|
|
||||||
"or add it to the excluded map with a justification.", missing)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// renderAddConfigSpecific renders the fields handled by the anonymize/non-anonymize
|
|
||||||
// branches in addConfig (ManagementURL, AdminURL, NATExternalIPs, CustomDNSAddress).
|
|
||||||
// addCommonConfigFields covers the rest. Keeping this in the test mirrors the
|
|
||||||
// production shape without needing to write an actual zip.
|
|
||||||
func renderAddConfigSpecific(g *BundleGenerator) string {
|
|
||||||
var sb strings.Builder
|
|
||||||
if g.anonymize {
|
|
||||||
if g.internalConfig.ManagementURL != nil {
|
|
||||||
sb.WriteString("ManagementURL: " + g.anonymizer.AnonymizeURI(g.internalConfig.ManagementURL.String()) + "\n")
|
|
||||||
}
|
|
||||||
if g.internalConfig.AdminURL != nil {
|
|
||||||
sb.WriteString("AdminURL: " + g.anonymizer.AnonymizeURI(g.internalConfig.AdminURL.String()) + "\n")
|
|
||||||
}
|
|
||||||
sb.WriteString("NATExternalIPs: x\n")
|
|
||||||
if g.internalConfig.CustomDNSAddress != "" {
|
|
||||||
sb.WriteString("CustomDNSAddress: " + g.anonymizer.AnonymizeString(g.internalConfig.CustomDNSAddress) + "\n")
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if g.internalConfig.ManagementURL != nil {
|
|
||||||
sb.WriteString("ManagementURL: " + g.internalConfig.ManagementURL.String() + "\n")
|
|
||||||
}
|
|
||||||
if g.internalConfig.AdminURL != nil {
|
|
||||||
sb.WriteString("AdminURL: " + g.internalConfig.AdminURL.String() + "\n")
|
|
||||||
}
|
|
||||||
sb.WriteString("NATExternalIPs: x\n")
|
|
||||||
if g.internalConfig.CustomDNSAddress != "" {
|
|
||||||
sb.WriteString("CustomDNSAddress: " + g.internalConfig.CustomDNSAddress + "\n")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return sb.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func newAnonymizerForTest() *anonymize.Anonymizer {
|
|
||||||
return anonymize.NewAnonymizer(anonymize.DefaultAddresses())
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,83 +12,52 @@ import (
|
|||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
func createPTRRecord(record nbdns.SimpleRecord, prefix netip.Prefix) (nbdns.SimpleRecord, bool) {
|
func createPTRRecord(aRecord nbdns.SimpleRecord, prefix netip.Prefix) (nbdns.SimpleRecord, bool) {
|
||||||
ip, err := netip.ParseAddr(record.RData)
|
ip, err := netip.ParseAddr(aRecord.RData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to parse IP address %s: %v", record.RData, err)
|
log.Warnf("failed to parse IP address %s: %v", aRecord.RData, err)
|
||||||
return nbdns.SimpleRecord{}, false
|
return nbdns.SimpleRecord{}, false
|
||||||
}
|
}
|
||||||
|
|
||||||
ip = ip.Unmap()
|
|
||||||
if !prefix.Contains(ip) {
|
if !prefix.Contains(ip) {
|
||||||
return nbdns.SimpleRecord{}, false
|
return nbdns.SimpleRecord{}, false
|
||||||
}
|
}
|
||||||
|
|
||||||
var rdnsName string
|
ipOctets := strings.Split(ip.String(), ".")
|
||||||
if ip.Is4() {
|
slices.Reverse(ipOctets)
|
||||||
octets := strings.Split(ip.String(), ".")
|
rdnsName := dns.Fqdn(strings.Join(ipOctets, ".") + ".in-addr.arpa")
|
||||||
slices.Reverse(octets)
|
|
||||||
rdnsName = dns.Fqdn(strings.Join(octets, ".") + ".in-addr.arpa")
|
|
||||||
} else {
|
|
||||||
// Expand to full 32 nibbles in reverse order (LSB first) per RFC 3596.
|
|
||||||
raw := ip.As16()
|
|
||||||
nibbles := make([]string, 32)
|
|
||||||
for i := 0; i < 16; i++ {
|
|
||||||
nibbles[31-i*2] = fmt.Sprintf("%x", raw[i]>>4)
|
|
||||||
nibbles[31-i*2-1] = fmt.Sprintf("%x", raw[i]&0x0f)
|
|
||||||
}
|
|
||||||
rdnsName = dns.Fqdn(strings.Join(nibbles, ".") + ".ip6.arpa")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nbdns.SimpleRecord{
|
return nbdns.SimpleRecord{
|
||||||
Name: rdnsName,
|
Name: rdnsName,
|
||||||
Type: int(dns.TypePTR),
|
Type: int(dns.TypePTR),
|
||||||
Class: record.Class,
|
Class: aRecord.Class,
|
||||||
TTL: record.TTL,
|
TTL: aRecord.TTL,
|
||||||
RData: dns.Fqdn(record.Name),
|
RData: dns.Fqdn(aRecord.Name),
|
||||||
}, true
|
}, true
|
||||||
}
|
}
|
||||||
|
|
||||||
// generateReverseZoneName creates the reverse DNS zone name for a given network.
|
// generateReverseZoneName creates the reverse DNS zone name for a given network
|
||||||
// For IPv4 it produces an in-addr.arpa name, for IPv6 an ip6.arpa name.
|
|
||||||
func generateReverseZoneName(network netip.Prefix) (string, error) {
|
func generateReverseZoneName(network netip.Prefix) (string, error) {
|
||||||
networkIP := network.Masked().Addr().Unmap()
|
networkIP := network.Masked().Addr()
|
||||||
bits := network.Bits()
|
|
||||||
|
|
||||||
if networkIP.Is4() {
|
if !networkIP.Is4() {
|
||||||
// Round up to nearest byte.
|
return "", fmt.Errorf("reverse DNS is only supported for IPv4 networks, got: %s", networkIP)
|
||||||
octetsToUse := (bits + 7) / 8
|
|
||||||
|
|
||||||
octets := strings.Split(networkIP.String(), ".")
|
|
||||||
if octetsToUse > len(octets) {
|
|
||||||
return "", fmt.Errorf("invalid network mask size for reverse DNS: %d", bits)
|
|
||||||
}
|
|
||||||
|
|
||||||
reverseOctets := make([]string, octetsToUse)
|
|
||||||
for i := 0; i < octetsToUse; i++ {
|
|
||||||
reverseOctets[octetsToUse-1-i] = octets[i]
|
|
||||||
}
|
|
||||||
|
|
||||||
return dns.Fqdn(strings.Join(reverseOctets, ".") + ".in-addr.arpa"), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// IPv6: round up to nearest nibble (4-bit boundary).
|
// round up to nearest byte
|
||||||
nibblesToUse := (bits + 3) / 4
|
octetsToUse := (network.Bits() + 7) / 8
|
||||||
|
|
||||||
raw := networkIP.As16()
|
octets := strings.Split(networkIP.String(), ".")
|
||||||
allNibbles := make([]string, 32)
|
if octetsToUse > len(octets) {
|
||||||
for i := 0; i < 16; i++ {
|
return "", fmt.Errorf("invalid network mask size for reverse DNS: %d", network.Bits())
|
||||||
allNibbles[i*2] = fmt.Sprintf("%x", raw[i]>>4)
|
|
||||||
allNibbles[i*2+1] = fmt.Sprintf("%x", raw[i]&0x0f)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Take the first nibblesToUse nibbles (network portion), reverse them.
|
reverseOctets := make([]string, octetsToUse)
|
||||||
used := make([]string, nibblesToUse)
|
for i := 0; i < octetsToUse; i++ {
|
||||||
for i := 0; i < nibblesToUse; i++ {
|
reverseOctets[octetsToUse-1-i] = octets[i]
|
||||||
used[nibblesToUse-1-i] = allNibbles[i]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return dns.Fqdn(strings.Join(used, ".") + ".ip6.arpa"), nil
|
return dns.Fqdn(strings.Join(reverseOctets, ".") + ".in-addr.arpa"), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// zoneExists checks if a zone with the given name already exists in the configuration
|
// zoneExists checks if a zone with the given name already exists in the configuration
|
||||||
@@ -102,7 +71,7 @@ func zoneExists(config *nbdns.Config, zoneName string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// collectPTRRecords gathers all PTR records for the given network from A and AAAA records.
|
// collectPTRRecords gathers all PTR records for the given network from A records
|
||||||
func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.SimpleRecord {
|
func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.SimpleRecord {
|
||||||
var records []nbdns.SimpleRecord
|
var records []nbdns.SimpleRecord
|
||||||
|
|
||||||
@@ -111,7 +80,7 @@ func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.Simple
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
for _, record := range zone.Records {
|
for _, record := range zone.Records {
|
||||||
if record.Type != int(dns.TypeA) && record.Type != int(dns.TypeAAAA) {
|
if record.Type != int(dns.TypeA) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -298,7 +298,6 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
|
|||||||
if ip, err := netip.ParseAddr(address); err == nil && !ip.IsUnspecified() {
|
if ip, err := netip.ParseAddr(address); err == nil && !ip.IsUnspecified() {
|
||||||
ip = ip.Unmap()
|
ip = ip.Unmap()
|
||||||
serverAddresses = append(serverAddresses, ip)
|
serverAddresses = append(serverAddresses, ip)
|
||||||
// Prefer the first IPv4 server as ServerIP since our DNS listener is IPv4.
|
|
||||||
if !dnsSettings.ServerIP.IsValid() && ip.Is4() {
|
if !dnsSettings.ServerIP.IsValid() && ip.Is4() {
|
||||||
dnsSettings.ServerIP = ip
|
dnsSettings.ServerIP = ip
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns/types"
|
"github.com/netbirdio/netbird/client/internal/dns/types"
|
||||||
@@ -66,9 +67,9 @@ func (d *Resolver) Stop() {
|
|||||||
d.mu.Lock()
|
d.mu.Lock()
|
||||||
defer d.mu.Unlock()
|
defer d.mu.Unlock()
|
||||||
|
|
||||||
clear(d.records)
|
maps.Clear(d.records)
|
||||||
clear(d.domains)
|
maps.Clear(d.domains)
|
||||||
clear(d.zones)
|
maps.Clear(d.zones)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ID returns the unique handler ID
|
// ID returns the unique handler ID
|
||||||
@@ -443,9 +444,9 @@ func (d *Resolver) Update(customZones []nbdns.CustomZone) {
|
|||||||
d.mu.Lock()
|
d.mu.Lock()
|
||||||
defer d.mu.Unlock()
|
defer d.mu.Unlock()
|
||||||
|
|
||||||
clear(d.records)
|
maps.Clear(d.records)
|
||||||
clear(d.domains)
|
maps.Clear(d.domains)
|
||||||
clear(d.zones)
|
maps.Clear(d.zones)
|
||||||
|
|
||||||
for _, zone := range customZones {
|
for _, zone := range customZones {
|
||||||
zoneDomain := domain.Domain(strings.ToLower(dns.Fqdn(zone.Domain)))
|
zoneDomain := domain.Domain(strings.ToLower(dns.Fqdn(zone.Domain)))
|
||||||
|
|||||||
@@ -110,25 +110,8 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig, st
|
|||||||
|
|
||||||
connSettings.cleanDeprecatedSettings()
|
connSettings.cleanDeprecatedSettings()
|
||||||
|
|
||||||
ipKey := networkManagerDbusIPv4Key
|
convDNSIP := binary.LittleEndian.Uint32(config.ServerIP.AsSlice())
|
||||||
staleKey := networkManagerDbusIPv6Key
|
connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSKey] = dbus.MakeVariant([]uint32{convDNSIP})
|
||||||
if config.ServerIP.Is6() {
|
|
||||||
ipKey = networkManagerDbusIPv6Key
|
|
||||||
staleKey = networkManagerDbusIPv4Key
|
|
||||||
raw := config.ServerIP.As16()
|
|
||||||
connSettings[ipKey][networkManagerDbusDNSKey] = dbus.MakeVariant([][]byte{raw[:]})
|
|
||||||
} else {
|
|
||||||
convDNSIP := binary.LittleEndian.Uint32(config.ServerIP.AsSlice())
|
|
||||||
connSettings[ipKey][networkManagerDbusDNSKey] = dbus.MakeVariant([]uint32{convDNSIP})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clear stale DNS settings from the opposite address family to avoid
|
|
||||||
// leftover entries if the server IP family changed.
|
|
||||||
if staleSettings, ok := connSettings[staleKey]; ok {
|
|
||||||
delete(staleSettings, networkManagerDbusDNSKey)
|
|
||||||
delete(staleSettings, networkManagerDbusDNSPriorityKey)
|
|
||||||
delete(staleSettings, networkManagerDbusDNSSearchKey)
|
|
||||||
}
|
|
||||||
var (
|
var (
|
||||||
searchDomains []string
|
searchDomains []string
|
||||||
matchDomains []string
|
matchDomains []string
|
||||||
@@ -163,8 +146,8 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig, st
|
|||||||
n.routingAll = false
|
n.routingAll = false
|
||||||
}
|
}
|
||||||
|
|
||||||
connSettings[ipKey][networkManagerDbusDNSPriorityKey] = dbus.MakeVariant(priority)
|
connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSPriorityKey] = dbus.MakeVariant(priority)
|
||||||
connSettings[ipKey][networkManagerDbusDNSSearchKey] = dbus.MakeVariant(newDomainList)
|
connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSSearchKey] = dbus.MakeVariant(newDomainList)
|
||||||
|
|
||||||
state := &ShutdownState{
|
state := &ShutdownState{
|
||||||
ManagerType: networkManager,
|
ManagerType: networkManager,
|
||||||
|
|||||||
@@ -410,7 +410,7 @@ func (s *DefaultServer) Stop() {
|
|||||||
log.Errorf("failed to disable DNS: %v", err)
|
log.Errorf("failed to disable DNS: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
clear(s.extraDomains)
|
maps.Clear(s.extraDomains)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) disableDNS() (retErr error) {
|
func (s *DefaultServer) disableDNS() (retErr error) {
|
||||||
|
|||||||
@@ -347,7 +347,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
|
|
||||||
opts := iface.WGIFaceOpts{
|
opts := iface.WGIFaceOpts{
|
||||||
IFaceName: fmt.Sprintf("utun230%d", n),
|
IFaceName: fmt.Sprintf("utun230%d", n),
|
||||||
Address: wgaddr.MustParseWGAddress(fmt.Sprintf("100.66.100.%d/32", n+1)),
|
Address: fmt.Sprintf("100.66.100.%d/32", n+1),
|
||||||
WGPort: 33100,
|
WGPort: 33100,
|
||||||
WGPrivKey: privKey.String(),
|
WGPrivKey: privKey.String(),
|
||||||
MTU: iface.DefaultMTU,
|
MTU: iface.DefaultMTU,
|
||||||
@@ -448,7 +448,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
|||||||
privKey, _ := wgtypes.GeneratePrivateKey()
|
privKey, _ := wgtypes.GeneratePrivateKey()
|
||||||
opts := iface.WGIFaceOpts{
|
opts := iface.WGIFaceOpts{
|
||||||
IFaceName: "utun2301",
|
IFaceName: "utun2301",
|
||||||
Address: wgaddr.MustParseWGAddress("100.66.100.1/32"),
|
Address: "100.66.100.1/32",
|
||||||
WGPort: 33100,
|
WGPort: 33100,
|
||||||
WGPrivKey: privKey.String(),
|
WGPrivKey: privKey.String(),
|
||||||
MTU: iface.DefaultMTU,
|
MTU: iface.DefaultMTU,
|
||||||
@@ -929,7 +929,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
|
|||||||
|
|
||||||
opts := iface.WGIFaceOpts{
|
opts := iface.WGIFaceOpts{
|
||||||
IFaceName: "utun2301",
|
IFaceName: "utun2301",
|
||||||
Address: wgaddr.MustParseWGAddress("100.66.100.2/24"),
|
Address: "100.66.100.2/24",
|
||||||
WGPort: 33100,
|
WGPort: 33100,
|
||||||
WGPrivKey: privKey.String(),
|
WGPrivKey: privKey.String(),
|
||||||
MTU: iface.DefaultMTU,
|
MTU: iface.DefaultMTU,
|
||||||
|
|||||||
@@ -16,8 +16,8 @@ const (
|
|||||||
// This is used when the DNS server cannot bind port 53 directly
|
// This is used when the DNS server cannot bind port 53 directly
|
||||||
// and needs firewall rules to redirect traffic.
|
// and needs firewall rules to redirect traffic.
|
||||||
type Firewall interface {
|
type Firewall interface {
|
||||||
AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error
|
AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error
|
||||||
RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error
|
RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type service interface {
|
type service interface {
|
||||||
|
|||||||
@@ -188,10 +188,11 @@ func (s *serviceViaListener) RuntimeIP() netip.Addr {
|
|||||||
return s.listenIP
|
return s.listenIP
|
||||||
}
|
}
|
||||||
|
|
||||||
// evalListenAddress figures out the listen address for the DNS server.
|
|
||||||
// IPv4-only: all peers have a v4 overlay address, and DNS config points to v4.
|
// evalListenAddress figure out the listen address for the DNS server
|
||||||
// First checks port 53 on WG interface or lo, then tries eBPF on a random port,
|
// first check the 53 port availability on WG interface or lo, if not success
|
||||||
// then falls back to port 5053.
|
// pick a random port on WG interface for eBPF, if not success
|
||||||
|
// check the 5053 port availability on WG interface or lo without eBPF usage,
|
||||||
func (s *serviceViaListener) evalListenAddress() (netip.Addr, uint16, error) {
|
func (s *serviceViaListener) evalListenAddress() (netip.Addr, uint16, error) {
|
||||||
if s.customAddr != nil {
|
if s.customAddr != nil {
|
||||||
return s.customAddr.Addr(), s.customAddr.Port(), nil
|
return s.customAddr.Addr(), s.customAddr.Port(), nil
|
||||||
@@ -277,7 +278,7 @@ func (s *serviceViaListener) tryToUseeBPF() (ebpfMgr.Manager, uint16, bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ebpfSrv := ebpf.GetEbpfManagerInstance()
|
ebpfSrv := ebpf.GetEbpfManagerInstance()
|
||||||
err = ebpfSrv.LoadDNSFwd(s.wgInterface.Address().IP, int(port))
|
err = ebpfSrv.LoadDNSFwd(s.wgInterface.Address().IP.String(), int(port))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to load DNS forwarder eBPF program, error: %s", err)
|
log.Warnf("failed to load DNS forwarder eBPF program, error: %s", err)
|
||||||
return nil, 0, false
|
return nil, 0, false
|
||||||
|
|||||||
@@ -90,12 +90,8 @@ func (s *systemdDbusConfigurator) supportCustomPort() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
|
func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
|
||||||
family := int32(unix.AF_INET)
|
|
||||||
if config.ServerIP.Is6() {
|
|
||||||
family = unix.AF_INET6
|
|
||||||
}
|
|
||||||
defaultLinkInput := systemdDbusDNSInput{
|
defaultLinkInput := systemdDbusDNSInput{
|
||||||
Family: family,
|
Family: unix.AF_INET,
|
||||||
Address: config.ServerIP.AsSlice(),
|
Address: config.ServerIP.AsSlice(),
|
||||||
}
|
}
|
||||||
if err := s.callLinkMethod(systemdDbusSetDNSMethodSuffix, []systemdDbusDNSInput{defaultLinkInput}); err != nil {
|
if err := s.callLinkMethod(systemdDbusSetDNSMethodSuffix, []systemdDbusDNSInput{defaultLinkInput}); err != nil {
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ import (
|
|||||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns/types"
|
"github.com/netbirdio/netbird/client/internal/dns/types"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
@@ -30,33 +29,6 @@ import (
|
|||||||
|
|
||||||
var currentMTU uint16 = iface.DefaultMTU
|
var currentMTU uint16 = iface.DefaultMTU
|
||||||
|
|
||||||
// nonRetryableEDECodes lists EDE info codes (RFC 8914) for which a SERVFAIL
|
|
||||||
// from one upstream means another upstream would return the same answer:
|
|
||||||
// DNSSEC validation outcomes and policy-based blocks. Transient errors
|
|
||||||
// (network, cached, not ready) are not included.
|
|
||||||
var nonRetryableEDECodes = map[uint16]struct{}{
|
|
||||||
dns.ExtendedErrorCodeUnsupportedDNSKEYAlgorithm: {},
|
|
||||||
dns.ExtendedErrorCodeUnsupportedDSDigestType: {},
|
|
||||||
dns.ExtendedErrorCodeDNSSECIndeterminate: {},
|
|
||||||
dns.ExtendedErrorCodeDNSBogus: {},
|
|
||||||
dns.ExtendedErrorCodeSignatureExpired: {},
|
|
||||||
dns.ExtendedErrorCodeSignatureNotYetValid: {},
|
|
||||||
dns.ExtendedErrorCodeDNSKEYMissing: {},
|
|
||||||
dns.ExtendedErrorCodeRRSIGsMissing: {},
|
|
||||||
dns.ExtendedErrorCodeNoZoneKeyBitSet: {},
|
|
||||||
dns.ExtendedErrorCodeNSECMissing: {},
|
|
||||||
dns.ExtendedErrorCodeBlocked: {},
|
|
||||||
dns.ExtendedErrorCodeCensored: {},
|
|
||||||
dns.ExtendedErrorCodeFiltered: {},
|
|
||||||
dns.ExtendedErrorCodeProhibited: {},
|
|
||||||
}
|
|
||||||
|
|
||||||
// privateClientIface is the subset of the WireGuard interface needed by GetClientPrivate.
|
|
||||||
type privateClientIface interface {
|
|
||||||
Name() string
|
|
||||||
Address() wgaddr.Address
|
|
||||||
}
|
|
||||||
|
|
||||||
func SetCurrentMTU(mtu uint16) {
|
func SetCurrentMTU(mtu uint16) {
|
||||||
currentMTU = mtu
|
currentMTU = mtu
|
||||||
}
|
}
|
||||||
@@ -271,18 +243,6 @@ func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.Re
|
|||||||
var t time.Duration
|
var t time.Duration
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
// Advertise EDNS0 so the upstream may include Extended DNS Errors
|
|
||||||
// (RFC 8914) in failure responses; we use those to short-circuit
|
|
||||||
// failover for definitive answers like DNSSEC validation failures.
|
|
||||||
// Operate on a copy so the inbound request is unchanged: a client that
|
|
||||||
// did not advertise EDNS0 must not see an OPT in the response.
|
|
||||||
hadEdns := r.IsEdns0() != nil
|
|
||||||
reqUp := r
|
|
||||||
if !hadEdns {
|
|
||||||
reqUp = r.Copy()
|
|
||||||
reqUp.SetEdns0(upstreamUDPSize(), false)
|
|
||||||
}
|
|
||||||
|
|
||||||
var startTime time.Time
|
var startTime time.Time
|
||||||
var upstreamProto *upstreamProtocolResult
|
var upstreamProto *upstreamProtocolResult
|
||||||
func() {
|
func() {
|
||||||
@@ -290,7 +250,7 @@ func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.Re
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
ctx, upstreamProto = contextWithupstreamProtocolResult(ctx)
|
ctx, upstreamProto = contextWithupstreamProtocolResult(ctx)
|
||||||
startTime = time.Now()
|
startTime = time.Now()
|
||||||
rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), reqUp)
|
rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -302,49 +262,13 @@ func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.Re
|
|||||||
}
|
}
|
||||||
|
|
||||||
if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused {
|
if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused {
|
||||||
if code, ok := nonRetryableEDE(rm); ok {
|
|
||||||
resutil.SetMeta(w, "ede", edeName(code))
|
|
||||||
if !hadEdns {
|
|
||||||
stripOPT(rm)
|
|
||||||
}
|
|
||||||
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return &upstreamFailure{upstream: upstream, reason: dns.RcodeToString[rm.Rcode]}
|
return &upstreamFailure{upstream: upstream, reason: dns.RcodeToString[rm.Rcode]}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !hadEdns {
|
|
||||||
stripOPT(rm)
|
|
||||||
}
|
|
||||||
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger)
|
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// upstreamUDPSize returns the EDNS0 UDP buffer size we advertise to upstreams,
|
|
||||||
// derived from the tunnel MTU and bounded against underflow.
|
|
||||||
func upstreamUDPSize() uint16 {
|
|
||||||
if currentMTU > ipUDPHeaderSize {
|
|
||||||
return currentMTU - ipUDPHeaderSize
|
|
||||||
}
|
|
||||||
return dns.MinMsgSize
|
|
||||||
}
|
|
||||||
|
|
||||||
// stripOPT removes any OPT pseudo-RRs from the response's Extra section so
|
|
||||||
// the response complies with RFC 6891 when the client did not advertise EDNS0.
|
|
||||||
func stripOPT(rm *dns.Msg) {
|
|
||||||
if len(rm.Extra) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
out := rm.Extra[:0]
|
|
||||||
for _, rr := range rm.Extra {
|
|
||||||
if _, ok := rr.(*dns.OPT); ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
out = append(out, rr)
|
|
||||||
}
|
|
||||||
rm.Extra = out
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, startTime time.Time) *upstreamFailure {
|
func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, startTime time.Time) *upstreamFailure {
|
||||||
if !errors.Is(err, context.DeadlineExceeded) && !isTimeout(err) {
|
if !errors.Is(err, context.DeadlineExceeded) && !isTimeout(err) {
|
||||||
return &upstreamFailure{upstream: upstream, reason: err.Error()}
|
return &upstreamFailure{upstream: upstream, reason: err.Error()}
|
||||||
@@ -406,34 +330,6 @@ func formatFailures(failures []upstreamFailure) string {
|
|||||||
return strings.Join(parts, ", ")
|
return strings.Join(parts, ", ")
|
||||||
}
|
}
|
||||||
|
|
||||||
// nonRetryableEDE returns the first non-retryable EDE code carried in the
|
|
||||||
// response, if any.
|
|
||||||
func nonRetryableEDE(rm *dns.Msg) (uint16, bool) {
|
|
||||||
opt := rm.IsEdns0()
|
|
||||||
if opt == nil {
|
|
||||||
return 0, false
|
|
||||||
}
|
|
||||||
for _, o := range opt.Option {
|
|
||||||
ede, ok := o.(*dns.EDNS0_EDE)
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if _, ok := nonRetryableEDECodes[ede.InfoCode]; ok {
|
|
||||||
return ede.InfoCode, true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return 0, false
|
|
||||||
}
|
|
||||||
|
|
||||||
// edeName returns a human-readable name for an EDE code, falling back to
|
|
||||||
// the numeric code when unknown.
|
|
||||||
func edeName(code uint16) string {
|
|
||||||
if name, ok := dns.ExtendedErrorCodeToString[code]; ok {
|
|
||||||
return name
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("EDE %d", code)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ProbeAvailability tests all upstream servers simultaneously and
|
// ProbeAvailability tests all upstream servers simultaneously and
|
||||||
// disables the resolver if none work
|
// disables the resolver if none work
|
||||||
func (u *upstreamResolverBase) ProbeAvailability(ctx context.Context) {
|
func (u *upstreamResolverBase) ProbeAvailability(ctx context.Context) {
|
||||||
|
|||||||
@@ -86,7 +86,7 @@ func (u *upstreamResolver) isLocalResolver(upstream string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetClientPrivate(_ privateClientIface, _ netip.Addr, dialTimeout time.Duration) (*dns.Client, error) {
|
func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
|
||||||
return &dns.Client{
|
return &dns.Client{
|
||||||
Timeout: dialTimeout,
|
Timeout: dialTimeout,
|
||||||
Net: "udp",
|
Net: "udp",
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns
|
|||||||
return ExchangeWithFallback(ctx, client, r, upstream)
|
return ExchangeWithFallback(ctx, client, r, upstream)
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetClientPrivate(_ privateClientIface, _ netip.Addr, dialTimeout time.Duration) (*dns.Client, error) {
|
func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
|
||||||
return &dns.Client{
|
return &dns.Client{
|
||||||
Timeout: dialTimeout,
|
Timeout: dialTimeout,
|
||||||
Net: "udp",
|
Net: "udp",
|
||||||
|
|||||||
@@ -19,7 +19,9 @@ import (
|
|||||||
|
|
||||||
type upstreamResolverIOS struct {
|
type upstreamResolverIOS struct {
|
||||||
*upstreamResolverBase
|
*upstreamResolverBase
|
||||||
wgIface WGIface
|
lIP netip.Addr
|
||||||
|
lNet netip.Prefix
|
||||||
|
interfaceName string
|
||||||
}
|
}
|
||||||
|
|
||||||
func newUpstreamResolver(
|
func newUpstreamResolver(
|
||||||
@@ -33,7 +35,9 @@ func newUpstreamResolver(
|
|||||||
|
|
||||||
ios := &upstreamResolverIOS{
|
ios := &upstreamResolverIOS{
|
||||||
upstreamResolverBase: upstreamResolverBase,
|
upstreamResolverBase: upstreamResolverBase,
|
||||||
wgIface: wgIface,
|
lIP: wgIface.Address().IP,
|
||||||
|
lNet: wgIface.Address().Network,
|
||||||
|
interfaceName: wgIface.Name(),
|
||||||
}
|
}
|
||||||
ios.upstreamClient = ios
|
ios.upstreamClient = ios
|
||||||
|
|
||||||
@@ -61,13 +65,11 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
|
|||||||
} else {
|
} else {
|
||||||
upstreamIP = upstreamIP.Unmap()
|
upstreamIP = upstreamIP.Unmap()
|
||||||
}
|
}
|
||||||
addr := u.wgIface.Address()
|
needsPrivate := u.lNet.Contains(upstreamIP) ||
|
||||||
needsPrivate := addr.Network.Contains(upstreamIP) ||
|
|
||||||
addr.IPv6Net.Contains(upstreamIP) ||
|
|
||||||
(u.routeMatch != nil && u.routeMatch(upstreamIP))
|
(u.routeMatch != nil && u.routeMatch(upstreamIP))
|
||||||
if needsPrivate {
|
if needsPrivate {
|
||||||
log.Debugf("using private client to query %s via upstream %s", r.Question[0].Name, upstream)
|
log.Debugf("using private client to query %s via upstream %s", r.Question[0].Name, upstream)
|
||||||
client, err = GetClientPrivate(u.wgIface, upstreamIP, timeout)
|
client, err = GetClientPrivate(u.lIP, u.interfaceName, timeout)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, fmt.Errorf("create private client: %s", err)
|
return nil, 0, fmt.Errorf("create private client: %s", err)
|
||||||
}
|
}
|
||||||
@@ -77,33 +79,25 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
|
|||||||
return ExchangeWithFallback(nil, client, r, upstream)
|
return ExchangeWithFallback(nil, client, r, upstream)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetClientPrivate returns a new DNS client bound to the local IP of the Netbird interface.
|
// GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface
|
||||||
// It selects the v6 bind address when the upstream is IPv6 and the interface has one, otherwise v4.
|
// This method is needed for iOS
|
||||||
func GetClientPrivate(iface privateClientIface, upstreamIP netip.Addr, dialTimeout time.Duration) (*dns.Client, error) {
|
func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
|
||||||
index, err := getInterfaceIndex(iface.Name())
|
index, err := getInterfaceIndex(interfaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("unable to get interface index for %s: %s", iface.Name(), err)
|
log.Debugf("unable to get interface index for %s: %s", interfaceName, err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
addr := iface.Address()
|
|
||||||
bindIP := addr.IP
|
|
||||||
if upstreamIP.Is6() && addr.HasIPv6() {
|
|
||||||
bindIP = addr.IPv6
|
|
||||||
}
|
|
||||||
|
|
||||||
proto, opt := unix.IPPROTO_IP, unix.IP_BOUND_IF
|
|
||||||
if bindIP.Is6() {
|
|
||||||
proto, opt = unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF
|
|
||||||
}
|
|
||||||
|
|
||||||
dialer := &net.Dialer{
|
dialer := &net.Dialer{
|
||||||
LocalAddr: net.UDPAddrFromAddrPort(netip.AddrPortFrom(bindIP, 0)),
|
LocalAddr: &net.UDPAddr{
|
||||||
Timeout: dialTimeout,
|
IP: ip.AsSlice(),
|
||||||
|
Port: 0, // Let the OS pick a free port
|
||||||
|
},
|
||||||
|
Timeout: dialTimeout,
|
||||||
Control: func(network, address string, c syscall.RawConn) error {
|
Control: func(network, address string, c syscall.RawConn) error {
|
||||||
var operr error
|
var operr error
|
||||||
fn := func(s uintptr) {
|
fn := func(s uintptr) {
|
||||||
operr = unix.SetsockoptInt(int(s), proto, opt, index)
|
operr = unix.SetsockoptInt(int(s), unix.IPPROTO_IP, unix.IP_BOUND_IF, index)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := c.Control(fn); err != nil {
|
if err := c.Control(fn); err != nil {
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user