diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index 4e690ff1b..2f1df9b1a 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -146,6 +146,64 @@ jobs: - name: Test run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay) + test_client_on_docker: + name: "Client (Docker) / Unit" + needs: [build-cache] + runs-on: ubuntu-22.04 + steps: + - name: Install Go + uses: actions/setup-go@v5 + with: + go-version: "1.23.x" + cache: false + + - name: Checkout code + uses: actions/checkout@v4 + + - name: Get Go environment + id: go-env + run: | + echo "cache_dir=$(go env GOCACHE)" >> $GITHUB_OUTPUT + echo "modcache_dir=$(go env GOMODCACHE)" >> $GITHUB_OUTPUT + + - name: Cache Go modules + uses: actions/cache/restore@v4 + id: cache-restore + with: + path: | + ${{ steps.go-env.outputs.cache_dir }} + ${{ steps.go-env.outputs.modcache_dir }} + key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-gotest-cache- + + - name: Run tests in container + env: + HOST_GOCACHE: ${{ steps.go-env.outputs.cache_dir }} + HOST_GOMODCACHE: ${{ steps.go-env.outputs.modcache_dir }} + run: | + CONTAINER_GOCACHE="/root/.cache/go-build" + CONTAINER_GOMODCACHE="/go/pkg/mod" + + docker run --rm \ + --cap-add=NET_ADMIN \ + --privileged \ + -v $PWD:/app \ + -w /app \ + -v "${HOST_GOCACHE}:${CONTAINER_GOCACHE}" \ + -v "${HOST_GOMODCACHE}:${CONTAINER_GOMODCACHE}" \ + -e CGO_ENABLED=1 \ + -e CI=true \ + -e GOARCH=${GOARCH_TARGET} \ + -e GOCACHE=${CONTAINER_GOCACHE} \ + -e GOMODCACHE=${CONTAINER_GOMODCACHE} \ + golang:1.23-alpine \ + sh -c ' \ + apk update; apk add --no-cache \ + ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \ + go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /client/ui) + ' + test_relay: name: "Relay / Unit" needs: [build-cache] @@ -179,13 +237,6 @@ jobs: restore-keys: | ${{ runner.os }}-gotest-cache- - - name: Install dependencies - run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev - - - name: Install 32-bit libpcap - if: matrix.arch == '386' - run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386 - - name: Install modules run: go mod tidy @@ -232,13 +283,6 @@ jobs: restore-keys: | ${{ runner.os }}-gotest-cache- - - name: Install dependencies - run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev - - - name: Install 32-bit libpcap - if: matrix.arch == '386' - run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386 - - name: Install modules run: go mod tidy @@ -286,13 +330,6 @@ jobs: restore-keys: | ${{ runner.os }}-gotest-cache- - - name: Install dependencies - run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev - - - name: Install 32-bit libpcap - if: matrix.arch == '386' - run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386 - - name: Install modules run: go mod tidy @@ -354,13 +391,6 @@ jobs: restore-keys: | ${{ runner.os }}-gotest-cache- - - name: Install dependencies - run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev - - - name: Install 32-bit libpcap - if: matrix.arch == '386' - run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386 - - name: Install modules run: go mod tidy @@ -449,13 +479,6 @@ jobs: restore-keys: | ${{ runner.os }}-gotest-cache- - - name: Install dependencies - run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev - - - name: Install 32-bit libpcap - if: matrix.arch == '386' - run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386 - - name: Install modules run: go mod tidy @@ -520,13 +543,6 @@ jobs: restore-keys: | ${{ runner.os }}-gotest-cache- - - name: Install dependencies - run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev - - - name: Install 32-bit libpcap - if: matrix.arch == '386' - run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386 - - name: Install modules run: go mod tidy @@ -541,99 +557,3 @@ jobs: go test -tags=integration \ -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \ -timeout 20m ./management/... - - test_client_on_docker: - name: "Client (Docker) / Unit" - needs: [ build-cache ] - runs-on: ubuntu-22.04 - steps: - - name: Install Go - uses: actions/setup-go@v5 - with: - go-version: "1.23.x" - cache: false - - - name: Checkout code - uses: actions/checkout@v4 - - - name: Get Go environment - run: | - echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV - echo "modcache=$(go env.GOMODCACHE)" >> $GITHUB_ENV - - - name: Cache Go modules - uses: actions/cache/restore@v4 - with: - path: | - ${{ env.cache }} - ${{ env.modcache }} - key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }} - restore-keys: | - ${{ runner.os }}-gotest-cache- - - - name: Install dependencies - run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev - - - name: Install modules - run: go mod tidy - - - name: Check git status - run: git --no-pager diff --exit-code - - - name: Generate Shared Sock Test bin - run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock - - - name: Generate RouteManager Test bin - run: CGO_ENABLED=0 go test -c -o routemanager-testing.bin ./client/internal/routemanager - - - name: Generate SystemOps Test bin (static via Alpine) - run: | - docker run --rm -v $PWD:/app -w /app \ - alpine:latest \ - sh -c " - apk add --no-cache go gcc musl-dev libpcap-dev dbus-dev && \ - adduser -D -u $(id -u) builder && \ - su builder -c '\ - cd /app && \ - CGO_ENABLED=1 GOOS=linux GOARCH=amd64 \ - go test -c -o /app/systemops-testing.bin \ - -tags netgo \ - -ldflags=\"-w -extldflags \\\"-static -ldbus-1 -lpcap\\\"\" \ - ./client/internal/routemanager/systemops \ - ' - " - - - name: Generate nftables Manager Test bin - run: CGO_ENABLED=0 go test -c -o nftablesmanager-testing.bin ./client/firewall/nftables/... - - - name: Generate Engine Test bin - run: CGO_ENABLED=1 go test -c -o engine-testing.bin ./client/internal - - - name: Generate Peer Test bin - run: CGO_ENABLED=0 go test -c -o peer-testing.bin ./client/internal/peer/ - - - run: chmod +x *testing.bin - - - name: Run Shared Sock tests in docker - run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/sharedsock --entrypoint /ci/sharedsock-testing.bin gcr.io/distroless/base:debug -test.timeout 5m -test.parallel 1 - - - name: Run Iface tests in docker - run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/netbird -v /tmp/cache:/tmp/cache -v /tmp/modcache:/tmp/modcache -w /netbird -e GOCACHE=/tmp/cache -e GOMODCACHE=/tmp/modcache -e CGO_ENABLED=0 golang:1.23-alpine go test -test.timeout 5m -test.parallel 1 ./client/iface/... - - - name: Run RouteManager tests in docker - run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /ci/routemanager-testing.bin gcr.io/distroless/base:debug -test.timeout 5m -test.parallel 1 - - - name: Run SystemOps tests in docker - run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager/systemops --entrypoint /ci/systemops-testing.bin gcr.io/distroless/base:debug -test.timeout 5m -test.parallel 1 - - - name: Run nftables Manager tests in docker - run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/firewall --entrypoint /ci/nftablesmanager-testing.bin gcr.io/distroless/base:debug -test.timeout 5m -test.parallel 1 - - - name: Run Engine tests in docker with file store - run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="jsonfile" --entrypoint /ci/engine-testing.bin gcr.io/distroless/base:debug -test.timeout 5m -test.parallel 1 - - - name: Run Engine tests in docker with sqlite store - run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="sqlite" --entrypoint /ci/engine-testing.bin gcr.io/distroless/base:debug -test.timeout 5m -test.parallel 1 - - - name: Run Peer tests in docker - run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /ci/peer-testing.bin gcr.io/distroless/base:debug -test.timeout 5m -test.parallel 1 diff --git a/client/Dockerfile b/client/Dockerfile index 35c1d04c2..16b2916c7 100644 --- a/client/Dockerfile +++ b/client/Dockerfile @@ -1,5 +1,6 @@ FROM alpine:3.21.3 -RUN apk add --no-cache ca-certificates iptables ip6tables +# iproute2: busybox doesn't display ip rules properly +RUN apk add --no-cache ca-certificates ip6tables iproute2 iptables ENV NB_FOREGROUND_MODE=true ENTRYPOINT [ "/usr/local/bin/netbird","up"] -COPY netbird /usr/local/bin/netbird \ No newline at end of file +COPY netbird /usr/local/bin/netbird diff --git a/client/cmd/login.go b/client/cmd/login.go index 549eef40e..84906a7a4 100644 --- a/client/cmd/login.go +++ b/client/cmd/login.go @@ -48,9 +48,6 @@ var loginCmd = &cobra.Command{ return err } - // update host's static platform and system information - system.UpdateStaticInfo() - // workaround to run without service if logFile == "console" { err = handleRebrand(cmd) @@ -58,6 +55,9 @@ var loginCmd = &cobra.Command{ return err } + // update host's static platform and system information + system.UpdateStaticInfo() + ic := internal.ConfigInput{ ManagementURL: managementURL, AdminURL: adminURL, diff --git a/client/firewall/uspfilter/forwarder/forwarder.go b/client/firewall/uspfilter/forwarder/forwarder.go index 0dff3acc7..2ae983f6e 100644 --- a/client/firewall/uspfilter/forwarder/forwarder.go +++ b/client/firewall/uspfilter/forwarder/forwarder.go @@ -4,7 +4,9 @@ import ( "context" "fmt" "net" + "net/netip" "runtime" + "sync" log "github.com/sirupsen/logrus" "gvisor.dev/gvisor/pkg/buffer" @@ -17,6 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "github.com/netbirdio/netbird/client/firewall/uspfilter/common" + "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" ) @@ -29,8 +32,10 @@ const ( ) type Forwarder struct { - logger *nblog.Logger - flowLogger nftypes.FlowLogger + logger *nblog.Logger + flowLogger nftypes.FlowLogger + // ruleIdMap is used to store the rule ID for a given connection + ruleIdMap sync.Map stack *stack.Stack endpoint *endpoint udpForwarder *udpForwarder @@ -167,3 +172,35 @@ func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP { } return addr.AsSlice() } + +func (f *Forwarder) RegisterRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, ruleID []byte) { + key := buildKey(srcIP, dstIP, srcPort, dstPort) + f.ruleIdMap.LoadOrStore(key, ruleID) +} + +func (f *Forwarder) getRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) ([]byte, bool) { + + if value, ok := f.ruleIdMap.Load(buildKey(srcIP, dstIP, srcPort, dstPort)); ok { + return value.([]byte), true + } else if value, ok := f.ruleIdMap.Load(buildKey(dstIP, srcIP, dstPort, srcPort)); ok { + return value.([]byte), true + } + + return nil, false +} + +func (f *Forwarder) DeleteRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) { + if _, ok := f.ruleIdMap.LoadAndDelete(buildKey(srcIP, dstIP, srcPort, dstPort)); ok { + return + } + f.ruleIdMap.LoadAndDelete(buildKey(dstIP, srcIP, dstPort, srcPort)) +} + +func buildKey(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) conntrack.ConnKey { + return conntrack.ConnKey{ + SrcIP: srcIP, + DstIP: dstIP, + SrcPort: srcPort, + DstPort: dstPort, + } +} diff --git a/client/firewall/uspfilter/forwarder/icmp.go b/client/firewall/uspfilter/forwarder/icmp.go index a21ec2c87..08d77ed05 100644 --- a/client/firewall/uspfilter/forwarder/icmp.go +++ b/client/firewall/uspfilter/forwarder/icmp.go @@ -25,7 +25,7 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf } flowID := uuid.New() - f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode) + f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode, 0, 0) ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second) defer cancel() @@ -34,14 +34,14 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf // TODO: support non-root conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0") if err != nil { - f.logger.Error("Failed to create ICMP socket for %v: %v", epID(id), err) + f.logger.Error("forwarder: Failed to create ICMP socket for %v: %v", epID(id), err) // This will make netstack reply on behalf of the original destination, that's ok for now return false } defer func() { if err := conn.Close(); err != nil { - f.logger.Debug("Failed to close ICMP socket: %v", err) + f.logger.Debug("forwarder: Failed to close ICMP socket: %v", err) } }() @@ -52,36 +52,37 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf payload := fullPacket.AsSlice() if _, err = conn.WriteTo(payload, dst); err != nil { - f.logger.Error("Failed to write ICMP packet for %v: %v", epID(id), err) + f.logger.Error("forwarder: Failed to write ICMP packet for %v: %v", epID(id), err) return true } - f.logger.Trace("Forwarded ICMP packet %v type %v code %v", + f.logger.Trace("forwarder: Forwarded ICMP packet %v type %v code %v", epID(id), icmpHdr.Type(), icmpHdr.Code()) // For Echo Requests, send and handle response if header.ICMPv4Type(icmpType) == header.ICMPv4Echo { - f.handleEchoResponse(icmpHdr, conn, id) - f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode) + rxBytes := pkt.Size() + txBytes := f.handleEchoResponse(icmpHdr, conn, id) + f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes)) } // For other ICMP types (Time Exceeded, Destination Unreachable, etc) do nothing return true } -func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) { +func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) int { if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { - f.logger.Error("Failed to set read deadline for ICMP response: %v", err) - return + f.logger.Error("forwarder: Failed to set read deadline for ICMP response: %v", err) + return 0 } response := make([]byte, f.endpoint.mtu) n, _, err := conn.ReadFrom(response) if err != nil { if !isTimeout(err) { - f.logger.Error("Failed to read ICMP response: %v", err) + f.logger.Error("forwarder: Failed to read ICMP response: %v", err) } - return + return 0 } ipHdr := make([]byte, header.IPv4MinimumSize) @@ -100,28 +101,54 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketCon fullPacket = append(fullPacket, response[:n]...) if err := f.InjectIncomingPacket(fullPacket); err != nil { - f.logger.Error("Failed to inject ICMP response: %v", err) + f.logger.Error("forwarder: Failed to inject ICMP response: %v", err) - return + return 0 } - f.logger.Trace("Forwarded ICMP echo reply for %v type %v code %v", + f.logger.Trace("forwarder: Forwarded ICMP echo reply for %v type %v code %v", epID(id), icmpHdr.Type(), icmpHdr.Code()) + + return len(fullPacket) } // sendICMPEvent stores flow events for ICMP packets -func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8) { - f.flowLogger.StoreEvent(nftypes.EventFields{ +func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, rxBytes, txBytes uint64) { + var rxPackets, txPackets uint64 + if rxBytes > 0 { + rxPackets = 1 + } + if txBytes > 0 { + txPackets = 1 + } + + srcIp := netip.AddrFrom4(id.RemoteAddress.As4()) + dstIp := netip.AddrFrom4(id.LocalAddress.As4()) + + fields := nftypes.EventFields{ FlowID: flowID, Type: typ, Direction: nftypes.Ingress, Protocol: nftypes.ICMP, // TODO: handle ipv6 - SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()), - DestIP: netip.AddrFrom4(id.LocalAddress.As4()), + SourceIP: srcIp, + DestIP: dstIp, ICMPType: icmpType, ICMPCode: icmpCode, - // TODO: get packets/bytes - }) + RxBytes: rxBytes, + TxBytes: txBytes, + RxPackets: rxPackets, + TxPackets: txPackets, + } + + if typ == nftypes.TypeStart { + if ruleId, ok := f.getRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort); ok { + fields.RuleID = ruleId + } + } else { + f.DeleteRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort) + } + + f.flowLogger.StoreEvent(fields) } diff --git a/client/firewall/uspfilter/forwarder/tcp.go b/client/firewall/uspfilter/forwarder/tcp.go index 71cd457ef..04b3ae233 100644 --- a/client/firewall/uspfilter/forwarder/tcp.go +++ b/client/firewall/uspfilter/forwarder/tcp.go @@ -6,8 +6,10 @@ import ( "io" "net" "net/netip" + "sync" "github.com/google/uuid" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -23,11 +25,11 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) { flowID := uuid.New() - f.sendTCPEvent(nftypes.TypeStart, flowID, id, nil) + f.sendTCPEvent(nftypes.TypeStart, flowID, id, 0, 0, 0, 0) var success bool defer func() { if !success { - f.sendTCPEvent(nftypes.TypeEnd, flowID, id, nil) + f.sendTCPEvent(nftypes.TypeEnd, flowID, id, 0, 0, 0, 0) } }() @@ -65,67 +67,97 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) { } func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) { - defer func() { - if err := inConn.Close(); err != nil { - f.logger.Debug("forwarder: inConn close error: %v", err) - } - if err := outConn.Close(); err != nil { - f.logger.Debug("forwarder: outConn close error: %v", err) - } - ep.Close() - f.sendTCPEvent(nftypes.TypeEnd, flowID, id, ep) - }() - - // Create context for managing the proxy goroutines ctx, cancel := context.WithCancel(f.ctx) defer cancel() - errChan := make(chan error, 2) - go func() { - _, err := io.Copy(outConn, inConn) - errChan <- err - }() - - go func() { - _, err := io.Copy(inConn, outConn) - errChan <- err - }() - - select { - case <-ctx.Done(): - f.logger.Trace("forwarder: tearing down TCP connection %v due to context done", epID(id)) - return - case err := <-errChan: - if err != nil && !isClosedError(err) { - f.logger.Error("proxyTCP: copy error: %v", err) + <-ctx.Done() + // Close connections and endpoint. + if err := inConn.Close(); err != nil && !isClosedError(err) { + f.logger.Debug("forwarder: inConn close error: %v", err) + } + if err := outConn.Close(); err != nil && !isClosedError(err) { + f.logger.Debug("forwarder: outConn close error: %v", err) + } + + ep.Close() + }() + + var wg sync.WaitGroup + wg.Add(2) + + var ( + bytesFromInToOut int64 // bytes from client to server (tx for client) + bytesFromOutToIn int64 // bytes from server to client (rx for client) + errInToOut error + errOutToIn error + ) + + go func() { + bytesFromInToOut, errInToOut = io.Copy(outConn, inConn) + cancel() + wg.Done() + }() + + go func() { + + bytesFromOutToIn, errOutToIn = io.Copy(inConn, outConn) + cancel() + wg.Done() + }() + + wg.Wait() + + if errInToOut != nil { + if !isClosedError(errInToOut) { + f.logger.Error("proxyTCP: copy error (in -> out): %v", errInToOut) } - f.logger.Trace("forwarder: tearing down TCP connection %v", epID(id)) - return } + if errOutToIn != nil { + if !isClosedError(errOutToIn) { + f.logger.Error("proxyTCP: copy error (out -> in): %v", errOutToIn) + } + } + + var rxPackets, txPackets uint64 + if tcpStats, ok := ep.Stats().(*tcp.Stats); ok { + // fields are flipped since this is the in conn + rxPackets = tcpStats.SegmentsSent.Value() + txPackets = tcpStats.SegmentsReceived.Value() + } + + f.logger.Trace("forwarder: Removed TCP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, bytesFromOutToIn, txPackets, bytesFromInToOut) + + f.sendTCPEvent(nftypes.TypeEnd, flowID, id, uint64(bytesFromOutToIn), uint64(bytesFromInToOut), rxPackets, txPackets) } -func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) { +func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) { + srcIp := netip.AddrFrom4(id.RemoteAddress.As4()) + dstIp := netip.AddrFrom4(id.LocalAddress.As4()) + fields := nftypes.EventFields{ FlowID: flowID, Type: typ, Direction: nftypes.Ingress, Protocol: nftypes.TCP, // TODO: handle ipv6 - SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()), - DestIP: netip.AddrFrom4(id.LocalAddress.As4()), + SourceIP: srcIp, + DestIP: dstIp, SourcePort: id.RemotePort, DestPort: id.LocalPort, + RxBytes: rxBytes, + TxBytes: txBytes, + RxPackets: rxPackets, + TxPackets: txPackets, } - if ep != nil { - if tcpStats, ok := ep.Stats().(*tcp.Stats); ok { - // fields are flipped since this is the in conn - // TODO: get bytes - fields.RxPackets = tcpStats.SegmentsSent.Value() - fields.TxPackets = tcpStats.SegmentsReceived.Value() + if typ == nftypes.TypeStart { + if ruleId, ok := f.getRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort); ok { + fields.RuleID = ruleId } + } else { + f.DeleteRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort) } f.flowLogger.StoreEvent(fields) diff --git a/client/firewall/uspfilter/forwarder/udp.go b/client/firewall/uspfilter/forwarder/udp.go index 7ce85e2b6..cb88aa59a 100644 --- a/client/firewall/uspfilter/forwarder/udp.go +++ b/client/firewall/uspfilter/forwarder/udp.go @@ -149,11 +149,11 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) { flowID := uuid.New() - f.sendUDPEvent(nftypes.TypeStart, flowID, id, nil) + f.sendUDPEvent(nftypes.TypeStart, flowID, id, 0, 0, 0, 0) var success bool defer func() { if !success { - f.sendUDPEvent(nftypes.TypeEnd, flowID, id, nil) + f.sendUDPEvent(nftypes.TypeEnd, flowID, id, 0, 0, 0, 0) } }() @@ -199,7 +199,6 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) { if err := outConn.Close(); err != nil { f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err) } - return } f.udpForwarder.conns[id] = pConn @@ -212,68 +211,94 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) { } func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID, ep tcpip.Endpoint) { - defer func() { + + ctx, cancel := context.WithCancel(f.ctx) + defer cancel() + + go func() { + <-ctx.Done() + pConn.cancel() - if err := pConn.conn.Close(); err != nil { + if err := pConn.conn.Close(); err != nil && !isClosedError(err) { f.logger.Debug("forwarder: UDP inConn close error for %v: %v", epID(id), err) } - if err := pConn.outConn.Close(); err != nil { + if err := pConn.outConn.Close(); err != nil && !isClosedError(err) { f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err) } ep.Close() - - f.udpForwarder.Lock() - delete(f.udpForwarder.conns, id) - f.udpForwarder.Unlock() - - f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id, ep) }() - errChan := make(chan error, 2) + var wg sync.WaitGroup + wg.Add(2) + var txBytes, rxBytes int64 + var outboundErr, inboundErr error + + // outbound->inbound: copy from pConn.conn to pConn.outConn go func() { - errChan <- pConn.copy(ctx, pConn.conn, pConn.outConn, &f.udpForwarder.bufPool, "outbound->inbound") + defer wg.Done() + txBytes, outboundErr = pConn.copy(ctx, pConn.conn, pConn.outConn, &f.udpForwarder.bufPool, "outbound->inbound") }() + // inbound->outbound: copy from pConn.outConn to pConn.conn go func() { - errChan <- pConn.copy(ctx, pConn.outConn, pConn.conn, &f.udpForwarder.bufPool, "inbound->outbound") + defer wg.Done() + rxBytes, inboundErr = pConn.copy(ctx, pConn.outConn, pConn.conn, &f.udpForwarder.bufPool, "inbound->outbound") }() - select { - case <-ctx.Done(): - f.logger.Trace("forwarder: tearing down UDP connection %v due to context done", epID(id)) - return - case err := <-errChan: - if err != nil && !isClosedError(err) { - f.logger.Error("proxyUDP: copy error: %v", err) - } - f.logger.Trace("forwarder: tearing down UDP connection %v", epID(id)) - return + wg.Wait() + + if outboundErr != nil && !isClosedError(outboundErr) { + f.logger.Error("proxyUDP: copy error (outbound->inbound): %v", outboundErr) } + if inboundErr != nil && !isClosedError(inboundErr) { + f.logger.Error("proxyUDP: copy error (inbound->outbound): %v", inboundErr) + } + + var rxPackets, txPackets uint64 + if udpStats, ok := ep.Stats().(*tcpip.TransportEndpointStats); ok { + // fields are flipped since this is the in conn + rxPackets = udpStats.PacketsSent.Value() + txPackets = udpStats.PacketsReceived.Value() + } + + f.logger.Trace("forwarder: Removed UDP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, rxBytes, txPackets, txBytes) + + f.udpForwarder.Lock() + delete(f.udpForwarder.conns, id) + f.udpForwarder.Unlock() + + f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id, uint64(rxBytes), uint64(txBytes), rxPackets, txPackets) } // sendUDPEvent stores flow events for UDP connections -func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) { +func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) { + srcIp := netip.AddrFrom4(id.RemoteAddress.As4()) + dstIp := netip.AddrFrom4(id.LocalAddress.As4()) + fields := nftypes.EventFields{ FlowID: flowID, Type: typ, Direction: nftypes.Ingress, Protocol: nftypes.UDP, // TODO: handle ipv6 - SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()), - DestIP: netip.AddrFrom4(id.LocalAddress.As4()), + SourceIP: srcIp, + DestIP: dstIp, SourcePort: id.RemotePort, DestPort: id.LocalPort, + RxBytes: rxBytes, + TxBytes: txBytes, + RxPackets: rxPackets, + TxPackets: txPackets, } - if ep != nil { - if tcpStats, ok := ep.Stats().(*tcpip.TransportEndpointStats); ok { - // fields are flipped since this is the in conn - // TODO: get bytes - fields.RxPackets = tcpStats.PacketsSent.Value() - fields.TxPackets = tcpStats.PacketsReceived.Value() + if typ == nftypes.TypeStart { + if ruleId, ok := f.getRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort); ok { + fields.RuleID = ruleId } + } else { + f.DeleteRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort) } f.flowLogger.StoreEvent(fields) @@ -288,18 +313,20 @@ func (c *udpPacketConn) getIdleDuration() time.Duration { return time.Since(lastSeen) } -func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bufPool *sync.Pool, direction string) error { +// copy reads from src and writes to dst. +func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bufPool *sync.Pool, direction string) (int64, error) { bufp := bufPool.Get().(*[]byte) defer bufPool.Put(bufp) buffer := *bufp + var totalBytes int64 = 0 for { if ctx.Err() != nil { - return ctx.Err() + return totalBytes, ctx.Err() } if err := src.SetDeadline(time.Now().Add(udpTimeout)); err != nil { - return fmt.Errorf("set read deadline: %w", err) + return totalBytes, fmt.Errorf("set read deadline: %w", err) } n, err := src.Read(buffer) @@ -307,14 +334,15 @@ func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bu if isTimeout(err) { continue } - return fmt.Errorf("read from %s: %w", direction, err) + return totalBytes, fmt.Errorf("read from %s: %w", direction, err) } - _, err = dst.Write(buffer[:n]) + nWritten, err := dst.Write(buffer[:n]) if err != nil { - return fmt.Errorf("write to %s: %w", direction, err) + return totalBytes, fmt.Errorf("write to %s: %w", direction, err) } + totalBytes += int64(nWritten) c.updateLastSeen() } } diff --git a/client/firewall/uspfilter/tracer_test.go b/client/firewall/uspfilter/tracer_test.go index 53ee6c886..bd87879a5 100644 --- a/client/firewall/uspfilter/tracer_test.go +++ b/client/firewall/uspfilter/tracer_test.go @@ -198,12 +198,12 @@ func TestTracePacket(t *testing.T) { m.forwarder.Store(&forwarder.Forwarder{}) src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32) - dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32) + dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 168, 17, 2}), 32) _, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept) require.NoError(t, err) }, packetBuilder: func() *PacketBuilder { - return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN) + return createPacketBuilder("1.1.1.1", "192.168.17.2", "tcp", 12345, 80, fw.RuleDirectionIN) }, expectedStages: []PacketStage{ StageReceived, @@ -222,12 +222,12 @@ func TestTracePacket(t *testing.T) { m.nativeRouter.Store(false) src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32) - dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32) + dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 168, 17, 2}), 32) _, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop) require.NoError(t, err) }, packetBuilder: func() *PacketBuilder { - return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN) + return createPacketBuilder("1.1.1.1", "192.168.17.2", "tcp", 12345, 80, fw.RuleDirectionIN) }, expectedStages: []PacketStage{ StageReceived, @@ -245,7 +245,7 @@ func TestTracePacket(t *testing.T) { m.nativeRouter.Store(true) }, packetBuilder: func() *PacketBuilder { - return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN) + return createPacketBuilder("1.1.1.1", "192.168.17.2", "tcp", 12345, 80, fw.RuleDirectionIN) }, expectedStages: []PacketStage{ StageReceived, @@ -263,7 +263,7 @@ func TestTracePacket(t *testing.T) { m.routingEnabled.Store(false) }, packetBuilder: func() *PacketBuilder { - return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN) + return createPacketBuilder("1.1.1.1", "192.168.17.2", "tcp", 12345, 80, fw.RuleDirectionIN) }, expectedStages: []PacketStage{ StageReceived, @@ -425,8 +425,8 @@ func TestTracePacket(t *testing.T) { require.True(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("100.10.0.100")), "100.10.0.100 should be recognized as a local IP") - require.False(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("172.17.0.2")), - "172.17.0.2 should not be recognized as a local IP") + require.False(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("192.168.17.2")), + "192.168.17.2 should not be recognized as a local IP") pb := tc.packetBuilder() diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index ccf0be225..11730dbb3 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -824,7 +824,8 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe proto, pnum := getProtocolFromPacket(d) srcPort, dstPort := getPortsFromPacket(d) - if ruleID, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort); !pass { + ruleID, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort) + if !pass { m.logger.Trace("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d", ruleID, pnum, srcIP, srcPort, dstIP, dstPort) @@ -850,8 +851,11 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe if fwd == nil { m.logger.Trace("failed to forward routed packet (forwarder not initialized)") } else { + fwd.RegisterRuleID(srcIP, dstIP, srcPort, dstPort, ruleID) + if err := fwd.InjectIncomingPacket(packetData); err != nil { m.logger.Error("Failed to inject routed packet: %v", err) + fwd.DeleteRuleID(srcIP, dstIP, srcPort, dstPort) } } diff --git a/client/internal/routemanager/systemops/systemops_bsd_test.go b/client/internal/routemanager/systemops/systemops_bsd_test.go index 84b84483e..a83d7f1de 100644 --- a/client/internal/routemanager/systemops/systemops_bsd_test.go +++ b/client/internal/routemanager/systemops/systemops_bsd_test.go @@ -24,7 +24,6 @@ func init() { testCases = append(testCases, []testCase{ { name: "To more specific route without custom dialer via vpn", - destination: "10.10.0.2:53", expectedInterface: expectedVPNint, dialer: &net.Dialer{}, expectedPacket: createPacketExpectation("100.64.0.1", 12345, "10.10.0.2", 53), diff --git a/client/internal/routemanager/systemops/systemops_linux_test.go b/client/internal/routemanager/systemops/systemops_linux_test.go index 8f12740d0..f0d7472dc 100644 --- a/client/internal/routemanager/systemops/systemops_linux_test.go +++ b/client/internal/routemanager/systemops/systemops_linux_test.go @@ -27,14 +27,12 @@ func init() { testCases = append(testCases, []testCase{ { name: "To more specific route without custom dialer via physical interface", - destination: "10.10.0.2:53", expectedInterface: expectedInternalInt, dialer: &net.Dialer{}, expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.10.0.2", 53), }, { name: "To more specific route (local) without custom dialer via physical interface", - destination: "127.0.10.1:53", expectedInterface: expectedLoopbackInt, dialer: &net.Dialer{}, expectedPacket: createPacketExpectation("127.0.0.1", 12345, "127.0.10.1", 53), @@ -134,6 +132,16 @@ func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, intf string) { _, dstIPNet, err := net.ParseCIDR(dstCIDR) require.NoError(t, err) + link, err := netlink.LinkByName(intf) + require.NoError(t, err) + linkIndex := link.Attrs().Index + + route := &netlink.Route{ + Dst: dstIPNet, + Gw: gw, + LinkIndex: linkIndex, + } + // Handle existing routes with metric 0 var originalNexthop net.IP var originalLinkIndex int @@ -145,32 +153,24 @@ func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, intf string) { } if originalNexthop != nil { + // remove original route err = netlink.RouteDel(&netlink.Route{Dst: dstIPNet, Priority: 0}) - switch { - case err != nil && !errors.Is(err, syscall.ESRCH): - t.Logf("Failed to delete route: %v", err) - case err == nil: - t.Cleanup(func() { - err := netlink.RouteAdd(&netlink.Route{Dst: dstIPNet, Gw: originalNexthop, LinkIndex: originalLinkIndex, Priority: 0}) - if err != nil && !errors.Is(err, syscall.EEXIST) { - t.Fatalf("Failed to add route: %v", err) - } - }) - default: - t.Logf("Failed to delete route: %v", err) - } + assert.NoError(t, err) + + // add new route + assert.NoError(t, netlink.RouteAdd(route)) + + t.Cleanup(func() { + // restore original route + assert.NoError(t, netlink.RouteDel(route)) + err := netlink.RouteAdd(&netlink.Route{Dst: dstIPNet, Gw: originalNexthop, LinkIndex: originalLinkIndex, Priority: 0}) + assert.NoError(t, err) + }) + + return } } - link, err := netlink.LinkByName(intf) - require.NoError(t, err) - linkIndex := link.Attrs().Index - - route := &netlink.Route{ - Dst: dstIPNet, - Gw: gw, - LinkIndex: linkIndex, - } err = netlink.RouteDel(route) if err != nil && !errors.Is(err, syscall.ESRCH) { t.Logf("Failed to delete route: %v", err) @@ -180,7 +180,6 @@ func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, intf string) { if err != nil && !errors.Is(err, syscall.EEXIST) { t.Fatalf("Failed to add route: %v", err) } - require.NoError(t, err) } func fetchOriginalGateway(family int) (net.IP, int, error) { @@ -190,7 +189,11 @@ func fetchOriginalGateway(family int) (net.IP, int, error) { } for _, route := range routes { - if route.Dst == nil && route.Priority == 0 { + ones := -1 + if route.Dst != nil { + ones, _ = route.Dst.Mask.Size() + } + if route.Dst == nil || ones == 0 && route.Priority == 0 { return route.Gw, route.LinkIndex, nil } } diff --git a/client/internal/routemanager/systemops/systemops_unix_test.go b/client/internal/routemanager/systemops/systemops_unix_test.go index d88c1ab6b..ad37f611f 100644 --- a/client/internal/routemanager/systemops/systemops_unix_test.go +++ b/client/internal/routemanager/systemops/systemops_unix_test.go @@ -31,7 +31,6 @@ type PacketExpectation struct { type testCase struct { name string - destination string expectedInterface string dialer dialer expectedPacket PacketExpectation @@ -40,14 +39,12 @@ type testCase struct { var testCases = []testCase{ { name: "To external host without custom dialer via vpn", - destination: "192.0.2.1:53", expectedInterface: expectedVPNint, dialer: &net.Dialer{}, expectedPacket: createPacketExpectation("100.64.0.1", 12345, "192.0.2.1", 53), }, { name: "To external host with custom dialer via physical interface", - destination: "192.0.2.1:53", expectedInterface: expectedExternalInt, dialer: nbnet.NewDialer(), expectedPacket: createPacketExpectation("192.168.0.1", 12345, "192.0.2.1", 53), @@ -55,14 +52,12 @@ var testCases = []testCase{ { name: "To duplicate internal route with custom dialer via physical interface", - destination: "10.0.0.2:53", expectedInterface: expectedInternalInt, dialer: nbnet.NewDialer(), expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53), }, { name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence - destination: "10.0.0.2:53", expectedInterface: expectedInternalInt, dialer: &net.Dialer{}, expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53), @@ -70,14 +65,12 @@ var testCases = []testCase{ { name: "To unique vpn route with custom dialer via physical interface", - destination: "172.16.0.2:53", expectedInterface: expectedExternalInt, dialer: nbnet.NewDialer(), expectedPacket: createPacketExpectation("192.168.0.1", 12345, "172.16.0.2", 53), }, { name: "To unique vpn route without custom dialer via vpn", - destination: "172.16.0.2:53", expectedInterface: expectedVPNint, dialer: &net.Dialer{}, expectedPacket: createPacketExpectation("100.64.0.1", 12345, "172.16.0.2", 53), @@ -94,10 +87,11 @@ func TestRouting(t *testing.T) { t.Run(tc.name, func(t *testing.T) { setupTestEnv(t) - filter := createBPFFilter(tc.destination) + dst := fmt.Sprintf("%s:%d", tc.expectedPacket.DstIP, tc.expectedPacket.DstPort) + filter := createBPFFilter(dst) handle := startPacketCapture(t, tc.expectedInterface, filter) - sendTestPacket(t, tc.destination, tc.expectedPacket.SrcPort, tc.dialer) + sendTestPacket(t, dst, tc.expectedPacket.SrcPort, tc.dialer) packetSource := gopacket.NewPacketSource(handle, handle.LinkType()) packet, err := packetSource.NextPacket()