mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-19 16:56:39 +00:00
Compare commits
69 Commits
feature/va
...
debug-0.33
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3efa7a282a | ||
|
|
40551099b3 | ||
|
|
9db1932664 | ||
|
|
1bbabf70b0 | ||
|
|
aa575d6f44 | ||
|
|
f66bbcc54c | ||
|
|
5dd6a08ea6 | ||
|
|
eb5d0569ae | ||
|
|
52ea2e84e9 | ||
|
|
78fab877c0 | ||
|
|
65a94f695f | ||
|
|
ec543f89fb | ||
|
|
a7d5c52203 | ||
|
|
582bb58714 | ||
|
|
121dfda915 | ||
|
|
a1c5287b7c | ||
|
|
12f442439a | ||
|
|
d9b691b8a5 | ||
|
|
4aee3c9e33 | ||
|
|
44e799c687 | ||
|
|
be78efbd42 | ||
|
|
6886691213 | ||
|
|
b48afd92fd | ||
|
|
39329e12a1 | ||
|
|
20a5afc359 | ||
|
|
6cb697eed6 | ||
|
|
e0bed2b0fb | ||
|
|
30f025e7dd | ||
|
|
b4d7605147 | ||
|
|
08b6e9d647 | ||
|
|
67ce14eaea | ||
|
|
669904cd06 | ||
|
|
4be826450b | ||
|
|
738387f2de | ||
|
|
baf0678ceb | ||
|
|
7fef8f6758 | ||
|
|
6829a64a2d | ||
|
|
cbf500024f | ||
|
|
509e184e10 | ||
|
|
3e88b7c56e | ||
|
|
b952d8693d | ||
|
|
5b46cc8e9c | ||
|
|
a9d06b883f | ||
|
|
5f06b202c3 | ||
|
|
0eb99c266a | ||
|
|
bac95ace18 | ||
|
|
9812de853b | ||
|
|
ad4f0a6fdf | ||
|
|
4c758c6e52 | ||
|
|
ec5095ba6b | ||
|
|
49a54624f8 | ||
|
|
729bcf2b01 | ||
|
|
a0cdb58303 | ||
|
|
39c99781cb | ||
|
|
01f24907c5 | ||
|
|
10480eb52f | ||
|
|
1e44c5b574 | ||
|
|
940f8b4547 | ||
|
|
46e37fa04c | ||
|
|
b9f205b2ce | ||
|
|
0fd874fa45 | ||
|
|
8016710d24 | ||
|
|
4e918e55ba | ||
|
|
869537c951 | ||
|
|
44f2ce666e | ||
|
|
563dca705c | ||
|
|
7bda385e1b | ||
|
|
30ebcf38c7 | ||
|
|
0106a95f7a |
3
.github/FUNDING.yml
vendored
Normal file
3
.github/FUNDING.yml
vendored
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
# These are supported funding model platforms
|
||||||
|
|
||||||
|
github: [netbirdio]
|
||||||
10
.github/workflows/golang-test-linux.yml
vendored
10
.github/workflows/golang-test-linux.yml
vendored
@@ -13,6 +13,7 @@ concurrency:
|
|||||||
jobs:
|
jobs:
|
||||||
test:
|
test:
|
||||||
strategy:
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
arch: [ '386','amd64' ]
|
arch: [ '386','amd64' ]
|
||||||
store: [ 'sqlite', 'postgres']
|
store: [ 'sqlite', 'postgres']
|
||||||
@@ -49,7 +50,7 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 6m -p 1 ./...
|
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m -p 1 ./...
|
||||||
|
|
||||||
test_client_on_docker:
|
test_client_on_docker:
|
||||||
runs-on: ubuntu-20.04
|
runs-on: ubuntu-20.04
|
||||||
@@ -79,9 +80,6 @@ jobs:
|
|||||||
- name: check git status
|
- name: check git status
|
||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Generate Iface Test bin
|
|
||||||
run: CGO_ENABLED=0 go test -c -o iface-testing.bin ./client/iface/
|
|
||||||
|
|
||||||
- name: Generate Shared Sock Test bin
|
- name: Generate Shared Sock Test bin
|
||||||
run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock
|
run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock
|
||||||
|
|
||||||
@@ -98,7 +96,7 @@ jobs:
|
|||||||
run: CGO_ENABLED=1 go test -c -o engine-testing.bin ./client/internal
|
run: CGO_ENABLED=1 go test -c -o engine-testing.bin ./client/internal
|
||||||
|
|
||||||
- name: Generate Peer Test bin
|
- name: Generate Peer Test bin
|
||||||
run: CGO_ENABLED=0 go test -c -o peer-testing.bin ./client/internal/peer/...
|
run: CGO_ENABLED=0 go test -c -o peer-testing.bin ./client/internal/peer/
|
||||||
|
|
||||||
- run: chmod +x *testing.bin
|
- run: chmod +x *testing.bin
|
||||||
|
|
||||||
@@ -106,7 +104,7 @@ jobs:
|
|||||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/sharedsock --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/sharedsock-testing.bin -test.timeout 5m -test.parallel 1
|
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/sharedsock --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/sharedsock-testing.bin -test.timeout 5m -test.parallel 1
|
||||||
|
|
||||||
- name: Run Iface tests in docker
|
- name: Run Iface tests in docker
|
||||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/iface --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/iface-testing.bin -test.timeout 5m -test.parallel 1
|
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/netbird -v /tmp/cache:/tmp/cache -v /tmp/modcache:/tmp/modcache -w /netbird -e GOCACHE=/tmp/cache -e GOMODCACHE=/tmp/modcache -e CGO_ENABLED=0 golang:1.23-alpine go test -test.timeout 5m -test.parallel 1 ./client/iface/...
|
||||||
|
|
||||||
- name: Run RouteManager tests in docker
|
- name: Run RouteManager tests in docker
|
||||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1
|
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1
|
||||||
|
|||||||
2
.github/workflows/golangci-lint.yml
vendored
2
.github/workflows/golangci-lint.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
|||||||
- name: codespell
|
- name: codespell
|
||||||
uses: codespell-project/actions-codespell@v2
|
uses: codespell-project/actions-codespell@v2
|
||||||
with:
|
with:
|
||||||
ignore_words_list: erro,clienta,hastable,iif
|
ignore_words_list: erro,clienta,hastable,iif,groupd
|
||||||
skip: go.mod,go.sum
|
skip: go.mod,go.sum
|
||||||
only_warn: 1
|
only_warn: 1
|
||||||
golangci:
|
golangci:
|
||||||
|
|||||||
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -9,7 +9,7 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
|
|
||||||
env:
|
env:
|
||||||
SIGN_PIPE_VER: "v0.0.16"
|
SIGN_PIPE_VER: "v0.0.17"
|
||||||
GORELEASER_VER: "v2.3.2"
|
GORELEASER_VER: "v2.3.2"
|
||||||
PRODUCT_NAME: "NetBird"
|
PRODUCT_NAME: "NetBird"
|
||||||
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"
|
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"
|
||||||
|
|||||||
@@ -19,6 +19,10 @@
|
|||||||
<br>
|
<br>
|
||||||
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2p5zwhm4g-8fHollzrQa5y4PZF5AEpvQ">
|
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2p5zwhm4g-8fHollzrQa5y4PZF5AEpvQ">
|
||||||
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
|
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
|
||||||
|
</a>
|
||||||
|
<br>
|
||||||
|
<a href="https://gurubase.io/g/netbird">
|
||||||
|
<img src="https://img.shields.io/badge/Gurubase-Ask%20NetBird%20Guru-006BFF"/>
|
||||||
</a>
|
</a>
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -201,6 +201,8 @@ func isWellKnown(addr netip.Addr) bool {
|
|||||||
"2606:4700:4700::1111", "2606:4700:4700::1001", // Cloudflare DNS IPv6
|
"2606:4700:4700::1111", "2606:4700:4700::1001", // Cloudflare DNS IPv6
|
||||||
"9.9.9.9", "149.112.112.112", // Quad9 DNS IPv4
|
"9.9.9.9", "149.112.112.112", // Quad9 DNS IPv4
|
||||||
"2620:fe::fe", "2620:fe::9", // Quad9 DNS IPv6
|
"2620:fe::fe", "2620:fe::9", // Quad9 DNS IPv6
|
||||||
|
|
||||||
|
"128.0.0.0", "8000::", // 2nd split subnet for default routes
|
||||||
}
|
}
|
||||||
|
|
||||||
if slices.Contains(wellKnown, addr.String()) {
|
if slices.Contains(wellKnown, addr.String()) {
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package cmd
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/kardianos/service"
|
"github.com/kardianos/service"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -13,10 +14,11 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type program struct {
|
type program struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
serv *grpc.Server
|
serv *grpc.Server
|
||||||
serverInstance *server.Server
|
serverInstance *server.Server
|
||||||
|
serverInstanceMu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func newProgram(ctx context.Context, cancel context.CancelFunc) *program {
|
func newProgram(ctx context.Context, cancel context.CancelFunc) *program {
|
||||||
|
|||||||
@@ -61,7 +61,9 @@ func (p *program) Start(svc service.Service) error {
|
|||||||
}
|
}
|
||||||
proto.RegisterDaemonServiceServer(p.serv, serverInstance)
|
proto.RegisterDaemonServiceServer(p.serv, serverInstance)
|
||||||
|
|
||||||
|
p.serverInstanceMu.Lock()
|
||||||
p.serverInstance = serverInstance
|
p.serverInstance = serverInstance
|
||||||
|
p.serverInstanceMu.Unlock()
|
||||||
|
|
||||||
log.Printf("started daemon server: %v", split[1])
|
log.Printf("started daemon server: %v", split[1])
|
||||||
if err := p.serv.Serve(listen); err != nil {
|
if err := p.serv.Serve(listen); err != nil {
|
||||||
@@ -72,6 +74,7 @@ func (p *program) Start(svc service.Service) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *program) Stop(srv service.Service) error {
|
func (p *program) Stop(srv service.Service) error {
|
||||||
|
p.serverInstanceMu.Lock()
|
||||||
if p.serverInstance != nil {
|
if p.serverInstance != nil {
|
||||||
in := new(proto.DownRequest)
|
in := new(proto.DownRequest)
|
||||||
_, err := p.serverInstance.Down(p.ctx, in)
|
_, err := p.serverInstance.Down(p.ctx, in)
|
||||||
@@ -79,6 +82,7 @@ func (p *program) Stop(srv service.Service) error {
|
|||||||
log.Errorf("failed to stop daemon: %v", err)
|
log.Errorf("failed to stop daemon: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
p.serverInstanceMu.Unlock()
|
||||||
|
|
||||||
p.cancel()
|
p.cancel()
|
||||||
|
|
||||||
|
|||||||
@@ -680,7 +680,7 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
|
|||||||
func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool {
|
func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool {
|
||||||
statusEval := false
|
statusEval := false
|
||||||
ipEval := false
|
ipEval := false
|
||||||
nameEval := false
|
nameEval := true
|
||||||
|
|
||||||
if statusFilter != "" {
|
if statusFilter != "" {
|
||||||
lowerStatusFilter := strings.ToLower(statusFilter)
|
lowerStatusFilter := strings.ToLower(statusFilter)
|
||||||
@@ -700,11 +700,13 @@ func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool {
|
|||||||
|
|
||||||
if len(prefixNamesFilter) > 0 {
|
if len(prefixNamesFilter) > 0 {
|
||||||
for prefixNameFilter := range prefixNamesFilterMap {
|
for prefixNameFilter := range prefixNamesFilterMap {
|
||||||
if !strings.HasPrefix(peerState.Fqdn, prefixNameFilter) {
|
if strings.HasPrefix(peerState.Fqdn, prefixNameFilter) {
|
||||||
nameEval = true
|
nameEval = false
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
nameEval = false
|
||||||
}
|
}
|
||||||
|
|
||||||
return statusEval || ipEval || nameEval
|
return statusEval || ipEval || nameEval
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
package firewall
|
package firewall
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
|
||||||
@@ -11,10 +10,11 @@ import (
|
|||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewFirewall creates a firewall manager instance
|
// NewFirewall creates a firewall manager instance
|
||||||
func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager, error) {
|
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager) (firewall.Manager, error) {
|
||||||
if !iface.IsUserspaceBind() {
|
if !iface.IsUserspaceBind() {
|
||||||
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
package firewall
|
package firewall
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
nbnftables "github.com/netbirdio/netbird/client/firewall/nftables"
|
nbnftables "github.com/netbirdio/netbird/client/firewall/nftables"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -32,54 +33,65 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
|
|||||||
// FWType is the type for the firewall type
|
// FWType is the type for the firewall type
|
||||||
type FWType int
|
type FWType int
|
||||||
|
|
||||||
func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager, error) {
|
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) {
|
||||||
// on the linux system we try to user nftables or iptables
|
// on the linux system we try to user nftables or iptables
|
||||||
// in any case, because we need to allow netbird interface traffic
|
// in any case, because we need to allow netbird interface traffic
|
||||||
// so we use AllowNetbird traffic from these firewall managers
|
// so we use AllowNetbird traffic from these firewall managers
|
||||||
// for the userspace packet filtering firewall
|
// for the userspace packet filtering firewall
|
||||||
var fm firewall.Manager
|
fm, err := createNativeFirewall(iface, stateManager)
|
||||||
var errFw error
|
|
||||||
|
|
||||||
|
if !iface.IsUserspaceBind() {
|
||||||
|
return fm, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
|
||||||
|
}
|
||||||
|
return createUserspaceFirewall(iface, fm)
|
||||||
|
}
|
||||||
|
|
||||||
|
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) {
|
||||||
|
fm, err := createFW(iface)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create firewall: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = fm.Init(stateManager); err != nil {
|
||||||
|
return nil, fmt.Errorf("init firewall: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return fm, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func createFW(iface IFaceMapper) (firewall.Manager, error) {
|
||||||
switch check() {
|
switch check() {
|
||||||
case IPTABLES:
|
case IPTABLES:
|
||||||
log.Info("creating an iptables firewall manager")
|
log.Info("creating an iptables firewall manager")
|
||||||
fm, errFw = nbiptables.Create(context, iface)
|
return nbiptables.Create(iface)
|
||||||
if errFw != nil {
|
|
||||||
log.Errorf("failed to create iptables manager: %s", errFw)
|
|
||||||
}
|
|
||||||
case NFTABLES:
|
case NFTABLES:
|
||||||
log.Info("creating an nftables firewall manager")
|
log.Info("creating an nftables firewall manager")
|
||||||
fm, errFw = nbnftables.Create(context, iface)
|
return nbnftables.Create(iface)
|
||||||
if errFw != nil {
|
|
||||||
log.Errorf("failed to create nftables manager: %s", errFw)
|
|
||||||
}
|
|
||||||
default:
|
default:
|
||||||
errFw = fmt.Errorf("no firewall manager found")
|
|
||||||
log.Info("no firewall manager found, trying to use userspace packet filtering firewall")
|
log.Info("no firewall manager found, trying to use userspace packet filtering firewall")
|
||||||
|
return nil, errors.New("no firewall manager found")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager) (firewall.Manager, error) {
|
||||||
|
var errUsp error
|
||||||
|
if fm != nil {
|
||||||
|
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm)
|
||||||
|
} else {
|
||||||
|
fm, errUsp = uspfilter.Create(iface)
|
||||||
}
|
}
|
||||||
|
|
||||||
if iface.IsUserspaceBind() {
|
if errUsp != nil {
|
||||||
var errUsp error
|
return nil, fmt.Errorf("create userspace firewall: %s", errUsp)
|
||||||
if errFw == nil {
|
|
||||||
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm)
|
|
||||||
} else {
|
|
||||||
fm, errUsp = uspfilter.Create(iface)
|
|
||||||
}
|
|
||||||
if errUsp != nil {
|
|
||||||
log.Debugf("failed to create userspace filtering firewall: %s", errUsp)
|
|
||||||
return nil, errUsp
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := fm.AllowNetbird(); err != nil {
|
|
||||||
log.Errorf("failed to allow netbird interface traffic: %v", err)
|
|
||||||
}
|
|
||||||
return fm, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if errFw != nil {
|
if err := fm.AllowNetbird(); err != nil {
|
||||||
return nil, errFw
|
log.Errorf("failed to allow netbird interface traffic: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return fm, nil
|
return fm, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -22,6 +23,8 @@ const (
|
|||||||
chainNameOutputRules = "NETBIRD-ACL-OUTPUT"
|
chainNameOutputRules = "NETBIRD-ACL-OUTPUT"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type aclEntries map[string][][]string
|
||||||
|
|
||||||
type entry struct {
|
type entry struct {
|
||||||
spec []string
|
spec []string
|
||||||
position int
|
position int
|
||||||
@@ -32,9 +35,11 @@ type aclManager struct {
|
|||||||
wgIface iFaceMapper
|
wgIface iFaceMapper
|
||||||
routingFwChainName string
|
routingFwChainName string
|
||||||
|
|
||||||
entries map[string][][]string
|
entries aclEntries
|
||||||
optionalEntries map[string][]entry
|
optionalEntries map[string][]entry
|
||||||
ipsetStore *ipsetStore
|
ipsetStore *ipsetStore
|
||||||
|
|
||||||
|
stateManager *statemanager.Manager
|
||||||
}
|
}
|
||||||
|
|
||||||
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routingFwChainName string) (*aclManager, error) {
|
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routingFwChainName string) (*aclManager, error) {
|
||||||
@@ -48,24 +53,30 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routi
|
|||||||
ipsetStore: newIpsetStore(),
|
ipsetStore: newIpsetStore(),
|
||||||
}
|
}
|
||||||
|
|
||||||
err := ipset.Init()
|
if err := ipset.Init(); err != nil {
|
||||||
if err != nil {
|
return nil, fmt.Errorf("init ipset: %w", err)
|
||||||
return nil, fmt.Errorf("failed to init ipset: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *aclManager) init(stateManager *statemanager.Manager) error {
|
||||||
|
m.stateManager = stateManager
|
||||||
|
|
||||||
m.seedInitialEntries()
|
m.seedInitialEntries()
|
||||||
m.seedInitialOptionalEntries()
|
m.seedInitialOptionalEntries()
|
||||||
|
|
||||||
err = m.cleanChains()
|
if err := m.cleanChains(); err != nil {
|
||||||
if err != nil {
|
return fmt.Errorf("clean chains: %w", err)
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = m.createDefaultChains()
|
if err := m.createDefaultChains(); err != nil {
|
||||||
if err != nil {
|
return fmt.Errorf("create default chains: %w", err)
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
return m, nil
|
|
||||||
|
m.updateState()
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *aclManager) AddPeerFiltering(
|
func (m *aclManager) AddPeerFiltering(
|
||||||
@@ -146,6 +157,8 @@ func (m *aclManager) AddPeerFiltering(
|
|||||||
chain: chain,
|
chain: chain,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.updateState()
|
||||||
|
|
||||||
return []firewall.Rule{rule}, nil
|
return []firewall.Rule{rule}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -180,15 +193,23 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err := m.iptablesClient.Delete(tableName, r.chain, r.specs...)
|
if err := m.iptablesClient.Delete(tableName, r.chain, r.specs...); err != nil {
|
||||||
if err != nil {
|
return fmt.Errorf("failed to delete rule: %s, %v: %w", r.chain, r.specs, err)
|
||||||
log.Debugf("failed to delete rule, %s, %v: %s", r.chain, r.specs, err)
|
|
||||||
}
|
}
|
||||||
return err
|
|
||||||
|
m.updateState()
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *aclManager) Reset() error {
|
func (m *aclManager) Reset() error {
|
||||||
return m.cleanChains()
|
if err := m.cleanChains(); err != nil {
|
||||||
|
return fmt.Errorf("clean chains: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.updateState()
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// todo write less destructive cleanup mechanism
|
// todo write less destructive cleanup mechanism
|
||||||
@@ -331,14 +352,14 @@ func (m *aclManager) seedInitialEntries() {
|
|||||||
func (m *aclManager) seedInitialOptionalEntries() {
|
func (m *aclManager) seedInitialOptionalEntries() {
|
||||||
m.optionalEntries["FORWARD"] = []entry{
|
m.optionalEntries["FORWARD"] = []entry{
|
||||||
{
|
{
|
||||||
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmark), "-j", chainNameInputRules},
|
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", chainNameInputRules},
|
||||||
position: 2,
|
position: 2,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
m.optionalEntries["PREROUTING"] = []entry{
|
m.optionalEntries["PREROUTING"] = []entry{
|
||||||
{
|
{
|
||||||
spec: []string{"-t", "mangle", "-i", m.wgIface.Name(), "-m", "addrtype", "--dst-type", "LOCAL", "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmark)},
|
spec: []string{"-t", "mangle", "-i", m.wgIface.Name(), "-m", "addrtype", "--dst-type", "LOCAL", "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected)},
|
||||||
position: 1,
|
position: 1,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -348,6 +369,32 @@ func (m *aclManager) appendToEntries(chainName string, spec []string) {
|
|||||||
m.entries[chainName] = append(m.entries[chainName], spec)
|
m.entries[chainName] = append(m.entries[chainName], spec)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *aclManager) updateState() {
|
||||||
|
if m.stateManager == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var currentState *ShutdownState
|
||||||
|
if existing := m.stateManager.GetState(currentState); existing != nil {
|
||||||
|
if existingState, ok := existing.(*ShutdownState); ok {
|
||||||
|
currentState = existingState
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if currentState == nil {
|
||||||
|
currentState = &ShutdownState{}
|
||||||
|
}
|
||||||
|
|
||||||
|
currentState.Lock()
|
||||||
|
defer currentState.Unlock()
|
||||||
|
|
||||||
|
currentState.ACLEntries = m.entries
|
||||||
|
currentState.ACLIPsetStore = m.ipsetStore
|
||||||
|
|
||||||
|
if err := m.stateManager.UpdateState(currentState); err != nil {
|
||||||
|
log.Errorf("failed to update state: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// filterRuleSpecs returns the specs of a filtering rule
|
// filterRuleSpecs returns the specs of a filtering rule
|
||||||
func filterRuleSpecs(
|
func filterRuleSpecs(
|
||||||
ip net.IP, protocol string, sPort, dPort string, direction firewall.RuleDirection, action firewall.Action, ipsetName string,
|
ip net.IP, protocol string, sPort, dPort string, direction firewall.RuleDirection, action firewall.Action, ipsetName string,
|
||||||
|
|||||||
@@ -8,10 +8,13 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/coreos/go-iptables/iptables"
|
"github.com/coreos/go-iptables/iptables"
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Manager of iptables firewall
|
// Manager of iptables firewall
|
||||||
@@ -33,10 +36,10 @@ type iFaceMapper interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create iptables firewall manager
|
// Create iptables firewall manager
|
||||||
func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
|
func Create(wgIface iFaceMapper) (*Manager, error) {
|
||||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("iptables is not installed in the system or not supported")
|
return nil, fmt.Errorf("init iptables: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
m := &Manager{
|
m := &Manager{
|
||||||
@@ -44,20 +47,51 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
|
|||||||
ipv4Client: iptablesClient,
|
ipv4Client: iptablesClient,
|
||||||
}
|
}
|
||||||
|
|
||||||
m.router, err = newRouter(context, iptablesClient, wgIface)
|
m.router, err = newRouter(iptablesClient, wgIface)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("failed to initialize route related chains: %s", err)
|
return nil, fmt.Errorf("create router: %w", err)
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
m.aclMgr, err = newAclManager(iptablesClient, wgIface, chainRTFWD)
|
m.aclMgr, err = newAclManager(iptablesClient, wgIface, chainRTFWD)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("failed to initialize ACL manager: %s", err)
|
return nil, fmt.Errorf("create acl manager: %w", err)
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||||
|
state := &ShutdownState{
|
||||||
|
InterfaceState: &InterfaceState{
|
||||||
|
NameStr: m.wgIface.Name(),
|
||||||
|
WGAddress: m.wgIface.Address(),
|
||||||
|
UserspaceBind: m.wgIface.IsUserspaceBind(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
stateManager.RegisterState(state)
|
||||||
|
if err := stateManager.UpdateState(state); err != nil {
|
||||||
|
log.Errorf("failed to update state: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.router.init(stateManager); err != nil {
|
||||||
|
return fmt.Errorf("router init: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.aclMgr.init(stateManager); err != nil {
|
||||||
|
// TODO: cleanup router
|
||||||
|
return fmt.Errorf("acl manager init: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// persist early to ensure cleanup of chains
|
||||||
|
go func() {
|
||||||
|
if err := stateManager.PersistState(context.Background()); err != nil {
|
||||||
|
log.Errorf("failed to persist state: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// AddPeerFiltering adds a rule to the firewall
|
// AddPeerFiltering adds a rule to the firewall
|
||||||
//
|
//
|
||||||
// Comment will be ignored because some system this feature is not supported
|
// Comment will be ignored because some system this feature is not supported
|
||||||
@@ -133,20 +167,27 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Reset firewall to the default state
|
// Reset firewall to the default state
|
||||||
func (m *Manager) Reset() error {
|
func (m *Manager) Reset(stateManager *statemanager.Manager) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
errAcl := m.aclMgr.Reset()
|
var merr *multierror.Error
|
||||||
if errAcl != nil {
|
|
||||||
log.Errorf("failed to clean up ACL rules from firewall: %s", errAcl)
|
if err := m.aclMgr.Reset(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err))
|
||||||
}
|
}
|
||||||
errMgr := m.router.Reset()
|
if err := m.router.Reset(); err != nil {
|
||||||
if errMgr != nil {
|
merr = multierror.Append(merr, fmt.Errorf("reset router: %w", err))
|
||||||
log.Errorf("failed to clean up router rules from firewall: %s", errMgr)
|
|
||||||
return errMgr
|
|
||||||
}
|
}
|
||||||
return errAcl
|
|
||||||
|
// attempt to delete state only if all other operations succeeded
|
||||||
|
if merr == nil {
|
||||||
|
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("delete state: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AllowNetbird allows netbird interface traffic
|
// AllowNetbird allows netbird interface traffic
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package iptables
|
package iptables
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -56,13 +55,14 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// just check on the local interface
|
// just check on the local interface
|
||||||
manager, err := Create(context.Background(), ifaceMock)
|
manager, err := Create(ifaceMock)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, manager.Init(nil))
|
||||||
|
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
err := manager.Reset()
|
err := manager.Reset(nil)
|
||||||
require.NoError(t, err, "clear the manager state")
|
require.NoError(t, err, "clear the manager state")
|
||||||
|
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
@@ -122,7 +122,7 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
_, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic")
|
_, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
err = manager.Reset()
|
err = manager.Reset(nil)
|
||||||
require.NoError(t, err, "failed to reset")
|
require.NoError(t, err, "failed to reset")
|
||||||
|
|
||||||
ok, err := ipv4Client.ChainExists("filter", chainNameInputRules)
|
ok, err := ipv4Client.ChainExists("filter", chainNameInputRules)
|
||||||
@@ -154,13 +154,14 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// just check on the local interface
|
// just check on the local interface
|
||||||
manager, err := Create(context.Background(), mock)
|
manager, err := Create(mock)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, manager.Init(nil))
|
||||||
|
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
err := manager.Reset()
|
err := manager.Reset(nil)
|
||||||
require.NoError(t, err, "clear the manager state")
|
require.NoError(t, err, "clear the manager state")
|
||||||
|
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
@@ -219,7 +220,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("reset check", func(t *testing.T) {
|
t.Run("reset check", func(t *testing.T) {
|
||||||
err = manager.Reset()
|
err = manager.Reset(nil)
|
||||||
require.NoError(t, err, "failed to reset")
|
require.NoError(t, err, "failed to reset")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -251,12 +252,13 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
|||||||
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
||||||
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
||||||
// just check on the local interface
|
// just check on the local interface
|
||||||
manager, err := Create(context.Background(), mock)
|
manager, err := Create(mock)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, manager.Init(nil))
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
err := manager.Reset()
|
err := manager.Reset(nil)
|
||||||
require.NoError(t, err, "clear the manager state")
|
require.NoError(t, err, "clear the manager state")
|
||||||
|
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
package iptables
|
package iptables
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -18,22 +17,25 @@ import (
|
|||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
)
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
const (
|
|
||||||
ipv4Nat = "netbird-rt-nat"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// constants needed to manage and create iptable rules
|
// constants needed to manage and create iptable rules
|
||||||
const (
|
const (
|
||||||
tableFilter = "filter"
|
tableFilter = "filter"
|
||||||
tableNat = "nat"
|
tableNat = "nat"
|
||||||
|
tableMangle = "mangle"
|
||||||
chainPOSTROUTING = "POSTROUTING"
|
chainPOSTROUTING = "POSTROUTING"
|
||||||
|
chainPREROUTING = "PREROUTING"
|
||||||
chainRTNAT = "NETBIRD-RT-NAT"
|
chainRTNAT = "NETBIRD-RT-NAT"
|
||||||
chainRTFWD = "NETBIRD-RT-FWD"
|
chainRTFWD = "NETBIRD-RT-FWD"
|
||||||
|
chainRTPRE = "NETBIRD-RT-PRE"
|
||||||
routingFinalForwardJump = "ACCEPT"
|
routingFinalForwardJump = "ACCEPT"
|
||||||
routingFinalNatJump = "MASQUERADE"
|
routingFinalNatJump = "MASQUERADE"
|
||||||
|
|
||||||
|
jumpPre = "jump-pre"
|
||||||
|
jumpNat = "jump-nat"
|
||||||
matchSet = "--match-set"
|
matchSet = "--match-set"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -48,28 +50,31 @@ type routeFilteringRuleParams struct {
|
|||||||
SetName string
|
SetName string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type routeRules map[string][]string
|
||||||
|
|
||||||
|
type ipsetCounter = refcounter.Counter[string, []netip.Prefix, struct{}]
|
||||||
|
|
||||||
type router struct {
|
type router struct {
|
||||||
ctx context.Context
|
|
||||||
stop context.CancelFunc
|
|
||||||
iptablesClient *iptables.IPTables
|
iptablesClient *iptables.IPTables
|
||||||
rules map[string][]string
|
rules routeRules
|
||||||
ipsetCounter *refcounter.Counter[string, []netip.Prefix, struct{}]
|
ipsetCounter *ipsetCounter
|
||||||
wgIface iFaceMapper
|
wgIface iFaceMapper
|
||||||
legacyManagement bool
|
legacyManagement bool
|
||||||
|
|
||||||
|
stateManager *statemanager.Manager
|
||||||
}
|
}
|
||||||
|
|
||||||
func newRouter(parentCtx context.Context, iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) {
|
func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) {
|
||||||
ctx, cancel := context.WithCancel(parentCtx)
|
|
||||||
r := &router{
|
r := &router{
|
||||||
ctx: ctx,
|
|
||||||
stop: cancel,
|
|
||||||
iptablesClient: iptablesClient,
|
iptablesClient: iptablesClient,
|
||||||
rules: make(map[string][]string),
|
rules: make(map[string][]string),
|
||||||
wgIface: wgIface,
|
wgIface: wgIface,
|
||||||
}
|
}
|
||||||
|
|
||||||
r.ipsetCounter = refcounter.New(
|
r.ipsetCounter = refcounter.New(
|
||||||
r.createIpSet,
|
func(name string, sources []netip.Prefix) (struct{}, error) {
|
||||||
|
return struct{}{}, r.createIpSet(name, sources)
|
||||||
|
},
|
||||||
func(name string, _ struct{}) error {
|
func(name string, _ struct{}) error {
|
||||||
return r.deleteIpSet(name)
|
return r.deleteIpSet(name)
|
||||||
},
|
},
|
||||||
@@ -79,16 +84,23 @@ func newRouter(parentCtx context.Context, iptablesClient *iptables.IPTables, wgI
|
|||||||
return nil, fmt.Errorf("init ipset: %w", err)
|
return nil, fmt.Errorf("init ipset: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err := r.cleanUpDefaultForwardRules()
|
return r, nil
|
||||||
if err != nil {
|
}
|
||||||
log.Errorf("cleanup routing rules: %s", err)
|
|
||||||
return nil, err
|
func (r *router) init(stateManager *statemanager.Manager) error {
|
||||||
|
r.stateManager = stateManager
|
||||||
|
|
||||||
|
if err := r.cleanUpDefaultForwardRules(); err != nil {
|
||||||
|
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||||
}
|
}
|
||||||
err = r.createContainers()
|
|
||||||
if err != nil {
|
if err := r.createContainers(); err != nil {
|
||||||
log.Errorf("create containers for route: %s", err)
|
return fmt.Errorf("create containers: %w", err)
|
||||||
}
|
}
|
||||||
return r, err
|
|
||||||
|
r.updateState()
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) AddRouteFiltering(
|
func (r *router) AddRouteFiltering(
|
||||||
@@ -129,6 +141,8 @@ func (r *router) AddRouteFiltering(
|
|||||||
|
|
||||||
r.rules[string(ruleKey)] = rule
|
r.rules[string(ruleKey)] = rule
|
||||||
|
|
||||||
|
r.updateState()
|
||||||
|
|
||||||
return ruleKey, nil
|
return ruleKey, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -152,6 +166,8 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
|||||||
log.Debugf("route rule %s not found", ruleKey)
|
log.Debugf("route rule %s not found", ruleKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
r.updateState()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -164,18 +180,18 @@ func (r *router) findSetNameInRule(rule []string) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) createIpSet(setName string, sources []netip.Prefix) (struct{}, error) {
|
func (r *router) createIpSet(setName string, sources []netip.Prefix) error {
|
||||||
if err := ipset.Create(setName, ipset.OptTimeout(0)); err != nil {
|
if err := ipset.Create(setName, ipset.OptTimeout(0)); err != nil {
|
||||||
return struct{}{}, fmt.Errorf("create set %s: %w", setName, err)
|
return fmt.Errorf("create set %s: %w", setName, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, prefix := range sources {
|
for _, prefix := range sources {
|
||||||
if err := ipset.AddPrefix(setName, prefix); err != nil {
|
if err := ipset.AddPrefix(setName, prefix); err != nil {
|
||||||
return struct{}{}, fmt.Errorf("add element to set %s: %w", setName, err)
|
return fmt.Errorf("add element to set %s: %w", setName, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return struct{}{}, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) deleteIpSet(setName string) error {
|
func (r *router) deleteIpSet(setName string) error {
|
||||||
@@ -206,6 +222,8 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
|||||||
return fmt.Errorf("add inverse nat rule: %w", err)
|
return fmt.Errorf("add inverse nat rule: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
r.updateState()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -223,6 +241,8 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
|||||||
return fmt.Errorf("remove legacy routing rule: %w", err)
|
return fmt.Errorf("remove legacy routing rule: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
r.updateState()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -278,8 +298,13 @@ func (r *router) RemoveAllLegacyRouteRules() error {
|
|||||||
}
|
}
|
||||||
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil {
|
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
|
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
|
||||||
|
} else {
|
||||||
|
delete(r.rules, k)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
r.updateState()
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -294,28 +319,31 @@ func (r *router) Reset() error {
|
|||||||
merr = multierror.Append(merr, err)
|
merr = multierror.Append(merr, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
r.updateState()
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) cleanUpDefaultForwardRules() error {
|
func (r *router) cleanUpDefaultForwardRules() error {
|
||||||
err := r.cleanJumpRules()
|
if err := r.cleanJumpRules(); err != nil {
|
||||||
if err != nil {
|
return fmt.Errorf("clean jump rules: %w", err)
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug("flushing routing related tables")
|
log.Debug("flushing routing related tables")
|
||||||
for _, chain := range []string{chainRTFWD, chainRTNAT} {
|
for _, chainInfo := range []struct {
|
||||||
table := r.getTableForChain(chain)
|
chain string
|
||||||
|
table string
|
||||||
ok, err := r.iptablesClient.ChainExists(table, chain)
|
}{
|
||||||
|
{chainRTFWD, tableFilter},
|
||||||
|
{chainRTNAT, tableNat},
|
||||||
|
{chainRTPRE, tableMangle},
|
||||||
|
} {
|
||||||
|
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed check chain %s, error: %v", chain, err)
|
return fmt.Errorf("check chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
|
||||||
return err
|
|
||||||
} else if ok {
|
} else if ok {
|
||||||
err = r.iptablesClient.ClearAndDeleteChain(table, chain)
|
if err = r.iptablesClient.ClearAndDeleteChain(chainInfo.table, chainInfo.chain); err != nil {
|
||||||
if err != nil {
|
return fmt.Errorf("clear and delete chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
|
||||||
log.Errorf("failed cleaning chain %s, error: %v", chain, err)
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -324,9 +352,16 @@ func (r *router) cleanUpDefaultForwardRules() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) createContainers() error {
|
func (r *router) createContainers() error {
|
||||||
for _, chain := range []string{chainRTFWD, chainRTNAT} {
|
for _, chainInfo := range []struct {
|
||||||
if err := r.createAndSetupChain(chain); err != nil {
|
chain string
|
||||||
return fmt.Errorf("create chain %s: %w", chain, err)
|
table string
|
||||||
|
}{
|
||||||
|
{chainRTFWD, tableFilter},
|
||||||
|
{chainRTPRE, tableMangle},
|
||||||
|
{chainRTNAT, tableNat},
|
||||||
|
} {
|
||||||
|
if err := r.createAndSetupChain(chainInfo.chain); err != nil {
|
||||||
|
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -334,6 +369,10 @@ func (r *router) createContainers() error {
|
|||||||
return fmt.Errorf("insert established rule: %w", err)
|
return fmt.Errorf("insert established rule: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := r.addPostroutingRules(); err != nil {
|
||||||
|
return fmt.Errorf("add static nat rules: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
if err := r.addJumpRules(); err != nil {
|
if err := r.addJumpRules(); err != nil {
|
||||||
return fmt.Errorf("add jump rules: %w", err)
|
return fmt.Errorf("add jump rules: %w", err)
|
||||||
}
|
}
|
||||||
@@ -341,6 +380,32 @@ func (r *router) createContainers() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *router) addPostroutingRules() error {
|
||||||
|
// First rule for outbound masquerade
|
||||||
|
rule1 := []string{
|
||||||
|
"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
|
||||||
|
"!", "-o", "lo",
|
||||||
|
"-j", routingFinalNatJump,
|
||||||
|
}
|
||||||
|
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule1...); err != nil {
|
||||||
|
return fmt.Errorf("add outbound masquerade rule: %v", err)
|
||||||
|
}
|
||||||
|
r.rules["static-nat-outbound"] = rule1
|
||||||
|
|
||||||
|
// Second rule for return traffic masquerade
|
||||||
|
rule2 := []string{
|
||||||
|
"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
|
||||||
|
"-o", r.wgIface.Name(),
|
||||||
|
"-j", routingFinalNatJump,
|
||||||
|
}
|
||||||
|
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule2...); err != nil {
|
||||||
|
return fmt.Errorf("add return masquerade rule: %v", err)
|
||||||
|
}
|
||||||
|
r.rules["static-nat-return"] = rule2
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (r *router) createAndSetupChain(chain string) error {
|
func (r *router) createAndSetupChain(chain string) error {
|
||||||
table := r.getTableForChain(chain)
|
table := r.getTableForChain(chain)
|
||||||
|
|
||||||
@@ -352,10 +417,14 @@ func (r *router) createAndSetupChain(chain string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) getTableForChain(chain string) string {
|
func (r *router) getTableForChain(chain string) string {
|
||||||
if chain == chainRTNAT {
|
switch chain {
|
||||||
|
case chainRTNAT:
|
||||||
return tableNat
|
return tableNat
|
||||||
|
case chainRTPRE:
|
||||||
|
return tableMangle
|
||||||
|
default:
|
||||||
|
return tableFilter
|
||||||
}
|
}
|
||||||
return tableFilter
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) insertEstablishedRule(chain string) error {
|
func (r *router) insertEstablishedRule(chain string) error {
|
||||||
@@ -373,25 +442,39 @@ func (r *router) insertEstablishedRule(chain string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) addJumpRules() error {
|
func (r *router) addJumpRules() error {
|
||||||
rule := []string{"-j", chainRTNAT}
|
// Jump to NAT chain
|
||||||
err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, rule...)
|
natRule := []string{"-j", chainRTNAT}
|
||||||
if err != nil {
|
if err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, natRule...); err != nil {
|
||||||
return err
|
return fmt.Errorf("add nat jump rule: %v", err)
|
||||||
}
|
}
|
||||||
r.rules[ipv4Nat] = rule
|
r.rules[jumpNat] = natRule
|
||||||
|
|
||||||
|
// Jump to prerouting chain
|
||||||
|
preRule := []string{"-j", chainRTPRE}
|
||||||
|
if err := r.iptablesClient.Insert(tableMangle, chainPREROUTING, 1, preRule...); err != nil {
|
||||||
|
return fmt.Errorf("add prerouting jump rule: %v", err)
|
||||||
|
}
|
||||||
|
r.rules[jumpPre] = preRule
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) cleanJumpRules() error {
|
func (r *router) cleanJumpRules() error {
|
||||||
rule, found := r.rules[ipv4Nat]
|
for _, ruleKey := range []string{jumpNat, jumpPre} {
|
||||||
if found {
|
if rule, exists := r.rules[ruleKey]; exists {
|
||||||
err := r.iptablesClient.DeleteIfExists(tableNat, chainPOSTROUTING, rule...)
|
table := tableNat
|
||||||
if err != nil {
|
chain := chainPOSTROUTING
|
||||||
return fmt.Errorf("failed cleaning rule from chain %s, err: %v", chainPOSTROUTING, err)
|
if ruleKey == jumpPre {
|
||||||
|
table = tableMangle
|
||||||
|
chain = chainPREROUTING
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.iptablesClient.DeleteIfExists(table, chain, rule...); err != nil {
|
||||||
|
return fmt.Errorf("delete rule from chain %s in table %s, err: %v", chain, table, err)
|
||||||
|
}
|
||||||
|
delete(r.rules, ruleKey)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -399,19 +482,35 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
|||||||
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
|
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
|
||||||
|
|
||||||
if rule, exists := r.rules[ruleKey]; exists {
|
if rule, exists := r.rules[ruleKey]; exists {
|
||||||
if err := r.iptablesClient.DeleteIfExists(tableNat, chainRTNAT, rule...); err != nil {
|
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPRE, rule...); err != nil {
|
||||||
return fmt.Errorf("error while removing existing NAT rule for %s: %v", pair.Destination, err)
|
return fmt.Errorf("error while removing existing marking rule for %s: %v", pair.Destination, err)
|
||||||
}
|
}
|
||||||
delete(r.rules, ruleKey)
|
delete(r.rules, ruleKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination, r.wgIface.Name(), pair.Inverse)
|
markValue := nbnet.PreroutingFwmarkMasquerade
|
||||||
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule...); err != nil {
|
if pair.Inverse {
|
||||||
return fmt.Errorf("error while appending new NAT rule for %s: %v", pair.Destination, err)
|
markValue = nbnet.PreroutingFwmarkMasqueradeReturn
|
||||||
|
}
|
||||||
|
|
||||||
|
rule := []string{"-i", r.wgIface.Name()}
|
||||||
|
if pair.Inverse {
|
||||||
|
rule = []string{"!", "-i", r.wgIface.Name()}
|
||||||
|
}
|
||||||
|
|
||||||
|
rule = append(rule,
|
||||||
|
"-m", "conntrack",
|
||||||
|
"--ctstate", "NEW",
|
||||||
|
"-s", pair.Source.String(),
|
||||||
|
"-d", pair.Destination.String(),
|
||||||
|
"-j", "MARK", "--set-mark", fmt.Sprintf("%#x", markValue),
|
||||||
|
)
|
||||||
|
|
||||||
|
if err := r.iptablesClient.Append(tableMangle, chainRTPRE, rule...); err != nil {
|
||||||
|
return fmt.Errorf("error while adding marking rule for %s: %v", pair.Destination, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
r.rules[ruleKey] = rule
|
r.rules[ruleKey] = rule
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -419,26 +518,41 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
|||||||
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
|
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
|
||||||
|
|
||||||
if rule, exists := r.rules[ruleKey]; exists {
|
if rule, exists := r.rules[ruleKey]; exists {
|
||||||
if err := r.iptablesClient.DeleteIfExists(tableNat, chainRTNAT, rule...); err != nil {
|
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPRE, rule...); err != nil {
|
||||||
return fmt.Errorf("error while removing existing nat rule for %s: %v", pair.Destination, err)
|
return fmt.Errorf("error while removing marking rule for %s: %v", pair.Destination, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
delete(r.rules, ruleKey)
|
delete(r.rules, ruleKey)
|
||||||
} else {
|
} else {
|
||||||
log.Debugf("nat rule %s not found", ruleKey)
|
log.Debugf("marking rule %s not found", ruleKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func genRuleSpec(jump string, source, destination netip.Prefix, intf string, inverse bool) []string {
|
func (r *router) updateState() {
|
||||||
intdir := "-i"
|
if r.stateManager == nil {
|
||||||
lointdir := "-o"
|
return
|
||||||
if inverse {
|
}
|
||||||
intdir = "-o"
|
|
||||||
lointdir = "-i"
|
var currentState *ShutdownState
|
||||||
|
if existing := r.stateManager.GetState(currentState); existing != nil {
|
||||||
|
if existingState, ok := existing.(*ShutdownState); ok {
|
||||||
|
currentState = existingState
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if currentState == nil {
|
||||||
|
currentState = &ShutdownState{}
|
||||||
|
}
|
||||||
|
|
||||||
|
currentState.Lock()
|
||||||
|
defer currentState.Unlock()
|
||||||
|
|
||||||
|
currentState.RouteRules = r.rules
|
||||||
|
currentState.RouteIPsetCounter = r.ipsetCounter
|
||||||
|
|
||||||
|
if err := r.stateManager.UpdateState(currentState); err != nil {
|
||||||
|
log.Errorf("failed to update state: %v", err)
|
||||||
}
|
}
|
||||||
return []string{intdir, intf, "!", lointdir, "lo", "-s", source.String(), "-d", destination.String(), "-j", jump}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
|
func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
|
||||||
|
|||||||
@@ -3,18 +3,18 @@
|
|||||||
package iptables
|
package iptables
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/coreos/go-iptables/iptables"
|
"github.com/coreos/go-iptables/iptables"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/firewall/test"
|
"github.com/netbirdio/netbird/client/firewall/test"
|
||||||
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
func isIptablesSupported() bool {
|
func isIptablesSupported() bool {
|
||||||
@@ -30,18 +30,29 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
|||||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
require.NoError(t, err, "failed to init iptables client")
|
require.NoError(t, err, "failed to init iptables client")
|
||||||
|
|
||||||
manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock)
|
manager, err := newRouter(iptablesClient, ifaceMock)
|
||||||
require.NoError(t, err, "should return a valid iptables manager")
|
require.NoError(t, err, "should return a valid iptables manager")
|
||||||
|
require.NoError(t, manager.init(nil))
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
_ = manager.Reset()
|
assert.NoError(t, manager.Reset(), "shouldn't return error")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
require.Len(t, manager.rules, 2, "should have created rules map")
|
// Now 5 rules:
|
||||||
|
// 1. established rule in forward chain
|
||||||
|
// 2. jump rule to NAT chain
|
||||||
|
// 3. jump rule to PRE chain
|
||||||
|
// 4. static outbound masquerade rule
|
||||||
|
// 5. static return masquerade rule
|
||||||
|
require.Len(t, manager.rules, 5, "should have created rules map")
|
||||||
|
|
||||||
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, manager.rules[ipv4Nat]...)
|
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
|
||||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
|
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
|
||||||
require.True(t, exists, "postrouting rule should exist")
|
require.True(t, exists, "postrouting jump rule should exist")
|
||||||
|
|
||||||
|
exists, err = manager.iptablesClient.Exists(tableMangle, chainPREROUTING, "-j", chainRTPRE)
|
||||||
|
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainPREROUTING)
|
||||||
|
require.True(t, exists, "prerouting jump rule should exist")
|
||||||
|
|
||||||
pair := firewall.RouterPair{
|
pair := firewall.RouterPair{
|
||||||
ID: "abc",
|
ID: "abc",
|
||||||
@@ -49,22 +60,15 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
|||||||
Destination: netip.MustParsePrefix("100.100.100.0/24"),
|
Destination: netip.MustParsePrefix("100.100.100.0/24"),
|
||||||
Masquerade: true,
|
Masquerade: true,
|
||||||
}
|
}
|
||||||
forward4Rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump}
|
|
||||||
|
|
||||||
err = manager.iptablesClient.Insert(tableFilter, chainRTFWD, 1, forward4Rule...)
|
err = manager.AddNatRule(pair)
|
||||||
require.NoError(t, err, "inserting rule should not return error")
|
require.NoError(t, err, "adding NAT rule should not return error")
|
||||||
|
|
||||||
nat4Rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination, ifaceMock.Name(), false)
|
|
||||||
|
|
||||||
err = manager.iptablesClient.Insert(tableNat, chainRTNAT, 1, nat4Rule...)
|
|
||||||
require.NoError(t, err, "inserting rule should not return error")
|
|
||||||
|
|
||||||
err = manager.Reset()
|
err = manager.Reset()
|
||||||
require.NoError(t, err, "shouldn't return error")
|
require.NoError(t, err, "shouldn't return error")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestIptablesManager_AddNatRule(t *testing.T) {
|
func TestIptablesManager_AddNatRule(t *testing.T) {
|
||||||
|
|
||||||
if !isIptablesSupported() {
|
if !isIptablesSupported() {
|
||||||
t.SkipNow()
|
t.SkipNow()
|
||||||
}
|
}
|
||||||
@@ -74,56 +78,71 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
|
|||||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
require.NoError(t, err, "failed to init iptables client")
|
require.NoError(t, err, "failed to init iptables client")
|
||||||
|
|
||||||
manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock)
|
manager, err := newRouter(iptablesClient, ifaceMock)
|
||||||
require.NoError(t, err, "shouldn't return error")
|
require.NoError(t, err, "shouldn't return error")
|
||||||
|
require.NoError(t, manager.init(nil))
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
err := manager.Reset()
|
assert.NoError(t, manager.Reset(), "shouldn't return error")
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to reset iptables manager: %s", err)
|
|
||||||
}
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
err = manager.AddNatRule(testCase.InputPair)
|
err = manager.AddNatRule(testCase.InputPair)
|
||||||
require.NoError(t, err, "forwarding pair should be inserted")
|
require.NoError(t, err, "marking rule should be inserted")
|
||||||
|
|
||||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
|
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
|
||||||
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination, ifaceMock.Name(), false)
|
markingRule := []string{
|
||||||
|
"-i", ifaceMock.Name(),
|
||||||
exists, err := iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
|
"-m", "conntrack",
|
||||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
|
"--ctstate", "NEW",
|
||||||
if testCase.InputPair.Masquerade {
|
"-s", testCase.InputPair.Source.String(),
|
||||||
require.True(t, exists, "nat rule should be created")
|
"-d", testCase.InputPair.Destination.String(),
|
||||||
foundNatRule, foundNat := manager.rules[natRuleKey]
|
"-j", "MARK", "--set-mark",
|
||||||
require.True(t, foundNat, "nat rule should exist in the map")
|
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
|
||||||
require.Equal(t, natRule[:4], foundNatRule[:4], "stored nat rule should match")
|
|
||||||
} else {
|
|
||||||
require.False(t, exists, "nat rule should not be created")
|
|
||||||
_, foundNat := manager.rules[natRuleKey]
|
|
||||||
require.False(t, foundNat, "nat rule should not exist in the map")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
|
exists, err := iptablesClient.Exists(tableMangle, chainRTPRE, markingRule...)
|
||||||
inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInversePair(testCase.InputPair).Source, firewall.GetInversePair(testCase.InputPair).Destination, ifaceMock.Name(), true)
|
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
|
||||||
|
|
||||||
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
|
|
||||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
|
|
||||||
if testCase.InputPair.Masquerade {
|
if testCase.InputPair.Masquerade {
|
||||||
require.True(t, exists, "income nat rule should be created")
|
require.True(t, exists, "marking rule should be created")
|
||||||
foundNatRule, foundNat := manager.rules[inNatRuleKey]
|
foundRule, found := manager.rules[natRuleKey]
|
||||||
require.True(t, foundNat, "income nat rule should exist in the map")
|
require.True(t, found, "marking rule should exist in the map")
|
||||||
require.Equal(t, inNatRule[:4], foundNatRule[:4], "stored income nat rule should match")
|
require.Equal(t, markingRule, foundRule, "stored marking rule should match")
|
||||||
} else {
|
} else {
|
||||||
require.False(t, exists, "nat rule should not be created")
|
require.False(t, exists, "marking rule should not be created")
|
||||||
_, foundNat := manager.rules[inNatRuleKey]
|
_, found := manager.rules[natRuleKey]
|
||||||
require.False(t, foundNat, "income nat rule should not exist in the map")
|
require.False(t, found, "marking rule should not exist in the map")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check inverse rule
|
||||||
|
inversePair := firewall.GetInversePair(testCase.InputPair)
|
||||||
|
inverseRuleKey := firewall.GenKey(firewall.NatFormat, inversePair)
|
||||||
|
inverseMarkingRule := []string{
|
||||||
|
"!", "-i", ifaceMock.Name(),
|
||||||
|
"-m", "conntrack",
|
||||||
|
"--ctstate", "NEW",
|
||||||
|
"-s", inversePair.Source.String(),
|
||||||
|
"-d", inversePair.Destination.String(),
|
||||||
|
"-j", "MARK", "--set-mark",
|
||||||
|
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
|
||||||
|
}
|
||||||
|
|
||||||
|
exists, err = iptablesClient.Exists(tableMangle, chainRTPRE, inverseMarkingRule...)
|
||||||
|
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
|
||||||
|
if testCase.InputPair.Masquerade {
|
||||||
|
require.True(t, exists, "inverse marking rule should be created")
|
||||||
|
foundRule, found := manager.rules[inverseRuleKey]
|
||||||
|
require.True(t, found, "inverse marking rule should exist in the map")
|
||||||
|
require.Equal(t, inverseMarkingRule, foundRule, "stored inverse marking rule should match")
|
||||||
|
} else {
|
||||||
|
require.False(t, exists, "inverse marking rule should not be created")
|
||||||
|
_, found := manager.rules[inverseRuleKey]
|
||||||
|
require.False(t, found, "inverse marking rule should not exist in the map")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestIptablesManager_RemoveNatRule(t *testing.T) {
|
func TestIptablesManager_RemoveNatRule(t *testing.T) {
|
||||||
|
|
||||||
if !isIptablesSupported() {
|
if !isIptablesSupported() {
|
||||||
t.SkipNow()
|
t.SkipNow()
|
||||||
}
|
}
|
||||||
@@ -132,45 +151,56 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
|
|||||||
t.Run(testCase.Name, func(t *testing.T) {
|
t.Run(testCase.Name, func(t *testing.T) {
|
||||||
iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
|
|
||||||
manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock)
|
manager, err := newRouter(iptablesClient, ifaceMock)
|
||||||
require.NoError(t, err, "shouldn't return error")
|
require.NoError(t, err, "shouldn't return error")
|
||||||
|
require.NoError(t, manager.init(nil))
|
||||||
defer func() {
|
defer func() {
|
||||||
_ = manager.Reset()
|
assert.NoError(t, manager.Reset(), "shouldn't return error")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
require.NoError(t, err, "shouldn't return error")
|
err = manager.AddNatRule(testCase.InputPair)
|
||||||
|
require.NoError(t, err, "should add NAT rule without error")
|
||||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
|
|
||||||
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination, ifaceMock.Name(), false)
|
|
||||||
|
|
||||||
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, natRule...)
|
|
||||||
require.NoError(t, err, "inserting rule should not return error")
|
|
||||||
|
|
||||||
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
|
|
||||||
inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInversePair(testCase.InputPair).Source, firewall.GetInversePair(testCase.InputPair).Destination, ifaceMock.Name(), true)
|
|
||||||
|
|
||||||
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, inNatRule...)
|
|
||||||
require.NoError(t, err, "inserting rule should not return error")
|
|
||||||
|
|
||||||
err = manager.Reset()
|
|
||||||
require.NoError(t, err, "shouldn't return error")
|
|
||||||
|
|
||||||
err = manager.RemoveNatRule(testCase.InputPair)
|
err = manager.RemoveNatRule(testCase.InputPair)
|
||||||
require.NoError(t, err, "shouldn't return error")
|
require.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
exists, err := iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
|
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
|
||||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
|
markingRule := []string{
|
||||||
require.False(t, exists, "nat rule should not exist")
|
"-i", ifaceMock.Name(),
|
||||||
|
"-m", "conntrack",
|
||||||
|
"--ctstate", "NEW",
|
||||||
|
"-s", testCase.InputPair.Source.String(),
|
||||||
|
"-d", testCase.InputPair.Destination.String(),
|
||||||
|
"-j", "MARK", "--set-mark",
|
||||||
|
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
|
||||||
|
}
|
||||||
|
|
||||||
|
exists, err := iptablesClient.Exists(tableMangle, chainRTPRE, markingRule...)
|
||||||
|
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
|
||||||
|
require.False(t, exists, "marking rule should not exist")
|
||||||
|
|
||||||
_, found := manager.rules[natRuleKey]
|
_, found := manager.rules[natRuleKey]
|
||||||
require.False(t, found, "nat rule should exist in the manager map")
|
require.False(t, found, "marking rule should not exist in the manager map")
|
||||||
|
|
||||||
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
|
// Check inverse rule removal
|
||||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
|
inversePair := firewall.GetInversePair(testCase.InputPair)
|
||||||
require.False(t, exists, "income nat rule should not exist")
|
inverseRuleKey := firewall.GenKey(firewall.NatFormat, inversePair)
|
||||||
|
inverseMarkingRule := []string{
|
||||||
|
"!", "-i", ifaceMock.Name(),
|
||||||
|
"-m", "conntrack",
|
||||||
|
"--ctstate", "NEW",
|
||||||
|
"-s", inversePair.Source.String(),
|
||||||
|
"-d", inversePair.Destination.String(),
|
||||||
|
"-j", "MARK", "--set-mark",
|
||||||
|
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
|
||||||
|
}
|
||||||
|
|
||||||
_, found = manager.rules[inNatRuleKey]
|
exists, err = iptablesClient.Exists(tableMangle, chainRTPRE, inverseMarkingRule...)
|
||||||
require.False(t, found, "income nat rule should exist in the manager map")
|
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
|
||||||
|
require.False(t, exists, "inverse marking rule should not exist")
|
||||||
|
|
||||||
|
_, found = manager.rules[inverseRuleKey]
|
||||||
|
require.False(t, found, "inverse marking rule should not exist in the map")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -183,8 +213,9 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
require.NoError(t, err, "Failed to create iptables client")
|
require.NoError(t, err, "Failed to create iptables client")
|
||||||
|
|
||||||
r, err := newRouter(context.Background(), iptablesClient, ifaceMock)
|
r, err := newRouter(iptablesClient, ifaceMock)
|
||||||
require.NoError(t, err, "Failed to create router manager")
|
require.NoError(t, err, "Failed to create router manager")
|
||||||
|
require.NoError(t, r.init(nil))
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
err := r.Reset()
|
err := r.Reset()
|
||||||
|
|||||||
@@ -1,14 +1,16 @@
|
|||||||
package iptables
|
package iptables
|
||||||
|
|
||||||
|
import "encoding/json"
|
||||||
|
|
||||||
type ipList struct {
|
type ipList struct {
|
||||||
ips map[string]struct{}
|
ips map[string]struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newIpList(ip string) ipList {
|
func newIpList(ip string) *ipList {
|
||||||
ips := make(map[string]struct{})
|
ips := make(map[string]struct{})
|
||||||
ips[ip] = struct{}{}
|
ips[ip] = struct{}{}
|
||||||
|
|
||||||
return ipList{
|
return &ipList{
|
||||||
ips: ips,
|
ips: ips,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -17,27 +19,47 @@ func (s *ipList) addIP(ip string) {
|
|||||||
s.ips[ip] = struct{}{}
|
s.ips[ip] = struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements json.Marshaler
|
||||||
|
func (s *ipList) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
IPs map[string]struct{} `json:"ips"`
|
||||||
|
}{
|
||||||
|
IPs: s.ips,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements json.Unmarshaler
|
||||||
|
func (s *ipList) UnmarshalJSON(data []byte) error {
|
||||||
|
temp := struct {
|
||||||
|
IPs map[string]struct{} `json:"ips"`
|
||||||
|
}{}
|
||||||
|
if err := json.Unmarshal(data, &temp); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.ips = temp.IPs
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
type ipsetStore struct {
|
type ipsetStore struct {
|
||||||
ipsets map[string]ipList // ipsetName -> ruleset
|
ipsets map[string]*ipList
|
||||||
}
|
}
|
||||||
|
|
||||||
func newIpsetStore() *ipsetStore {
|
func newIpsetStore() *ipsetStore {
|
||||||
return &ipsetStore{
|
return &ipsetStore{
|
||||||
ipsets: make(map[string]ipList),
|
ipsets: make(map[string]*ipList),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ipsetStore) ipset(ipsetName string) (ipList, bool) {
|
func (s *ipsetStore) ipset(ipsetName string) (*ipList, bool) {
|
||||||
r, ok := s.ipsets[ipsetName]
|
r, ok := s.ipsets[ipsetName]
|
||||||
return r, ok
|
return r, ok
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ipsetStore) addIpList(ipsetName string, list ipList) {
|
func (s *ipsetStore) addIpList(ipsetName string, list *ipList) {
|
||||||
s.ipsets[ipsetName] = list
|
s.ipsets[ipsetName] = list
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ipsetStore) deleteIpset(ipsetName string) {
|
func (s *ipsetStore) deleteIpset(ipsetName string) {
|
||||||
s.ipsets[ipsetName] = ipList{}
|
|
||||||
delete(s.ipsets, ipsetName)
|
delete(s.ipsets, ipsetName)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -48,3 +70,24 @@ func (s *ipsetStore) ipsetNames() []string {
|
|||||||
}
|
}
|
||||||
return names
|
return names
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements json.Marshaler
|
||||||
|
func (s *ipsetStore) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
IPSets map[string]*ipList `json:"ipsets"`
|
||||||
|
}{
|
||||||
|
IPSets: s.ipsets,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements json.Unmarshaler
|
||||||
|
func (s *ipsetStore) UnmarshalJSON(data []byte) error {
|
||||||
|
temp := struct {
|
||||||
|
IPSets map[string]*ipList `json:"ipsets"`
|
||||||
|
}{}
|
||||||
|
if err := json.Unmarshal(data, &temp); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.ipsets = temp.IPSets
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
70
client/firewall/iptables/state_linux.go
Normal file
70
client/firewall/iptables/state_linux.go
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
package iptables
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
)
|
||||||
|
|
||||||
|
type InterfaceState struct {
|
||||||
|
NameStr string `json:"name"`
|
||||||
|
WGAddress iface.WGAddress `json:"wg_address"`
|
||||||
|
UserspaceBind bool `json:"userspace_bind"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *InterfaceState) Name() string {
|
||||||
|
return i.NameStr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *InterfaceState) Address() device.WGAddress {
|
||||||
|
return i.WGAddress
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *InterfaceState) IsUserspaceBind() bool {
|
||||||
|
return i.UserspaceBind
|
||||||
|
}
|
||||||
|
|
||||||
|
type ShutdownState struct {
|
||||||
|
sync.Mutex
|
||||||
|
|
||||||
|
InterfaceState *InterfaceState `json:"interface_state,omitempty"`
|
||||||
|
|
||||||
|
RouteRules routeRules `json:"route_rules,omitempty"`
|
||||||
|
RouteIPsetCounter *ipsetCounter `json:"route_ipset_counter,omitempty"`
|
||||||
|
|
||||||
|
ACLEntries aclEntries `json:"acl_entries,omitempty"`
|
||||||
|
ACLIPsetStore *ipsetStore `json:"acl_ipset_store,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ShutdownState) Name() string {
|
||||||
|
return "iptables_state"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ShutdownState) Cleanup() error {
|
||||||
|
ipt, err := Create(s.InterfaceState)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create iptables manager: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.RouteRules != nil {
|
||||||
|
ipt.router.rules = s.RouteRules
|
||||||
|
}
|
||||||
|
if s.RouteIPsetCounter != nil {
|
||||||
|
ipt.router.ipsetCounter.LoadData(s.RouteIPsetCounter)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.ACLEntries != nil {
|
||||||
|
ipt.aclMgr.entries = s.ACLEntries
|
||||||
|
}
|
||||||
|
if s.ACLIPsetStore != nil {
|
||||||
|
ipt.aclMgr.ipsetStore = s.ACLIPsetStore
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ipt.Reset(nil); err != nil {
|
||||||
|
return fmt.Errorf("reset iptables manager: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -10,11 +10,14 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
ForwardingFormatPrefix = "netbird-fwd-"
|
ForwardingFormatPrefix = "netbird-fwd-"
|
||||||
ForwardingFormat = "netbird-fwd-%s-%t"
|
ForwardingFormat = "netbird-fwd-%s-%t"
|
||||||
|
PreroutingFormat = "netbird-prerouting-%s-%t"
|
||||||
NatFormat = "netbird-nat-%s-%t"
|
NatFormat = "netbird-nat-%s-%t"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -52,6 +55,8 @@ const (
|
|||||||
// It declares methods which handle actions required by the
|
// It declares methods which handle actions required by the
|
||||||
// Netbird client for ACL and routing functionality
|
// Netbird client for ACL and routing functionality
|
||||||
type Manager interface {
|
type Manager interface {
|
||||||
|
Init(stateManager *statemanager.Manager) error
|
||||||
|
|
||||||
// AllowNetbird allows netbird interface traffic
|
// AllowNetbird allows netbird interface traffic
|
||||||
AllowNetbird() error
|
AllowNetbird() error
|
||||||
|
|
||||||
@@ -91,7 +96,7 @@ type Manager interface {
|
|||||||
SetLegacyManagement(legacy bool) error
|
SetLegacyManagement(legacy bool) error
|
||||||
|
|
||||||
// Reset firewall to the default state
|
// Reset firewall to the default state
|
||||||
Reset() error
|
Reset(stateManager *statemanager.Manager) error
|
||||||
|
|
||||||
// Flush the changes to firewall controller
|
// Flush the changes to firewall controller
|
||||||
Flush() error
|
Flush() error
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ import (
|
|||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -56,13 +55,6 @@ type AclManager struct {
|
|||||||
rules map[string]*Rule
|
rules map[string]*Rule
|
||||||
}
|
}
|
||||||
|
|
||||||
// iFaceMapper defines subset methods of interface required for manager
|
|
||||||
type iFaceMapper interface {
|
|
||||||
Name() string
|
|
||||||
Address() iface.WGAddress
|
|
||||||
IsUserspaceBind() bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainName string) (*AclManager, error) {
|
func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainName string) (*AclManager, error) {
|
||||||
// sConn is used for creating sets and adding/removing elements from them
|
// sConn is used for creating sets and adding/removing elements from them
|
||||||
// it's differ then rConn (which does create new conn for each flush operation)
|
// it's differ then rConn (which does create new conn for each flush operation)
|
||||||
@@ -70,10 +62,10 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainNam
|
|||||||
// overloads netlink with high amount of rules ( > 10000)
|
// overloads netlink with high amount of rules ( > 10000)
|
||||||
sConn, err := nftables.New(nftables.AsLasting())
|
sConn, err := nftables.New(nftables.AsLasting())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("create nf conn: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
m := &AclManager{
|
return &AclManager{
|
||||||
rConn: &nftables.Conn{},
|
rConn: &nftables.Conn{},
|
||||||
sConn: sConn,
|
sConn: sConn,
|
||||||
wgIface: wgIface,
|
wgIface: wgIface,
|
||||||
@@ -82,14 +74,12 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainNam
|
|||||||
|
|
||||||
ipsetStore: newIpsetStore(),
|
ipsetStore: newIpsetStore(),
|
||||||
rules: make(map[string]*Rule),
|
rules: make(map[string]*Rule),
|
||||||
}
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
err = m.createDefaultChains()
|
func (m *AclManager) init(workTable *nftables.Table) error {
|
||||||
if err != nil {
|
m.workTable = workTable
|
||||||
return nil, err
|
return m.createDefaultChains()
|
||||||
}
|
|
||||||
|
|
||||||
return m, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddPeerFiltering rule to the firewall
|
// AddPeerFiltering rule to the firewall
|
||||||
@@ -530,7 +520,7 @@ func (m *AclManager) addPreroutingRule(preroutingChain *nftables.Chain) {
|
|||||||
},
|
},
|
||||||
&expr.Immediate{
|
&expr.Immediate{
|
||||||
Register: 1,
|
Register: 1,
|
||||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmark),
|
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
|
||||||
},
|
},
|
||||||
&expr.Meta{
|
&expr.Meta{
|
||||||
Key: expr.MetaKeyMARK,
|
Key: expr.MetaKeyMARK,
|
||||||
@@ -553,7 +543,7 @@ func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) {
|
|||||||
&expr.Cmp{
|
&expr.Cmp{
|
||||||
Op: expr.CmpOpEq,
|
Op: expr.CmpOpEq,
|
||||||
Register: 1,
|
Register: 1,
|
||||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmark),
|
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
|
||||||
},
|
},
|
||||||
&expr.Verdict{
|
&expr.Verdict{
|
||||||
Kind: expr.VerdictJump,
|
Kind: expr.VerdictJump,
|
||||||
|
|||||||
@@ -14,6 +14,8 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -24,6 +26,13 @@ const (
|
|||||||
chainNameInput = "INPUT"
|
chainNameInput = "INPUT"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// iFaceMapper defines subset methods of interface required for manager
|
||||||
|
type iFaceMapper interface {
|
||||||
|
Name() string
|
||||||
|
Address() iface.WGAddress
|
||||||
|
IsUserspaceBind() bool
|
||||||
|
}
|
||||||
|
|
||||||
// Manager of iptables firewall
|
// Manager of iptables firewall
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
mutex sync.Mutex
|
mutex sync.Mutex
|
||||||
@@ -35,30 +44,70 @@ type Manager struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create nftables firewall manager
|
// Create nftables firewall manager
|
||||||
func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
|
func Create(wgIface iFaceMapper) (*Manager, error) {
|
||||||
m := &Manager{
|
m := &Manager{
|
||||||
rConn: &nftables.Conn{},
|
rConn: &nftables.Conn{},
|
||||||
wgIface: wgIface,
|
wgIface: wgIface,
|
||||||
}
|
}
|
||||||
|
|
||||||
workTable, err := m.createWorkTable()
|
workTable := &nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4}
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
m.router, err = newRouter(context, workTable, wgIface)
|
var err error
|
||||||
|
m.router, err = newRouter(workTable, wgIface)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("create router: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
m.aclManager, err = newAclManager(workTable, wgIface, chainNameRoutingFw)
|
m.aclManager, err = newAclManager(workTable, wgIface, chainNameRoutingFw)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("create acl manager: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Init nftables firewall manager
|
||||||
|
func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||||
|
workTable, err := m.createWorkTable()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create work table: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.router.init(workTable); err != nil {
|
||||||
|
return fmt.Errorf("router init: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.aclManager.init(workTable); err != nil {
|
||||||
|
// TODO: cleanup router
|
||||||
|
return fmt.Errorf("acl manager init: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
stateManager.RegisterState(&ShutdownState{})
|
||||||
|
|
||||||
|
// We only need to record minimal interface state for potential recreation.
|
||||||
|
// Unlike iptables, which requires tracking individual rules, nftables maintains
|
||||||
|
// a known state (our netbird table plus a few static rules). This allows for easy
|
||||||
|
// cleanup using Reset() without needing to store specific rules.
|
||||||
|
if err := stateManager.UpdateState(&ShutdownState{
|
||||||
|
InterfaceState: &InterfaceState{
|
||||||
|
NameStr: m.wgIface.Name(),
|
||||||
|
WGAddress: m.wgIface.Address(),
|
||||||
|
UserspaceBind: m.wgIface.IsUserspaceBind(),
|
||||||
|
},
|
||||||
|
}); err != nil {
|
||||||
|
log.Errorf("failed to update state: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// persist early
|
||||||
|
go func() {
|
||||||
|
if err := stateManager.PersistState(context.Background()); err != nil {
|
||||||
|
log.Errorf("failed to persist state: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// AddPeerFiltering rule to the firewall
|
// AddPeerFiltering rule to the firewall
|
||||||
//
|
//
|
||||||
// If comment argument is empty firewall manager should set
|
// If comment argument is empty firewall manager should set
|
||||||
@@ -150,7 +199,7 @@ func (m *Manager) AllowNetbird() error {
|
|||||||
|
|
||||||
var chain *nftables.Chain
|
var chain *nftables.Chain
|
||||||
for _, c := range chains {
|
for _, c := range chains {
|
||||||
if c.Table.Name == tableNameFilter && c.Name == chainNameForward {
|
if c.Table.Name == tableNameFilter && c.Name == chainNameInput {
|
||||||
chain = c
|
chain = c
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -183,68 +232,84 @@ func (m *Manager) AllowNetbird() error {
|
|||||||
|
|
||||||
// SetLegacyManagement sets the route manager to use legacy management
|
// SetLegacyManagement sets the route manager to use legacy management
|
||||||
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||||
oldLegacy := m.router.legacyManagement
|
return firewall.SetLegacyManagement(m.router, isLegacy)
|
||||||
|
}
|
||||||
|
|
||||||
if oldLegacy != isLegacy {
|
// Reset firewall to the default state
|
||||||
m.router.legacyManagement = isLegacy
|
func (m *Manager) Reset(stateManager *statemanager.Manager) error {
|
||||||
log.Debugf("Set legacy management to %v", isLegacy)
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
if err := m.resetNetbirdInputRules(); err != nil {
|
||||||
|
return fmt.Errorf("reset netbird input rules: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// client reconnected to a newer mgmt, we need to cleanup the legacy rules
|
if err := m.router.Reset(); err != nil {
|
||||||
if !isLegacy && oldLegacy {
|
return fmt.Errorf("reset router: %v", err)
|
||||||
if err := m.router.RemoveAllLegacyRouteRules(); err != nil {
|
}
|
||||||
return fmt.Errorf("remove legacy routing rules: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("Legacy routing rules removed")
|
if err := m.cleanupNetbirdTables(); err != nil {
|
||||||
|
return fmt.Errorf("cleanup netbird tables: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.rConn.Flush(); err != nil {
|
||||||
|
return fmt.Errorf(flushError, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
|
||||||
|
return fmt.Errorf("delete state: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reset firewall to the default state
|
func (m *Manager) resetNetbirdInputRules() error {
|
||||||
func (m *Manager) Reset() error {
|
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
|
|
||||||
chains, err := m.rConn.ListChains()
|
chains, err := m.rConn.ListChains()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("list of chains: %w", err)
|
return fmt.Errorf("list chains: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.deleteNetbirdInputRules(chains)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) deleteNetbirdInputRules(chains []*nftables.Chain) {
|
||||||
for _, c := range chains {
|
for _, c := range chains {
|
||||||
// delete Netbird allow input traffic rule if it exists
|
if c.Table.Name == tableNameFilter && c.Name == chainNameInput {
|
||||||
if c.Table.Name == "filter" && c.Name == "INPUT" {
|
|
||||||
rules, err := m.rConn.GetRules(c.Table, c)
|
rules, err := m.rConn.GetRules(c.Table, c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("get rules for chain %q: %v", c.Name, err)
|
log.Errorf("get rules for chain %q: %v", c.Name, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
for _, r := range rules {
|
|
||||||
if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) {
|
m.deleteMatchingRules(rules)
|
||||||
if err := m.rConn.DelRule(r); err != nil {
|
}
|
||||||
log.Errorf("delete rule: %v", err)
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
func (m *Manager) deleteMatchingRules(rules []*nftables.Rule) {
|
||||||
|
for _, r := range rules {
|
||||||
|
if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) {
|
||||||
|
if err := m.rConn.DelRule(r); err != nil {
|
||||||
|
log.Errorf("delete rule: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if err := m.router.Reset(); err != nil {
|
func (m *Manager) cleanupNetbirdTables() error {
|
||||||
return fmt.Errorf("reset forward rules: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
tables, err := m.rConn.ListTables()
|
tables, err := m.rConn.ListTables()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("list of tables: %w", err)
|
return fmt.Errorf("list tables: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, t := range tables {
|
for _, t := range tables {
|
||||||
if t.Name == tableNameNetbird {
|
if t.Name == tableNameNetbird {
|
||||||
m.rConn.DelTable(t)
|
m.rConn.DelTable(t)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
return m.rConn.Flush()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Flush rule/chain/set operations from the buffer
|
// Flush rule/chain/set operations from the buffer
|
||||||
@@ -286,7 +351,9 @@ func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
|
|||||||
Register: 1,
|
Register: 1,
|
||||||
Data: ifname(m.wgIface.Name()),
|
Data: ifname(m.wgIface.Name()),
|
||||||
},
|
},
|
||||||
&expr.Verdict{},
|
&expr.Verdict{
|
||||||
|
Kind: expr.VerdictAccept,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
UserData: []byte(allowNetbirdInputRuleID),
|
UserData: []byte(allowNetbirdInputRuleID),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package nftables
|
package nftables
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
@@ -58,12 +57,13 @@ func (i *iFaceMock) IsUserspaceBind() bool { return false }
|
|||||||
func TestNftablesManager(t *testing.T) {
|
func TestNftablesManager(t *testing.T) {
|
||||||
|
|
||||||
// just check on the local interface
|
// just check on the local interface
|
||||||
manager, err := Create(context.Background(), ifaceMock)
|
manager, err := Create(ifaceMock)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, manager.Init(nil))
|
||||||
time.Sleep(time.Second * 3)
|
time.Sleep(time.Second * 3)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
err = manager.Reset()
|
err = manager.Reset(nil)
|
||||||
require.NoError(t, err, "failed to reset")
|
require.NoError(t, err, "failed to reset")
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}()
|
}()
|
||||||
@@ -169,7 +169,7 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
// established rule remains
|
// established rule remains
|
||||||
require.Len(t, rules, 1, "expected 1 rules after deletion")
|
require.Len(t, rules, 1, "expected 1 rules after deletion")
|
||||||
|
|
||||||
err = manager.Reset()
|
err = manager.Reset(nil)
|
||||||
require.NoError(t, err, "failed to reset")
|
require.NoError(t, err, "failed to reset")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -192,12 +192,13 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
|||||||
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
||||||
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
||||||
// just check on the local interface
|
// just check on the local interface
|
||||||
manager, err := Create(context.Background(), mock)
|
manager, err := Create(mock)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, manager.Init(nil))
|
||||||
time.Sleep(time.Second * 3)
|
time.Sleep(time.Second * 3)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := manager.Reset(); err != nil {
|
if err := manager.Reset(nil); err != nil {
|
||||||
t.Errorf("clear the manager state: %v", err)
|
t.Errorf("clear the manager state: %v", err)
|
||||||
}
|
}
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package nftables
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -22,6 +21,7 @@ import (
|
|||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -40,8 +40,6 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type router struct {
|
type router struct {
|
||||||
ctx context.Context
|
|
||||||
stop context.CancelFunc
|
|
||||||
conn *nftables.Conn
|
conn *nftables.Conn
|
||||||
workTable *nftables.Table
|
workTable *nftables.Table
|
||||||
filterTable *nftables.Table
|
filterTable *nftables.Table
|
||||||
@@ -54,12 +52,8 @@ type router struct {
|
|||||||
legacyManagement bool
|
legacyManagement bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func newRouter(parentCtx context.Context, workTable *nftables.Table, wgIface iFaceMapper) (*router, error) {
|
func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error) {
|
||||||
ctx, cancel := context.WithCancel(parentCtx)
|
|
||||||
|
|
||||||
r := &router{
|
r := &router{
|
||||||
ctx: ctx,
|
|
||||||
stop: cancel,
|
|
||||||
conn: &nftables.Conn{},
|
conn: &nftables.Conn{},
|
||||||
workTable: workTable,
|
workTable: workTable,
|
||||||
chains: make(map[string]*nftables.Chain),
|
chains: make(map[string]*nftables.Chain),
|
||||||
@@ -78,20 +72,25 @@ func newRouter(parentCtx context.Context, workTable *nftables.Table, wgIface iFa
|
|||||||
if errors.Is(err, errFilterTableNotFound) {
|
if errors.Is(err, errFilterTableNotFound) {
|
||||||
log.Warnf("table 'filter' not found for forward rules")
|
log.Warnf("table 'filter' not found for forward rules")
|
||||||
} else {
|
} else {
|
||||||
return nil, err
|
return nil, fmt.Errorf("load filter table: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = r.removeAcceptForwardRules()
|
return r, nil
|
||||||
if err != nil {
|
}
|
||||||
|
|
||||||
|
func (r *router) init(workTable *nftables.Table) error {
|
||||||
|
r.workTable = workTable
|
||||||
|
|
||||||
|
if err := r.removeAcceptForwardRules(); err != nil {
|
||||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = r.createContainers()
|
if err := r.createContainers(); err != nil {
|
||||||
if err != nil {
|
return fmt.Errorf("create containers: %w", err)
|
||||||
log.Errorf("failed to create containers for route: %s", err)
|
|
||||||
}
|
}
|
||||||
return r, err
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reset cleans existing nftables default forward rules from the system
|
// Reset cleans existing nftables default forward rules from the system
|
||||||
@@ -126,7 +125,6 @@ func (r *router) createContainers() error {
|
|||||||
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
|
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
|
||||||
|
|
||||||
prio := *nftables.ChainPriorityNATSource - 1
|
prio := *nftables.ChainPriorityNATSource - 1
|
||||||
|
|
||||||
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
|
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
|
||||||
Name: chainNameRoutingNat,
|
Name: chainNameRoutingNat,
|
||||||
Table: r.workTable,
|
Table: r.workTable,
|
||||||
@@ -135,6 +133,21 @@ func (r *router) createContainers() error {
|
|||||||
Type: nftables.ChainTypeNAT,
|
Type: nftables.ChainTypeNAT,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Chain is created by acl manager
|
||||||
|
// TODO: move creation to a common place
|
||||||
|
r.chains[chainNamePrerouting] = &nftables.Chain{
|
||||||
|
Name: chainNamePrerouting,
|
||||||
|
Table: r.workTable,
|
||||||
|
Type: nftables.ChainTypeFilter,
|
||||||
|
Hooknum: nftables.ChainHookPrerouting,
|
||||||
|
Priority: nftables.ChainPriorityMangle,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the single NAT rule that matches on mark
|
||||||
|
if err := r.addPostroutingRules(); err != nil {
|
||||||
|
return fmt.Errorf("add single nat rule: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
if err := r.acceptForwardRules(); err != nil {
|
if err := r.acceptForwardRules(); err != nil {
|
||||||
log.Errorf("failed to add accept rules for the forward chain: %s", err)
|
log.Errorf("failed to add accept rules for the forward chain: %s", err)
|
||||||
}
|
}
|
||||||
@@ -424,59 +437,149 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
|||||||
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
|
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
|
||||||
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
|
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
|
||||||
|
|
||||||
dir := expr.MetaKeyIIFNAME
|
op := expr.CmpOpEq
|
||||||
notDir := expr.MetaKeyOIFNAME
|
|
||||||
if pair.Inverse {
|
if pair.Inverse {
|
||||||
dir = expr.MetaKeyOIFNAME
|
op = expr.CmpOpNeq
|
||||||
notDir = expr.MetaKeyIIFNAME
|
|
||||||
}
|
}
|
||||||
|
|
||||||
lo := ifname("lo")
|
|
||||||
intf := ifname(r.wgIface.Name())
|
|
||||||
|
|
||||||
exprs := []expr.Any{
|
exprs := []expr.Any{
|
||||||
&expr.Meta{
|
// We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading.
|
||||||
Key: dir,
|
// Masquerading will take care of the conntrack state, which means we won't need to mark established connections.
|
||||||
|
&expr.Ct{
|
||||||
|
Key: expr.CtKeySTATE,
|
||||||
Register: 1,
|
Register: 1,
|
||||||
},
|
},
|
||||||
&expr.Cmp{
|
&expr.Bitwise{
|
||||||
Op: expr.CmpOpEq,
|
SourceRegister: 1,
|
||||||
Register: 1,
|
DestRegister: 1,
|
||||||
Data: intf,
|
Len: 4,
|
||||||
},
|
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
|
||||||
|
Xor: binaryutil.NativeEndian.PutUint32(0),
|
||||||
// We need to exclude the loopback interface as this changes the ebpf proxy port
|
|
||||||
&expr.Meta{
|
|
||||||
Key: notDir,
|
|
||||||
Register: 1,
|
|
||||||
},
|
},
|
||||||
&expr.Cmp{
|
&expr.Cmp{
|
||||||
Op: expr.CmpOpNeq,
|
Op: expr.CmpOpNeq,
|
||||||
Register: 1,
|
Register: 1,
|
||||||
Data: lo,
|
Data: []byte{0, 0, 0, 0},
|
||||||
|
},
|
||||||
|
|
||||||
|
// interface matching
|
||||||
|
&expr.Meta{
|
||||||
|
Key: expr.MetaKeyIIFNAME,
|
||||||
|
Register: 1,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: op,
|
||||||
|
Register: 1,
|
||||||
|
Data: ifname(r.wgIface.Name()),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
exprs = append(exprs, sourceExp...)
|
exprs = append(exprs, sourceExp...)
|
||||||
exprs = append(exprs, destExp...)
|
exprs = append(exprs, destExp...)
|
||||||
|
|
||||||
|
var markValue uint32 = nbnet.PreroutingFwmarkMasquerade
|
||||||
|
if pair.Inverse {
|
||||||
|
markValue = nbnet.PreroutingFwmarkMasqueradeReturn
|
||||||
|
}
|
||||||
|
|
||||||
exprs = append(exprs,
|
exprs = append(exprs,
|
||||||
&expr.Counter{}, &expr.Masq{},
|
&expr.Immediate{
|
||||||
|
Register: 1,
|
||||||
|
Data: binaryutil.NativeEndian.PutUint32(markValue),
|
||||||
|
},
|
||||||
|
&expr.Meta{
|
||||||
|
Key: expr.MetaKeyMARK,
|
||||||
|
SourceRegister: true,
|
||||||
|
Register: 1,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
|
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
|
||||||
|
|
||||||
if _, exists := r.rules[ruleKey]; exists {
|
if _, exists := r.rules[ruleKey]; exists {
|
||||||
if err := r.removeNatRule(pair); err != nil {
|
if err := r.removeNatRule(pair); err != nil {
|
||||||
return fmt.Errorf("remove routing rule: %w", err)
|
return fmt.Errorf("remove prerouting rule: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
|
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
|
||||||
Table: r.workTable,
|
Table: r.workTable,
|
||||||
Chain: r.chains[chainNameRoutingNat],
|
Chain: r.chains[chainNamePrerouting],
|
||||||
Exprs: exprs,
|
Exprs: exprs,
|
||||||
UserData: []byte(ruleKey),
|
UserData: []byte(ruleKey),
|
||||||
})
|
})
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// addPostroutingRules adds the masquerade rules
|
||||||
|
func (r *router) addPostroutingRules() error {
|
||||||
|
// First masquerade rule for traffic coming in from WireGuard interface
|
||||||
|
exprs := []expr.Any{
|
||||||
|
// Match on the first fwmark
|
||||||
|
&expr.Meta{
|
||||||
|
Key: expr.MetaKeyMARK,
|
||||||
|
Register: 1,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkMasquerade),
|
||||||
|
},
|
||||||
|
|
||||||
|
// We need to exclude the loopback interface as this changes the ebpf proxy port
|
||||||
|
&expr.Meta{
|
||||||
|
Key: expr.MetaKeyOIFNAME,
|
||||||
|
Register: 1,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpNeq,
|
||||||
|
Register: 1,
|
||||||
|
Data: ifname("lo"),
|
||||||
|
},
|
||||||
|
&expr.Counter{},
|
||||||
|
&expr.Masq{},
|
||||||
|
}
|
||||||
|
|
||||||
|
r.conn.AddRule(&nftables.Rule{
|
||||||
|
Table: r.workTable,
|
||||||
|
Chain: r.chains[chainNameRoutingNat],
|
||||||
|
Exprs: exprs,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Second masquerade rule for traffic going out through WireGuard interface
|
||||||
|
exprs2 := []expr.Any{
|
||||||
|
// Match on the second fwmark
|
||||||
|
&expr.Meta{
|
||||||
|
Key: expr.MetaKeyMARK,
|
||||||
|
Register: 1,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkMasqueradeReturn),
|
||||||
|
},
|
||||||
|
|
||||||
|
// Match WireGuard interface
|
||||||
|
&expr.Meta{
|
||||||
|
Key: expr.MetaKeyOIFNAME,
|
||||||
|
Register: 1,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: ifname(r.wgIface.Name()),
|
||||||
|
},
|
||||||
|
&expr.Counter{},
|
||||||
|
&expr.Masq{},
|
||||||
|
}
|
||||||
|
|
||||||
|
r.conn.AddRule(&nftables.Rule{
|
||||||
|
Table: r.workTable,
|
||||||
|
Chain: r.chains[chainNameRoutingNat],
|
||||||
|
Exprs: exprs2,
|
||||||
|
})
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -553,7 +656,10 @@ func (r *router) RemoveAllLegacyRouteRules() error {
|
|||||||
}
|
}
|
||||||
if err := r.conn.DelRule(rule); err != nil {
|
if err := r.conn.DelRule(rule); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
|
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
|
||||||
|
} else {
|
||||||
|
delete(r.rules, k)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
@@ -722,18 +828,18 @@ func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error
|
|||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveNatRule removes a nftables rule pair from nat chains
|
// RemoveNatRule removes the prerouting mark rule
|
||||||
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
||||||
if err := r.refreshRulesMap(); err != nil {
|
if err := r.refreshRulesMap(); err != nil {
|
||||||
return fmt.Errorf(refreshRulesMapError, err)
|
return fmt.Errorf(refreshRulesMapError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.removeNatRule(pair); err != nil {
|
if err := r.removeNatRule(pair); err != nil {
|
||||||
return fmt.Errorf("remove nat rule: %w", err)
|
return fmt.Errorf("remove prerouting rule: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||||
return fmt.Errorf("remove inverse nat rule: %w", err)
|
return fmt.Errorf("remove inverse prerouting rule: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.removeLegacyRouteRule(pair); err != nil {
|
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||||
@@ -748,21 +854,20 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// removeNatRule adds a nftables rule to the removal queue and deletes it from the rules map
|
|
||||||
func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
||||||
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
|
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
|
||||||
|
|
||||||
if rule, exists := r.rules[ruleKey]; exists {
|
if rule, exists := r.rules[ruleKey]; exists {
|
||||||
err := r.conn.DelRule(rule)
|
err := r.conn.DelRule(rule)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("remove nat rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
return fmt.Errorf("remove prerouting rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("nftables: removed nat rule %s -> %s", pair.Source, pair.Destination)
|
log.Debugf("nftables: removed prerouting rule %s -> %s", pair.Source, pair.Destination)
|
||||||
|
|
||||||
delete(r.rules, ruleKey)
|
delete(r.rules, ruleKey)
|
||||||
} else {
|
} else {
|
||||||
log.Debugf("nftables: nat rule %s not found", ruleKey)
|
log.Debugf("nftables: prerouting rule %s not found", ruleKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
package nftables
|
package nftables
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
@@ -11,6 +10,7 @@ import (
|
|||||||
|
|
||||||
"github.com/coreos/go-iptables/iptables"
|
"github.com/coreos/go-iptables/iptables"
|
||||||
"github.com/google/nftables"
|
"github.com/google/nftables"
|
||||||
|
"github.com/google/nftables/binaryutil"
|
||||||
"github.com/google/nftables/expr"
|
"github.com/google/nftables/expr"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@@ -33,99 +33,87 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
|||||||
t.Skip("nftables not supported on this OS")
|
t.Skip("nftables not supported on this OS")
|
||||||
}
|
}
|
||||||
|
|
||||||
table, err := createWorkTable()
|
|
||||||
require.NoError(t, err, "Failed to create work table")
|
|
||||||
|
|
||||||
defer deleteWorkTable()
|
|
||||||
|
|
||||||
for _, testCase := range test.InsertRuleTestCases {
|
for _, testCase := range test.InsertRuleTestCases {
|
||||||
t.Run(testCase.Name, func(t *testing.T) {
|
t.Run(testCase.Name, func(t *testing.T) {
|
||||||
manager, err := newRouter(context.TODO(), table, ifaceMock)
|
// need fw manager to init both acl mgr and router for all chains to be present
|
||||||
require.NoError(t, err, "failed to create router")
|
manager, err := Create(ifaceMock)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.NoError(t, manager.Reset(nil))
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, manager.Init(nil))
|
||||||
|
|
||||||
nftablesTestingClient := &nftables.Conn{}
|
nftablesTestingClient := &nftables.Conn{}
|
||||||
|
|
||||||
defer func(manager *router) {
|
rtr := manager.router
|
||||||
require.NoError(t, manager.Reset(), "failed to reset rules")
|
err = rtr.AddNatRule(testCase.InputPair)
|
||||||
}(manager)
|
|
||||||
|
|
||||||
require.NoError(t, err, "shouldn't return error")
|
|
||||||
|
|
||||||
err = manager.AddNatRule(testCase.InputPair)
|
|
||||||
require.NoError(t, err, "pair should be inserted")
|
require.NoError(t, err, "pair should be inserted")
|
||||||
|
|
||||||
defer func(manager *router, pair firewall.RouterPair) {
|
t.Cleanup(func() {
|
||||||
require.NoError(t, manager.RemoveNatRule(pair), "failed to remove rule")
|
require.NoError(t, rtr.RemoveNatRule(testCase.InputPair), "failed to remove rule")
|
||||||
}(manager, testCase.InputPair)
|
})
|
||||||
|
|
||||||
if testCase.InputPair.Masquerade {
|
if testCase.InputPair.Masquerade {
|
||||||
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
|
// Build expected expressions for connection tracking
|
||||||
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
|
conntrackExprs := []expr.Any{
|
||||||
testingExpression := append(sourceExp, destExp...) //nolint:gocritic
|
&expr.Ct{
|
||||||
testingExpression = append(testingExpression,
|
Key: expr.CtKeySTATE,
|
||||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
Register: 1,
|
||||||
|
},
|
||||||
|
&expr.Bitwise{
|
||||||
|
SourceRegister: 1,
|
||||||
|
DestRegister: 1,
|
||||||
|
Len: 4,
|
||||||
|
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
|
||||||
|
Xor: binaryutil.NativeEndian.PutUint32(0),
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpNeq,
|
||||||
|
Register: 1,
|
||||||
|
Data: []byte{0, 0, 0, 0},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build interface matching expression
|
||||||
|
ifaceExprs := []expr.Any{
|
||||||
|
&expr.Meta{
|
||||||
|
Key: expr.MetaKeyIIFNAME,
|
||||||
|
Register: 1,
|
||||||
|
},
|
||||||
&expr.Cmp{
|
&expr.Cmp{
|
||||||
Op: expr.CmpOpEq,
|
Op: expr.CmpOpEq,
|
||||||
Register: 1,
|
Register: 1,
|
||||||
Data: ifname(ifaceMock.Name()),
|
Data: ifname(ifaceMock.Name()),
|
||||||
},
|
},
|
||||||
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpNeq,
|
|
||||||
Register: 1,
|
|
||||||
Data: ifname("lo"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
|
|
||||||
found := 0
|
|
||||||
for _, chain := range manager.chains {
|
|
||||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
|
||||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
|
||||||
for _, rule := range rules {
|
|
||||||
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
|
||||||
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "nat rule elements should match")
|
|
||||||
found = 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
|
||||||
}
|
|
||||||
|
|
||||||
if testCase.InputPair.Masquerade {
|
// Build CIDR matching expressions
|
||||||
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
|
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
|
||||||
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
|
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
|
||||||
testingExpression := append(sourceExp, destExp...) //nolint:gocritic
|
|
||||||
testingExpression = append(testingExpression,
|
|
||||||
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: ifname(ifaceMock.Name()),
|
|
||||||
},
|
|
||||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpNeq,
|
|
||||||
Register: 1,
|
|
||||||
Data: ifname("lo"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
|
// Combine all expressions in the correct order
|
||||||
|
// nolint:gocritic
|
||||||
|
testingExpression := append(conntrackExprs, ifaceExprs...)
|
||||||
|
testingExpression = append(testingExpression, sourceExp...)
|
||||||
|
testingExpression = append(testingExpression, destExp...)
|
||||||
|
|
||||||
|
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
|
||||||
found := 0
|
found := 0
|
||||||
for _, chain := range manager.chains {
|
for _, chain := range rtr.chains {
|
||||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
if chain.Name == chainNamePrerouting {
|
||||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||||
for _, rule := range rules {
|
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||||
if len(rule.UserData) > 0 && string(rule.UserData) == inNatRuleKey {
|
for _, rule := range rules {
|
||||||
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income nat rule elements should match")
|
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||||
found = 1
|
// Compare expressions up to the mark setting expressions
|
||||||
|
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "prerouting nat rule elements should match")
|
||||||
|
found = 1
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
require.Equal(t, 1, found, "should find at least 1 rule in prerouting chain")
|
||||||
}
|
}
|
||||||
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -135,67 +123,66 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
|
|||||||
t.Skip("nftables not supported on this OS")
|
t.Skip("nftables not supported on this OS")
|
||||||
}
|
}
|
||||||
|
|
||||||
table, err := createWorkTable()
|
|
||||||
require.NoError(t, err, "Failed to create work table")
|
|
||||||
|
|
||||||
defer deleteWorkTable()
|
|
||||||
|
|
||||||
for _, testCase := range test.RemoveRuleTestCases {
|
for _, testCase := range test.RemoveRuleTestCases {
|
||||||
t.Run(testCase.Name, func(t *testing.T) {
|
t.Run(testCase.Name, func(t *testing.T) {
|
||||||
manager, err := newRouter(context.TODO(), table, ifaceMock)
|
manager, err := Create(ifaceMock)
|
||||||
require.NoError(t, err, "failed to create router")
|
t.Cleanup(func() {
|
||||||
|
require.NoError(t, manager.Reset(nil))
|
||||||
nftablesTestingClient := &nftables.Conn{}
|
|
||||||
|
|
||||||
defer func(manager *router) {
|
|
||||||
require.NoError(t, manager.Reset(), "failed to reset rules")
|
|
||||||
}(manager)
|
|
||||||
|
|
||||||
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
|
|
||||||
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
|
|
||||||
|
|
||||||
natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
|
|
||||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
|
|
||||||
|
|
||||||
insertedNat := nftablesTestingClient.InsertRule(&nftables.Rule{
|
|
||||||
Table: manager.workTable,
|
|
||||||
Chain: manager.chains[chainNameRoutingNat],
|
|
||||||
Exprs: natExp,
|
|
||||||
UserData: []byte(natRuleKey),
|
|
||||||
})
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, manager.Init(nil))
|
||||||
|
|
||||||
sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInversePair(testCase.InputPair).Source)
|
rtr := manager.router
|
||||||
destExp = generateCIDRMatcherExpressions(false, firewall.GetInversePair(testCase.InputPair).Destination)
|
|
||||||
|
|
||||||
natExp = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
|
// First add the NAT rule using the router's method
|
||||||
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
|
err = rtr.AddNatRule(testCase.InputPair)
|
||||||
|
require.NoError(t, err, "should add NAT rule")
|
||||||
|
|
||||||
insertedInNat := nftablesTestingClient.InsertRule(&nftables.Rule{
|
// Verify the rule was added
|
||||||
Table: manager.workTable,
|
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
|
||||||
Chain: manager.chains[chainNameRoutingNat],
|
found := false
|
||||||
Exprs: natExp,
|
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting])
|
||||||
UserData: []byte(inNatRuleKey),
|
require.NoError(t, err, "should list rules")
|
||||||
})
|
for _, rule := range rules {
|
||||||
|
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||||
err = nftablesTestingClient.Flush()
|
found = true
|
||||||
require.NoError(t, err, "shouldn't return error")
|
break
|
||||||
|
|
||||||
err = manager.Reset()
|
|
||||||
require.NoError(t, err, "shouldn't return error")
|
|
||||||
|
|
||||||
err = manager.RemoveNatRule(testCase.InputPair)
|
|
||||||
require.NoError(t, err, "shouldn't return error")
|
|
||||||
|
|
||||||
for _, chain := range manager.chains {
|
|
||||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
|
||||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
|
||||||
for _, rule := range rules {
|
|
||||||
if len(rule.UserData) > 0 {
|
|
||||||
require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should not exist")
|
|
||||||
require.NotEqual(t, insertedInNat.UserData, rule.UserData, "income nat rule should not exist")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
require.True(t, found, "NAT rule should exist before removal")
|
||||||
|
|
||||||
|
// Now remove the rule
|
||||||
|
err = rtr.RemoveNatRule(testCase.InputPair)
|
||||||
|
require.NoError(t, err, "shouldn't return error when removing rule")
|
||||||
|
|
||||||
|
// Verify the rule was removed
|
||||||
|
found = false
|
||||||
|
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting])
|
||||||
|
require.NoError(t, err, "should list rules after removal")
|
||||||
|
for _, rule := range rules {
|
||||||
|
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.False(t, found, "NAT rule should not exist after removal")
|
||||||
|
|
||||||
|
// Verify the static postrouting rules still exist
|
||||||
|
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameRoutingNat])
|
||||||
|
require.NoError(t, err, "should list postrouting rules")
|
||||||
|
foundCounter := false
|
||||||
|
for _, rule := range rules {
|
||||||
|
for _, e := range rule.Exprs {
|
||||||
|
if _, ok := e.(*expr.Counter); ok {
|
||||||
|
foundCounter = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if foundCounter {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.True(t, foundCounter, "static postrouting rule should remain")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -210,8 +197,9 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
|
|
||||||
defer deleteWorkTable()
|
defer deleteWorkTable()
|
||||||
|
|
||||||
r, err := newRouter(context.Background(), workTable, ifaceMock)
|
r, err := newRouter(workTable, ifaceMock)
|
||||||
require.NoError(t, err, "Failed to create router")
|
require.NoError(t, err, "Failed to create router")
|
||||||
|
require.NoError(t, r.init(workTable))
|
||||||
|
|
||||||
defer func(r *router) {
|
defer func(r *router) {
|
||||||
require.NoError(t, r.Reset(), "Failed to reset rules")
|
require.NoError(t, r.Reset(), "Failed to reset rules")
|
||||||
@@ -376,8 +364,9 @@ func TestNftablesCreateIpSet(t *testing.T) {
|
|||||||
|
|
||||||
defer deleteWorkTable()
|
defer deleteWorkTable()
|
||||||
|
|
||||||
r, err := newRouter(context.Background(), workTable, ifaceMock)
|
r, err := newRouter(workTable, ifaceMock)
|
||||||
require.NoError(t, err, "Failed to create router")
|
require.NoError(t, err, "Failed to create router")
|
||||||
|
require.NoError(t, r.init(workTable))
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(t, r.Reset(), "Failed to reset router")
|
require.NoError(t, r.Reset(), "Failed to reset router")
|
||||||
|
|||||||
1
client/firewall/nftables/state.go
Normal file
1
client/firewall/nftables/state.go
Normal file
@@ -0,0 +1 @@
|
|||||||
|
package nftables
|
||||||
47
client/firewall/nftables/state_linux.go
Normal file
47
client/firewall/nftables/state_linux.go
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
package nftables
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
)
|
||||||
|
|
||||||
|
type InterfaceState struct {
|
||||||
|
NameStr string `json:"name"`
|
||||||
|
WGAddress iface.WGAddress `json:"wg_address"`
|
||||||
|
UserspaceBind bool `json:"userspace_bind"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *InterfaceState) Name() string {
|
||||||
|
return i.NameStr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *InterfaceState) Address() device.WGAddress {
|
||||||
|
return i.WGAddress
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *InterfaceState) IsUserspaceBind() bool {
|
||||||
|
return i.UserspaceBind
|
||||||
|
}
|
||||||
|
|
||||||
|
type ShutdownState struct {
|
||||||
|
InterfaceState *InterfaceState `json:"interface_state,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ShutdownState) Name() string {
|
||||||
|
return "nftables_state"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ShutdownState) Cleanup() error {
|
||||||
|
nft, err := Create(s.InterfaceState)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create nftables manager: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := nft.Reset(nil); err != nil {
|
||||||
|
return fmt.Errorf("reset nftables manager: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -2,8 +2,10 @@
|
|||||||
|
|
||||||
package uspfilter
|
package uspfilter
|
||||||
|
|
||||||
|
import "github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
|
|
||||||
// Reset firewall to the default state
|
// Reset firewall to the default state
|
||||||
func (m *Manager) Reset() error {
|
func (m *Manager) Reset(stateManager *statemanager.Manager) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
@@ -11,7 +13,7 @@ func (m *Manager) Reset() error {
|
|||||||
m.incomingRules = make(map[string]RuleSet)
|
m.incomingRules = make(map[string]RuleSet)
|
||||||
|
|
||||||
if m.nativeFirewall != nil {
|
if m.nativeFirewall != nil {
|
||||||
return m.nativeFirewall.Reset()
|
return m.nativeFirewall.Reset(stateManager)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ import (
|
|||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
type action string
|
type action string
|
||||||
@@ -17,7 +19,7 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// Reset firewall to the default state
|
// Reset firewall to the default state
|
||||||
func (m *Manager) Reset() error {
|
func (m *Manager) Reset(*statemanager.Manager) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
const layerTypeAll = 0
|
const layerTypeAll = 0
|
||||||
@@ -97,6 +98,10 @@ func create(iface IFaceMapper) (*Manager, error) {
|
|||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Init(*statemanager.Manager) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Manager) IsServerRouteSupported() bool {
|
func (m *Manager) IsServerRouteSupported() bool {
|
||||||
if m.nativeFirewall == nil {
|
if m.nativeFirewall == nil {
|
||||||
return false
|
return false
|
||||||
@@ -190,7 +195,7 @@ func (m *Manager) AddPeerFiltering(
|
|||||||
return []firewall.Rule{&r}, nil
|
return []firewall.Rule{&r}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) AddRouteFiltering(sources [] netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action ) (firewall.Rule, error) {
|
func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) {
|
||||||
if m.nativeFirewall == nil {
|
if m.nativeFirewall == nil {
|
||||||
return nil, errRouteNotSupported
|
return nil, errRouteNotSupported
|
||||||
}
|
}
|
||||||
@@ -232,8 +237,11 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SetLegacyManagement doesn't need to be implemented for this manager
|
// SetLegacyManagement doesn't need to be implemented for this manager
|
||||||
func (m *Manager) SetLegacyManagement(_ bool) error {
|
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||||
return nil
|
if m.nativeFirewall == nil {
|
||||||
|
return errRouteNotSupported
|
||||||
|
}
|
||||||
|
return m.nativeFirewall.SetLegacyManagement(isLegacy)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Flush doesn't need to be implemented for this manager
|
// Flush doesn't need to be implemented for this manager
|
||||||
|
|||||||
@@ -259,7 +259,7 @@ func TestManagerReset(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = m.Reset()
|
err = m.Reset(nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to reset Manager: %v", err)
|
t.Errorf("failed to reset Manager: %v", err)
|
||||||
return
|
return
|
||||||
@@ -330,7 +330,7 @@ func TestNotMatchByIP(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = m.Reset(); err != nil {
|
if err = m.Reset(nil); err != nil {
|
||||||
t.Errorf("failed to reset Manager: %v", err)
|
t.Errorf("failed to reset Manager: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -396,7 +396,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
|||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := manager.Reset(); err != nil {
|
if err := manager.Reset(nil); err != nil {
|
||||||
t.Errorf("clear the manager state: %v", err)
|
t.Errorf("clear the manager state: %v", err)
|
||||||
}
|
}
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
|
|||||||
@@ -1,142 +0,0 @@
|
|||||||
package bind
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"runtime"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/pion/stun/v2"
|
|
||||||
"github.com/pion/transport/v3"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"golang.org/x/net/ipv4"
|
|
||||||
wgConn "golang.zx2c4.com/wireguard/conn"
|
|
||||||
)
|
|
||||||
|
|
||||||
type receiverCreator struct {
|
|
||||||
iceBind *ICEBind
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rc receiverCreator) CreateIPv4ReceiverFn(msgPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc {
|
|
||||||
return rc.iceBind.createIPv4ReceiverFn(msgPool, pc, conn)
|
|
||||||
}
|
|
||||||
|
|
||||||
type ICEBind struct {
|
|
||||||
*wgConn.StdNetBind
|
|
||||||
|
|
||||||
muUDPMux sync.Mutex
|
|
||||||
|
|
||||||
transportNet transport.Net
|
|
||||||
udpMux *UniversalUDPMuxDefault
|
|
||||||
|
|
||||||
filterFn FilterFn
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind {
|
|
||||||
ib := &ICEBind{
|
|
||||||
transportNet: transportNet,
|
|
||||||
filterFn: filterFn,
|
|
||||||
}
|
|
||||||
|
|
||||||
rc := receiverCreator{
|
|
||||||
ib,
|
|
||||||
}
|
|
||||||
ib.StdNetBind = wgConn.NewStdNetBindWithReceiverCreator(rc)
|
|
||||||
return ib
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
|
|
||||||
func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
|
|
||||||
s.muUDPMux.Lock()
|
|
||||||
defer s.muUDPMux.Unlock()
|
|
||||||
if s.udpMux == nil {
|
|
||||||
return nil, fmt.Errorf("ICEBind has not been initialized yet")
|
|
||||||
}
|
|
||||||
|
|
||||||
return s.udpMux, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc {
|
|
||||||
s.muUDPMux.Lock()
|
|
||||||
defer s.muUDPMux.Unlock()
|
|
||||||
|
|
||||||
s.udpMux = NewUniversalUDPMuxDefault(
|
|
||||||
UniversalUDPMuxParams{
|
|
||||||
UDPConn: conn,
|
|
||||||
Net: s.transportNet,
|
|
||||||
FilterFn: s.filterFn,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
|
||||||
msgs := ipv4MsgsPool.Get().(*[]ipv4.Message)
|
|
||||||
defer ipv4MsgsPool.Put(msgs)
|
|
||||||
for i := range bufs {
|
|
||||||
(*msgs)[i].Buffers[0] = bufs[i]
|
|
||||||
}
|
|
||||||
var numMsgs int
|
|
||||||
if runtime.GOOS == "linux" {
|
|
||||||
numMsgs, err = pc.ReadBatch(*msgs, 0)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
msg := &(*msgs)[0]
|
|
||||||
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
numMsgs = 1
|
|
||||||
}
|
|
||||||
for i := 0; i < numMsgs; i++ {
|
|
||||||
msg := &(*msgs)[i]
|
|
||||||
|
|
||||||
// todo: handle err
|
|
||||||
ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr)
|
|
||||||
if ok {
|
|
||||||
sizes[i] = 0
|
|
||||||
} else {
|
|
||||||
sizes[i] = msg.N
|
|
||||||
}
|
|
||||||
|
|
||||||
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
|
|
||||||
ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
|
|
||||||
wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep)
|
|
||||||
eps[i] = ep
|
|
||||||
}
|
|
||||||
return numMsgs, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr) (bool, error) {
|
|
||||||
for i := range buffers {
|
|
||||||
if !stun.IsMessage(buffers[i]) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
msg, err := s.parseSTUNMessage(buffers[i][:n])
|
|
||||||
if err != nil {
|
|
||||||
buffers[i] = []byte{}
|
|
||||||
return true, err
|
|
||||||
}
|
|
||||||
|
|
||||||
muxErr := s.udpMux.HandleSTUNMessage(msg, addr)
|
|
||||||
if muxErr != nil {
|
|
||||||
log.Warnf("failed to handle STUN packet")
|
|
||||||
}
|
|
||||||
|
|
||||||
buffers[i] = []byte{}
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ICEBind) parseSTUNMessage(raw []byte) (*stun.Message, error) {
|
|
||||||
msg := &stun.Message{
|
|
||||||
Raw: raw,
|
|
||||||
}
|
|
||||||
if err := msg.Decode(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return msg, nil
|
|
||||||
}
|
|
||||||
5
client/iface/bind/endpoint.go
Normal file
5
client/iface/bind/endpoint.go
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
package bind
|
||||||
|
|
||||||
|
import wgConn "golang.zx2c4.com/wireguard/conn"
|
||||||
|
|
||||||
|
type Endpoint = wgConn.StdNetEndpoint
|
||||||
303
client/iface/bind/ice_bind.go
Normal file
303
client/iface/bind/ice_bind.go
Normal file
@@ -0,0 +1,303 @@
|
|||||||
|
package bind
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/pion/stun/v2"
|
||||||
|
"github.com/pion/transport/v3"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/net/ipv4"
|
||||||
|
"golang.org/x/net/ipv6"
|
||||||
|
wgConn "golang.zx2c4.com/wireguard/conn"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RecvMessage struct {
|
||||||
|
Endpoint *Endpoint
|
||||||
|
Buffer []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type receiverCreator struct {
|
||||||
|
iceBind *ICEBind
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rc receiverCreator) CreateIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc {
|
||||||
|
return rc.iceBind.createIPv4ReceiverFn(pc, conn, rxOffload, msgPool)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ICEBind is a bind implementation with two main features:
|
||||||
|
// 1. filter out STUN messages and handle them
|
||||||
|
// 2. forward the received packets to the WireGuard interface from the relayed connection
|
||||||
|
//
|
||||||
|
// ICEBind.endpoints var is a map that stores the connection for each relayed peer. Fake address is just an IP address
|
||||||
|
// without port, in the format of 127.1.x.x where x.x is the last two octets of the peer address. We try to avoid to
|
||||||
|
// use the port because in the Send function the wgConn.Endpoint the port info is not exported.
|
||||||
|
type ICEBind struct {
|
||||||
|
*wgConn.StdNetBind
|
||||||
|
RecvChan chan RecvMessage
|
||||||
|
|
||||||
|
transportNet transport.Net
|
||||||
|
filterFn FilterFn
|
||||||
|
endpoints map[netip.Addr]net.Conn
|
||||||
|
endpointsMu sync.Mutex
|
||||||
|
// every time when Close() is called (i.e. BindUpdate()) we need to close exit from the receiveRelayed and create a
|
||||||
|
// new closed channel. With the closedChanMu we can safely close the channel and create a new one
|
||||||
|
closedChan chan struct{}
|
||||||
|
closedChanMu sync.RWMutex // protect the closeChan recreation from reading from it.
|
||||||
|
closed bool
|
||||||
|
|
||||||
|
muUDPMux sync.Mutex
|
||||||
|
udpMux *UniversalUDPMuxDefault
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind {
|
||||||
|
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
|
||||||
|
ib := &ICEBind{
|
||||||
|
StdNetBind: b,
|
||||||
|
RecvChan: make(chan RecvMessage, 1),
|
||||||
|
transportNet: transportNet,
|
||||||
|
filterFn: filterFn,
|
||||||
|
endpoints: make(map[netip.Addr]net.Conn),
|
||||||
|
closedChan: make(chan struct{}),
|
||||||
|
closed: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
rc := receiverCreator{
|
||||||
|
ib,
|
||||||
|
}
|
||||||
|
ib.StdNetBind = wgConn.NewStdNetBindWithReceiverCreator(rc)
|
||||||
|
return ib
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ICEBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) {
|
||||||
|
s.closed = false
|
||||||
|
s.closedChanMu.Lock()
|
||||||
|
s.closedChan = make(chan struct{})
|
||||||
|
s.closedChanMu.Unlock()
|
||||||
|
fns, port, err := s.StdNetBind.Open(uport)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
fns = append(fns, s.receiveRelayed)
|
||||||
|
return fns, port, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ICEBind) Close() error {
|
||||||
|
if s.closed {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
s.closed = true
|
||||||
|
|
||||||
|
close(s.closedChan)
|
||||||
|
|
||||||
|
return s.StdNetBind.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
|
||||||
|
func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
|
||||||
|
s.muUDPMux.Lock()
|
||||||
|
defer s.muUDPMux.Unlock()
|
||||||
|
if s.udpMux == nil {
|
||||||
|
return nil, fmt.Errorf("ICEBind has not been initialized yet")
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.udpMux, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *ICEBind) SetEndpoint(peerAddress *net.UDPAddr, conn net.Conn) (*net.UDPAddr, error) {
|
||||||
|
fakeUDPAddr, err := fakeAddress(peerAddress)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// force IPv4
|
||||||
|
fakeAddr, ok := netip.AddrFromSlice(fakeUDPAddr.IP.To4())
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("failed to convert IP to netip.Addr")
|
||||||
|
}
|
||||||
|
|
||||||
|
b.endpointsMu.Lock()
|
||||||
|
b.endpoints[fakeAddr] = conn
|
||||||
|
b.endpointsMu.Unlock()
|
||||||
|
|
||||||
|
return fakeUDPAddr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *ICEBind) RemoveEndpoint(fakeUDPAddr *net.UDPAddr) {
|
||||||
|
fakeAddr, ok := netip.AddrFromSlice(fakeUDPAddr.IP.To4())
|
||||||
|
if !ok {
|
||||||
|
log.Warnf("failed to convert IP to netip.Addr")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
b.endpointsMu.Lock()
|
||||||
|
defer b.endpointsMu.Unlock()
|
||||||
|
delete(b.endpoints, fakeAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
|
||||||
|
b.endpointsMu.Lock()
|
||||||
|
conn, ok := b.endpoints[ep.DstIP()]
|
||||||
|
b.endpointsMu.Unlock()
|
||||||
|
if !ok {
|
||||||
|
return b.StdNetBind.Send(bufs, ep)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, buf := range bufs {
|
||||||
|
if _, err := conn.Write(buf); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc {
|
||||||
|
s.muUDPMux.Lock()
|
||||||
|
defer s.muUDPMux.Unlock()
|
||||||
|
|
||||||
|
s.udpMux = NewUniversalUDPMuxDefault(
|
||||||
|
UniversalUDPMuxParams{
|
||||||
|
UDPConn: conn,
|
||||||
|
Net: s.transportNet,
|
||||||
|
FilterFn: s.filterFn,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
||||||
|
msgs := getMessages(msgsPool)
|
||||||
|
for i := range bufs {
|
||||||
|
(*msgs)[i].Buffers[0] = bufs[i]
|
||||||
|
(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
|
||||||
|
}
|
||||||
|
defer putMessages(msgs, msgsPool)
|
||||||
|
var numMsgs int
|
||||||
|
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||||
|
if rxOffload {
|
||||||
|
readAt := len(*msgs) - (wgConn.IdealBatchSize / wgConn.UdpSegmentMaxDatagrams)
|
||||||
|
//nolint
|
||||||
|
numMsgs, err = pc.ReadBatch((*msgs)[readAt:], 0)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
numMsgs, err = wgConn.SplitCoalescedMessages(*msgs, readAt, wgConn.GetGSOSize)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
numMsgs, err = pc.ReadBatch(*msgs, 0)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
msg := &(*msgs)[0]
|
||||||
|
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
numMsgs = 1
|
||||||
|
}
|
||||||
|
for i := 0; i < numMsgs; i++ {
|
||||||
|
msg := &(*msgs)[i]
|
||||||
|
|
||||||
|
// todo: handle err
|
||||||
|
ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr)
|
||||||
|
if ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
sizes[i] = msg.N
|
||||||
|
if sizes[i] == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
|
||||||
|
ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
|
||||||
|
wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep)
|
||||||
|
eps[i] = ep
|
||||||
|
}
|
||||||
|
return numMsgs, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr) (bool, error) {
|
||||||
|
for i := range buffers {
|
||||||
|
if !stun.IsMessage(buffers[i]) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
msg, err := s.parseSTUNMessage(buffers[i][:n])
|
||||||
|
if err != nil {
|
||||||
|
buffers[i] = []byte{}
|
||||||
|
return true, err
|
||||||
|
}
|
||||||
|
|
||||||
|
muxErr := s.udpMux.HandleSTUNMessage(msg, addr)
|
||||||
|
if muxErr != nil {
|
||||||
|
log.Warnf("failed to handle STUN packet")
|
||||||
|
}
|
||||||
|
|
||||||
|
buffers[i] = []byte{}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ICEBind) parseSTUNMessage(raw []byte) (*stun.Message, error) {
|
||||||
|
msg := &stun.Message{
|
||||||
|
Raw: raw,
|
||||||
|
}
|
||||||
|
if err := msg.Decode(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return msg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// receiveRelayed is a receive function that is used to receive packets from the relayed connection and forward to the
|
||||||
|
// WireGuard. Critical part is do not block if the Closed() has been called.
|
||||||
|
func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpoint) (int, error) {
|
||||||
|
c.closedChanMu.RLock()
|
||||||
|
defer c.closedChanMu.RUnlock()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-c.closedChan:
|
||||||
|
return 0, net.ErrClosed
|
||||||
|
case msg, ok := <-c.RecvChan:
|
||||||
|
if !ok {
|
||||||
|
return 0, net.ErrClosed
|
||||||
|
}
|
||||||
|
copy(buffs[0], msg.Buffer)
|
||||||
|
sizes[0] = len(msg.Buffer)
|
||||||
|
eps[0] = wgConn.Endpoint(msg.Endpoint)
|
||||||
|
return 1, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// fakeAddress returns a fake address that is used to as an identifier for the peer.
|
||||||
|
// The fake address is in the format of 127.1.x.x where x.x is the last two octets of the peer address.
|
||||||
|
func fakeAddress(peerAddress *net.UDPAddr) (*net.UDPAddr, error) {
|
||||||
|
octets := strings.Split(peerAddress.IP.String(), ".")
|
||||||
|
if len(octets) != 4 {
|
||||||
|
return nil, fmt.Errorf("invalid IP format")
|
||||||
|
}
|
||||||
|
|
||||||
|
newAddr := &net.UDPAddr{
|
||||||
|
IP: net.ParseIP(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3])),
|
||||||
|
Port: peerAddress.Port,
|
||||||
|
}
|
||||||
|
return newAddr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getMessages(msgsPool *sync.Pool) *[]ipv6.Message {
|
||||||
|
return msgsPool.Get().(*[]ipv6.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
func putMessages(msgs *[]ipv6.Message, msgsPool *sync.Pool) {
|
||||||
|
for i := range *msgs {
|
||||||
|
(*msgs)[i].OOB = (*msgs)[i].OOB[:0]
|
||||||
|
(*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB}
|
||||||
|
}
|
||||||
|
msgsPool.Put(msgs)
|
||||||
|
}
|
||||||
@@ -5,7 +5,6 @@ package device
|
|||||||
import (
|
import (
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/pion/transport/v3"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
"golang.zx2c4.com/wireguard/device"
|
"golang.zx2c4.com/wireguard/device"
|
||||||
@@ -31,13 +30,13 @@ type WGTunDevice struct {
|
|||||||
configurer WGConfigurer
|
configurer WGConfigurer
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTunDevice(address WGAddress, port int, key string, mtu int, transportNet transport.Net, tunAdapter TunAdapter, filterFn bind.FilterFn) *WGTunDevice {
|
func NewTunDevice(address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice {
|
||||||
return &WGTunDevice{
|
return &WGTunDevice{
|
||||||
address: address,
|
address: address,
|
||||||
port: port,
|
port: port,
|
||||||
key: key,
|
key: key,
|
||||||
mtu: mtu,
|
mtu: mtu,
|
||||||
iceBind: bind.NewICEBind(transportNet, filterFn),
|
iceBind: iceBind,
|
||||||
tunAdapter: tunAdapter,
|
tunAdapter: tunAdapter,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
|
|
||||||
"github.com/pion/transport/v3"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.zx2c4.com/wireguard/device"
|
"golang.zx2c4.com/wireguard/device"
|
||||||
"golang.zx2c4.com/wireguard/tun"
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
@@ -29,14 +28,14 @@ type TunDevice struct {
|
|||||||
configurer WGConfigurer
|
configurer WGConfigurer
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) *TunDevice {
|
func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
|
||||||
return &TunDevice{
|
return &TunDevice{
|
||||||
name: name,
|
name: name,
|
||||||
address: address,
|
address: address,
|
||||||
port: port,
|
port: port,
|
||||||
key: key,
|
key: key,
|
||||||
mtu: mtu,
|
mtu: mtu,
|
||||||
iceBind: bind.NewICEBind(transportNet, filterFn),
|
iceBind: iceBind,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ package device
|
|||||||
import (
|
import (
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/pion/transport/v3"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
"golang.zx2c4.com/wireguard/device"
|
"golang.zx2c4.com/wireguard/device"
|
||||||
@@ -30,13 +29,13 @@ type TunDevice struct {
|
|||||||
configurer WGConfigurer
|
configurer WGConfigurer
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTunDevice(name string, address WGAddress, port int, key string, transportNet transport.Net, tunFd int, filterFn bind.FilterFn) *TunDevice {
|
func NewTunDevice(name string, address WGAddress, port int, key string, iceBind *bind.ICEBind, tunFd int) *TunDevice {
|
||||||
return &TunDevice{
|
return &TunDevice{
|
||||||
name: name,
|
name: name,
|
||||||
address: address,
|
address: address,
|
||||||
port: port,
|
port: port,
|
||||||
key: key,
|
key: key,
|
||||||
iceBind: bind.NewICEBind(transportNet, filterFn),
|
iceBind: iceBind,
|
||||||
tunFd: tunFd,
|
tunFd: tunFd,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ package device
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/pion/transport/v3"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.zx2c4.com/wireguard/device"
|
"golang.zx2c4.com/wireguard/device"
|
||||||
|
|
||||||
@@ -31,7 +30,7 @@ type TunNetstackDevice struct {
|
|||||||
configurer WGConfigurer
|
configurer WGConfigurer
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net, listenAddress string, filterFn bind.FilterFn) *TunNetstackDevice {
|
func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice {
|
||||||
return &TunNetstackDevice{
|
return &TunNetstackDevice{
|
||||||
name: name,
|
name: name,
|
||||||
address: address,
|
address: address,
|
||||||
@@ -39,7 +38,7 @@ func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, m
|
|||||||
key: key,
|
key: key,
|
||||||
mtu: mtu,
|
mtu: mtu,
|
||||||
listenAddress: listenAddress,
|
listenAddress: listenAddress,
|
||||||
iceBind: bind.NewICEBind(transportNet, filterFn),
|
iceBind: iceBind,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
|
||||||
"github.com/pion/transport/v3"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.zx2c4.com/wireguard/device"
|
"golang.zx2c4.com/wireguard/device"
|
||||||
"golang.zx2c4.com/wireguard/tun"
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
@@ -30,7 +29,7 @@ type USPDevice struct {
|
|||||||
configurer WGConfigurer
|
configurer WGConfigurer
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) *USPDevice {
|
func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *USPDevice {
|
||||||
log.Infof("using userspace bind mode")
|
log.Infof("using userspace bind mode")
|
||||||
|
|
||||||
checkUser()
|
checkUser()
|
||||||
@@ -41,7 +40,8 @@ func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int,
|
|||||||
port: port,
|
port: port,
|
||||||
key: key,
|
key: key,
|
||||||
mtu: mtu,
|
mtu: mtu,
|
||||||
iceBind: bind.NewICEBind(transportNet, filterFn)}
|
iceBind: iceBind,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *USPDevice) Create() (WGConfigurer, error) {
|
func (t *USPDevice) Create() (WGConfigurer, error) {
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/pion/transport/v3"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/sys/windows"
|
"golang.org/x/sys/windows"
|
||||||
"golang.zx2c4.com/wireguard/device"
|
"golang.zx2c4.com/wireguard/device"
|
||||||
@@ -32,14 +31,14 @@ type TunDevice struct {
|
|||||||
configurer WGConfigurer
|
configurer WGConfigurer
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) *TunDevice {
|
func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
|
||||||
return &TunDevice{
|
return &TunDevice{
|
||||||
name: name,
|
name: name,
|
||||||
address: address,
|
address: address,
|
||||||
port: port,
|
port: port,
|
||||||
key: key,
|
key: key,
|
||||||
mtu: mtu,
|
mtu: mtu,
|
||||||
iceBind: bind.NewICEBind(transportNet, filterFn),
|
iceBind: iceBind,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,12 +6,16 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
|
"github.com/pion/transport/v3"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/errors"
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -22,14 +26,35 @@ const (
|
|||||||
|
|
||||||
type WGAddress = device.WGAddress
|
type WGAddress = device.WGAddress
|
||||||
|
|
||||||
|
type wgProxyFactory interface {
|
||||||
|
GetProxy() wgproxy.Proxy
|
||||||
|
Free() error
|
||||||
|
}
|
||||||
|
|
||||||
|
type WGIFaceOpts struct {
|
||||||
|
IFaceName string
|
||||||
|
Address string
|
||||||
|
WGPort int
|
||||||
|
WGPrivKey string
|
||||||
|
MTU int
|
||||||
|
MobileArgs *device.MobileIFaceArguments
|
||||||
|
TransportNet transport.Net
|
||||||
|
FilterFn bind.FilterFn
|
||||||
|
}
|
||||||
|
|
||||||
// WGIface represents an interface instance
|
// WGIface represents an interface instance
|
||||||
type WGIface struct {
|
type WGIface struct {
|
||||||
tun WGTunDevice
|
tun WGTunDevice
|
||||||
userspaceBind bool
|
userspaceBind bool
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
|
|
||||||
configurer device.WGConfigurer
|
configurer device.WGConfigurer
|
||||||
filter device.PacketFilter
|
filter device.PacketFilter
|
||||||
|
wgProxyFactory wgProxyFactory
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *WGIface) GetProxy() wgproxy.Proxy {
|
||||||
|
return w.wgProxyFactory.GetProxy()
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsUserspaceBind indicates whether this interfaces is userspace with bind.ICEBind
|
// IsUserspaceBind indicates whether this interfaces is userspace with bind.ICEBind
|
||||||
@@ -124,22 +149,26 @@ func (w *WGIface) Close() error {
|
|||||||
w.mu.Lock()
|
w.mu.Lock()
|
||||||
defer w.mu.Unlock()
|
defer w.mu.Unlock()
|
||||||
|
|
||||||
err := w.tun.Close()
|
var result *multierror.Error
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err)
|
if err := w.wgProxyFactory.Free(); err != nil {
|
||||||
|
result = multierror.Append(result, fmt.Errorf("failed to free WireGuard proxy: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
err = w.waitUntilRemoved()
|
if err := w.tun.Close(); err != nil {
|
||||||
if err != nil {
|
result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := w.waitUntilRemoved(); err != nil {
|
||||||
log.Warnf("failed to remove WireGuard interface %s: %v", w.Name(), err)
|
log.Warnf("failed to remove WireGuard interface %s: %v", w.Name(), err)
|
||||||
err = w.Destroy()
|
if err := w.Destroy(); err != nil {
|
||||||
if err != nil {
|
result = multierror.Append(result, fmt.Errorf("failed to remove WireGuard interface %s: %w", w.Name(), err))
|
||||||
return fmt.Errorf("failed to remove WireGuard interface %s: %w", w.Name(), err)
|
return errors.FormatErrorOrNil(result)
|
||||||
}
|
}
|
||||||
log.Infof("interface %s successfully removed", w.Name())
|
log.Infof("interface %s successfully removed", w.Name())
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return errors.FormatErrorOrNil(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetFilter sets packet filters for the userspace implementation
|
// SetFilter sets packet filters for the userspace implementation
|
||||||
|
|||||||
@@ -1,43 +0,0 @@
|
|||||||
package iface
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/pion/transport/v3"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
|
||||||
)
|
|
||||||
|
|
||||||
// NewWGIFace Creates a new WireGuard interface instance
|
|
||||||
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
|
|
||||||
wgAddress, err := device.ParseWGAddress(address)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
wgIFace := &WGIface{
|
|
||||||
tun: device.NewTunDevice(wgAddress, wgPort, wgPrivKey, mtu, transportNet, args.TunAdapter, filterFn),
|
|
||||||
userspaceBind: true,
|
|
||||||
}
|
|
||||||
return wgIFace, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up.
|
|
||||||
// Will reuse an existing one.
|
|
||||||
func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []string) error {
|
|
||||||
w.mu.Lock()
|
|
||||||
defer w.mu.Unlock()
|
|
||||||
|
|
||||||
cfgr, err := w.tun.Create(routes, dns, searchDomains)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
w.configurer = cfgr
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create this function make sense on mobile only
|
|
||||||
func (w *WGIface) Create() error {
|
|
||||||
return fmt.Errorf("this function has not implemented on this platform")
|
|
||||||
}
|
|
||||||
@@ -2,6 +2,8 @@
|
|||||||
|
|
||||||
package iface
|
package iface
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
// Create creates a new Wireguard interface, sets a given IP and brings it up.
|
// Create creates a new Wireguard interface, sets a given IP and brings it up.
|
||||||
// Will reuse an existing one.
|
// Will reuse an existing one.
|
||||||
// this function is different on Android
|
// this function is different on Android
|
||||||
@@ -17,3 +19,8 @@ func (w *WGIface) Create() error {
|
|||||||
w.configurer = cfgr
|
w.configurer = cfgr
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CreateOnAndroid this function make sense on mobile only
|
||||||
|
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
|
||||||
|
return fmt.Errorf("this function has not implemented on non mobile")
|
||||||
|
}
|
||||||
|
|||||||
24
client/iface/iface_create_android.go
Normal file
24
client/iface/iface_create_android.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
package iface
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up.
|
||||||
|
// Will reuse an existing one.
|
||||||
|
func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []string) error {
|
||||||
|
w.mu.Lock()
|
||||||
|
defer w.mu.Unlock()
|
||||||
|
|
||||||
|
cfgr, err := w.tun.Create(routes, dns, searchDomains)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
w.configurer = cfgr
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create this function make sense on mobile only
|
||||||
|
func (w *WGIface) Create() error {
|
||||||
|
return fmt.Errorf("this function has not implemented on this platform")
|
||||||
|
}
|
||||||
@@ -7,39 +7,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/cenkalti/backoff/v4"
|
"github.com/cenkalti/backoff/v4"
|
||||||
"github.com/pion/transport/v3"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewWGIFace Creates a new WireGuard interface instance
|
|
||||||
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, _ *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
|
|
||||||
wgAddress, err := device.ParseWGAddress(address)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
wgIFace := &WGIface{
|
|
||||||
userspaceBind: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
if netstack.IsEnabled() {
|
|
||||||
wgIFace.tun = device.NewNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn)
|
|
||||||
return wgIFace, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
wgIFace.tun = device.NewTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, filterFn)
|
|
||||||
|
|
||||||
return wgIFace, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateOnAndroid this function make sense on mobile only
|
|
||||||
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
|
|
||||||
return fmt.Errorf("this function has not implemented on this platform")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create creates a new Wireguard interface, sets a given IP and brings it up.
|
// Create creates a new Wireguard interface, sets a given IP and brings it up.
|
||||||
// Will reuse an existing one.
|
// Will reuse an existing one.
|
||||||
// this function is different on Android
|
// this function is different on Android
|
||||||
@@ -65,3 +34,8 @@ func (w *WGIface) Create() error {
|
|||||||
|
|
||||||
return backoff.Retry(operation, backOff)
|
return backoff.Retry(operation, backOff)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CreateOnAndroid this function make sense on mobile only
|
||||||
|
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
|
||||||
|
return fmt.Errorf("this function has not implemented on this platform")
|
||||||
|
}
|
||||||
10
client/iface/iface_guid_windows.go
Normal file
10
client/iface/iface_guid_windows.go
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
package iface
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetInterfaceGUIDString returns an interface GUID. This is useful on Windows only
|
||||||
|
func (w *WGIface) GetInterfaceGUIDString() (string, error) {
|
||||||
|
return w.tun.(*device.TunDevice).GetInterfaceGUIDString()
|
||||||
|
}
|
||||||
@@ -1,31 +0,0 @@
|
|||||||
//go:build ios
|
|
||||||
|
|
||||||
package iface
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/pion/transport/v3"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
|
||||||
)
|
|
||||||
|
|
||||||
// NewWGIFace Creates a new WireGuard interface instance
|
|
||||||
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
|
|
||||||
wgAddress, err := device.ParseWGAddress(address)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
wgIFace := &WGIface{
|
|
||||||
tun: device.NewTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, transportNet, args.TunFd, filterFn),
|
|
||||||
userspaceBind: true,
|
|
||||||
}
|
|
||||||
return wgIFace, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up.
|
|
||||||
// Will reuse an existing one.
|
|
||||||
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
|
|
||||||
return fmt.Errorf("this function has not implemented on this platform")
|
|
||||||
}
|
|
||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
type MockWGIface struct {
|
type MockWGIface struct {
|
||||||
@@ -30,6 +31,7 @@ type MockWGIface struct {
|
|||||||
GetDeviceFunc func() *device.FilteredDevice
|
GetDeviceFunc func() *device.FilteredDevice
|
||||||
GetStatsFunc func(peerKey string) (configurer.WGStats, error)
|
GetStatsFunc func(peerKey string) (configurer.WGStats, error)
|
||||||
GetInterfaceGUIDStringFunc func() (string, error)
|
GetInterfaceGUIDStringFunc func() (string, error)
|
||||||
|
GetProxyFunc func() wgproxy.Proxy
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockWGIface) GetInterfaceGUIDString() (string, error) {
|
func (m *MockWGIface) GetInterfaceGUIDString() (string, error) {
|
||||||
@@ -103,3 +105,8 @@ func (m *MockWGIface) GetDevice() *device.FilteredDevice {
|
|||||||
func (m *MockWGIface) GetStats(peerKey string) (configurer.WGStats, error) {
|
func (m *MockWGIface) GetStats(peerKey string) (configurer.WGStats, error) {
|
||||||
return m.GetStatsFunc(peerKey)
|
return m.GetStatsFunc(peerKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *MockWGIface) GetProxy() wgproxy.Proxy {
|
||||||
|
//TODO implement me
|
||||||
|
panic("implement me")
|
||||||
|
}
|
||||||
|
|||||||
24
client/iface/iface_new_android.go
Normal file
24
client/iface/iface_new_android.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
package iface
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewWGIFace Creates a new WireGuard interface instance
|
||||||
|
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||||
|
wgAddress, err := device.ParseWGAddress(opts.Address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
|
||||||
|
|
||||||
|
wgIFace := &WGIface{
|
||||||
|
userspaceBind: true,
|
||||||
|
tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter),
|
||||||
|
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
|
||||||
|
}
|
||||||
|
return wgIFace, nil
|
||||||
|
}
|
||||||
34
client/iface/iface_new_darwin.go
Normal file
34
client/iface/iface_new_darwin.go
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
//go:build !ios
|
||||||
|
|
||||||
|
package iface
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewWGIFace Creates a new WireGuard interface instance
|
||||||
|
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||||
|
wgAddress, err := device.ParseWGAddress(opts.Address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
|
||||||
|
|
||||||
|
var tun WGTunDevice
|
||||||
|
if netstack.IsEnabled() {
|
||||||
|
tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
|
||||||
|
} else {
|
||||||
|
tun = device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
|
||||||
|
}
|
||||||
|
|
||||||
|
wgIFace := &WGIface{
|
||||||
|
userspaceBind: true,
|
||||||
|
tun: tun,
|
||||||
|
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
|
||||||
|
}
|
||||||
|
return wgIFace, nil
|
||||||
|
}
|
||||||
26
client/iface/iface_new_ios.go
Normal file
26
client/iface/iface_new_ios.go
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
//go:build ios
|
||||||
|
|
||||||
|
package iface
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewWGIFace Creates a new WireGuard interface instance
|
||||||
|
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||||
|
wgAddress, err := device.ParseWGAddress(opts.Address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
|
||||||
|
|
||||||
|
wgIFace := &WGIface{
|
||||||
|
tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, iceBind, opts.MobileArgs.TunFd),
|
||||||
|
userspaceBind: true,
|
||||||
|
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
|
||||||
|
}
|
||||||
|
return wgIFace, nil
|
||||||
|
}
|
||||||
45
client/iface/iface_new_unix.go
Normal file
45
client/iface/iface_new_unix.go
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
//go:build (linux && !android) || freebsd
|
||||||
|
|
||||||
|
package iface
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewWGIFace Creates a new WireGuard interface instance
|
||||||
|
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||||
|
wgAddress, err := device.ParseWGAddress(opts.Address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
wgIFace := &WGIface{}
|
||||||
|
|
||||||
|
if netstack.IsEnabled() {
|
||||||
|
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
|
||||||
|
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
|
||||||
|
wgIFace.userspaceBind = true
|
||||||
|
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
|
||||||
|
return wgIFace, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if device.WireGuardModuleIsLoaded() {
|
||||||
|
wgIFace.tun = device.NewKernelDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, opts.TransportNet)
|
||||||
|
wgIFace.wgProxyFactory = wgproxy.NewKernelFactory(opts.WGPort)
|
||||||
|
return wgIFace, nil
|
||||||
|
}
|
||||||
|
if device.ModuleTunIsLoaded() {
|
||||||
|
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
|
||||||
|
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
|
||||||
|
wgIFace.userspaceBind = true
|
||||||
|
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
|
||||||
|
return wgIFace, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("couldn't check or load tun module")
|
||||||
|
}
|
||||||
32
client/iface/iface_new_windows.go
Normal file
32
client/iface/iface_new_windows.go
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
package iface
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewWGIFace Creates a new WireGuard interface instance
|
||||||
|
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||||
|
wgAddress, err := device.ParseWGAddress(opts.Address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
|
||||||
|
|
||||||
|
var tun WGTunDevice
|
||||||
|
if netstack.IsEnabled() {
|
||||||
|
tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
|
||||||
|
} else {
|
||||||
|
tun = device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
|
||||||
|
}
|
||||||
|
|
||||||
|
wgIFace := &WGIface{
|
||||||
|
userspaceBind: true,
|
||||||
|
tun: tun,
|
||||||
|
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
|
||||||
|
}
|
||||||
|
return wgIFace, nil
|
||||||
|
|
||||||
|
}
|
||||||
@@ -45,7 +45,16 @@ func TestWGIface_UpdateAddr(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
iface, err := NewWGIFace(ifaceName, addr, wgPort, key, DefaultMTU, newNet, nil, nil)
|
opts := WGIFaceOpts{
|
||||||
|
IFaceName: ifaceName,
|
||||||
|
Address: addr,
|
||||||
|
WGPort: wgPort,
|
||||||
|
WGPrivKey: key,
|
||||||
|
MTU: DefaultMTU,
|
||||||
|
TransportNet: newNet,
|
||||||
|
}
|
||||||
|
|
||||||
|
iface, err := NewWGIFace(opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -118,7 +127,16 @@ func Test_CreateInterface(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
iface, err := NewWGIFace(ifaceName, wgIP, 33100, key, DefaultMTU, newNet, nil, nil)
|
opts := WGIFaceOpts{
|
||||||
|
IFaceName: ifaceName,
|
||||||
|
Address: wgIP,
|
||||||
|
WGPort: 33100,
|
||||||
|
WGPrivKey: key,
|
||||||
|
MTU: DefaultMTU,
|
||||||
|
TransportNet: newNet,
|
||||||
|
}
|
||||||
|
|
||||||
|
iface, err := NewWGIFace(opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -153,7 +171,16 @@ func Test_Close(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
iface, err := NewWGIFace(ifaceName, wgIP, wgPort, key, DefaultMTU, newNet, nil, nil)
|
opts := WGIFaceOpts{
|
||||||
|
IFaceName: ifaceName,
|
||||||
|
Address: wgIP,
|
||||||
|
WGPort: wgPort,
|
||||||
|
WGPrivKey: key,
|
||||||
|
MTU: DefaultMTU,
|
||||||
|
TransportNet: newNet,
|
||||||
|
}
|
||||||
|
|
||||||
|
iface, err := NewWGIFace(opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -189,7 +216,16 @@ func TestRecreation(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
iface, err := NewWGIFace(ifaceName, wgIP, wgPort, key, DefaultMTU, newNet, nil, nil)
|
opts := WGIFaceOpts{
|
||||||
|
IFaceName: ifaceName,
|
||||||
|
Address: wgIP,
|
||||||
|
WGPort: wgPort,
|
||||||
|
WGPrivKey: key,
|
||||||
|
MTU: DefaultMTU,
|
||||||
|
TransportNet: newNet,
|
||||||
|
}
|
||||||
|
|
||||||
|
iface, err := NewWGIFace(opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -252,7 +288,15 @@ func Test_ConfigureInterface(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
iface, err := NewWGIFace(ifaceName, wgIP, wgPort, key, DefaultMTU, newNet, nil, nil)
|
opts := WGIFaceOpts{
|
||||||
|
IFaceName: ifaceName,
|
||||||
|
Address: wgIP,
|
||||||
|
WGPort: wgPort,
|
||||||
|
WGPrivKey: key,
|
||||||
|
MTU: DefaultMTU,
|
||||||
|
TransportNet: newNet,
|
||||||
|
}
|
||||||
|
iface, err := NewWGIFace(opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -300,7 +344,16 @@ func Test_UpdatePeer(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
iface, err := NewWGIFace(ifaceName, wgIP, 33100, key, DefaultMTU, newNet, nil, nil)
|
opts := WGIFaceOpts{
|
||||||
|
IFaceName: ifaceName,
|
||||||
|
Address: wgIP,
|
||||||
|
WGPort: 33100,
|
||||||
|
WGPrivKey: key,
|
||||||
|
MTU: DefaultMTU,
|
||||||
|
TransportNet: newNet,
|
||||||
|
}
|
||||||
|
|
||||||
|
iface, err := NewWGIFace(opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -361,7 +414,16 @@ func Test_RemovePeer(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
iface, err := NewWGIFace(ifaceName, wgIP, 33100, key, DefaultMTU, newNet, nil, nil)
|
opts := WGIFaceOpts{
|
||||||
|
IFaceName: ifaceName,
|
||||||
|
Address: wgIP,
|
||||||
|
WGPort: 33100,
|
||||||
|
WGPrivKey: key,
|
||||||
|
MTU: DefaultMTU,
|
||||||
|
TransportNet: newNet,
|
||||||
|
}
|
||||||
|
|
||||||
|
iface, err := NewWGIFace(opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -418,7 +480,15 @@ func Test_ConnectPeers(t *testing.T) {
|
|||||||
guid := fmt.Sprintf("{%s}", uuid.New().String())
|
guid := fmt.Sprintf("{%s}", uuid.New().String())
|
||||||
device.CustomWindowsGUIDString = strings.ToLower(guid)
|
device.CustomWindowsGUIDString = strings.ToLower(guid)
|
||||||
|
|
||||||
iface1, err := NewWGIFace(peer1ifaceName, peer1wgIP, peer1wgPort, peer1Key.String(), DefaultMTU, newNet, nil, nil)
|
optsPeer1 := WGIFaceOpts{
|
||||||
|
IFaceName: peer1ifaceName,
|
||||||
|
Address: peer1wgIP,
|
||||||
|
WGPort: peer1wgPort,
|
||||||
|
WGPrivKey: peer1Key.String(),
|
||||||
|
MTU: DefaultMTU,
|
||||||
|
TransportNet: newNet,
|
||||||
|
}
|
||||||
|
iface1, err := NewWGIFace(optsPeer1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -432,7 +502,12 @@ func Test_ConnectPeers(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
peer1endpoint, err := net.ResolveUDPAddr("udp", fmt.Sprintf("127.0.0.1:%d", peer1wgPort))
|
localIP, err := getLocalIP()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
peer1endpoint, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", localIP, peer1wgPort))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -444,7 +519,17 @@ func Test_ConnectPeers(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
iface2, err := NewWGIFace(peer2ifaceName, peer2wgIP, peer2wgPort, peer2Key.String(), DefaultMTU, newNet, nil, nil)
|
|
||||||
|
optsPeer2 := WGIFaceOpts{
|
||||||
|
IFaceName: peer2ifaceName,
|
||||||
|
Address: peer2wgIP,
|
||||||
|
WGPort: peer2wgPort,
|
||||||
|
WGPrivKey: peer2Key.String(),
|
||||||
|
MTU: DefaultMTU,
|
||||||
|
TransportNet: newNet,
|
||||||
|
}
|
||||||
|
|
||||||
|
iface2, err := NewWGIFace(optsPeer2)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -458,7 +543,7 @@ func Test_ConnectPeers(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
peer2endpoint, err := net.ResolveUDPAddr("udp", fmt.Sprintf("127.0.0.1:%d", peer2wgPort))
|
peer2endpoint, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", localIP, peer2wgPort))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -527,3 +612,28 @@ func getPeer(ifaceName, peerPubKey string) (wgtypes.Peer, error) {
|
|||||||
}
|
}
|
||||||
return wgtypes.Peer{}, fmt.Errorf("peer not found")
|
return wgtypes.Peer{}, fmt.Errorf("peer not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getLocalIP() (string, error) {
|
||||||
|
// Get all interfaces
|
||||||
|
addrs, err := net.InterfaceAddrs()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, addr := range addrs {
|
||||||
|
ipNet, ok := addr.(*net.IPNet)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ipNet.IP.IsLoopback() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if ipNet.IP.To4() == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return ipNet.IP.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", fmt.Errorf("no local IP found")
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,49 +0,0 @@
|
|||||||
//go:build (linux && !android) || freebsd
|
|
||||||
|
|
||||||
package iface
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"runtime"
|
|
||||||
|
|
||||||
"github.com/pion/transport/v3"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
|
||||||
)
|
|
||||||
|
|
||||||
// NewWGIFace Creates a new WireGuard interface instance
|
|
||||||
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
|
|
||||||
wgAddress, err := device.ParseWGAddress(address)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
wgIFace := &WGIface{}
|
|
||||||
|
|
||||||
// move the kernel/usp/netstack preference evaluation to upper layer
|
|
||||||
if netstack.IsEnabled() {
|
|
||||||
wgIFace.tun = device.NewNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn)
|
|
||||||
wgIFace.userspaceBind = true
|
|
||||||
return wgIFace, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if device.WireGuardModuleIsLoaded() {
|
|
||||||
wgIFace.tun = device.NewKernelDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet)
|
|
||||||
wgIFace.userspaceBind = false
|
|
||||||
return wgIFace, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if !device.ModuleTunIsLoaded() {
|
|
||||||
return nil, fmt.Errorf("couldn't check or load tun module")
|
|
||||||
}
|
|
||||||
wgIFace.tun = device.NewUSPDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, nil)
|
|
||||||
wgIFace.userspaceBind = true
|
|
||||||
return wgIFace, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateOnAndroid this function make sense on mobile only
|
|
||||||
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
|
|
||||||
return fmt.Errorf("CreateOnAndroid function has not implemented on %s platform", runtime.GOOS)
|
|
||||||
}
|
|
||||||
@@ -1,41 +0,0 @@
|
|||||||
package iface
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/pion/transport/v3"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
|
||||||
)
|
|
||||||
|
|
||||||
// NewWGIFace Creates a new WireGuard interface instance
|
|
||||||
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
|
|
||||||
wgAddress, err := device.ParseWGAddress(address)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
wgIFace := &WGIface{
|
|
||||||
userspaceBind: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
if netstack.IsEnabled() {
|
|
||||||
wgIFace.tun = device.NewNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn)
|
|
||||||
return wgIFace, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
wgIFace.tun = device.NewTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, filterFn)
|
|
||||||
return wgIFace, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateOnAndroid this function make sense on mobile only
|
|
||||||
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
|
|
||||||
return fmt.Errorf("this function has not implemented on non mobile")
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetInterfaceGUIDString returns an interface GUID. This is useful on Windows only
|
|
||||||
func (w *WGIface) GetInterfaceGUIDString() (string, error) {
|
|
||||||
return w.tun.(*device.TunDevice).GetInterfaceGUIDString()
|
|
||||||
}
|
|
||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
type IWGIface interface {
|
type IWGIface interface {
|
||||||
@@ -22,6 +23,7 @@ type IWGIface interface {
|
|||||||
ToInterface() *net.Interface
|
ToInterface() *net.Interface
|
||||||
Up() (*bind.UniversalUDPMuxDefault, error)
|
Up() (*bind.UniversalUDPMuxDefault, error)
|
||||||
UpdateAddr(newAddr string) error
|
UpdateAddr(newAddr string) error
|
||||||
|
GetProxy() wgproxy.Proxy
|
||||||
UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||||
RemovePeer(peerKey string) error
|
RemovePeer(peerKey string) error
|
||||||
AddAllowedIP(peerKey string, allowedIP string) error
|
AddAllowedIP(peerKey string, allowedIP string) error
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
type IWGIface interface {
|
type IWGIface interface {
|
||||||
@@ -20,6 +21,7 @@ type IWGIface interface {
|
|||||||
ToInterface() *net.Interface
|
ToInterface() *net.Interface
|
||||||
Up() (*bind.UniversalUDPMuxDefault, error)
|
Up() (*bind.UniversalUDPMuxDefault, error)
|
||||||
UpdateAddr(newAddr string) error
|
UpdateAddr(newAddr string) error
|
||||||
|
GetProxy() wgproxy.Proxy
|
||||||
UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||||
RemovePeer(peerKey string) error
|
RemovePeer(peerKey string) error
|
||||||
AddAllowedIP(peerKey string, allowedIP string) error
|
AddAllowedIP(peerKey string, allowedIP string) error
|
||||||
|
|||||||
141
client/iface/wgproxy/bind/proxy.go
Normal file
141
client/iface/wgproxy/bind/proxy.go
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
package bind
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ProxyBind struct {
|
||||||
|
Bind *bind.ICEBind
|
||||||
|
|
||||||
|
wgAddr *net.UDPAddr
|
||||||
|
wgEndpoint *bind.Endpoint
|
||||||
|
remoteConn net.Conn
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
closeMu sync.Mutex
|
||||||
|
closed bool
|
||||||
|
|
||||||
|
pausedMu sync.Mutex
|
||||||
|
paused bool
|
||||||
|
isStarted bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddTurnConn adds a new connection to the bind.
|
||||||
|
// endpoint is the NetBird address of the remote peer. The SetEndpoint return with the address what will be used in the
|
||||||
|
// WireGuard configuration.
|
||||||
|
func (p *ProxyBind) AddTurnConn(ctx context.Context, nbAddr *net.UDPAddr, remoteConn net.Conn) error {
|
||||||
|
addr, err := p.Bind.SetEndpoint(nbAddr, remoteConn)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
p.wgAddr = addr
|
||||||
|
p.wgEndpoint = addrToEndpoint(addr)
|
||||||
|
p.remoteConn = remoteConn
|
||||||
|
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||||
|
return err
|
||||||
|
|
||||||
|
}
|
||||||
|
func (p *ProxyBind) EndpointAddr() *net.UDPAddr {
|
||||||
|
return p.wgAddr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ProxyBind) Work() {
|
||||||
|
if p.remoteConn == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p.pausedMu.Lock()
|
||||||
|
p.paused = false
|
||||||
|
p.pausedMu.Unlock()
|
||||||
|
|
||||||
|
// Start the proxy only once
|
||||||
|
if !p.isStarted {
|
||||||
|
p.isStarted = true
|
||||||
|
go p.proxyToLocal(p.ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ProxyBind) Pause() {
|
||||||
|
if p.remoteConn == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p.pausedMu.Lock()
|
||||||
|
p.paused = true
|
||||||
|
p.pausedMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ProxyBind) CloseConn() error {
|
||||||
|
if p.cancel == nil {
|
||||||
|
return fmt.Errorf("proxy not started")
|
||||||
|
}
|
||||||
|
return p.close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ProxyBind) close() error {
|
||||||
|
p.closeMu.Lock()
|
||||||
|
defer p.closeMu.Unlock()
|
||||||
|
|
||||||
|
if p.closed {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
p.closed = true
|
||||||
|
|
||||||
|
p.cancel()
|
||||||
|
|
||||||
|
p.Bind.RemoveEndpoint(p.wgAddr)
|
||||||
|
|
||||||
|
if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) {
|
||||||
|
return rErr
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ProxyBind) proxyToLocal(ctx context.Context) {
|
||||||
|
defer func() {
|
||||||
|
if err := p.close(); err != nil {
|
||||||
|
log.Warnf("failed to close remote conn: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
buf := make([]byte, 1500)
|
||||||
|
n, err := p.remoteConn.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p.pausedMu.Lock()
|
||||||
|
if p.paused {
|
||||||
|
p.pausedMu.Unlock()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := bind.RecvMessage{
|
||||||
|
Endpoint: p.wgEndpoint,
|
||||||
|
Buffer: buf[:n],
|
||||||
|
}
|
||||||
|
p.Bind.RecvChan <- msg
|
||||||
|
p.pausedMu.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func addrToEndpoint(addr *net.UDPAddr) *bind.Endpoint {
|
||||||
|
ip, _ := netip.AddrFromSlice(addr.IP.To4())
|
||||||
|
addrPort := netip.AddrPortFrom(ip, uint16(addr.Port))
|
||||||
|
return &bind.Endpoint{AddrPort: addrPort}
|
||||||
|
}
|
||||||
@@ -5,9 +5,9 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
var (
|
||||||
portRangeStart = 3128
|
portRangeStart = 3128
|
||||||
portRangeEnd = 3228
|
portRangeEnd = portRangeStart + 100
|
||||||
)
|
)
|
||||||
|
|
||||||
type portLookup struct {
|
type portLookup struct {
|
||||||
@@ -17,6 +17,9 @@ func Test_portLookup_searchFreePort(t *testing.T) {
|
|||||||
func Test_portLookup_on_allocated(t *testing.T) {
|
func Test_portLookup_on_allocated(t *testing.T) {
|
||||||
pl := portLookup{}
|
pl := portLookup{}
|
||||||
|
|
||||||
|
portRangeStart = 4128
|
||||||
|
portRangeEnd = portRangeStart + 100
|
||||||
|
|
||||||
allocatedPort, err := allocatePort(portRangeStart)
|
allocatedPort, err := allocatePort(portRangeStart)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@@ -119,7 +119,7 @@ func (p *WGEBPFProxy) Free() error {
|
|||||||
p.ctxCancel()
|
p.ctxCancel()
|
||||||
|
|
||||||
var result *multierror.Error
|
var result *multierror.Error
|
||||||
if p.conn != nil { // p.conn will be nil if we have failed to listen
|
if p.conn != nil {
|
||||||
if err := p.conn.Close(); err != nil {
|
if err := p.conn.Close(); err != nil {
|
||||||
result = multierror.Append(result, err)
|
result = multierror.Append(result, err)
|
||||||
}
|
}
|
||||||
@@ -28,7 +28,7 @@ type ProxyWrapper struct {
|
|||||||
isStarted bool
|
isStarted bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) error {
|
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
|
||||||
addr, err := p.WgeBPFProxy.AddTurnConn(remoteConn)
|
addr, err := p.WgeBPFProxy.AddTurnConn(remoteConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("add turn conn: %w", err)
|
return fmt.Errorf("add turn conn: %w", err)
|
||||||
@@ -77,7 +77,7 @@ func (e *ProxyWrapper) CloseConn() error {
|
|||||||
|
|
||||||
e.cancel()
|
e.cancel()
|
||||||
|
|
||||||
if err := e.remoteConn.Close(); err != nil {
|
if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||||
return fmt.Errorf("failed to close remote conn: %w", err)
|
return fmt.Errorf("failed to close remote conn: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
49
client/iface/wgproxy/factory_kernel.go
Normal file
49
client/iface/wgproxy/factory_kernel.go
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
//go:build linux && !android
|
||||||
|
|
||||||
|
package wgproxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
|
||||||
|
udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp"
|
||||||
|
)
|
||||||
|
|
||||||
|
type KernelFactory struct {
|
||||||
|
wgPort int
|
||||||
|
|
||||||
|
ebpfProxy *ebpf.WGEBPFProxy
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewKernelFactory(wgPort int) *KernelFactory {
|
||||||
|
f := &KernelFactory{
|
||||||
|
wgPort: wgPort,
|
||||||
|
}
|
||||||
|
|
||||||
|
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort)
|
||||||
|
if err := ebpfProxy.Listen(); err != nil {
|
||||||
|
log.Infof("WireGuard Proxy Factory will produce UDP proxy")
|
||||||
|
log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err)
|
||||||
|
return f
|
||||||
|
}
|
||||||
|
log.Infof("WireGuard Proxy Factory will produce eBPF proxy")
|
||||||
|
f.ebpfProxy = ebpfProxy
|
||||||
|
return f
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *KernelFactory) GetProxy() Proxy {
|
||||||
|
if w.ebpfProxy == nil {
|
||||||
|
return udpProxy.NewWGUDPProxy(w.wgPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &ebpf.ProxyWrapper{
|
||||||
|
WgeBPFProxy: w.ebpfProxy,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *KernelFactory) Free() error {
|
||||||
|
if w.ebpfProxy == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return w.ebpfProxy.Free()
|
||||||
|
}
|
||||||
29
client/iface/wgproxy/factory_kernel_freebsd.go
Normal file
29
client/iface/wgproxy/factory_kernel_freebsd.go
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
package wgproxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp"
|
||||||
|
)
|
||||||
|
|
||||||
|
// KernelFactory todo: check eBPF support on FreeBSD
|
||||||
|
type KernelFactory struct {
|
||||||
|
wgPort int
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewKernelFactory(wgPort int) *KernelFactory {
|
||||||
|
log.Infof("WireGuard Proxy Factory will produce UDP proxy")
|
||||||
|
f := &KernelFactory{
|
||||||
|
wgPort: wgPort,
|
||||||
|
}
|
||||||
|
|
||||||
|
return f
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *KernelFactory) GetProxy() Proxy {
|
||||||
|
return udpProxy.NewWGUDPProxy(w.wgPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *KernelFactory) Free() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
30
client/iface/wgproxy/factory_usp.go
Normal file
30
client/iface/wgproxy/factory_usp.go
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
package wgproxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
|
proxyBind "github.com/netbirdio/netbird/client/iface/wgproxy/bind"
|
||||||
|
)
|
||||||
|
|
||||||
|
type USPFactory struct {
|
||||||
|
bind *bind.ICEBind
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewUSPFactory(iceBind *bind.ICEBind) *USPFactory {
|
||||||
|
log.Infof("WireGuard Proxy Factory will produce bind proxy")
|
||||||
|
f := &USPFactory{
|
||||||
|
bind: iceBind,
|
||||||
|
}
|
||||||
|
return f
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *USPFactory) GetProxy() Proxy {
|
||||||
|
return &proxyBind.ProxyBind{
|
||||||
|
Bind: w.bind,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *USPFactory) Free() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
15
client/iface/wgproxy/proxy.go
Normal file
15
client/iface/wgproxy/proxy.go
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
package wgproxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Proxy is a transfer layer between the relayed connection and the WireGuard
|
||||||
|
type Proxy interface {
|
||||||
|
AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error
|
||||||
|
EndpointAddr() *net.UDPAddr // EndpointAddr returns the address of the WireGuard peer endpoint
|
||||||
|
Work() // Work start or resume the proxy
|
||||||
|
Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works.
|
||||||
|
CloseConn() error
|
||||||
|
}
|
||||||
56
client/iface/wgproxy/proxy_linux_test.go
Normal file
56
client/iface/wgproxy/proxy_linux_test.go
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
//go:build linux && !android
|
||||||
|
|
||||||
|
package wgproxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestProxyCloseByRemoteConnEBPF(t *testing.T) {
|
||||||
|
if os.Getenv("GITHUB_ACTIONS") != "true" {
|
||||||
|
t.Skip("Skipping test as it requires root privileges")
|
||||||
|
}
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
ebpfProxy := ebpf.NewWGEBPFProxy(51831)
|
||||||
|
if err := ebpfProxy.Listen(); err != nil {
|
||||||
|
t.Fatalf("failed to initialize ebpf proxy: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if err := ebpfProxy.Free(); err != nil {
|
||||||
|
t.Errorf("failed to free ebpf proxy: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
proxy Proxy
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "ebpf proxy",
|
||||||
|
proxy: &ebpf.ProxyWrapper{
|
||||||
|
WgeBPFProxy: ebpfProxy,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
relayedConn := newMockConn()
|
||||||
|
err := tt.proxy.AddTurnConn(ctx, nil, relayedConn)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = relayedConn.Close()
|
||||||
|
if err := tt.proxy.CloseConn(); err != nil {
|
||||||
|
t.Errorf("error: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -11,8 +11,8 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/wgproxy/ebpf"
|
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
|
||||||
"github.com/netbirdio/netbird/client/internal/wgproxy/usp"
|
udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -84,7 +84,7 @@ func TestProxyCloseByRemoteConn(t *testing.T) {
|
|||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "userspace proxy",
|
name: "userspace proxy",
|
||||||
proxy: usp.NewWGUserSpaceProxy(51830),
|
proxy: udpProxy.NewWGUDPProxy(51830),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -114,7 +114,7 @@ func TestProxyCloseByRemoteConn(t *testing.T) {
|
|||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
relayedConn := newMockConn()
|
relayedConn := newMockConn()
|
||||||
err := tt.proxy.AddTurnConn(ctx, relayedConn)
|
err := tt.proxy.AddTurnConn(ctx, nil, relayedConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("error: %v", err)
|
t.Errorf("error: %v", err)
|
||||||
}
|
}
|
||||||
@@ -1,19 +1,21 @@
|
|||||||
package usp
|
package udp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/hashicorp/go-multierror"
|
"github.com/hashicorp/go-multierror"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/errors"
|
cerrors "github.com/netbirdio/netbird/client/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
// WGUserSpaceProxy proxies
|
// WGUDPProxy proxies
|
||||||
type WGUserSpaceProxy struct {
|
type WGUDPProxy struct {
|
||||||
localWGListenPort int
|
localWGListenPort int
|
||||||
|
|
||||||
remoteConn net.Conn
|
remoteConn net.Conn
|
||||||
@@ -28,10 +30,10 @@ type WGUserSpaceProxy struct {
|
|||||||
isStarted bool
|
isStarted bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewWGUserSpaceProxy instantiate a user space WireGuard proxy. This is not a thread safe implementation
|
// NewWGUDPProxy instantiate a UDP based WireGuard proxy. This is not a thread safe implementation
|
||||||
func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy {
|
func NewWGUDPProxy(wgPort int) *WGUDPProxy {
|
||||||
log.Debugf("Initializing new user space proxy with port %d", wgPort)
|
log.Debugf("Initializing new user space proxy with port %d", wgPort)
|
||||||
p := &WGUserSpaceProxy{
|
p := &WGUDPProxy{
|
||||||
localWGListenPort: wgPort,
|
localWGListenPort: wgPort,
|
||||||
}
|
}
|
||||||
return p
|
return p
|
||||||
@@ -42,7 +44,7 @@ func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy {
|
|||||||
// the connection is complete, an error is returned. Once successfully
|
// the connection is complete, an error is returned. Once successfully
|
||||||
// connected, any expiration of the context will not affect the
|
// connected, any expiration of the context will not affect the
|
||||||
// connection.
|
// connection.
|
||||||
func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) error {
|
func (p *WGUDPProxy) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
|
||||||
dialer := net.Dialer{}
|
dialer := net.Dialer{}
|
||||||
localConn, err := dialer.DialContext(ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
|
localConn, err := dialer.DialContext(ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -57,7 +59,7 @@ func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn)
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *WGUserSpaceProxy) EndpointAddr() *net.UDPAddr {
|
func (p *WGUDPProxy) EndpointAddr() *net.UDPAddr {
|
||||||
if p.localConn == nil {
|
if p.localConn == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -66,7 +68,7 @@ func (p *WGUserSpaceProxy) EndpointAddr() *net.UDPAddr {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Work starts the proxy or resumes it if it was paused
|
// Work starts the proxy or resumes it if it was paused
|
||||||
func (p *WGUserSpaceProxy) Work() {
|
func (p *WGUDPProxy) Work() {
|
||||||
if p.remoteConn == nil {
|
if p.remoteConn == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -83,7 +85,7 @@ func (p *WGUserSpaceProxy) Work() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Pause pauses the proxy from receiving data from the remote peer
|
// Pause pauses the proxy from receiving data from the remote peer
|
||||||
func (p *WGUserSpaceProxy) Pause() {
|
func (p *WGUDPProxy) Pause() {
|
||||||
if p.remoteConn == nil {
|
if p.remoteConn == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -94,14 +96,14 @@ func (p *WGUserSpaceProxy) Pause() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CloseConn close the localConn
|
// CloseConn close the localConn
|
||||||
func (p *WGUserSpaceProxy) CloseConn() error {
|
func (p *WGUDPProxy) CloseConn() error {
|
||||||
if p.cancel == nil {
|
if p.cancel == nil {
|
||||||
return fmt.Errorf("proxy not started")
|
return fmt.Errorf("proxy not started")
|
||||||
}
|
}
|
||||||
return p.close()
|
return p.close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *WGUserSpaceProxy) close() error {
|
func (p *WGUDPProxy) close() error {
|
||||||
p.closeMu.Lock()
|
p.closeMu.Lock()
|
||||||
defer p.closeMu.Unlock()
|
defer p.closeMu.Unlock()
|
||||||
|
|
||||||
@@ -114,18 +116,18 @@ func (p *WGUserSpaceProxy) close() error {
|
|||||||
p.cancel()
|
p.cancel()
|
||||||
|
|
||||||
var result *multierror.Error
|
var result *multierror.Error
|
||||||
if err := p.remoteConn.Close(); err != nil {
|
if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||||
result = multierror.Append(result, fmt.Errorf("remote conn: %s", err))
|
result = multierror.Append(result, fmt.Errorf("remote conn: %s", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := p.localConn.Close(); err != nil {
|
if err := p.localConn.Close(); err != nil {
|
||||||
result = multierror.Append(result, fmt.Errorf("local conn: %s", err))
|
result = multierror.Append(result, fmt.Errorf("local conn: %s", err))
|
||||||
}
|
}
|
||||||
return errors.FormatErrorOrNil(result)
|
return cerrors.FormatErrorOrNil(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
// proxyToRemote proxies from Wireguard to the RemoteKey
|
// proxyToRemote proxies from Wireguard to the RemoteKey
|
||||||
func (p *WGUserSpaceProxy) proxyToRemote(ctx context.Context) {
|
func (p *WGUDPProxy) proxyToRemote(ctx context.Context) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := p.close(); err != nil {
|
if err := p.close(); err != nil {
|
||||||
log.Warnf("error in proxy to remote loop: %s", err)
|
log.Warnf("error in proxy to remote loop: %s", err)
|
||||||
@@ -157,21 +159,19 @@ func (p *WGUserSpaceProxy) proxyToRemote(ctx context.Context) {
|
|||||||
|
|
||||||
// proxyToLocal proxies from the Remote peer to local WireGuard
|
// proxyToLocal proxies from the Remote peer to local WireGuard
|
||||||
// if the proxy is paused it will drain the remote conn and drop the packets
|
// if the proxy is paused it will drain the remote conn and drop the packets
|
||||||
func (p *WGUserSpaceProxy) proxyToLocal(ctx context.Context) {
|
func (p *WGUDPProxy) proxyToLocal(ctx context.Context) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := p.close(); err != nil {
|
if err := p.close(); err != nil {
|
||||||
log.Warnf("error in proxy to local loop: %s", err)
|
if !errors.Is(err, io.EOF) {
|
||||||
|
log.Warnf("error in proxy to local loop: %s", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
buf := make([]byte, 1500)
|
buf := make([]byte, 1500)
|
||||||
for {
|
for {
|
||||||
n, err := p.remoteConn.Read(buf)
|
n, err := p.remoteConnRead(ctx, buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if ctx.Err() != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -193,3 +193,15 @@ func (p *WGUserSpaceProxy) proxyToLocal(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *WGUDPProxy) remoteConnRead(ctx context.Context, buf []byte) (n int, err error) {
|
||||||
|
n, err = p.remoteConn.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.LocalAddr(), err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
@@ -3,6 +3,7 @@ package acl
|
|||||||
import (
|
import (
|
||||||
"crypto/md5"
|
"crypto/md5"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
@@ -10,14 +11,18 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||||
"github.com/netbirdio/netbird/client/ssh"
|
"github.com/netbirdio/netbird/client/ssh"
|
||||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var ErrSourceRangesEmpty = errors.New("sources range is empty")
|
||||||
|
|
||||||
// Manager is a ACL rules manager
|
// Manager is a ACL rules manager
|
||||||
type Manager interface {
|
type Manager interface {
|
||||||
ApplyFiltering(networkMap *mgmProto.NetworkMap)
|
ApplyFiltering(networkMap *mgmProto.NetworkMap)
|
||||||
@@ -167,31 +172,40 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule) error {
|
func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule) error {
|
||||||
var newRouteRules = make(map[id.RuleID]struct{})
|
newRouteRules := make(map[id.RuleID]struct{}, len(rules))
|
||||||
|
var merr *multierror.Error
|
||||||
|
|
||||||
|
// Apply new rules - firewall manager will return existing rule ID if already present
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
id, err := d.applyRouteACL(rule)
|
id, err := d.applyRouteACL(rule)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("apply route ACL: %w", err)
|
if errors.Is(err, ErrSourceRangesEmpty) {
|
||||||
|
log.Debugf("skipping empty rule with destination %s: %v", rule.Destination, err)
|
||||||
|
} else {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("add route rule: %w", err))
|
||||||
|
}
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
newRouteRules[id] = struct{}{}
|
newRouteRules[id] = struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Clean up old firewall rules
|
||||||
for id := range d.routeRules {
|
for id := range d.routeRules {
|
||||||
if _, ok := newRouteRules[id]; !ok {
|
if _, exists := newRouteRules[id]; !exists {
|
||||||
if err := d.firewall.DeleteRouteRule(id); err != nil {
|
if err := d.firewall.DeleteRouteRule(id); err != nil {
|
||||||
log.Errorf("failed to delete route firewall rule: %v", err)
|
merr = multierror.Append(merr, fmt.Errorf("delete route rule: %w", err))
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
delete(d.routeRules, id)
|
// implicitly deleted from the map
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
d.routeRules = newRouteRules
|
d.routeRules = newRouteRules
|
||||||
return nil
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.RuleID, error) {
|
func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.RuleID, error) {
|
||||||
if len(rule.SourceRanges) == 0 {
|
if len(rule.SourceRanges) == 0 {
|
||||||
return "", fmt.Errorf("source ranges is empty")
|
return "", ErrSourceRangesEmpty
|
||||||
}
|
}
|
||||||
|
|
||||||
var sources []netip.Prefix
|
var sources []netip.Prefix
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package acl
|
package acl
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@@ -52,13 +51,13 @@ func TestDefaultManager(t *testing.T) {
|
|||||||
}).AnyTimes()
|
}).AnyTimes()
|
||||||
|
|
||||||
// we receive one rule from the management so for testing purposes ignore it
|
// we receive one rule from the management so for testing purposes ignore it
|
||||||
fw, err := firewall.NewFirewall(context.Background(), ifaceMock)
|
fw, err := firewall.NewFirewall(ifaceMock, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("create firewall: %v", err)
|
t.Errorf("create firewall: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer func(fw manager.Manager) {
|
defer func(fw manager.Manager) {
|
||||||
_ = fw.Reset()
|
_ = fw.Reset(nil)
|
||||||
}(fw)
|
}(fw)
|
||||||
acl := NewDefaultManager(fw)
|
acl := NewDefaultManager(fw)
|
||||||
|
|
||||||
@@ -345,13 +344,13 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
|||||||
}).AnyTimes()
|
}).AnyTimes()
|
||||||
|
|
||||||
// we receive one rule from the management so for testing purposes ignore it
|
// we receive one rule from the management so for testing purposes ignore it
|
||||||
fw, err := firewall.NewFirewall(context.Background(), ifaceMock)
|
fw, err := firewall.NewFirewall(ifaceMock, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("create firewall: %v", err)
|
t.Errorf("create firewall: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer func(fw manager.Manager) {
|
defer func(fw manager.Manager) {
|
||||||
_ = fw.Reset()
|
_ = fw.Reset(nil)
|
||||||
}(fw)
|
}(fw)
|
||||||
acl := NewDefaultManager(fw)
|
acl := NewDefaultManager(fw)
|
||||||
|
|
||||||
|
|||||||
@@ -164,7 +164,7 @@ func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
err = util.WriteJsonWithRestrictedPermission(input.ConfigPath, cfg)
|
err = util.WriteJsonWithRestrictedPermission(context.Background(), input.ConfigPath, cfg)
|
||||||
return cfg, err
|
return cfg, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -185,7 +185,7 @@ func CreateInMemoryConfig(input ConfigInput) (*Config, error) {
|
|||||||
|
|
||||||
// WriteOutConfig write put the prepared config to the given path
|
// WriteOutConfig write put the prepared config to the given path
|
||||||
func WriteOutConfig(path string, config *Config) error {
|
func WriteOutConfig(path string, config *Config) error {
|
||||||
return util.WriteJson(path, config)
|
return util.WriteJson(context.Background(), path, config)
|
||||||
}
|
}
|
||||||
|
|
||||||
// createNewConfig creates a new config generating a new Wireguard key and saving to file
|
// createNewConfig creates a new config generating a new Wireguard key and saving to file
|
||||||
@@ -215,7 +215,7 @@ func update(input ConfigInput) (*Config, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if updated {
|
if updated {
|
||||||
if err := util.WriteJson(input.ConfigPath, config); err != nil {
|
if err := util.WriteJson(context.Background(), input.ConfigPath, config); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -62,10 +62,7 @@ func (c *ConnectClient) Run() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RunWithProbes runs the client's main logic with probes attached
|
// RunWithProbes runs the client's main logic with probes attached
|
||||||
func (c *ConnectClient) RunWithProbes(
|
func (c *ConnectClient) RunWithProbes(probes *ProbeHolder, runningChan chan error) error {
|
||||||
probes *ProbeHolder,
|
|
||||||
runningChan chan error,
|
|
||||||
) error {
|
|
||||||
return c.run(MobileDependency{}, probes, runningChan)
|
return c.run(MobileDependency{}, probes, runningChan)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -104,11 +101,7 @@ func (c *ConnectClient) RunOniOS(
|
|||||||
return c.run(mobileDependency, nil, nil)
|
return c.run(mobileDependency, nil, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ConnectClient) run(
|
func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHolder, runningChan chan error) error {
|
||||||
mobileDependency MobileDependency,
|
|
||||||
probes *ProbeHolder,
|
|
||||||
runningChan chan error,
|
|
||||||
) error {
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
log.Panicf("Panic occurred: %v, stack trace: %s", r, string(debug.Stack()))
|
log.Panicf("Panic occurred: %v, stack trace: %s", r, string(debug.Stack()))
|
||||||
@@ -117,12 +110,6 @@ func (c *ConnectClient) run(
|
|||||||
|
|
||||||
log.Infof("starting NetBird client version %s on %s/%s", version.NetbirdVersion(), runtime.GOOS, runtime.GOARCH)
|
log.Infof("starting NetBird client version %s on %s/%s", version.NetbirdVersion(), runtime.GOOS, runtime.GOARCH)
|
||||||
|
|
||||||
// Check if client was not shut down in a clean way and restore DNS config if required.
|
|
||||||
// Otherwise, we might not be able to connect to the management server to retrieve new config.
|
|
||||||
if err := dns.CheckUncleanShutdown(c.config.WgIface); err != nil {
|
|
||||||
log.Errorf("checking unclean shutdown error: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
backOff := &backoff.ExponentialBackOff{
|
backOff := &backoff.ExponentialBackOff{
|
||||||
InitialInterval: time.Second,
|
InitialInterval: time.Second,
|
||||||
RandomizationFactor: 1,
|
RandomizationFactor: 1,
|
||||||
@@ -170,7 +157,8 @@ func (c *ConnectClient) run(
|
|||||||
|
|
||||||
engineCtx, cancel := context.WithCancel(c.ctx)
|
engineCtx, cancel := context.WithCancel(c.ctx)
|
||||||
defer func() {
|
defer func() {
|
||||||
c.statusRecorder.MarkManagementDisconnected(state.err)
|
_, err := state.Status()
|
||||||
|
c.statusRecorder.MarkManagementDisconnected(err)
|
||||||
c.statusRecorder.CleanLocalPeerState()
|
c.statusRecorder.CleanLocalPeerState()
|
||||||
cancel()
|
cancel()
|
||||||
}()
|
}()
|
||||||
@@ -220,7 +208,8 @@ func (c *ConnectClient) run(
|
|||||||
|
|
||||||
c.statusRecorder.MarkSignalDisconnected(nil)
|
c.statusRecorder.MarkSignalDisconnected(nil)
|
||||||
defer func() {
|
defer func() {
|
||||||
c.statusRecorder.MarkSignalDisconnected(state.err)
|
_, err := state.Status()
|
||||||
|
c.statusRecorder.MarkSignalDisconnected(err)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// with the global Wiretrustee config in hand connect (just a connection, no stream yet) Signal
|
// with the global Wiretrustee config in hand connect (just a connection, no stream yet) Signal
|
||||||
@@ -358,7 +347,11 @@ func (c *ConnectClient) Stop() error {
|
|||||||
if c.engine == nil {
|
if c.engine == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return c.engine.Stop()
|
if err := c.engine.Stop(); err != nil {
|
||||||
|
return fmt.Errorf("stop engine: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ConnectClient) isContextCancelled() bool {
|
func (c *ConnectClient) isContextCancelled() bool {
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
const (
|
const (
|
||||||
fileUncleanShutdownResolvConfLocation = "/var/db/netbird/resolv.conf"
|
fileUncleanShutdownResolvConfLocation = "/var/db/netbird/resolv.conf"
|
||||||
fileUncleanShutdownManagerTypeLocation = "/var/db/netbird/manager"
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -3,6 +3,5 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
const (
|
const (
|
||||||
fileUncleanShutdownResolvConfLocation = "/var/lib/netbird/resolv.conf"
|
fileUncleanShutdownResolvConfLocation = "/var/lib/netbird/resolv.conf"
|
||||||
fileUncleanShutdownManagerTypeLocation = "/var/lib/netbird/manager"
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -9,6 +9,8 @@ import (
|
|||||||
|
|
||||||
"github.com/fsnotify/fsnotify"
|
"github.com/fsnotify/fsnotify"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -20,7 +22,7 @@ var (
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
type repairConfFn func([]string, string, *resolvConf) error
|
type repairConfFn func([]string, string, *resolvConf, *statemanager.Manager) error
|
||||||
|
|
||||||
type repair struct {
|
type repair struct {
|
||||||
operationFile string
|
operationFile string
|
||||||
@@ -40,7 +42,7 @@ func newRepair(operationFile string, updateFn repairConfFn) *repair {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *repair) watchFileChanges(nbSearchDomains []string, nbNameserverIP string) {
|
func (f *repair) watchFileChanges(nbSearchDomains []string, nbNameserverIP string, stateManager *statemanager.Manager) {
|
||||||
if f.inotify != nil {
|
if f.inotify != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -81,7 +83,7 @@ func (f *repair) watchFileChanges(nbSearchDomains []string, nbNameserverIP strin
|
|||||||
log.Errorf("failed to rm inotify watch for resolv.conf: %s", err)
|
log.Errorf("failed to rm inotify watch for resolv.conf: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = f.updateFn(nbSearchDomains, nbNameserverIP, rConf)
|
err = f.updateFn(nbSearchDomains, nbNameserverIP, rConf, stateManager)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to repair resolv.conf: %v", err)
|
log.Errorf("failed to repair resolv.conf: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -104,14 +105,14 @@ nameserver 8.8.8.8`,
|
|||||||
|
|
||||||
var changed bool
|
var changed bool
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||||
updateFn := func([]string, string, *resolvConf) error {
|
updateFn := func([]string, string, *resolvConf, *statemanager.Manager) error {
|
||||||
changed = true
|
changed = true
|
||||||
cancel()
|
cancel()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
r := newRepair(operationFile, updateFn)
|
r := newRepair(operationFile, updateFn)
|
||||||
r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1")
|
r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1", nil)
|
||||||
|
|
||||||
err = os.WriteFile(operationFile, []byte(tt.touchedConfContent), 0755)
|
err = os.WriteFile(operationFile, []byte(tt.touchedConfContent), 0755)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -151,14 +152,14 @@ searchdomain netbird.cloud something`
|
|||||||
|
|
||||||
var changed bool
|
var changed bool
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||||
updateFn := func([]string, string, *resolvConf) error {
|
updateFn := func([]string, string, *resolvConf, *statemanager.Manager) error {
|
||||||
changed = true
|
changed = true
|
||||||
cancel()
|
cancel()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
r := newRepair(tmpLink, updateFn)
|
r := newRepair(tmpLink, updateFn)
|
||||||
r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1")
|
r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1", nil)
|
||||||
|
|
||||||
err = os.WriteFile(tmpLink, []byte(modifyContent), 0755)
|
err = os.WriteFile(tmpLink, []byte(modifyContent), 0755)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -11,6 +11,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -36,7 +38,7 @@ type fileConfigurator struct {
|
|||||||
nbNameserverIP string
|
nbNameserverIP string
|
||||||
}
|
}
|
||||||
|
|
||||||
func newFileConfigurator() (hostManager, error) {
|
func newFileConfigurator() (*fileConfigurator, error) {
|
||||||
fc := &fileConfigurator{}
|
fc := &fileConfigurator{}
|
||||||
fc.repair = newRepair(defaultResolvConfPath, fc.updateConfig)
|
fc.repair = newRepair(defaultResolvConfPath, fc.updateConfig)
|
||||||
return fc, nil
|
return fc, nil
|
||||||
@@ -46,7 +48,7 @@ func (f *fileConfigurator) supportCustomPort() bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
|
||||||
backupFileExist := f.isBackupFileExist()
|
backupFileExist := f.isBackupFileExist()
|
||||||
if !config.RouteAll {
|
if !config.RouteAll {
|
||||||
if backupFileExist {
|
if backupFileExist {
|
||||||
@@ -76,15 +78,15 @@ func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
|||||||
|
|
||||||
f.repair.stopWatchFileChanges()
|
f.repair.stopWatchFileChanges()
|
||||||
|
|
||||||
err = f.updateConfig(nbSearchDomains, f.nbNameserverIP, resolvConf)
|
err = f.updateConfig(nbSearchDomains, f.nbNameserverIP, resolvConf, stateManager)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
f.repair.watchFileChanges(nbSearchDomains, f.nbNameserverIP)
|
f.repair.watchFileChanges(nbSearchDomains, f.nbNameserverIP, stateManager)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP string, cfg *resolvConf) error {
|
func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP string, cfg *resolvConf, stateManager *statemanager.Manager) error {
|
||||||
searchDomainList := mergeSearchDomains(nbSearchDomains, cfg.searchDomains)
|
searchDomainList := mergeSearchDomains(nbSearchDomains, cfg.searchDomains)
|
||||||
nameServers := generateNsList(nbNameserverIP, cfg)
|
nameServers := generateNsList(nbNameserverIP, cfg)
|
||||||
|
|
||||||
@@ -107,7 +109,7 @@ func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP
|
|||||||
log.Infof("created a NetBird managed %s file with the DNS settings. Added %d search domains. Search list: %s", defaultResolvConfPath, len(searchDomainList), searchDomainList)
|
log.Infof("created a NetBird managed %s file with the DNS settings. Added %d search domains. Search list: %s", defaultResolvConfPath, len(searchDomainList), searchDomainList)
|
||||||
|
|
||||||
// create another backup for unclean shutdown detection right after overwriting the original resolv.conf
|
// create another backup for unclean shutdown detection right after overwriting the original resolv.conf
|
||||||
if err := createUncleanShutdownIndicator(fileDefaultResolvConfBackupLocation, fileManager, nbNameserverIP); err != nil {
|
if err := createUncleanShutdownIndicator(fileDefaultResolvConfBackupLocation, nbNameserverIP, stateManager); err != nil {
|
||||||
log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err)
|
log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -145,10 +147,6 @@ func (f *fileConfigurator) restore() error {
|
|||||||
return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileDefaultResolvConfBackupLocation, err)
|
return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileDefaultResolvConfBackupLocation, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := removeUncleanShutdownIndicator(); err != nil {
|
|
||||||
log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return os.RemoveAll(fileDefaultResolvConfBackupLocation)
|
return os.RemoveAll(fileDefaultResolvConfBackupLocation)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -176,7 +174,7 @@ func (f *fileConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Add
|
|||||||
return restoreResolvConfFile()
|
return restoreResolvConfFile()
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info("restoring unclean shutdown: first current nameserver differs from saved nameserver pre-netbird: not restoring")
|
log.Infof("restoring unclean shutdown: first current nameserver differs from saved nameserver pre-netbird: %s (current) vs %s (stored): not restoring", currentDNSAddress, storedDNSAddress)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -192,10 +190,6 @@ func restoreResolvConfFile() error {
|
|||||||
return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileUncleanShutdownResolvConfLocation, err)
|
return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileUncleanShutdownResolvConfLocation, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := removeUncleanShutdownIndicator(); err != nil {
|
|
||||||
log.Errorf("failed to remove unclean shutdown resolv.conf file: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -5,14 +5,14 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
type hostManager interface {
|
type hostManager interface {
|
||||||
applyDNSConfig(config HostDNSConfig) error
|
applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error
|
||||||
restoreHostDNS() error
|
restoreHostDNS() error
|
||||||
supportCustomPort() bool
|
supportCustomPort() bool
|
||||||
restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type SystemDNSSettings struct {
|
type SystemDNSSettings struct {
|
||||||
@@ -35,15 +35,15 @@ type DomainConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type mockHostConfigurator struct {
|
type mockHostConfigurator struct {
|
||||||
applyDNSConfigFunc func(config HostDNSConfig) error
|
applyDNSConfigFunc func(config HostDNSConfig, stateManager *statemanager.Manager) error
|
||||||
restoreHostDNSFunc func() error
|
restoreHostDNSFunc func() error
|
||||||
supportCustomPortFunc func() bool
|
supportCustomPortFunc func() bool
|
||||||
restoreUncleanShutdownDNSFunc func(*netip.Addr) error
|
restoreUncleanShutdownDNSFunc func(*netip.Addr) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockHostConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
func (m *mockHostConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
|
||||||
if m.applyDNSConfigFunc != nil {
|
if m.applyDNSConfigFunc != nil {
|
||||||
return m.applyDNSConfigFunc(config)
|
return m.applyDNSConfigFunc(config, stateManager)
|
||||||
}
|
}
|
||||||
return fmt.Errorf("method applyDNSSettings is not implemented")
|
return fmt.Errorf("method applyDNSSettings is not implemented")
|
||||||
}
|
}
|
||||||
@@ -62,16 +62,9 @@ func (m *mockHostConfigurator) supportCustomPort() bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockHostConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error {
|
|
||||||
if m.restoreUncleanShutdownDNSFunc != nil {
|
|
||||||
return m.restoreUncleanShutdownDNSFunc(storedDNSAddress)
|
|
||||||
}
|
|
||||||
return fmt.Errorf("method restoreUncleanShutdownDNS is not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func newNoopHostMocker() hostManager {
|
func newNoopHostMocker() hostManager {
|
||||||
return &mockHostConfigurator{
|
return &mockHostConfigurator{
|
||||||
applyDNSConfigFunc: func(config HostDNSConfig) error { return nil },
|
applyDNSConfigFunc: func(config HostDNSConfig, stateManager *statemanager.Manager) error { return nil },
|
||||||
restoreHostDNSFunc: func() error { return nil },
|
restoreHostDNSFunc: func() error { return nil },
|
||||||
supportCustomPortFunc: func() bool { return true },
|
supportCustomPortFunc: func() bool { return true },
|
||||||
restoreUncleanShutdownDNSFunc: func(*netip.Addr) error { return nil },
|
restoreUncleanShutdownDNSFunc: func(*netip.Addr) error { return nil },
|
||||||
|
|||||||
@@ -1,15 +1,17 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
import "net/netip"
|
import (
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
|
)
|
||||||
|
|
||||||
type androidHostManager struct {
|
type androidHostManager struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHostManager() (hostManager, error) {
|
func newHostManager() (*androidHostManager, error) {
|
||||||
return &androidHostManager{}, nil
|
return &androidHostManager{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a androidHostManager) applyDNSConfig(config HostDNSConfig) error {
|
func (a androidHostManager) applyDNSConfig(HostDNSConfig, *statemanager.Manager) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -20,7 +22,3 @@ func (a androidHostManager) restoreHostDNS() error {
|
|||||||
func (a androidHostManager) supportCustomPort() bool {
|
func (a androidHostManager) supportCustomPort() bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a androidHostManager) restoreUncleanShutdownDNS(*netip.Addr) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -8,12 +8,13 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -37,7 +38,7 @@ type systemConfigurator struct {
|
|||||||
systemDNSSettings SystemDNSSettings
|
systemDNSSettings SystemDNSSettings
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHostManager() (hostManager, error) {
|
func newHostManager() (*systemConfigurator, error) {
|
||||||
return &systemConfigurator{
|
return &systemConfigurator{
|
||||||
createdKeys: make(map[string]struct{}),
|
createdKeys: make(map[string]struct{}),
|
||||||
}, nil
|
}, nil
|
||||||
@@ -47,12 +48,11 @@ func (s *systemConfigurator) supportCustomPort() bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
// create a file for unclean shutdown detection
|
if err := stateManager.UpdateState(&ShutdownState{}); err != nil {
|
||||||
if err := createUncleanShutdownIndicator(); err != nil {
|
log.Errorf("failed to update shutdown state: %s", err)
|
||||||
log.Errorf("failed to create unclean shutdown file: %s", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -123,10 +123,6 @@ func (s *systemConfigurator) restoreHostDNS() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := removeUncleanShutdownIndicator(); err != nil {
|
|
||||||
log.Errorf("failed to remove unclean shutdown file: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -320,7 +316,7 @@ func (s *systemConfigurator) getPrimaryService() (string, string, error) {
|
|||||||
return primaryService, router, nil
|
return primaryService, router, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *systemConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error {
|
func (s *systemConfigurator) restoreUncleanShutdownDNS() error {
|
||||||
if err := s.restoreHostDNS(); err != nil {
|
if err := s.restoreHostDNS(); err != nil {
|
||||||
return fmt.Errorf("restoring dns via scutil: %w", err)
|
return fmt.Errorf("restoring dns via scutil: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,9 +3,10 @@ package dns
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
type iosHostManager struct {
|
type iosHostManager struct {
|
||||||
@@ -13,13 +14,13 @@ type iosHostManager struct {
|
|||||||
config HostDNSConfig
|
config HostDNSConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHostManager(dnsManager IosDnsManager) (hostManager, error) {
|
func newHostManager(dnsManager IosDnsManager) (*iosHostManager, error) {
|
||||||
return &iosHostManager{
|
return &iosHostManager{
|
||||||
dnsManager: dnsManager,
|
dnsManager: dnsManager,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a iosHostManager) applyDNSConfig(config HostDNSConfig) error {
|
func (a iosHostManager) applyDNSConfig(config HostDNSConfig, _ *statemanager.Manager) error {
|
||||||
jsonData, err := json.Marshal(config)
|
jsonData, err := json.Marshal(config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("marshal: %w", err)
|
return fmt.Errorf("marshal: %w", err)
|
||||||
@@ -37,7 +38,3 @@ func (a iosHostManager) restoreHostDNS() error {
|
|||||||
func (a iosHostManager) supportCustomPort() bool {
|
func (a iosHostManager) supportCustomPort() bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a iosHostManager) restoreUncleanShutdownDNS(*netip.Addr) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -4,9 +4,9 @@ package dns
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -21,27 +21,8 @@ const (
|
|||||||
resolvConfManager
|
resolvConfManager
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrUnknownOsManagerType = errors.New("unknown os manager type")
|
|
||||||
|
|
||||||
type osManagerType int
|
type osManagerType int
|
||||||
|
|
||||||
func newOsManagerType(osManager string) (osManagerType, error) {
|
|
||||||
switch osManager {
|
|
||||||
case "netbird":
|
|
||||||
return fileManager, nil
|
|
||||||
case "file":
|
|
||||||
return netbirdManager, nil
|
|
||||||
case "networkManager":
|
|
||||||
return networkManager, nil
|
|
||||||
case "systemd":
|
|
||||||
return systemdManager, nil
|
|
||||||
case "resolvconf":
|
|
||||||
return resolvConfManager, nil
|
|
||||||
default:
|
|
||||||
return 0, ErrUnknownOsManagerType
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t osManagerType) String() string {
|
func (t osManagerType) String() string {
|
||||||
switch t {
|
switch t {
|
||||||
case netbirdManager:
|
case netbirdManager:
|
||||||
@@ -59,6 +40,11 @@ func (t osManagerType) String() string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type restoreHostManager interface {
|
||||||
|
hostManager
|
||||||
|
restoreUncleanShutdownDNS(*netip.Addr) error
|
||||||
|
}
|
||||||
|
|
||||||
func newHostManager(wgInterface string) (hostManager, error) {
|
func newHostManager(wgInterface string) (hostManager, error) {
|
||||||
osManager, err := getOSDNSManagerType()
|
osManager, err := getOSDNSManagerType()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -69,7 +55,7 @@ func newHostManager(wgInterface string) (hostManager, error) {
|
|||||||
return newHostManagerFromType(wgInterface, osManager)
|
return newHostManagerFromType(wgInterface, osManager)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHostManagerFromType(wgInterface string, osManager osManagerType) (hostManager, error) {
|
func newHostManagerFromType(wgInterface string, osManager osManagerType) (restoreHostManager, error) {
|
||||||
switch osManager {
|
switch osManager {
|
||||||
case networkManager:
|
case networkManager:
|
||||||
return newNetworkManagerDbusConfigurator(wgInterface)
|
return newNetworkManagerDbusConfigurator(wgInterface)
|
||||||
|
|||||||
@@ -3,11 +3,12 @@ package dns
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/netip"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/sys/windows/registry"
|
"golang.org/x/sys/windows/registry"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -31,7 +32,7 @@ type registryConfigurator struct {
|
|||||||
routingAll bool
|
routingAll bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHostManager(wgInterface WGIface) (hostManager, error) {
|
func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
|
||||||
guid, err := wgInterface.GetInterfaceGUIDString()
|
guid, err := wgInterface.GetInterfaceGUIDString()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -39,7 +40,7 @@ func newHostManager(wgInterface WGIface) (hostManager, error) {
|
|||||||
return newHostManagerWithGuid(guid)
|
return newHostManagerWithGuid(guid)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHostManagerWithGuid(guid string) (hostManager, error) {
|
func newHostManagerWithGuid(guid string) (*registryConfigurator, error) {
|
||||||
return ®istryConfigurator{
|
return ®istryConfigurator{
|
||||||
guid: guid,
|
guid: guid,
|
||||||
}, nil
|
}, nil
|
||||||
@@ -49,7 +50,7 @@ func (r *registryConfigurator) supportCustomPort() bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
|
||||||
var err error
|
var err error
|
||||||
if config.RouteAll {
|
if config.RouteAll {
|
||||||
err = r.addDNSSetupForAll(config.ServerIP)
|
err = r.addDNSSetupForAll(config.ServerIP)
|
||||||
@@ -65,9 +66,8 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
|||||||
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
|
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
// create a file for unclean shutdown detection
|
if err := stateManager.UpdateState(&ShutdownState{Guid: r.guid}); err != nil {
|
||||||
if err := createUncleanShutdownIndicator(r.guid); err != nil {
|
log.Errorf("failed to update shutdown state: %s", err)
|
||||||
log.Errorf("failed to create unclean shutdown file: %s", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -160,10 +160,6 @@ func (r *registryConfigurator) restoreHostDNS() error {
|
|||||||
return fmt.Errorf("remove interface registry key: %w", err)
|
return fmt.Errorf("remove interface registry key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := removeUncleanShutdownIndicator(); err != nil {
|
|
||||||
log.Errorf("failed to remove unclean shutdown file: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -221,7 +217,7 @@ func (r *registryConfigurator) getInterfaceRegistryKey() (registry.Key, error) {
|
|||||||
return regKey, nil
|
return regKey, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *registryConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error {
|
func (r *registryConfigurator) restoreUncleanShutdownDNS() error {
|
||||||
if err := r.restoreHostDNS(); err != nil {
|
if err := r.restoreHostDNS(); err != nil {
|
||||||
return fmt.Errorf("restoring dns via registry: %w", err)
|
return fmt.Errorf("restoring dns via registry: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
nbversion "github.com/netbirdio/netbird/version"
|
nbversion "github.com/netbirdio/netbird/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -53,6 +54,7 @@ var supportedNetworkManagerVersionConstraints = []string{
|
|||||||
type networkManagerDbusConfigurator struct {
|
type networkManagerDbusConfigurator struct {
|
||||||
dbusLinkObject dbus.ObjectPath
|
dbusLinkObject dbus.ObjectPath
|
||||||
routingAll bool
|
routingAll bool
|
||||||
|
ifaceName string
|
||||||
}
|
}
|
||||||
|
|
||||||
// the types below are based on dbus specification, each field is mapped to a dbus type
|
// the types below are based on dbus specification, each field is mapped to a dbus type
|
||||||
@@ -77,7 +79,7 @@ func (s networkManagerConnSettings) cleanDeprecatedSettings() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newNetworkManagerDbusConfigurator(wgInterface string) (hostManager, error) {
|
func newNetworkManagerDbusConfigurator(wgInterface string) (*networkManagerDbusConfigurator, error) {
|
||||||
obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode)
|
obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("get nm dbus: %w", err)
|
return nil, fmt.Errorf("get nm dbus: %w", err)
|
||||||
@@ -93,6 +95,7 @@ func newNetworkManagerDbusConfigurator(wgInterface string) (hostManager, error)
|
|||||||
|
|
||||||
return &networkManagerDbusConfigurator{
|
return &networkManagerDbusConfigurator{
|
||||||
dbusLinkObject: dbus.ObjectPath(s),
|
dbusLinkObject: dbus.ObjectPath(s),
|
||||||
|
ifaceName: wgInterface,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -100,7 +103,7 @@ func (n *networkManagerDbusConfigurator) supportCustomPort() bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
|
||||||
connSettings, configVersion, err := n.getAppliedConnectionSettings()
|
connSettings, configVersion, err := n.getAppliedConnectionSettings()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("retrieving the applied connection settings, error: %w", err)
|
return fmt.Errorf("retrieving the applied connection settings, error: %w", err)
|
||||||
@@ -151,10 +154,12 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig) er
|
|||||||
connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSPriorityKey] = dbus.MakeVariant(priority)
|
connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSPriorityKey] = dbus.MakeVariant(priority)
|
||||||
connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSSearchKey] = dbus.MakeVariant(newDomainList)
|
connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSSearchKey] = dbus.MakeVariant(newDomainList)
|
||||||
|
|
||||||
// create a backup for unclean shutdown detection before adding domains, as these might end up in the resolv.conf file.
|
state := &ShutdownState{
|
||||||
// The file content itself is not important for network-manager restoration
|
ManagerType: networkManager,
|
||||||
if err := createUncleanShutdownIndicator(defaultResolvConfPath, networkManager, dnsIP.String()); err != nil {
|
WgIface: n.ifaceName,
|
||||||
log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err)
|
}
|
||||||
|
if err := stateManager.UpdateState(state); err != nil {
|
||||||
|
log.Errorf("failed to update shutdown state: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains)
|
log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains)
|
||||||
@@ -171,10 +176,6 @@ func (n *networkManagerDbusConfigurator) restoreHostDNS() error {
|
|||||||
return fmt.Errorf("delete connection settings: %w", err)
|
return fmt.Errorf("delete connection settings: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := removeUncleanShutdownIndicator(); err != nil {
|
|
||||||
log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,8 @@ import (
|
|||||||
"os/exec"
|
"os/exec"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
const resolvconfCommand = "resolvconf"
|
const resolvconfCommand = "resolvconf"
|
||||||
@@ -22,7 +24,7 @@ type resolvconf struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// supported "openresolv" only
|
// supported "openresolv" only
|
||||||
func newResolvConfConfigurator(wgInterface string) (hostManager, error) {
|
func newResolvConfConfigurator(wgInterface string) (*resolvconf, error) {
|
||||||
resolvConfEntries, err := parseDefaultResolvConf()
|
resolvConfEntries, err := parseDefaultResolvConf()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("could not read original search domains from %s: %s", defaultResolvConfPath, err)
|
log.Errorf("could not read original search domains from %s: %s", defaultResolvConfPath, err)
|
||||||
@@ -40,7 +42,7 @@ func (r *resolvconf) supportCustomPort() bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *resolvconf) applyDNSConfig(config HostDNSConfig) error {
|
func (r *resolvconf) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
|
||||||
var err error
|
var err error
|
||||||
if !config.RouteAll {
|
if !config.RouteAll {
|
||||||
err = r.restoreHostDNS()
|
err = r.restoreHostDNS()
|
||||||
@@ -60,9 +62,12 @@ func (r *resolvconf) applyDNSConfig(config HostDNSConfig) error {
|
|||||||
append([]string{config.ServerIP}, r.originalNameServers...),
|
append([]string{config.ServerIP}, r.originalNameServers...),
|
||||||
options)
|
options)
|
||||||
|
|
||||||
// create a backup for unclean shutdown detection before the resolv.conf is changed
|
state := &ShutdownState{
|
||||||
if err := createUncleanShutdownIndicator(defaultResolvConfPath, resolvConfManager, config.ServerIP); err != nil {
|
ManagerType: resolvConfManager,
|
||||||
log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err)
|
WgIface: r.ifaceName,
|
||||||
|
}
|
||||||
|
if err := stateManager.UpdateState(state); err != nil {
|
||||||
|
log.Errorf("failed to update shutdown state: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = r.applyConfig(buf)
|
err = r.applyConfig(buf)
|
||||||
@@ -79,11 +84,7 @@ func (r *resolvconf) restoreHostDNS() error {
|
|||||||
cmd := exec.Command(resolvconfCommand, "-f", "-d", r.ifaceName)
|
cmd := exec.Command(resolvconfCommand, "-f", "-d", r.ifaceName)
|
||||||
_, err := cmd.Output()
|
_, err := cmd.Output()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("removing resolvconf configuration for %s interface, error: %w", r.ifaceName, err)
|
return fmt.Errorf("removing resolvconf configuration for %s interface: %w", r.ifaceName, err)
|
||||||
}
|
|
||||||
|
|
||||||
if err := removeUncleanShutdownIndicator(); err != nil {
|
|
||||||
log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -95,7 +96,7 @@ func (r *resolvconf) applyConfig(content bytes.Buffer) error {
|
|||||||
cmd.Stdin = &content
|
cmd.Stdin = &content
|
||||||
_, err := cmd.Output()
|
_, err := cmd.Output()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("applying resolvconf configuration for %s interface, error: %w", r.ifaceName, err)
|
return fmt.Errorf("applying resolvconf configuration for %s interface: %w", r.ifaceName, err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -63,6 +64,7 @@ type DefaultServer struct {
|
|||||||
iosDnsManager IosDnsManager
|
iosDnsManager IosDnsManager
|
||||||
|
|
||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
|
stateManager *statemanager.Manager
|
||||||
}
|
}
|
||||||
|
|
||||||
type handlerWithStop interface {
|
type handlerWithStop interface {
|
||||||
@@ -77,12 +79,7 @@ type muxUpdate struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewDefaultServer returns a new dns server
|
// NewDefaultServer returns a new dns server
|
||||||
func NewDefaultServer(
|
func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress string, statusRecorder *peer.Status, stateManager *statemanager.Manager) (*DefaultServer, error) {
|
||||||
ctx context.Context,
|
|
||||||
wgInterface WGIface,
|
|
||||||
customAddress string,
|
|
||||||
statusRecorder *peer.Status,
|
|
||||||
) (*DefaultServer, error) {
|
|
||||||
var addrPort *netip.AddrPort
|
var addrPort *netip.AddrPort
|
||||||
if customAddress != "" {
|
if customAddress != "" {
|
||||||
parsedAddrPort, err := netip.ParseAddrPort(customAddress)
|
parsedAddrPort, err := netip.ParseAddrPort(customAddress)
|
||||||
@@ -99,7 +96,7 @@ func NewDefaultServer(
|
|||||||
dnsService = newServiceViaListener(wgInterface, addrPort)
|
dnsService = newServiceViaListener(wgInterface, addrPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
return newDefaultServer(ctx, wgInterface, dnsService, statusRecorder), nil
|
return newDefaultServer(ctx, wgInterface, dnsService, statusRecorder, stateManager), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems
|
// NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems
|
||||||
@@ -112,7 +109,7 @@ func NewDefaultServerPermanentUpstream(
|
|||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
) *DefaultServer {
|
) *DefaultServer {
|
||||||
log.Debugf("host dns address list is: %v", hostsDnsList)
|
log.Debugf("host dns address list is: %v", hostsDnsList)
|
||||||
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder)
|
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil)
|
||||||
ds.hostsDNSHolder.set(hostsDnsList)
|
ds.hostsDNSHolder.set(hostsDnsList)
|
||||||
ds.permanent = true
|
ds.permanent = true
|
||||||
ds.addHostRootZone()
|
ds.addHostRootZone()
|
||||||
@@ -130,12 +127,12 @@ func NewDefaultServerIos(
|
|||||||
iosDnsManager IosDnsManager,
|
iosDnsManager IosDnsManager,
|
||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
) *DefaultServer {
|
) *DefaultServer {
|
||||||
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder)
|
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil)
|
||||||
ds.iosDnsManager = iosDnsManager
|
ds.iosDnsManager = iosDnsManager
|
||||||
return ds
|
return ds
|
||||||
}
|
}
|
||||||
|
|
||||||
func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service, statusRecorder *peer.Status) *DefaultServer {
|
func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service, statusRecorder *peer.Status, stateManager *statemanager.Manager) *DefaultServer {
|
||||||
ctx, stop := context.WithCancel(ctx)
|
ctx, stop := context.WithCancel(ctx)
|
||||||
defaultServer := &DefaultServer{
|
defaultServer := &DefaultServer{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
@@ -147,6 +144,7 @@ func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService servi
|
|||||||
},
|
},
|
||||||
wgInterface: wgInterface,
|
wgInterface: wgInterface,
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
|
stateManager: stateManager,
|
||||||
hostsDNSHolder: newHostsDNSHolder(),
|
hostsDNSHolder: newHostsDNSHolder(),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -169,6 +167,7 @@ func (s *DefaultServer) Initialize() (err error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.stateManager.RegisterState(&ShutdownState{})
|
||||||
s.hostManager, err = s.initialize()
|
s.hostManager, err = s.initialize()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("initialize: %w", err)
|
return fmt.Errorf("initialize: %w", err)
|
||||||
@@ -191,9 +190,10 @@ func (s *DefaultServer) Stop() {
|
|||||||
s.ctxCancel()
|
s.ctxCancel()
|
||||||
|
|
||||||
if s.hostManager != nil {
|
if s.hostManager != nil {
|
||||||
err := s.hostManager.restoreHostDNS()
|
if err := s.hostManager.restoreHostDNS(); err != nil {
|
||||||
if err != nil {
|
log.Error("failed to restore host DNS settings: ", err)
|
||||||
log.Error(err)
|
} else if err := s.stateManager.DeleteState(&ShutdownState{}); err != nil {
|
||||||
|
log.Errorf("failed to delete shutdown dns state: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -318,10 +318,17 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
|||||||
hostUpdate.RouteAll = false
|
hostUpdate.RouteAll = false
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = s.hostManager.applyDNSConfig(hostUpdate); err != nil {
|
if err = s.hostManager.applyDNSConfig(hostUpdate, s.stateManager); err != nil {
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
// persist dns state right away
|
||||||
|
if err := s.stateManager.PersistState(s.ctx); err != nil {
|
||||||
|
log.Errorf("Failed to persist dns state: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
if s.searchDomainNotifier != nil {
|
if s.searchDomainNotifier != nil {
|
||||||
s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains())
|
s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains())
|
||||||
}
|
}
|
||||||
@@ -521,10 +528,16 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
|
if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil {
|
||||||
l.Errorf("Failed to apply nameserver deactivation on the host: %v", err)
|
l.Errorf("Failed to apply nameserver deactivation on the host: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
if err := s.stateManager.PersistState(s.ctx); err != nil {
|
||||||
|
l.Errorf("Failed to persist dns state: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
if runtime.GOOS == "android" && nsGroup.Primary && len(s.hostsDNSHolder.get()) > 0 {
|
if runtime.GOOS == "android" && nsGroup.Primary && len(s.hostsDNSHolder.get()) > 0 {
|
||||||
s.addHostRootZone()
|
s.addHostRootZone()
|
||||||
}
|
}
|
||||||
@@ -551,7 +564,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
s.currentConfig.RouteAll = true
|
s.currentConfig.RouteAll = true
|
||||||
s.service.RegisterMux(nbdns.RootZone, handler)
|
s.service.RegisterMux(nbdns.RootZone, handler)
|
||||||
}
|
}
|
||||||
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
|
if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil {
|
||||||
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
|
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
|
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/formatter"
|
"github.com/netbirdio/netbird/formatter"
|
||||||
@@ -267,7 +268,17 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil)
|
|
||||||
|
opts := iface.WGIFaceOpts{
|
||||||
|
IFaceName: fmt.Sprintf("utun230%d", n),
|
||||||
|
Address: fmt.Sprintf("100.66.100.%d/32", n+1),
|
||||||
|
WGPort: 33100,
|
||||||
|
WGPrivKey: privKey.String(),
|
||||||
|
MTU: iface.DefaultMTU,
|
||||||
|
TransportNet: newNet,
|
||||||
|
}
|
||||||
|
|
||||||
|
wgIface, err := iface.NewWGIFace(opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -281,7 +292,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
t.Log(err)
|
t.Log(err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{})
|
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -345,7 +356,15 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
privKey, _ := wgtypes.GeneratePrivateKey()
|
privKey, _ := wgtypes.GeneratePrivateKey()
|
||||||
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.1/32", 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil)
|
opts := iface.WGIFaceOpts{
|
||||||
|
IFaceName: "utun2301",
|
||||||
|
Address: "100.66.100.1/32",
|
||||||
|
WGPort: 33100,
|
||||||
|
WGPrivKey: privKey.String(),
|
||||||
|
MTU: iface.DefaultMTU,
|
||||||
|
TransportNet: newNet,
|
||||||
|
}
|
||||||
|
wgIface, err := iface.NewWGIFace(opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("build interface wireguard: %v", err)
|
t.Errorf("build interface wireguard: %v", err)
|
||||||
return
|
return
|
||||||
@@ -382,7 +401,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{})
|
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("create DNS server: %v", err)
|
t.Errorf("create DNS server: %v", err)
|
||||||
return
|
return
|
||||||
@@ -477,7 +496,7 @@ func TestDNSServerStartStop(t *testing.T) {
|
|||||||
|
|
||||||
for _, testCase := range testCases {
|
for _, testCase := range testCases {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, &peer.Status{})
|
dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, &peer.Status{}, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("%v", err)
|
t.Fatalf("%v", err)
|
||||||
}
|
}
|
||||||
@@ -536,6 +555,7 @@ func TestDNSServerStartStop(t *testing.T) {
|
|||||||
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
||||||
hostManager := &mockHostConfigurator{}
|
hostManager := &mockHostConfigurator{}
|
||||||
server := DefaultServer{
|
server := DefaultServer{
|
||||||
|
ctx: context.Background(),
|
||||||
service: NewServiceViaMemory(&mocWGIface{}),
|
service: NewServiceViaMemory(&mocWGIface{}),
|
||||||
localResolver: &localResolver{
|
localResolver: &localResolver{
|
||||||
registeredMap: make(registrationMap),
|
registeredMap: make(registrationMap),
|
||||||
@@ -552,7 +572,7 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var domainsUpdate string
|
var domainsUpdate string
|
||||||
hostManager.applyDNSConfigFunc = func(config HostDNSConfig) error {
|
hostManager.applyDNSConfigFunc = func(config HostDNSConfig, statemanager *statemanager.Manager) error {
|
||||||
domains := []string{}
|
domains := []string{}
|
||||||
for _, item := range config.Domains {
|
for _, item := range config.Domains {
|
||||||
if item.Disabled {
|
if item.Disabled {
|
||||||
@@ -762,7 +782,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
|
|||||||
Port: 53,
|
Port: 53,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Domains: []string{"customdomain.com"},
|
Domains: []string{"google.com"},
|
||||||
Primary: false,
|
Primary: false,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -784,7 +804,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
|
|||||||
if ips[0] != zoneRecords[0].RData {
|
if ips[0] != zoneRecords[0].RData {
|
||||||
t.Fatalf("invalid zone record: %v", err)
|
t.Fatalf("invalid zone record: %v", err)
|
||||||
}
|
}
|
||||||
_, err = resolver.LookupHost(context.Background(), "customdomain.com")
|
_, err = resolver.LookupHost(context.Background(), "google.com")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to resolve: %s", err)
|
t.Errorf("failed to resolve: %s", err)
|
||||||
}
|
}
|
||||||
@@ -803,7 +823,17 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
privKey, _ := wgtypes.GeneratePrivateKey()
|
privKey, _ := wgtypes.GeneratePrivateKey()
|
||||||
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil)
|
|
||||||
|
opts := iface.WGIFaceOpts{
|
||||||
|
IFaceName: "utun2301",
|
||||||
|
Address: "100.66.100.2/24",
|
||||||
|
WGPort: 33100,
|
||||||
|
WGPrivKey: privKey.String(),
|
||||||
|
MTU: iface.DefaultMTU,
|
||||||
|
TransportNet: newNet,
|
||||||
|
}
|
||||||
|
|
||||||
|
wgIface, err := iface.NewWGIFace(opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("build interface wireguard: %v", err)
|
t.Fatalf("build interface wireguard: %v", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
func (s *DefaultServer) initialize() (manager hostManager, err error) {
|
func (s *DefaultServer) initialize() (hostManager, error) {
|
||||||
return newHostManager(s.wgInterface)
|
return newHostManager(s.wgInterface)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import (
|
|||||||
|
|
||||||
var errNotImplemented = errors.New("not implemented")
|
var errNotImplemented = errors.New("not implemented")
|
||||||
|
|
||||||
func newSystemdDbusConfigurator(wgInterface string) (hostManager, error) {
|
func newSystemdDbusConfigurator(string) (restoreHostManager, error) {
|
||||||
return nil, fmt.Errorf("systemd dns management: %w on freebsd", errNotImplemented)
|
return nil, fmt.Errorf("systemd dns management: %w on freebsd", errNotImplemented)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -38,6 +39,7 @@ const (
|
|||||||
type systemdDbusConfigurator struct {
|
type systemdDbusConfigurator struct {
|
||||||
dbusLinkObject dbus.ObjectPath
|
dbusLinkObject dbus.ObjectPath
|
||||||
routingAll bool
|
routingAll bool
|
||||||
|
ifaceName string
|
||||||
}
|
}
|
||||||
|
|
||||||
// the types below are based on dbus specification, each field is mapped to a dbus type
|
// the types below are based on dbus specification, each field is mapped to a dbus type
|
||||||
@@ -55,7 +57,7 @@ type systemdDbusLinkDomainsInput struct {
|
|||||||
MatchOnly bool
|
MatchOnly bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSystemdDbusConfigurator(wgInterface string) (hostManager, error) {
|
func newSystemdDbusConfigurator(wgInterface string) (*systemdDbusConfigurator, error) {
|
||||||
iface, err := net.InterfaceByName(wgInterface)
|
iface, err := net.InterfaceByName(wgInterface)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("get interface: %w", err)
|
return nil, fmt.Errorf("get interface: %w", err)
|
||||||
@@ -77,6 +79,7 @@ func newSystemdDbusConfigurator(wgInterface string) (hostManager, error) {
|
|||||||
|
|
||||||
return &systemdDbusConfigurator{
|
return &systemdDbusConfigurator{
|
||||||
dbusLinkObject: dbus.ObjectPath(s),
|
dbusLinkObject: dbus.ObjectPath(s),
|
||||||
|
ifaceName: wgInterface,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -84,7 +87,7 @@ func (s *systemdDbusConfigurator) supportCustomPort() bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
|
||||||
parsedIP, err := netip.ParseAddr(config.ServerIP)
|
parsedIP, err := netip.ParseAddr(config.ServerIP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to parse ip address, error: %w", err)
|
return fmt.Errorf("unable to parse ip address, error: %w", err)
|
||||||
@@ -135,10 +138,12 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
|||||||
log.Infof("removing %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort)
|
log.Infof("removing %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
// create a backup for unclean shutdown detection before adding domains, as these might end up in the resolv.conf file.
|
state := &ShutdownState{
|
||||||
// The file content itself is not important for systemd restoration
|
ManagerType: systemdManager,
|
||||||
if err := createUncleanShutdownIndicator(defaultResolvConfPath, systemdManager, parsedIP.String()); err != nil {
|
WgIface: s.ifaceName,
|
||||||
log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err)
|
}
|
||||||
|
if err := stateManager.UpdateState(state); err != nil {
|
||||||
|
log.Errorf("failed to update shutdown state: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains)
|
log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains)
|
||||||
@@ -174,10 +179,6 @@ func (s *systemdDbusConfigurator) restoreHostDNS() error {
|
|||||||
return fmt.Errorf("unable to revert link configuration, got error: %w", err)
|
return fmt.Errorf("unable to revert link configuration, got error: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := removeUncleanShutdownIndicator(); err != nil {
|
|
||||||
log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return s.flushCaches()
|
return s.flushCaches()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +0,0 @@
|
|||||||
package dns
|
|
||||||
|
|
||||||
func CheckUncleanShutdown(string) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -3,57 +3,25 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/fs"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const fileUncleanShutdownFileLocation = "/var/lib/netbird/unclean_shutdown_dns"
|
type ShutdownState struct {
|
||||||
|
}
|
||||||
|
|
||||||
func CheckUncleanShutdown(string) error {
|
func (s *ShutdownState) Name() string {
|
||||||
if _, err := os.Stat(fileUncleanShutdownFileLocation); err != nil {
|
return "dns_state"
|
||||||
if errors.Is(err, fs.ErrNotExist) {
|
}
|
||||||
// no file -> clean shutdown
|
|
||||||
return nil
|
|
||||||
} else {
|
|
||||||
return fmt.Errorf("state: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Warnf("detected unclean shutdown, file %s exists. Restoring unclean shutdown dns settings.", fileUncleanShutdownFileLocation)
|
|
||||||
|
|
||||||
|
func (s *ShutdownState) Cleanup() error {
|
||||||
manager, err := newHostManager()
|
manager, err := newHostManager()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("create host manager: %w", err)
|
return fmt.Errorf("create host manager: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := manager.restoreUncleanShutdownDNS(nil); err != nil {
|
if err := manager.restoreUncleanShutdownDNS(); err != nil {
|
||||||
return fmt.Errorf("restore unclean shutdown backup: %w", err)
|
return fmt.Errorf("restore unclean shutdown dns: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func createUncleanShutdownIndicator() error {
|
|
||||||
dir := filepath.Dir(fileUncleanShutdownFileLocation)
|
|
||||||
if err := os.MkdirAll(dir, os.FileMode(0755)); err != nil {
|
|
||||||
return fmt.Errorf("create dir %s: %w", dir, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := os.WriteFile(fileUncleanShutdownFileLocation, nil, 0644); err != nil { //nolint:gosec
|
|
||||||
return fmt.Errorf("create %s: %w", fileUncleanShutdownFileLocation, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func removeUncleanShutdownIndicator() error {
|
|
||||||
if err := os.Remove(fileUncleanShutdownFileLocation); err != nil && !errors.Is(err, fs.ErrNotExist) {
|
|
||||||
return fmt.Errorf("remove %s: %w", fileUncleanShutdownFileLocation, err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,5 +0,0 @@
|
|||||||
package dns
|
|
||||||
|
|
||||||
func CheckUncleanShutdown(string) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
14
client/internal/dns/unclean_shutdown_mobile.go
Normal file
14
client/internal/dns/unclean_shutdown_mobile.go
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
//go:build ios || android
|
||||||
|
|
||||||
|
package dns
|
||||||
|
|
||||||
|
type ShutdownState struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ShutdownState) Name() string {
|
||||||
|
return "dns_state"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ShutdownState) Cleanup() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -3,66 +3,44 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/fs"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
func CheckUncleanShutdown(wgIface string) error {
|
type ShutdownState struct {
|
||||||
if _, err := os.Stat(fileUncleanShutdownResolvConfLocation); err != nil {
|
ManagerType osManagerType
|
||||||
if errors.Is(err, fs.ErrNotExist) {
|
DNSAddress netip.Addr
|
||||||
// no file -> clean shutdown
|
WgIface string
|
||||||
return nil
|
}
|
||||||
} else {
|
|
||||||
return fmt.Errorf("state: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Warnf("detected unclean shutdown, file %s exists", fileUncleanShutdownResolvConfLocation)
|
func (s *ShutdownState) Name() string {
|
||||||
|
return "dns_state"
|
||||||
|
}
|
||||||
|
|
||||||
managerData, err := os.ReadFile(fileUncleanShutdownManagerTypeLocation)
|
func (s *ShutdownState) Cleanup() error {
|
||||||
if err != nil {
|
manager, err := newHostManagerFromType(s.WgIface, s.ManagerType)
|
||||||
return fmt.Errorf("read %s: %w", fileUncleanShutdownManagerTypeLocation, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
managerFields := strings.Split(string(managerData), ",")
|
|
||||||
if len(managerFields) < 2 {
|
|
||||||
return errors.New("split manager data: insufficient number of fields")
|
|
||||||
}
|
|
||||||
osManagerTypeStr, dnsAddressStr := managerFields[0], managerFields[1]
|
|
||||||
|
|
||||||
dnsAddress, err := netip.ParseAddr(dnsAddressStr)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("parse dns address %s failed: %w", dnsAddressStr, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Warnf("restoring unclean shutdown dns settings via previously detected manager: %s", osManagerTypeStr)
|
|
||||||
|
|
||||||
// determine os manager type, so we can invoke the respective restore action
|
|
||||||
osManagerType, err := newOsManagerType(osManagerTypeStr)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("detect previous host manager: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
manager, err := newHostManagerFromType(wgIface, osManagerType)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("create previous host manager: %w", err)
|
return fmt.Errorf("create previous host manager: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := manager.restoreUncleanShutdownDNS(&dnsAddress); err != nil {
|
if err := manager.restoreUncleanShutdownDNS(&s.DNSAddress); err != nil {
|
||||||
return fmt.Errorf("restore unclean shutdown backup: %w", err)
|
return fmt.Errorf("restore unclean shutdown dns: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func createUncleanShutdownIndicator(sourcePath string, managerType osManagerType, dnsAddress string) error {
|
// TODO: move file contents to state manager
|
||||||
|
func createUncleanShutdownIndicator(sourcePath string, dnsAddressStr string, stateManager *statemanager.Manager) error {
|
||||||
|
dnsAddress, err := netip.ParseAddr(dnsAddressStr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parse dns address %s: %w", dnsAddressStr, err)
|
||||||
|
}
|
||||||
|
|
||||||
dir := filepath.Dir(fileUncleanShutdownResolvConfLocation)
|
dir := filepath.Dir(fileUncleanShutdownResolvConfLocation)
|
||||||
if err := os.MkdirAll(dir, os.FileMode(0755)); err != nil {
|
if err := os.MkdirAll(dir, os.FileMode(0755)); err != nil {
|
||||||
return fmt.Errorf("create dir %s: %w", dir, err)
|
return fmt.Errorf("create dir %s: %w", dir, err)
|
||||||
@@ -72,20 +50,13 @@ func createUncleanShutdownIndicator(sourcePath string, managerType osManagerType
|
|||||||
return fmt.Errorf("create %s: %w", sourcePath, err)
|
return fmt.Errorf("create %s: %w", sourcePath, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
managerData := fmt.Sprintf("%s,%s", managerType, dnsAddress)
|
state := &ShutdownState{
|
||||||
|
ManagerType: fileManager,
|
||||||
if err := os.WriteFile(fileUncleanShutdownManagerTypeLocation, []byte(managerData), 0644); err != nil { //nolint:gosec
|
DNSAddress: dnsAddress,
|
||||||
return fmt.Errorf("create %s: %w", fileUncleanShutdownManagerTypeLocation, err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func removeUncleanShutdownIndicator() error {
|
|
||||||
if err := os.Remove(fileUncleanShutdownResolvConfLocation); err != nil && !errors.Is(err, fs.ErrNotExist) {
|
|
||||||
return fmt.Errorf("remove %s: %w", fileUncleanShutdownResolvConfLocation, err)
|
|
||||||
}
|
|
||||||
if err := os.Remove(fileUncleanShutdownManagerTypeLocation); err != nil && !errors.Is(err, fs.ErrNotExist) {
|
|
||||||
return fmt.Errorf("remove %s: %w", fileUncleanShutdownManagerTypeLocation, err)
|
|
||||||
}
|
}
|
||||||
|
if err := stateManager.UpdateState(state); err != nil {
|
||||||
|
return fmt.Errorf("update state: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,75 +1,26 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/fs"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
type ShutdownState struct {
|
||||||
netbirdProgramDataLocation = "Netbird"
|
Guid string
|
||||||
fileUncleanShutdownFile = "unclean_shutdown_dns.txt"
|
}
|
||||||
)
|
|
||||||
|
|
||||||
func CheckUncleanShutdown(string) error {
|
func (s *ShutdownState) Name() string {
|
||||||
file := getUncleanShutdownFile()
|
return "dns_state"
|
||||||
|
}
|
||||||
|
|
||||||
if _, err := os.Stat(file); err != nil {
|
func (s *ShutdownState) Cleanup() error {
|
||||||
if errors.Is(err, fs.ErrNotExist) {
|
manager, err := newHostManagerWithGuid(s.Guid)
|
||||||
// no file -> clean shutdown
|
|
||||||
return nil
|
|
||||||
} else {
|
|
||||||
return fmt.Errorf("state: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
logrus.Warnf("detected unclean shutdown, file %s exists. Restoring unclean shutdown dns settings.", file)
|
|
||||||
|
|
||||||
guid, err := os.ReadFile(file)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("read %s: %w", file, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
manager, err := newHostManagerWithGuid(string(guid))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("create host manager: %w", err)
|
return fmt.Errorf("create host manager: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := manager.restoreUncleanShutdownDNS(nil); err != nil {
|
if err := manager.restoreUncleanShutdownDNS(); err != nil {
|
||||||
return fmt.Errorf("restore unclean shutdown backup: %w", err)
|
return fmt.Errorf("restore unclean shutdown dns: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func createUncleanShutdownIndicator(guid string) error {
|
|
||||||
file := getUncleanShutdownFile()
|
|
||||||
|
|
||||||
dir := filepath.Dir(file)
|
|
||||||
if err := os.MkdirAll(dir, os.FileMode(0755)); err != nil {
|
|
||||||
return fmt.Errorf("create dir %s: %w", dir, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := os.WriteFile(file, []byte(guid), 0600); err != nil {
|
|
||||||
return fmt.Errorf("create %s: %w", file, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func removeUncleanShutdownIndicator() error {
|
|
||||||
file := getUncleanShutdownFile()
|
|
||||||
|
|
||||||
if err := os.Remove(file); err != nil && !errors.Is(err, fs.ErrNotExist) {
|
|
||||||
return fmt.Errorf("remove %s: %w", file, err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getUncleanShutdownFile() string {
|
|
||||||
return filepath.Join(os.Getenv("PROGRAMDATA"), netbirdProgramDataLocation, fileUncleanShutdownFile)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"reflect"
|
"reflect"
|
||||||
"runtime"
|
"runtime"
|
||||||
"slices"
|
"slices"
|
||||||
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
@@ -23,19 +24,21 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall"
|
"github.com/netbirdio/netbird/client/firewall"
|
||||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl"
|
"github.com/netbirdio/netbird/client/internal/acl"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/networkmonitor"
|
"github.com/netbirdio/netbird/client/internal/networkmonitor"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||||
|
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||||
"github.com/netbirdio/netbird/client/internal/relay"
|
"github.com/netbirdio/netbird/client/internal/relay"
|
||||||
"github.com/netbirdio/netbird/client/internal/rosenpass"
|
"github.com/netbirdio/netbird/client/internal/rosenpass"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||||
"github.com/netbirdio/netbird/client/internal/wgproxy"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
|
|
||||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
@@ -141,8 +144,7 @@ type Engine struct {
|
|||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
|
|
||||||
wgInterface iface.IWGIface
|
wgInterface iface.IWGIface
|
||||||
wgProxyFactory *wgproxy.Factory
|
|
||||||
|
|
||||||
udpMux *bind.UniversalUDPMuxDefault
|
udpMux *bind.UniversalUDPMuxDefault
|
||||||
|
|
||||||
@@ -168,6 +170,8 @@ type Engine struct {
|
|||||||
checks []*mgmProto.Checks
|
checks []*mgmProto.Checks
|
||||||
|
|
||||||
relayManager *relayClient.Manager
|
relayManager *relayClient.Manager
|
||||||
|
stateManager *statemanager.Manager
|
||||||
|
srWatcher *guard.SRWatcher
|
||||||
}
|
}
|
||||||
|
|
||||||
// Peer is an instance of the Connection Peer
|
// Peer is an instance of the Connection Peer
|
||||||
@@ -215,7 +219,7 @@ func NewEngineWithProbes(
|
|||||||
probes *ProbeHolder,
|
probes *ProbeHolder,
|
||||||
checks []*mgmProto.Checks,
|
checks []*mgmProto.Checks,
|
||||||
) *Engine {
|
) *Engine {
|
||||||
return &Engine{
|
engine := &Engine{
|
||||||
clientCtx: clientCtx,
|
clientCtx: clientCtx,
|
||||||
clientCancel: clientCancel,
|
clientCancel: clientCancel,
|
||||||
signal: signalClient,
|
signal: signalClient,
|
||||||
@@ -234,6 +238,11 @@ func NewEngineWithProbes(
|
|||||||
probes: probes,
|
probes: probes,
|
||||||
checks: checks,
|
checks: checks,
|
||||||
}
|
}
|
||||||
|
if path := statemanager.GetDefaultStatePath(); path != "" {
|
||||||
|
engine.stateManager = statemanager.New(path)
|
||||||
|
}
|
||||||
|
|
||||||
|
return engine
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) Stop() error {
|
func (e *Engine) Stop() error {
|
||||||
@@ -255,7 +264,11 @@ func (e *Engine) Stop() error {
|
|||||||
e.stopDNSServer()
|
e.stopDNSServer()
|
||||||
|
|
||||||
if e.routeManager != nil {
|
if e.routeManager != nil {
|
||||||
e.routeManager.Stop()
|
e.routeManager.Stop(e.stateManager)
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.srWatcher != nil {
|
||||||
|
e.srWatcher.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
err := e.removeAllPeers()
|
err := e.removeAllPeers()
|
||||||
@@ -277,6 +290,17 @@ func (e *Engine) Stop() error {
|
|||||||
|
|
||||||
e.close()
|
e.close()
|
||||||
log.Infof("stopped Netbird Engine")
|
log.Infof("stopped Netbird Engine")
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := e.stateManager.Stop(ctx); err != nil {
|
||||||
|
return fmt.Errorf("failed to stop state manager: %w", err)
|
||||||
|
}
|
||||||
|
if err := e.stateManager.PersistState(context.Background()); err != nil {
|
||||||
|
log.Errorf("failed to persist state: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -299,9 +323,6 @@ func (e *Engine) Start() error {
|
|||||||
}
|
}
|
||||||
e.wgInterface = wgIface
|
e.wgInterface = wgIface
|
||||||
|
|
||||||
userspace := e.wgInterface.IsUserspaceBind()
|
|
||||||
e.wgProxyFactory = wgproxy.NewFactory(userspace, e.config.WgPort)
|
|
||||||
|
|
||||||
if e.config.RosenpassEnabled {
|
if e.config.RosenpassEnabled {
|
||||||
log.Infof("rosenpass is enabled")
|
log.Infof("rosenpass is enabled")
|
||||||
if e.config.RosenpassPermissive {
|
if e.config.RosenpassPermissive {
|
||||||
@@ -319,6 +340,8 @@ func (e *Engine) Start() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
e.stateManager.Start()
|
||||||
|
|
||||||
initialRoutes, dnsServer, err := e.newDnsServer()
|
initialRoutes, dnsServer, err := e.newDnsServer()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
e.close()
|
e.close()
|
||||||
@@ -327,7 +350,7 @@ func (e *Engine) Start() error {
|
|||||||
e.dnsServer = dnsServer
|
e.dnsServer = dnsServer
|
||||||
|
|
||||||
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.config.DNSRouteInterval, e.wgInterface, e.statusRecorder, e.relayManager, initialRoutes)
|
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.config.DNSRouteInterval, e.wgInterface, e.statusRecorder, e.relayManager, initialRoutes)
|
||||||
beforePeerHook, afterPeerHook, err := e.routeManager.Init()
|
beforePeerHook, afterPeerHook, err := e.routeManager.Init(e.stateManager)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed to initialize route manager: %s", err)
|
log.Errorf("Failed to initialize route manager: %s", err)
|
||||||
} else {
|
} else {
|
||||||
@@ -344,7 +367,7 @@ func (e *Engine) Start() error {
|
|||||||
return fmt.Errorf("create wg interface: %w", err)
|
return fmt.Errorf("create wg interface: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
e.firewall, err = firewall.NewFirewall(e.ctx, e.wgInterface)
|
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed creating firewall manager: %s", err)
|
log.Errorf("failed creating firewall manager: %s", err)
|
||||||
}
|
}
|
||||||
@@ -374,6 +397,18 @@ func (e *Engine) Start() error {
|
|||||||
return fmt.Errorf("initialize dns server: %w", err)
|
return fmt.Errorf("initialize dns server: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
iceCfg := icemaker.Config{
|
||||||
|
StunTurn: &e.stunTurn,
|
||||||
|
InterfaceBlackList: e.config.IFaceBlackList,
|
||||||
|
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
|
||||||
|
UDPMux: e.udpMux.UDPMuxDefault,
|
||||||
|
UDPMuxSrflx: e.udpMux,
|
||||||
|
NATExternalIPs: e.parseNATExternalIPMappings(),
|
||||||
|
}
|
||||||
|
|
||||||
|
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
|
||||||
|
e.srWatcher.Start()
|
||||||
|
|
||||||
e.receiveSignalEvents()
|
e.receiveSignalEvents()
|
||||||
e.receiveManagementEvents()
|
e.receiveManagementEvents()
|
||||||
e.receiveProbeEvents()
|
e.receiveProbeEvents()
|
||||||
@@ -606,6 +641,10 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||||
|
if e.wgInterface == nil {
|
||||||
|
return errors.New("wireguard interface is not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
if e.wgInterface.Address().String() != conf.Address {
|
if e.wgInterface.Address().String() != conf.Address {
|
||||||
oldAddr := e.wgInterface.Address().String()
|
oldAddr := e.wgInterface.Address().String()
|
||||||
log.Debugf("updating peer address from %s to %s", oldAddr, conf.Address)
|
log.Debugf("updating peer address from %s to %s", oldAddr, conf.Address)
|
||||||
@@ -956,7 +995,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e
|
|||||||
LocalWgPort: e.config.WgPort,
|
LocalWgPort: e.config.WgPort,
|
||||||
RosenpassPubKey: e.getRosenpassPubKey(),
|
RosenpassPubKey: e.getRosenpassPubKey(),
|
||||||
RosenpassAddr: e.getRosenpassAddr(),
|
RosenpassAddr: e.getRosenpassAddr(),
|
||||||
ICEConfig: peer.ICEConfig{
|
ICEConfig: icemaker.Config{
|
||||||
StunTurn: &e.stunTurn,
|
StunTurn: &e.stunTurn,
|
||||||
InterfaceBlackList: e.config.IFaceBlackList,
|
InterfaceBlackList: e.config.IFaceBlackList,
|
||||||
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
|
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
|
||||||
@@ -966,7 +1005,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.wgProxyFactory, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager)
|
peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager, e.srWatcher)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -1117,12 +1156,6 @@ func (e *Engine) parseNATExternalIPMappings() []string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) close() {
|
func (e *Engine) close() {
|
||||||
if e.wgProxyFactory != nil {
|
|
||||||
if err := e.wgProxyFactory.Free(); err != nil {
|
|
||||||
log.Errorf("failed closing ebpf proxy: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
|
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
|
||||||
if e.wgInterface != nil {
|
if e.wgInterface != nil {
|
||||||
if err := e.wgInterface.Close(); err != nil {
|
if err := e.wgInterface.Close(); err != nil {
|
||||||
@@ -1139,7 +1172,7 @@ func (e *Engine) close() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if e.firewall != nil {
|
if e.firewall != nil {
|
||||||
err := e.firewall.Reset()
|
err := e.firewall.Reset(e.stateManager)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to reset firewall: %s", err)
|
log.Warnf("failed to reset firewall: %s", err)
|
||||||
}
|
}
|
||||||
@@ -1167,21 +1200,29 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) {
|
|||||||
log.Errorf("failed to create pion's stdnet: %s", err)
|
log.Errorf("failed to create pion's stdnet: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var mArgs *device.MobileIFaceArguments
|
opts := iface.WGIFaceOpts{
|
||||||
|
IFaceName: e.config.WgIfaceName,
|
||||||
|
Address: e.config.WgAddr,
|
||||||
|
WGPort: e.config.WgPort,
|
||||||
|
WGPrivKey: e.config.WgPrivateKey.String(),
|
||||||
|
MTU: iface.DefaultMTU,
|
||||||
|
TransportNet: transportNet,
|
||||||
|
FilterFn: e.addrViaRoutes,
|
||||||
|
}
|
||||||
|
|
||||||
switch runtime.GOOS {
|
switch runtime.GOOS {
|
||||||
case "android":
|
case "android":
|
||||||
mArgs = &device.MobileIFaceArguments{
|
opts.MobileArgs = &device.MobileIFaceArguments{
|
||||||
TunAdapter: e.mobileDep.TunAdapter,
|
TunAdapter: e.mobileDep.TunAdapter,
|
||||||
TunFd: int(e.mobileDep.FileDescriptor),
|
TunFd: int(e.mobileDep.FileDescriptor),
|
||||||
}
|
}
|
||||||
case "ios":
|
case "ios":
|
||||||
mArgs = &device.MobileIFaceArguments{
|
opts.MobileArgs = &device.MobileIFaceArguments{
|
||||||
TunFd: int(e.mobileDep.FileDescriptor),
|
TunFd: int(e.mobileDep.FileDescriptor),
|
||||||
}
|
}
|
||||||
default:
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return iface.NewWGIFace(e.config.WgIfaceName, e.config.WgAddr, e.config.WgPort, e.config.WgPrivateKey.String(), iface.DefaultMTU, transportNet, mArgs, e.addrViaRoutes)
|
return iface.NewWGIFace(opts)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) wgInterfaceCreate() (err error) {
|
func (e *Engine) wgInterfaceCreate() (err error) {
|
||||||
@@ -1222,10 +1263,11 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) {
|
|||||||
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder)
|
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder)
|
||||||
return nil, dnsServer, nil
|
return nil, dnsServer, nil
|
||||||
default:
|
default:
|
||||||
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder)
|
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder, e.stateManager)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, dnsServer, nil
|
return nil, dnsServer, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1443,6 +1485,17 @@ func (e *Engine) stopDNSServer() {
|
|||||||
|
|
||||||
// isChecksEqual checks if two slices of checks are equal.
|
// isChecksEqual checks if two slices of checks are equal.
|
||||||
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
|
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
|
||||||
|
for _, check := range checks {
|
||||||
|
sort.Slice(check.Files, func(i, j int) bool {
|
||||||
|
return check.Files[i] < check.Files[j]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
for _, oCheck := range oChecks {
|
||||||
|
sort.Slice(oCheck.Files, func(i, j int) bool {
|
||||||
|
return oCheck.Files[i] < oCheck.Files[j]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool {
|
return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool {
|
||||||
return slices.Equal(checks.Files, oChecks.Files)
|
return slices.Equal(checks.Files, oChecks.Files)
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -29,6 +29,8 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||||
|
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
"github.com/netbirdio/netbird/client/ssh"
|
"github.com/netbirdio/netbird/client/ssh"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
@@ -258,6 +260,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
|||||||
}
|
}
|
||||||
engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn})
|
engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn})
|
||||||
engine.ctx = ctx
|
engine.ctx = ctx
|
||||||
|
engine.srWatcher = guard.NewSRWatcher(nil, nil, nil, icemaker.Config{})
|
||||||
|
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
name string
|
name string
|
||||||
@@ -602,7 +605,16 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil, nil)
|
|
||||||
|
opts := iface.WGIFaceOpts{
|
||||||
|
IFaceName: wgIfaceName,
|
||||||
|
Address: wgAddr,
|
||||||
|
WGPort: engine.config.WgPort,
|
||||||
|
WGPrivKey: key.String(),
|
||||||
|
MTU: iface.DefaultMTU,
|
||||||
|
TransportNet: newNet,
|
||||||
|
}
|
||||||
|
engine.wgInterface, err = iface.NewWGIFace(opts)
|
||||||
assert.NoError(t, err, "shouldn't return error")
|
assert.NoError(t, err, "shouldn't return error")
|
||||||
input := struct {
|
input := struct {
|
||||||
inputSerial uint64
|
inputSerial uint64
|
||||||
@@ -774,7 +786,15 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, 33100, key.String(), iface.DefaultMTU, newNet, nil, nil)
|
opts := iface.WGIFaceOpts{
|
||||||
|
IFaceName: wgIfaceName,
|
||||||
|
Address: wgAddr,
|
||||||
|
WGPort: 33100,
|
||||||
|
WGPrivKey: key.String(),
|
||||||
|
MTU: iface.DefaultMTU,
|
||||||
|
TransportNet: newNet,
|
||||||
|
}
|
||||||
|
engine.wgInterface, err = iface.NewWGIFace(opts)
|
||||||
assert.NoError(t, err, "shouldn't return error")
|
assert.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
mockRouteManager := &routemanager.MockManager{
|
mockRouteManager := &routemanager.MockManager{
|
||||||
@@ -986,6 +1006,99 @@ func Test_ParseNATExternalIPMappings(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_CheckFilesEqual(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
inputChecks1 []*mgmtProto.Checks
|
||||||
|
inputChecks2 []*mgmtProto.Checks
|
||||||
|
expectedBool bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Equal Files In Equal Order Should Return True",
|
||||||
|
inputChecks1: []*mgmtProto.Checks{
|
||||||
|
{
|
||||||
|
Files: []string{
|
||||||
|
"testfile1",
|
||||||
|
"testfile2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
inputChecks2: []*mgmtProto.Checks{
|
||||||
|
{
|
||||||
|
Files: []string{
|
||||||
|
"testfile1",
|
||||||
|
"testfile2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedBool: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Equal Files In Reverse Order Should Return True",
|
||||||
|
inputChecks1: []*mgmtProto.Checks{
|
||||||
|
{
|
||||||
|
Files: []string{
|
||||||
|
"testfile1",
|
||||||
|
"testfile2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
inputChecks2: []*mgmtProto.Checks{
|
||||||
|
{
|
||||||
|
Files: []string{
|
||||||
|
"testfile2",
|
||||||
|
"testfile1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedBool: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Unequal Files Should Return False",
|
||||||
|
inputChecks1: []*mgmtProto.Checks{
|
||||||
|
{
|
||||||
|
Files: []string{
|
||||||
|
"testfile1",
|
||||||
|
"testfile2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
inputChecks2: []*mgmtProto.Checks{
|
||||||
|
{
|
||||||
|
Files: []string{
|
||||||
|
"testfile1",
|
||||||
|
"testfile3",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedBool: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Compared With Empty Should Return False",
|
||||||
|
inputChecks1: []*mgmtProto.Checks{
|
||||||
|
{
|
||||||
|
Files: []string{
|
||||||
|
"testfile1",
|
||||||
|
"testfile2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
inputChecks2: []*mgmtProto.Checks{
|
||||||
|
{
|
||||||
|
Files: []string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedBool: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, testCase := range testCases {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
result := isChecksEqual(testCase.inputChecks1, testCase.inputChecks2)
|
||||||
|
assert.Equal(t, testCase.expectedBool, result, "result should match expected bool")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) {
|
func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) {
|
||||||
key, err := wgtypes.GeneratePrivateKey()
|
key, err := wgtypes.GeneratePrivateKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user