diff --git a/client/internal/engine.go b/client/internal/engine.go index 10d74d931..2a98ee81e 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -8,6 +8,7 @@ import ( "net/netip" "reflect" "runtime" + "strconv" "strings" "sync" "time" @@ -16,6 +17,7 @@ import ( log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "github.com/netbirdio/netbird/client/firewall" "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/proxy" @@ -112,7 +114,9 @@ type Engine struct { statusRecorder *peer.Status - routeManager routemanager.Manager + routeManager routemanager.Manager + firewallManager firewall.Manager + firewallRules map[string]firewall.Rule dnsServer dns.Server } @@ -219,6 +223,12 @@ func (e *Engine) Start() error { e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder) + e.firewallManager, err = buildFirewallManager() + if err != nil { + log.Error("failed to create firewall manager, ACL policy will not work: %s", err.Error()) + } + e.firewallRules = make(map[string]firewall.Rule) + if e.dnsServer == nil { // todo fix custom address dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress) @@ -622,6 +632,9 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { log.Errorf("failed to update dns server, err: %v", err) } + if err := e.applyFirewallRules(networkMap.FirewallRules); err != nil { + log.Errorf("failed apply firewall rules, err: %v", err) + } e.networkSerial = serial return nil } @@ -1030,6 +1043,96 @@ func (e *Engine) close() { } +// applyFirewallRules to the local firewall manager processed by ACL policy. +func (e *Engine) applyFirewallRules(rules []*mgmProto.FirewallRule) error { + if e.firewallManager == nil { + log.Debug("firewall manager is not supported, skipping firewall rules") + return nil + } + + newRules := make([]string, 0) + for _, r := range rules { + rule := e.protoRuleToFirewallRule(r) + if rule == nil { + continue + } + newRules = append(newRules, rule.GetRuleID()) + } + + for _, ruleID := range newRules { + if rule, ok := e.firewallRules[ruleID]; ok { + if err := e.firewallManager.DeleteRule(rule); err != nil { + log.Debug("failed to delete firewall rule: %v", err) + continue + } + delete(e.firewallRules, ruleID) + } + } + + return nil +} + +func (e *Engine) protoRuleToFirewallRule(r *mgmProto.FirewallRule) firewall.Rule { + ip := net.ParseIP(r.PeerIP) + if ip == nil { + log.Debug("invalid IP address, skipping firewall rule") + return nil + } + + var port firewall.Port + if r.Port != "" { + split := strings.Split(r.Port, "/") + value, err := strconv.Atoi(split[0]) + if err != nil { + log.Debug("invalid port, skipping firewall rule") + return nil + } + port.Values = []int{value} + // get protocol from the port suffix if it exists + if len(split) > 1 { + switch split[1] { + case "tcp": + port.Proto = firewall.PortProtocolTCP + case "udp": + port.Proto = firewall.PortProtocolUDP + default: + log.Debug("invalid protocol, skipping firewall rule") + return nil + } + } + } + + var direction firewall.Direction + switch r.Direction { + case "src": + direction = firewall.DirectionSrc + case "dst": + direction = firewall.DirectionDst + default: + log.Debug("invalid direction, skipping firewall rule") + return nil + } + + var action firewall.Action + switch r.Action { + case "accept": + action = firewall.ActionAccept + case "drop": + action = firewall.ActionDrop + default: + log.Debug("invalid action, skipping firewall rule") + return nil + } + + rule, err := e.firewallManager.AddFiltering(ip, &port, direction, action, "") + if err != nil { + log.Debug("failed to add firewall rule: %v", err) + return nil + } + e.firewallRules[rule.GetRuleID()] = rule + return rule +} + func findIPFromInterfaceName(ifaceName string) (net.IP, error) { iface, err := net.InterfaceByName(ifaceName) if err != nil { diff --git a/client/internal/firewall.go b/client/internal/firewall.go new file mode 100644 index 000000000..70c4d8d10 --- /dev/null +++ b/client/internal/firewall.go @@ -0,0 +1,14 @@ +//go:build !linux + +package internal + +import ( + "fmt" + "runtime" + + "github.com/netbirdio/netbird/client/firewall" +) + +func buildFirewallManager() (fw firewall.Manager, err error) { + return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS) +} diff --git a/client/internal/firewall_linux.go b/client/internal/firewall_linux.go new file mode 100644 index 000000000..89fd95454 --- /dev/null +++ b/client/internal/firewall_linux.go @@ -0,0 +1,17 @@ +package internal + +import ( + "fmt" + + "github.com/netbirdio/netbird/client/firewall" + "github.com/netbirdio/netbird/client/firewall/iptables" +) + +func buildFirewallManager() (fw firewall.Manager, err error) { + fw, err = iptables.Create() + if err != nil { + // TODO: handle init nftables manager when it will be implemented + return nil, fmt.Errorf("create iptables manager: %w", err) + } + return fw, nil +} diff --git a/management/server/account.go b/management/server/account.go index 5b9c9402d..809f3953d 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -272,7 +272,7 @@ func (a *Account) GetGroup(groupID string) *Group { // GetPeerNetworkMap returns a group by ID if exists, nil otherwise func (a *Account) GetPeerNetworkMap(peerID, dnsDomain string) *NetworkMap { - aclPeers, _ := a.getPeersByPolicy(peerID) + aclPeers, firewallRules := a.getPeersByPolicy(peerID) // exclude expired peers var peersToConnect []*Peer var expiredPeers []*Peer @@ -303,11 +303,12 @@ func (a *Account) GetPeerNetworkMap(peerID, dnsDomain string) *NetworkMap { } return &NetworkMap{ - Peers: peersToConnect, - Network: a.Network.Copy(), - Routes: routesUpdate, - DNSConfig: dnsUpdate, - OfflinePeers: expiredPeers, + Peers: peersToConnect, + Network: a.Network.Copy(), + Routes: routesUpdate, + DNSConfig: dnsUpdate, + OfflinePeers: expiredPeers, + FirewallRules: firewallRules, } } diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index fa0e49ed3..ba27d643d 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -420,6 +420,8 @@ func toSyncResponse(config *Config, peer *Peer, turnCredentials *TURNCredentials offlinePeers := toRemotePeerConfig(networkMap.OfflinePeers, dnsName) + firewallRules := toProtocolFirewallRules(networkMap.FirewallRules) + return &proto.SyncResponse{ WiretrusteeConfig: wtConfig, PeerConfig: pConfig, @@ -433,6 +435,7 @@ func toSyncResponse(config *Config, peer *Peer, turnCredentials *TURNCredentials RemotePeersIsEmpty: len(remotePeers) == 0, Routes: routesUpdate, DNSConfig: dnsUpdate, + FirewallRules: firewallRules, }, } } diff --git a/management/server/network.go b/management/server/network.go index c436a88b8..c26c1bbd3 100644 --- a/management/server/network.go +++ b/management/server/network.go @@ -23,11 +23,12 @@ const ( ) type NetworkMap struct { - Peers []*Peer - Network *Network - Routes []*route.Route - DNSConfig nbdns.Config - OfflinePeers []*Peer + Peers []*Peer + Network *Network + Routes []*route.Route + DNSConfig nbdns.Config + OfflinePeers []*Peer + FirewallRules []*FirewallRule } type Network struct { diff --git a/management/server/policy.go b/management/server/policy.go index 31f6bb655..5ae3e56fc 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -8,6 +8,7 @@ import ( "html/template" "strings" + "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/status" @@ -449,3 +450,17 @@ func (am *DefaultAccountManager) savePolicy(account *Account, policy *Policy) (e } return } + +func toProtocolFirewallRules(update []*FirewallRule) []*proto.FirewallRule { + result := make([]*proto.FirewallRule, len(update)) + for i := range update { + result[i] = &proto.FirewallRule{ + PeerID: update[i].PeerID, + PeerIP: update[i].PeerIP, + Direction: update[i].Direction, + Port: update[i].Port, + Action: update[i].Action, + } + } + return result +}