mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
[client] Use platform-native routing APIs for freeBSD, macOS and Windows
This commit is contained in:
@@ -3,11 +3,11 @@ package server
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
@@ -19,81 +19,32 @@ func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) (
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.connectClient == nil {
|
||||
return nil, fmt.Errorf("connect client not initialized")
|
||||
}
|
||||
engine := s.connectClient.Engine()
|
||||
if engine == nil {
|
||||
return nil, fmt.Errorf("engine not initialized")
|
||||
tracer, engine, err := s.getPacketTracer()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fwManager := engine.GetFirewallManager()
|
||||
if fwManager == nil {
|
||||
return nil, fmt.Errorf("firewall manager not initialized")
|
||||
srcAddr, err := s.parseAddress(req.GetSourceIp(), engine)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid source IP address: %w", err)
|
||||
}
|
||||
|
||||
tracer, ok := fwManager.(packetTracer)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("firewall manager does not support packet tracing")
|
||||
dstAddr, err := s.parseAddress(req.GetDestinationIp(), engine)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid destination IP address: %w", err)
|
||||
}
|
||||
|
||||
srcIP := net.ParseIP(req.GetSourceIp())
|
||||
if req.GetSourceIp() == "self" {
|
||||
srcIP = engine.GetWgAddr()
|
||||
protocol, err := s.parseProtocol(req.GetProtocol())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
srcAddr, ok := netip.AddrFromSlice(srcIP)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source IP address")
|
||||
direction, err := s.parseDirection(req.GetDirection())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dstIP := net.ParseIP(req.GetDestinationIp())
|
||||
if req.GetDestinationIp() == "self" {
|
||||
dstIP = engine.GetWgAddr()
|
||||
}
|
||||
|
||||
dstAddr, ok := netip.AddrFromSlice(dstIP)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source IP address")
|
||||
}
|
||||
|
||||
if srcIP == nil || dstIP == nil {
|
||||
return nil, fmt.Errorf("invalid IP address")
|
||||
}
|
||||
|
||||
var tcpState *uspfilter.TCPState
|
||||
if flags := req.GetTcpFlags(); flags != nil {
|
||||
tcpState = &uspfilter.TCPState{
|
||||
SYN: flags.GetSyn(),
|
||||
ACK: flags.GetAck(),
|
||||
FIN: flags.GetFin(),
|
||||
RST: flags.GetRst(),
|
||||
PSH: flags.GetPsh(),
|
||||
URG: flags.GetUrg(),
|
||||
}
|
||||
}
|
||||
|
||||
var dir fw.RuleDirection
|
||||
switch req.GetDirection() {
|
||||
case "in":
|
||||
dir = fw.RuleDirectionIN
|
||||
case "out":
|
||||
dir = fw.RuleDirectionOUT
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid direction")
|
||||
}
|
||||
|
||||
var protocol fw.Protocol
|
||||
switch req.GetProtocol() {
|
||||
case "tcp":
|
||||
protocol = fw.ProtocolTCP
|
||||
case "udp":
|
||||
protocol = fw.ProtocolUDP
|
||||
case "icmp":
|
||||
protocol = fw.ProtocolICMP
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid protocolcol")
|
||||
}
|
||||
tcpState := s.parseTCPFlags(req.GetTcpFlags())
|
||||
|
||||
builder := &uspfilter.PacketBuilder{
|
||||
SrcIP: srcAddr,
|
||||
@@ -101,16 +52,96 @@ func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) (
|
||||
Protocol: protocol,
|
||||
SrcPort: uint16(req.GetSourcePort()),
|
||||
DstPort: uint16(req.GetDestinationPort()),
|
||||
Direction: dir,
|
||||
Direction: direction,
|
||||
TCPState: tcpState,
|
||||
ICMPType: uint8(req.GetIcmpType()),
|
||||
ICMPCode: uint8(req.GetIcmpCode()),
|
||||
}
|
||||
|
||||
trace, err := tracer.TracePacketFromBuilder(builder)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("trace packet: %w", err)
|
||||
}
|
||||
|
||||
return s.buildTraceResponse(trace), nil
|
||||
}
|
||||
|
||||
func (s *Server) getPacketTracer() (packetTracer, *internal.Engine, error) {
|
||||
if s.connectClient == nil {
|
||||
return nil, nil, fmt.Errorf("connect client not initialized")
|
||||
}
|
||||
|
||||
engine := s.connectClient.Engine()
|
||||
if engine == nil {
|
||||
return nil, nil, fmt.Errorf("engine not initialized")
|
||||
}
|
||||
|
||||
fwManager := engine.GetFirewallManager()
|
||||
if fwManager == nil {
|
||||
return nil, nil, fmt.Errorf("firewall manager not initialized")
|
||||
}
|
||||
|
||||
tracer, ok := fwManager.(packetTracer)
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("firewall manager does not support packet tracing")
|
||||
}
|
||||
|
||||
return tracer, engine, nil
|
||||
}
|
||||
|
||||
func (s *Server) parseAddress(addr string, engine *internal.Engine) (netip.Addr, error) {
|
||||
if addr == "self" {
|
||||
return engine.GetWgAddr(), nil
|
||||
}
|
||||
|
||||
a, err := netip.ParseAddr(addr)
|
||||
if err != nil {
|
||||
return netip.Addr{}, err
|
||||
}
|
||||
|
||||
return a.Unmap(), nil
|
||||
}
|
||||
|
||||
func (s *Server) parseProtocol(protocol string) (fw.Protocol, error) {
|
||||
switch protocol {
|
||||
case "tcp":
|
||||
return fw.ProtocolTCP, nil
|
||||
case "udp":
|
||||
return fw.ProtocolUDP, nil
|
||||
case "icmp":
|
||||
return fw.ProtocolICMP, nil
|
||||
default:
|
||||
return "", fmt.Errorf("invalid protocol")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) parseDirection(direction string) (fw.RuleDirection, error) {
|
||||
switch direction {
|
||||
case "in":
|
||||
return fw.RuleDirectionIN, nil
|
||||
case "out":
|
||||
return fw.RuleDirectionOUT, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("invalid direction")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) parseTCPFlags(flags *proto.TCPFlags) *uspfilter.TCPState {
|
||||
if flags == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &uspfilter.TCPState{
|
||||
SYN: flags.GetSyn(),
|
||||
ACK: flags.GetAck(),
|
||||
FIN: flags.GetFin(),
|
||||
RST: flags.GetRst(),
|
||||
PSH: flags.GetPsh(),
|
||||
URG: flags.GetUrg(),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) buildTraceResponse(trace *uspfilter.PacketTrace) *proto.TracePacketResponse {
|
||||
resp := &proto.TracePacketResponse{}
|
||||
|
||||
for _, result := range trace.Results {
|
||||
@@ -119,10 +150,12 @@ func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) (
|
||||
Message: result.Message,
|
||||
Allowed: result.Allowed,
|
||||
}
|
||||
|
||||
if result.ForwarderAction != nil {
|
||||
details := fmt.Sprintf("%s to %s", result.ForwarderAction.Action, result.ForwarderAction.RemoteAddr)
|
||||
stage.ForwardingDetails = &details
|
||||
}
|
||||
|
||||
resp.Stages = append(resp.Stages, stage)
|
||||
}
|
||||
|
||||
@@ -130,5 +163,5 @@ func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) (
|
||||
resp.FinalDisposition = trace.Results[len(trace.Results)-1].Allowed
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
return resp
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user