mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-20 17:26:40 +00:00
Compare commits
3 Commits
update-get
...
separate_p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7b43d7e8ef | ||
|
|
dcc83c8741 | ||
|
|
d56669ec2e |
6
.github/workflows/golang-test-linux.yml
vendored
6
.github/workflows/golang-test-linux.yml
vendored
@@ -78,9 +78,6 @@ jobs:
|
|||||||
- name: Generate RouteManager Test bin
|
- name: Generate RouteManager Test bin
|
||||||
run: CGO_ENABLED=0 go test -c -o routemanager-testing.bin ./client/internal/routemanager/...
|
run: CGO_ENABLED=0 go test -c -o routemanager-testing.bin ./client/internal/routemanager/...
|
||||||
|
|
||||||
- name: Generate nftables Manager Test bin
|
|
||||||
run: CGO_ENABLED=0 go test -c -o nftablesmanager-testing.bin ./client/firewall/nftables/...
|
|
||||||
|
|
||||||
- name: Generate Engine Test bin
|
- name: Generate Engine Test bin
|
||||||
run: CGO_ENABLED=0 go test -c -o engine-testing.bin ./client/internal
|
run: CGO_ENABLED=0 go test -c -o engine-testing.bin ./client/internal
|
||||||
|
|
||||||
@@ -99,9 +96,6 @@ jobs:
|
|||||||
- name: Run RouteManager tests in docker
|
- name: Run RouteManager tests in docker
|
||||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1
|
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1
|
||||||
|
|
||||||
- name: Run nftables Manager tests in docker
|
|
||||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/firewall --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/nftablesmanager-testing.bin -test.timeout 5m -test.parallel 1
|
|
||||||
|
|
||||||
- name: Run Engine tests in docker
|
- name: Run Engine tests in docker
|
||||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
|
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
|
||||||
|
|
||||||
|
|||||||
16
.github/workflows/release.yml
vendored
16
.github/workflows/release.yml
vendored
@@ -7,19 +7,9 @@ on:
|
|||||||
branches:
|
branches:
|
||||||
- main
|
- main
|
||||||
pull_request:
|
pull_request:
|
||||||
paths:
|
|
||||||
- 'go.mod'
|
|
||||||
- 'go.sum'
|
|
||||||
- '.goreleaser.yml'
|
|
||||||
- '.goreleaser_ui.yaml'
|
|
||||||
- '.goreleaser_ui_darwin.yaml'
|
|
||||||
- '.github/workflows/release.yml'
|
|
||||||
- 'release_files/**'
|
|
||||||
- '**/Dockerfile'
|
|
||||||
- '**/Dockerfile.*'
|
|
||||||
|
|
||||||
env:
|
env:
|
||||||
SIGN_PIPE_VER: "v0.0.8"
|
SIGN_PIPE_VER: "v0.0.6"
|
||||||
GORELEASER_VER: "v1.14.1"
|
GORELEASER_VER: "v1.14.1"
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
@@ -126,7 +116,7 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64
|
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-mingw-w64-x86-64
|
||||||
- name: Install rsrc
|
- name: Install rsrc
|
||||||
run: go install github.com/akavel/rsrc@v0.10.2
|
run: go install github.com/akavel/rsrc@v0.10.2
|
||||||
- name: Generate windows rsrc
|
- name: Generate windows rsrc
|
||||||
@@ -205,7 +195,7 @@ jobs:
|
|||||||
|
|
||||||
trigger_darwin_signer:
|
trigger_darwin_signer:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
needs: [release,release_ui_darwin]
|
needs: release_ui_darwin
|
||||||
if: startsWith(github.ref, 'refs/tags/')
|
if: startsWith(github.ref, 'refs/tags/')
|
||||||
steps:
|
steps:
|
||||||
- name: Trigger Darwin App binaries sign pipeline
|
- name: Trigger Darwin App binaries sign pipeline
|
||||||
|
|||||||
@@ -1,21 +1,17 @@
|
|||||||
name: Test Infrastructure files
|
name: Test Docker Compose Linux
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches:
|
branches:
|
||||||
- main
|
- main
|
||||||
pull_request:
|
pull_request:
|
||||||
paths:
|
|
||||||
- 'infrastructure_files/**'
|
|
||||||
- '.github/workflows/test-infrastructure-files.yml'
|
|
||||||
|
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test-docker-compose:
|
test:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Install jq
|
- name: Install jq
|
||||||
@@ -38,7 +34,7 @@ jobs:
|
|||||||
${{ runner.os }}-go-
|
${{ runner.os }}-go-
|
||||||
|
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v2
|
||||||
|
|
||||||
- name: cp setup.env
|
- name: cp setup.env
|
||||||
run: cp infrastructure_files/tests/setup.env infrastructure_files/
|
run: cp infrastructure_files/tests/setup.env infrastructure_files/
|
||||||
@@ -49,21 +45,15 @@ jobs:
|
|||||||
env:
|
env:
|
||||||
CI_NETBIRD_DOMAIN: localhost
|
CI_NETBIRD_DOMAIN: localhost
|
||||||
CI_NETBIRD_AUTH_CLIENT_ID: testing.client.id
|
CI_NETBIRD_AUTH_CLIENT_ID: testing.client.id
|
||||||
CI_NETBIRD_AUTH_CLIENT_SECRET: testing.client.secret
|
|
||||||
CI_NETBIRD_AUTH_AUDIENCE: testing.ci
|
CI_NETBIRD_AUTH_AUDIENCE: testing.ci
|
||||||
CI_NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT: https://example.eu.auth0.com/.well-known/openid-configuration
|
CI_NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT: https://example.eu.auth0.com/.well-known/openid-configuration
|
||||||
CI_NETBIRD_USE_AUTH0: true
|
CI_NETBIRD_USE_AUTH0: true
|
||||||
CI_NETBIRD_MGMT_IDP: "none"
|
|
||||||
CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id
|
|
||||||
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
|
|
||||||
CI_NETBIRD_AUTH_SUPPORTED_SCOPES: "openid profile email offline_access api email_verified"
|
|
||||||
|
|
||||||
- name: check values
|
- name: check values
|
||||||
working-directory: infrastructure_files
|
working-directory: infrastructure_files
|
||||||
env:
|
env:
|
||||||
CI_NETBIRD_DOMAIN: localhost
|
CI_NETBIRD_DOMAIN: localhost
|
||||||
CI_NETBIRD_AUTH_CLIENT_ID: testing.client.id
|
CI_NETBIRD_AUTH_CLIENT_ID: testing.client.id
|
||||||
CI_NETBIRD_AUTH_CLIENT_SECRET: testing.client.secret
|
|
||||||
CI_NETBIRD_AUTH_AUDIENCE: testing.ci
|
CI_NETBIRD_AUTH_AUDIENCE: testing.ci
|
||||||
CI_NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT: https://example.eu.auth0.com/.well-known/openid-configuration
|
CI_NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT: https://example.eu.auth0.com/.well-known/openid-configuration
|
||||||
CI_NETBIRD_USE_AUTH0: true
|
CI_NETBIRD_USE_AUTH0: true
|
||||||
@@ -72,19 +62,14 @@ jobs:
|
|||||||
CI_NETBIRD_AUTH_JWT_CERTS: https://example.eu.auth0.com/.well-known/jwks.json
|
CI_NETBIRD_AUTH_JWT_CERTS: https://example.eu.auth0.com/.well-known/jwks.json
|
||||||
CI_NETBIRD_AUTH_TOKEN_ENDPOINT: https://example.eu.auth0.com/oauth/token
|
CI_NETBIRD_AUTH_TOKEN_ENDPOINT: https://example.eu.auth0.com/oauth/token
|
||||||
CI_NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT: https://example.eu.auth0.com/oauth/device/code
|
CI_NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT: https://example.eu.auth0.com/oauth/device/code
|
||||||
CI_NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT: https://example.eu.auth0.com/authorize
|
|
||||||
CI_NETBIRD_AUTH_REDIRECT_URI: "/peers"
|
CI_NETBIRD_AUTH_REDIRECT_URI: "/peers"
|
||||||
CI_NETBIRD_TOKEN_SOURCE: "idToken"
|
CI_NETBIRD_TOKEN_SOURCE: "idToken"
|
||||||
CI_NETBIRD_AUTH_USER_ID_CLAIM: "email"
|
CI_NETBIRD_AUTH_USER_ID_CLAIM: "email"
|
||||||
CI_NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE: "super"
|
CI_NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE: "super"
|
||||||
CI_NETBIRD_AUTH_DEVICE_AUTH_SCOPE: "openid email"
|
CI_NETBIRD_AUTH_DEVICE_AUTH_SCOPE: "openid email"
|
||||||
CI_NETBIRD_MGMT_IDP: "none"
|
|
||||||
CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id
|
|
||||||
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
|
|
||||||
|
|
||||||
run: |
|
run: |
|
||||||
grep AUTH_CLIENT_ID docker-compose.yml | grep $CI_NETBIRD_AUTH_CLIENT_ID
|
grep AUTH_CLIENT_ID docker-compose.yml | grep $CI_NETBIRD_AUTH_CLIENT_ID
|
||||||
grep AUTH_CLIENT_SECRET docker-compose.yml | grep $CI_NETBIRD_AUTH_CLIENT_SECRET
|
|
||||||
grep AUTH_AUTHORITY docker-compose.yml | grep $CI_NETBIRD_AUTH_AUTHORITY
|
grep AUTH_AUTHORITY docker-compose.yml | grep $CI_NETBIRD_AUTH_AUTHORITY
|
||||||
grep AUTH_AUDIENCE docker-compose.yml | grep $CI_NETBIRD_AUTH_AUDIENCE
|
grep AUTH_AUDIENCE docker-compose.yml | grep $CI_NETBIRD_AUTH_AUDIENCE
|
||||||
grep AUTH_SUPPORTED_SCOPES docker-compose.yml | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES"
|
grep AUTH_SUPPORTED_SCOPES docker-compose.yml | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES"
|
||||||
@@ -95,21 +80,9 @@ jobs:
|
|||||||
grep LETSENCRYPT_DOMAIN docker-compose.yml | egrep 'LETSENCRYPT_DOMAIN=$'
|
grep LETSENCRYPT_DOMAIN docker-compose.yml | egrep 'LETSENCRYPT_DOMAIN=$'
|
||||||
grep NETBIRD_TOKEN_SOURCE docker-compose.yml | grep $CI_NETBIRD_TOKEN_SOURCE
|
grep NETBIRD_TOKEN_SOURCE docker-compose.yml | grep $CI_NETBIRD_TOKEN_SOURCE
|
||||||
grep AuthUserIDClaim management.json | grep $CI_NETBIRD_AUTH_USER_ID_CLAIM
|
grep AuthUserIDClaim management.json | grep $CI_NETBIRD_AUTH_USER_ID_CLAIM
|
||||||
grep -A 3 DeviceAuthorizationFlow management.json | grep -A 1 ProviderConfig | grep Audience | grep $CI_NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE
|
grep -A 1 ProviderConfig management.json | grep Audience | grep $CI_NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE
|
||||||
grep -A 8 DeviceAuthorizationFlow management.json | grep -A 6 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_DEVICE_AUTH_SCOPE"
|
grep Scope management.json | grep "$CI_NETBIRD_AUTH_DEVICE_AUTH_SCOPE"
|
||||||
grep UseIDToken management.json | grep false
|
grep UseIDToken management.json | grep false
|
||||||
grep -A 1 IdpManagerConfig management.json | grep ManagerType | grep $CI_NETBIRD_MGMT_IDP
|
|
||||||
grep -A 3 IdpManagerConfig management.json | grep -A 1 ClientConfig | grep Issuer | grep $CI_NETBIRD_AUTH_AUTHORITY
|
|
||||||
grep -A 4 IdpManagerConfig management.json | grep -A 2 ClientConfig | grep TokenEndpoint | grep $CI_NETBIRD_AUTH_TOKEN_ENDPOINT
|
|
||||||
grep -A 5 IdpManagerConfig management.json | grep -A 3 ClientConfig | grep ClientID | grep $CI_NETBIRD_IDP_MGMT_CLIENT_ID
|
|
||||||
grep -A 6 IdpManagerConfig management.json | grep -A 4 ClientConfig | grep ClientSecret | grep $CI_NETBIRD_IDP_MGMT_CLIENT_SECRET
|
|
||||||
grep -A 7 IdpManagerConfig management.json | grep -A 5 ClientConfig | grep GrantType | grep client_credentials
|
|
||||||
grep -A 2 PKCEAuthorizationFlow management.json | grep -A 1 ProviderConfig | grep Audience | grep $CI_NETBIRD_AUTH_AUDIENCE
|
|
||||||
grep -A 3 PKCEAuthorizationFlow management.json | grep -A 2 ProviderConfig | grep ClientID | grep $CI_NETBIRD_AUTH_CLIENT_ID
|
|
||||||
grep -A 4 PKCEAuthorizationFlow management.json | grep -A 3 ProviderConfig | grep ClientSecret | grep $CI_NETBIRD_AUTH_CLIENT_SECRET
|
|
||||||
grep -A 5 PKCEAuthorizationFlow management.json | grep -A 4 ProviderConfig | grep AuthorizationEndpoint | grep $CI_NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT
|
|
||||||
grep -A 6 PKCEAuthorizationFlow management.json | grep -A 5 ProviderConfig | grep TokenEndpoint | grep $CI_NETBIRD_AUTH_TOKEN_ENDPOINT
|
|
||||||
grep -A 7 PKCEAuthorizationFlow management.json | grep -A 6 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES"
|
|
||||||
|
|
||||||
- name: run docker compose up
|
- name: run docker compose up
|
||||||
working-directory: infrastructure_files
|
working-directory: infrastructure_files
|
||||||
@@ -124,28 +97,3 @@ jobs:
|
|||||||
count=$(docker compose ps --format json | jq '.[] | select(.Project | contains("infrastructure_files")) | .State' | grep -c running)
|
count=$(docker compose ps --format json | jq '.[] | select(.Project | contains("infrastructure_files")) | .State' | grep -c running)
|
||||||
test $count -eq 4
|
test $count -eq 4
|
||||||
working-directory: infrastructure_files
|
working-directory: infrastructure_files
|
||||||
|
|
||||||
test-getting-started-script:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- name: Install jq
|
|
||||||
run: sudo apt-get install -y jq
|
|
||||||
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
|
|
||||||
- name: run script
|
|
||||||
run: bash -x infrastructure_files/getting-started-with-zitadel.sh
|
|
||||||
|
|
||||||
- name: test Caddy file gen
|
|
||||||
run: test -f Caddyfile
|
|
||||||
- name: test docker-compose file gen
|
|
||||||
run: test -f docker-compose.yml
|
|
||||||
- name: test management.json file gen
|
|
||||||
run: test -f management.json
|
|
||||||
- name: test turnserver.conf file gen
|
|
||||||
run: test -f turnserver.conf
|
|
||||||
- name: test zitadel.env file gen
|
|
||||||
run: test -f zitadel.env
|
|
||||||
- name: test dashboard.env file gen
|
|
||||||
run: test -f dashboard.env
|
|
||||||
22
.github/workflows/update-docs.yml
vendored
22
.github/workflows/update-docs.yml
vendored
@@ -1,22 +0,0 @@
|
|||||||
name: update docs
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
tags:
|
|
||||||
- 'v*'
|
|
||||||
paths:
|
|
||||||
- 'management/server/http/api/openapi.yml'
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
trigger_docs_api_update:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
if: startsWith(github.ref, 'refs/tags/')
|
|
||||||
steps:
|
|
||||||
- name: Trigger API pages generation
|
|
||||||
uses: benc-uk/workflow-dispatch@v1
|
|
||||||
with:
|
|
||||||
workflow: generate api pages
|
|
||||||
repo: netbirdio/docs
|
|
||||||
ref: "refs/heads/main"
|
|
||||||
token: ${{ secrets.SIGN_GITHUB_TOKEN }}
|
|
||||||
inputs: '{ "tag": "${{ github.ref }}" }'
|
|
||||||
7
.gitignore
vendored
7
.gitignore
vendored
@@ -7,15 +7,8 @@ bin/
|
|||||||
conf.json
|
conf.json
|
||||||
http-cmds.sh
|
http-cmds.sh
|
||||||
infrastructure_files/management.json
|
infrastructure_files/management.json
|
||||||
infrastructure_files/management-*.json
|
|
||||||
infrastructure_files/docker-compose.yml
|
infrastructure_files/docker-compose.yml
|
||||||
infrastructure_files/openid-configuration.json
|
|
||||||
infrastructure_files/turnserver.conf
|
|
||||||
management/management
|
|
||||||
client/client
|
|
||||||
client/client.exe
|
|
||||||
*.syso
|
*.syso
|
||||||
client/.distfiles/
|
client/.distfiles/
|
||||||
infrastructure_files/setup.env
|
infrastructure_files/setup.env
|
||||||
infrastructure_files/setup-*.env
|
|
||||||
.vscode
|
.vscode
|
||||||
|
|||||||
@@ -12,7 +12,11 @@ builds:
|
|||||||
- arm
|
- arm
|
||||||
- amd64
|
- amd64
|
||||||
- arm64
|
- arm64
|
||||||
|
- mips
|
||||||
- 386
|
- 386
|
||||||
|
gomips:
|
||||||
|
- hardfloat
|
||||||
|
- softfloat
|
||||||
ignore:
|
ignore:
|
||||||
- goos: windows
|
- goos: windows
|
||||||
goarch: arm64
|
goarch: arm64
|
||||||
@@ -26,26 +30,6 @@ builds:
|
|||||||
tags:
|
tags:
|
||||||
- load_wgnt_from_rsrc
|
- load_wgnt_from_rsrc
|
||||||
|
|
||||||
- id: netbird-static
|
|
||||||
dir: client
|
|
||||||
binary: netbird
|
|
||||||
env: [CGO_ENABLED=0]
|
|
||||||
goos:
|
|
||||||
- linux
|
|
||||||
goarch:
|
|
||||||
- mips
|
|
||||||
- mipsle
|
|
||||||
- mips64
|
|
||||||
- mips64le
|
|
||||||
gomips:
|
|
||||||
- hardfloat
|
|
||||||
- softfloat
|
|
||||||
ldflags:
|
|
||||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
|
||||||
mod_timestamp: '{{ .CommitTimestamp }}'
|
|
||||||
tags:
|
|
||||||
- load_wgnt_from_rsrc
|
|
||||||
|
|
||||||
- id: netbird-mgmt
|
- id: netbird-mgmt
|
||||||
dir: management
|
dir: management
|
||||||
env:
|
env:
|
||||||
@@ -83,7 +67,6 @@ builds:
|
|||||||
archives:
|
archives:
|
||||||
- builds:
|
- builds:
|
||||||
- netbird
|
- netbird
|
||||||
- netbird-static
|
|
||||||
|
|
||||||
nfpms:
|
nfpms:
|
||||||
|
|
||||||
@@ -377,8 +360,3 @@ uploads:
|
|||||||
target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }}
|
target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }}
|
||||||
username: dev@wiretrustee.com
|
username: dev@wiretrustee.com
|
||||||
method: PUT
|
method: PUT
|
||||||
|
|
||||||
release:
|
|
||||||
extra_files:
|
|
||||||
- glob: ./infrastructure_files/getting-started-with-zitadel.sh
|
|
||||||
- glob: ./release_files/install.sh
|
|
||||||
@@ -11,8 +11,6 @@ builds:
|
|||||||
- amd64
|
- amd64
|
||||||
ldflags:
|
ldflags:
|
||||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||||
tags:
|
|
||||||
- legacy_appindicator
|
|
||||||
mod_timestamp: '{{ .CommitTimestamp }}'
|
mod_timestamp: '{{ .CommitTimestamp }}'
|
||||||
|
|
||||||
- id: netbird-ui-windows
|
- id: netbird-ui-windows
|
||||||
@@ -57,6 +55,9 @@ nfpms:
|
|||||||
- src: client/ui/disconnected.png
|
- src: client/ui/disconnected.png
|
||||||
dst: /usr/share/pixmaps/netbird.png
|
dst: /usr/share/pixmaps/netbird.png
|
||||||
dependencies:
|
dependencies:
|
||||||
|
- libayatana-appindicator3-1
|
||||||
|
- libgtk-3-dev
|
||||||
|
- libappindicator3-dev
|
||||||
- netbird
|
- netbird
|
||||||
|
|
||||||
- maintainer: Netbird <dev@netbird.io>
|
- maintainer: Netbird <dev@netbird.io>
|
||||||
@@ -74,6 +75,9 @@ nfpms:
|
|||||||
- src: client/ui/disconnected.png
|
- src: client/ui/disconnected.png
|
||||||
dst: /usr/share/pixmaps/netbird.png
|
dst: /usr/share/pixmaps/netbird.png
|
||||||
dependencies:
|
dependencies:
|
||||||
|
- libayatana-appindicator3-1
|
||||||
|
- libgtk-3-dev
|
||||||
|
- libappindicator3-dev
|
||||||
- netbird
|
- netbird
|
||||||
|
|
||||||
uploads:
|
uploads:
|
||||||
|
|||||||
13
README.md
13
README.md
@@ -57,10 +57,9 @@ NetBird uses [NAT traversal techniques](https://en.wikipedia.org/wiki/Interactiv
|
|||||||
- \[x] Network Routes.
|
- \[x] Network Routes.
|
||||||
- \[x] Private DNS.
|
- \[x] Private DNS.
|
||||||
- \[x] Network Activity Monitoring.
|
- \[x] Network Activity Monitoring.
|
||||||
- \[x] Mobile clients (Android).
|
|
||||||
-
|
|
||||||
**Coming soon:**
|
**Coming soon:**
|
||||||
- \[ ] Mobile clients (iOS).
|
- \[ ] Mobile clients.
|
||||||
|
|
||||||
### Secure peer-to-peer VPN with SSO and MFA in minutes
|
### Secure peer-to-peer VPN with SSO and MFA in minutes
|
||||||
|
|
||||||
@@ -71,9 +70,9 @@ For stable versions, see [releases](https://github.com/netbirdio/netbird/release
|
|||||||
|
|
||||||
### Start using NetBird
|
### Start using NetBird
|
||||||
- Hosted version: [https://app.netbird.io/](https://app.netbird.io/).
|
- Hosted version: [https://app.netbird.io/](https://app.netbird.io/).
|
||||||
- See our documentation for [Quickstart Guide](https://docs.netbird.io/how-to/getting-started).
|
- See our documentation for [Quickstart Guide](https://netbird.io/docs/getting-started/quickstart).
|
||||||
- If you are looking to self-host NetBird, check our [Self-Hosting Guide](https://docs.netbird.io/selfhosted/selfhosted-guide).
|
- If you are looking to self-host NetBird, check our [Self-Hosting Guide](https://netbird.io/docs/getting-started/self-hosting).
|
||||||
- Step-by-step [Installation Guide](https://docs.netbird.io/how-to/getting-started#installation) for different platforms.
|
- Step-by-step [Installation Guide](https://netbird.io/docs/getting-started/installation) for different platforms.
|
||||||
- Web UI [repository](https://github.com/netbirdio/dashboard).
|
- Web UI [repository](https://github.com/netbirdio/dashboard).
|
||||||
- 5 min [demo video](https://youtu.be/Tu9tPsUWaY0) on YouTube.
|
- 5 min [demo video](https://youtu.be/Tu9tPsUWaY0) on YouTube.
|
||||||
|
|
||||||
@@ -92,7 +91,7 @@ For stable versions, see [releases](https://github.com/netbirdio/netbird/release
|
|||||||
<img src="https://netbird.io/docs/img/architecture/high-level-dia.png" width="700"/>
|
<img src="https://netbird.io/docs/img/architecture/high-level-dia.png" width="700"/>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
See a complete [architecture overview](https://docs.netbird.io/about-netbird/how-netbird-works#architecture) for details.
|
See a complete [architecture overview](https://netbird.io/docs/overview/architecture) for details.
|
||||||
|
|
||||||
### Roadmap
|
### Roadmap
|
||||||
- [Public Roadmap](https://github.com/netbirdio/netbird/projects/2)
|
- [Public Roadmap](https://github.com/netbirdio/netbird/projects/2)
|
||||||
|
|||||||
@@ -1,59 +0,0 @@
|
|||||||
package base62
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"math"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
alphabet = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
|
||||||
base = uint32(len(alphabet))
|
|
||||||
)
|
|
||||||
|
|
||||||
// Encode encodes a uint32 value to a base62 string.
|
|
||||||
func Encode(num uint32) string {
|
|
||||||
if num == 0 {
|
|
||||||
return string(alphabet[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
var encoded strings.Builder
|
|
||||||
remainder := uint32(0)
|
|
||||||
|
|
||||||
for num > 0 {
|
|
||||||
remainder = num % base
|
|
||||||
encoded.WriteByte(alphabet[remainder])
|
|
||||||
num /= base
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reverse the encoded string
|
|
||||||
encodedString := encoded.String()
|
|
||||||
reversed := reverse(encodedString)
|
|
||||||
return reversed
|
|
||||||
}
|
|
||||||
|
|
||||||
// Decode decodes a base62 string to a uint32 value.
|
|
||||||
func Decode(encoded string) (uint32, error) {
|
|
||||||
var decoded uint32
|
|
||||||
strLen := len(encoded)
|
|
||||||
|
|
||||||
for i, char := range encoded {
|
|
||||||
index := strings.IndexRune(alphabet, char)
|
|
||||||
if index < 0 {
|
|
||||||
return 0, fmt.Errorf("invalid character: %c", char)
|
|
||||||
}
|
|
||||||
|
|
||||||
decoded += uint32(index) * uint32(math.Pow(float64(base), float64(strLen-i-1)))
|
|
||||||
}
|
|
||||||
|
|
||||||
return decoded, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reverse a string.
|
|
||||||
func reverse(s string) string {
|
|
||||||
runes := []rune(s)
|
|
||||||
for i, j := 0, len(runes)-1; i < j; i, j = i+1, j-1 {
|
|
||||||
runes[i], runes[j] = runes[j], runes[i]
|
|
||||||
}
|
|
||||||
return string(runes)
|
|
||||||
}
|
|
||||||
@@ -1,31 +0,0 @@
|
|||||||
package base62
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestEncodeDecode(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
num uint32
|
|
||||||
}{
|
|
||||||
{0},
|
|
||||||
{1},
|
|
||||||
{42},
|
|
||||||
{12345},
|
|
||||||
{99999},
|
|
||||||
{123456789},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
encoded := Encode(tt.num)
|
|
||||||
decoded, err := Decode(encoded)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Decode error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if decoded != tt.num {
|
|
||||||
t.Errorf("Decode(%v) = %v, want %v", encoded, decoded, tt.num)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -7,9 +7,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
"github.com/netbirdio/netbird/formatter"
|
"github.com/netbirdio/netbird/formatter"
|
||||||
@@ -31,16 +29,6 @@ type IFaceDiscover interface {
|
|||||||
stdnet.ExternalIFaceDiscover
|
stdnet.ExternalIFaceDiscover
|
||||||
}
|
}
|
||||||
|
|
||||||
// RouteListener export internal RouteListener for mobile
|
|
||||||
type RouteListener interface {
|
|
||||||
routemanager.RouteListener
|
|
||||||
}
|
|
||||||
|
|
||||||
// DnsReadyListener export internal dns ReadyListener for mobile
|
|
||||||
type DnsReadyListener interface {
|
|
||||||
dns.ReadyListener
|
|
||||||
}
|
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
formatter.SetLogcatFormatter(log.StandardLogger())
|
formatter.SetLogcatFormatter(log.StandardLogger())
|
||||||
}
|
}
|
||||||
@@ -54,12 +42,13 @@ type Client struct {
|
|||||||
ctxCancel context.CancelFunc
|
ctxCancel context.CancelFunc
|
||||||
ctxCancelLock *sync.Mutex
|
ctxCancelLock *sync.Mutex
|
||||||
deviceName string
|
deviceName string
|
||||||
routeListener routemanager.RouteListener
|
|
||||||
onHostDnsFn func([]string)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClient instantiate a new Client
|
// NewClient instantiate a new Client
|
||||||
func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, routeListener RouteListener) *Client {
|
func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover) *Client {
|
||||||
|
lvl, _ := log.ParseLevel("trace")
|
||||||
|
log.SetLevel(lvl)
|
||||||
|
|
||||||
return &Client{
|
return &Client{
|
||||||
cfgFile: cfgFile,
|
cfgFile: cfgFile,
|
||||||
deviceName: deviceName,
|
deviceName: deviceName,
|
||||||
@@ -67,12 +56,11 @@ func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter, iFaceDiscover
|
|||||||
iFaceDiscover: iFaceDiscover,
|
iFaceDiscover: iFaceDiscover,
|
||||||
recorder: peer.NewRecorder(""),
|
recorder: peer.NewRecorder(""),
|
||||||
ctxCancelLock: &sync.Mutex{},
|
ctxCancelLock: &sync.Mutex{},
|
||||||
routeListener: routeListener,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run start the internal client. It is a blocker function
|
// Run start the internal client. It is a blocker function
|
||||||
func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener) error {
|
func (c *Client) Run(urlOpener URLOpener) error {
|
||||||
cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{
|
cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{
|
||||||
ConfigPath: c.cfgFile,
|
ConfigPath: c.cfgFile,
|
||||||
})
|
})
|
||||||
@@ -97,8 +85,7 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
|
|||||||
|
|
||||||
// todo do not throw error in case of cancelled context
|
// todo do not throw error in case of cancelled context
|
||||||
ctx = internal.CtxInitState(ctx)
|
ctx = internal.CtxInitState(ctx)
|
||||||
c.onHostDnsFn = func([]string) {}
|
return internal.RunClient(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover)
|
||||||
return internal.RunClientMobile(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.routeListener, dns.items, dnsReadyListener)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop the internal client and free the resources
|
// Stop the internal client and free the resources
|
||||||
@@ -112,11 +99,6 @@ func (c *Client) Stop() {
|
|||||||
c.ctxCancel()
|
c.ctxCancel()
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetTraceLogLevel configure the logger to trace level
|
|
||||||
func (c *Client) SetTraceLogLevel() {
|
|
||||||
log.SetLevel(log.TraceLevel)
|
|
||||||
}
|
|
||||||
|
|
||||||
// PeersList return with the list of the PeerInfos
|
// PeersList return with the list of the PeerInfos
|
||||||
func (c *Client) PeersList() *PeerInfoArray {
|
func (c *Client) PeersList() *PeerInfoArray {
|
||||||
|
|
||||||
@@ -128,23 +110,14 @@ func (c *Client) PeersList() *PeerInfoArray {
|
|||||||
p.IP,
|
p.IP,
|
||||||
p.FQDN,
|
p.FQDN,
|
||||||
p.ConnStatus.String(),
|
p.ConnStatus.String(),
|
||||||
|
p.Direct,
|
||||||
}
|
}
|
||||||
peerInfos[n] = pi
|
peerInfos[n] = pi
|
||||||
}
|
}
|
||||||
|
|
||||||
return &PeerInfoArray{items: peerInfos}
|
return &PeerInfoArray{items: peerInfos}
|
||||||
}
|
}
|
||||||
|
|
||||||
// OnUpdatedHostDNS update the DNS servers addresses for root zones
|
|
||||||
func (c *Client) OnUpdatedHostDNS(list *DNSList) error {
|
|
||||||
dnsServer, err := dns.GetServerDns()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
dnsServer.OnUpdatedHostDNSServer(list.items)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetConnectionListener set the network connection listener
|
// SetConnectionListener set the network connection listener
|
||||||
func (c *Client) SetConnectionListener(listener ConnectionListener) {
|
func (c *Client) SetConnectionListener(listener ConnectionListener) {
|
||||||
c.recorder.SetConnectionListener(listener)
|
c.recorder.SetConnectionListener(listener)
|
||||||
|
|||||||
@@ -1,26 +0,0 @@
|
|||||||
package android
|
|
||||||
|
|
||||||
import "fmt"
|
|
||||||
|
|
||||||
// DNSList is a wrapper of []string
|
|
||||||
type DNSList struct {
|
|
||||||
items []string
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add new DNS address to the collection
|
|
||||||
func (array *DNSList) Add(s string) {
|
|
||||||
array.items = append(array.items, s)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get return an element of the collection
|
|
||||||
func (array *DNSList) Get(i int) (string, error) {
|
|
||||||
if i >= len(array.items) || i < 0 {
|
|
||||||
return "", fmt.Errorf("out of range")
|
|
||||||
}
|
|
||||||
return array.items[i], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Size return with the size of the collection
|
|
||||||
func (array *DNSList) Size() int {
|
|
||||||
return len(array.items)
|
|
||||||
}
|
|
||||||
@@ -1,24 +0,0 @@
|
|||||||
package android
|
|
||||||
|
|
||||||
import "testing"
|
|
||||||
|
|
||||||
func TestDNSList_Get(t *testing.T) {
|
|
||||||
l := DNSList{
|
|
||||||
items: make([]string, 1),
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := l.Get(0)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("invalid error: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = l.Get(-1)
|
|
||||||
if err == nil {
|
|
||||||
t.Errorf("expected error but got nil")
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = l.Get(1)
|
|
||||||
if err == nil {
|
|
||||||
t.Errorf("expected error but got nil")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
package android
|
|
||||||
|
|
||||||
import _ "golang.org/x/mobile/bind"
|
|
||||||
|
|
||||||
// to keep our CI/CD that checks go.mod and go.sum files happy, we need to import the package above
|
|
||||||
@@ -6,14 +6,15 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/cenkalti/backoff/v4"
|
"github.com/cenkalti/backoff/v4"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
gstatus "google.golang.org/grpc/status"
|
gstatus "google.golang.org/grpc/status"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/cmd"
|
"github.com/netbirdio/netbird/client/cmd"
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/auth"
|
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
)
|
)
|
||||||
|
|
||||||
// SSOListener is async listener for mobile framework
|
// SSOListener is async listener for mobile framework
|
||||||
@@ -85,16 +86,10 @@ func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
|
|||||||
supportsSSO := true
|
supportsSSO := true
|
||||||
err := a.withBackOff(a.ctx, func() (err error) {
|
err := a.withBackOff(a.ctx, func() (err error) {
|
||||||
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
||||||
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.NotFound {
|
|
||||||
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
|
||||||
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.NotFound {
|
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.NotFound {
|
||||||
supportsSSO = false
|
supportsSSO = false
|
||||||
err = nil
|
err = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -188,15 +183,27 @@ func (a *Auth) login(urlOpener URLOpener) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, error) {
|
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*internal.TokenInfo, error) {
|
||||||
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config)
|
providerConfig, err := internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
s, ok := gstatus.FromError(err)
|
||||||
|
if ok && s.Code() == codes.NotFound {
|
||||||
|
return nil, fmt.Errorf("no SSO provider returned from management. " +
|
||||||
|
"If you are using hosting Netbird see documentation at " +
|
||||||
|
"https://github.com/netbirdio/netbird/tree/main/management for details")
|
||||||
|
} else if ok && s.Code() == codes.Unimplemented {
|
||||||
|
return nil, fmt.Errorf("the management server, %s, does not support SSO providers, "+
|
||||||
|
"please update your servver or use Setup Keys to login", a.config.ManagementURL)
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("getting device authorization flow info failed with error: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
flowInfo, err := oAuthFlow.RequestAuthInfo(context.TODO())
|
hostedClient := internal.NewHostedDeviceFlow(providerConfig.ProviderConfig)
|
||||||
|
|
||||||
|
flowInfo, err := hostedClient.RequestDeviceCode(context.TODO())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
|
return nil, fmt.Errorf("getting a request device code failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
go urlOpener.Open(flowInfo.VerificationURIComplete)
|
go urlOpener.Open(flowInfo.VerificationURIComplete)
|
||||||
@@ -204,7 +211,7 @@ func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, err
|
|||||||
waitTimeout := time.Duration(flowInfo.ExpiresIn)
|
waitTimeout := time.Duration(flowInfo.ExpiresIn)
|
||||||
waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout*time.Second)
|
waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
|
tokenInfo, err := hostedClient.WaitToken(waitCTX, flowInfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
|
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ type PeerInfo struct {
|
|||||||
IP string
|
IP string
|
||||||
FQDN string
|
FQDN string
|
||||||
ConnStatus string // Todo replace to enum
|
ConnStatus string // Todo replace to enum
|
||||||
|
Direct bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// PeerInfoCollection made for Java layer to get non default types as collection
|
// PeerInfoCollection made for Java layer to get non default types as collection
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package cmd
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/netbirdio/netbird/client/internal/auth"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -164,15 +163,31 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.C
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *internal.Config) (*auth.TokenInfo, error) {
|
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *internal.Config) (*internal.TokenInfo, error) {
|
||||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config)
|
providerConfig, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
s, ok := gstatus.FromError(err)
|
||||||
|
if ok && s.Code() == codes.NotFound {
|
||||||
|
return nil, fmt.Errorf("no SSO provider returned from management. " +
|
||||||
|
"If you are using hosting Netbird see documentation at " +
|
||||||
|
"https://github.com/netbirdio/netbird/tree/main/management for details")
|
||||||
|
} else if ok && s.Code() == codes.Unimplemented {
|
||||||
|
mgmtURL := managementURL
|
||||||
|
if mgmtURL == "" {
|
||||||
|
mgmtURL = internal.DefaultManagementURL
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("the management server, %s, does not support SSO providers, "+
|
||||||
|
"please update your servver or use Setup Keys to login", mgmtURL)
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("getting device authorization flow info failed with error: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
flowInfo, err := oAuthFlow.RequestAuthInfo(context.TODO())
|
hostedClient := internal.NewHostedDeviceFlow(providerConfig.ProviderConfig)
|
||||||
|
|
||||||
|
flowInfo, err := hostedClient.RequestDeviceCode(context.TODO())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
|
return nil, fmt.Errorf("getting a request device code failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode)
|
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode)
|
||||||
@@ -181,7 +196,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
|
|||||||
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout*time.Second)
|
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout*time.Second)
|
||||||
defer c()
|
defer c()
|
||||||
|
|
||||||
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
|
tokenInfo, err := hostedClient.WaitToken(waitCTX, flowInfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
|
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -191,11 +206,9 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
|
|||||||
|
|
||||||
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string) {
|
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string) {
|
||||||
var codeMsg string
|
var codeMsg string
|
||||||
if userCode != "" {
|
|
||||||
if !strings.Contains(verificationURIComplete, userCode) {
|
if !strings.Contains(verificationURIComplete, userCode) {
|
||||||
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
|
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
err := open.Run(verificationURIComplete)
|
err := open.Run(verificationURIComplete)
|
||||||
cmd.Printf("Please do the SSO login in your browser. \n" +
|
cmd.Printf("Please do the SSO login in your browser. \n" +
|
||||||
|
|||||||
@@ -73,8 +73,7 @@ var sshCmd = &cobra.Command{
|
|||||||
go func() {
|
go func() {
|
||||||
// blocking
|
// blocking
|
||||||
if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil {
|
if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil {
|
||||||
log.Debug(err)
|
log.Print(err)
|
||||||
os.Exit(1)
|
|
||||||
}
|
}
|
||||||
cancel()
|
cancel()
|
||||||
}()
|
}()
|
||||||
@@ -93,10 +92,12 @@ func runSSH(ctx context.Context, addr string, pemKey []byte, cmd *cobra.Command)
|
|||||||
c, err := nbssh.DialWithKey(fmt.Sprintf("%s:%d", addr, port), user, pemKey)
|
c, err := nbssh.DialWithKey(fmt.Sprintf("%s:%d", addr, port), user, pemKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cmd.Printf("Error: %v\n", err)
|
cmd.Printf("Error: %v\n", err)
|
||||||
cmd.Printf("Couldn't connect. Please check the connection status or if the ssh server is enabled on the other peer" +
|
cmd.Printf("Couldn't connect. " +
|
||||||
"You can verify the connection by running:\n\n" +
|
"You might be disconnected from the NetBird network, or the NetBird agent isn't running.\n" +
|
||||||
" netbird status\n\n")
|
"Run the status command: \n\n" +
|
||||||
return err
|
" netbird status\n\n" +
|
||||||
|
"It might also be that the SSH server is disabled on the agent you are trying to connect to.\n")
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
go func() {
|
go func() {
|
||||||
<-ctx.Done()
|
<-ctx.Done()
|
||||||
|
|||||||
@@ -2,23 +2,21 @@ package cmd
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"net"
|
"net"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
|
|
||||||
"google.golang.org/grpc"
|
|
||||||
|
|
||||||
clientProto "github.com/netbirdio/netbird/client/proto"
|
clientProto "github.com/netbirdio/netbird/client/proto"
|
||||||
client "github.com/netbirdio/netbird/client/server"
|
client "github.com/netbirdio/netbird/client/server"
|
||||||
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
||||||
mgmt "github.com/netbirdio/netbird/management/server"
|
mgmt "github.com/netbirdio/netbird/management/server"
|
||||||
sigProto "github.com/netbirdio/netbird/signal/proto"
|
sigProto "github.com/netbirdio/netbird/signal/proto"
|
||||||
sig "github.com/netbirdio/netbird/signal/server"
|
sig "github.com/netbirdio/netbird/signal/server"
|
||||||
|
"google.golang.org/grpc"
|
||||||
)
|
)
|
||||||
|
|
||||||
func startTestingServices(t *testing.T) string {
|
func startTestingServices(t *testing.T) string {
|
||||||
@@ -65,7 +63,7 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
s := grpc.NewServer()
|
s := grpc.NewServer()
|
||||||
store, err := mgmt.NewFileStore(config.Datadir, nil)
|
store, err := mgmt.NewFileStore(config.Datadir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -104,7 +104,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
var cancel context.CancelFunc
|
var cancel context.CancelFunc
|
||||||
ctx, cancel = context.WithCancel(ctx)
|
ctx, cancel = context.WithCancel(ctx)
|
||||||
SetupCloseHandler(ctx, cancel)
|
SetupCloseHandler(ctx, cancel)
|
||||||
return internal.RunClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()))
|
return internal.RunClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()), nil, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||||
|
|||||||
@@ -13,24 +13,22 @@ type Rule interface {
|
|||||||
GetRuleID() string
|
GetRuleID() string
|
||||||
}
|
}
|
||||||
|
|
||||||
// RuleDirection is the traffic direction which a rule is applied
|
// Direction is the direction of the traffic
|
||||||
type RuleDirection int
|
type Direction int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// RuleDirectionIN applies to filters that handlers incoming traffic
|
// DirectionSrc is the direction of the traffic from the source
|
||||||
RuleDirectionIN RuleDirection = iota
|
DirectionSrc Direction = iota
|
||||||
// RuleDirectionOUT applies to filters that handlers outgoing traffic
|
// DirectionDst is the direction of the traffic from the destination
|
||||||
RuleDirectionOUT
|
DirectionDst
|
||||||
)
|
)
|
||||||
|
|
||||||
// Action is the action to be taken on a rule
|
// Action is the action to be taken on a rule
|
||||||
type Action int
|
type Action int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// ActionUnknown is a unknown action
|
|
||||||
ActionUnknown Action = iota
|
|
||||||
// ActionAccept is the action to accept a packet
|
// ActionAccept is the action to accept a packet
|
||||||
ActionAccept
|
ActionAccept Action = iota
|
||||||
// ActionDrop is the action to drop a packet
|
// ActionDrop is the action to drop a packet
|
||||||
ActionDrop
|
ActionDrop
|
||||||
)
|
)
|
||||||
@@ -41,17 +39,11 @@ const (
|
|||||||
// Netbird client for ACL and routing functionality
|
// Netbird client for ACL and routing functionality
|
||||||
type Manager interface {
|
type Manager interface {
|
||||||
// AddFiltering rule to the firewall
|
// AddFiltering rule to the firewall
|
||||||
//
|
|
||||||
// If comment argument is empty firewall manager should set
|
|
||||||
// rule ID as comment for the rule
|
|
||||||
AddFiltering(
|
AddFiltering(
|
||||||
ip net.IP,
|
ip net.IP,
|
||||||
proto Protocol,
|
port *Port,
|
||||||
sPort *Port,
|
direction Direction,
|
||||||
dPort *Port,
|
|
||||||
direction RuleDirection,
|
|
||||||
action Action,
|
action Action,
|
||||||
ipsetName string,
|
|
||||||
comment string,
|
comment string,
|
||||||
) (Rule, error)
|
) (Rule, error)
|
||||||
|
|
||||||
@@ -61,8 +53,5 @@ type Manager interface {
|
|||||||
// Reset firewall to the default state
|
// Reset firewall to the default state
|
||||||
Reset() error
|
Reset() error
|
||||||
|
|
||||||
// Flush the changes to firewall controller
|
|
||||||
Flush() error
|
|
||||||
|
|
||||||
// TODO: migrate routemanager firewal actions to this interface
|
// TODO: migrate routemanager firewal actions to this interface
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,205 +8,83 @@ import (
|
|||||||
|
|
||||||
"github.com/coreos/go-iptables/iptables"
|
"github.com/coreos/go-iptables/iptables"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/nadoo/ipset"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
fw "github.com/netbirdio/netbird/client/firewall"
|
fw "github.com/netbirdio/netbird/client/firewall"
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// ChainInputFilterName is the name of the chain that is used for filtering incoming packets
|
// ChainFilterName is the name of the chain that is used for filtering by the Netbird client
|
||||||
ChainInputFilterName = "NETBIRD-ACL-INPUT"
|
ChainFilterName = "NETBIRD-ACL"
|
||||||
|
|
||||||
// ChainOutputFilterName is the name of the chain that is used for filtering outgoing packets
|
|
||||||
ChainOutputFilterName = "NETBIRD-ACL-OUTPUT"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// dropAllDefaultRule in the Netbird chain
|
|
||||||
var dropAllDefaultRule = []string{"-j", "DROP"}
|
|
||||||
|
|
||||||
// Manager of iptables firewall
|
// Manager of iptables firewall
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
mutex sync.Mutex
|
mutex sync.Mutex
|
||||||
|
|
||||||
ipv4Client *iptables.IPTables
|
ipv4Client *iptables.IPTables
|
||||||
ipv6Client *iptables.IPTables
|
ipv6Client *iptables.IPTables
|
||||||
|
|
||||||
inputDefaultRuleSpecs []string
|
|
||||||
outputDefaultRuleSpecs []string
|
|
||||||
wgIface iFaceMapper
|
|
||||||
|
|
||||||
rulesets map[string]ruleset
|
|
||||||
}
|
|
||||||
|
|
||||||
// iFaceMapper defines subset methods of interface required for manager
|
|
||||||
type iFaceMapper interface {
|
|
||||||
Name() string
|
|
||||||
Address() iface.WGAddress
|
|
||||||
}
|
|
||||||
|
|
||||||
type ruleset struct {
|
|
||||||
rule *Rule
|
|
||||||
ips map[string]string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create iptables firewall manager
|
// Create iptables firewall manager
|
||||||
func Create(wgIface iFaceMapper) (*Manager, error) {
|
func Create() (*Manager, error) {
|
||||||
m := &Manager{
|
m := &Manager{}
|
||||||
wgIface: wgIface,
|
|
||||||
inputDefaultRuleSpecs: []string{
|
|
||||||
"-i", wgIface.Name(), "-j", ChainInputFilterName, "-s", wgIface.Address().String()},
|
|
||||||
outputDefaultRuleSpecs: []string{
|
|
||||||
"-o", wgIface.Name(), "-j", ChainOutputFilterName, "-d", wgIface.Address().String()},
|
|
||||||
rulesets: make(map[string]ruleset),
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := ipset.Init(); err != nil {
|
|
||||||
return nil, fmt.Errorf("init ipset: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// init clients for booth ipv4 and ipv6
|
// init clients for booth ipv4 and ipv6
|
||||||
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("iptables is not installed in the system or not supported")
|
return nil, fmt.Errorf("iptables is not installed in the system or not supported")
|
||||||
}
|
}
|
||||||
if isIptablesClientAvailable(ipv4Client) {
|
|
||||||
m.ipv4Client = ipv4Client
|
m.ipv4Client = ipv4Client
|
||||||
}
|
|
||||||
|
|
||||||
ipv6Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
ipv6Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("ip6tables is not installed in the system or not supported: %v", err)
|
return nil, fmt.Errorf("ip6tables is not installed in the system or not supported")
|
||||||
} else {
|
}
|
||||||
if isIptablesClientAvailable(ipv6Client) {
|
|
||||||
m.ipv6Client = ipv6Client
|
m.ipv6Client = ipv6Client
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.Reset(); err != nil {
|
if err := m.Reset(); err != nil {
|
||||||
return nil, fmt.Errorf("failed to reset firewall: %v", err)
|
return nil, fmt.Errorf("failed to reset firewall: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func isIptablesClientAvailable(client *iptables.IPTables) bool {
|
|
||||||
_, err := client.ListChains("filter")
|
|
||||||
return err == nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddFiltering rule to the firewall
|
// AddFiltering rule to the firewall
|
||||||
//
|
|
||||||
// If comment is empty rule ID is used as comment
|
|
||||||
func (m *Manager) AddFiltering(
|
func (m *Manager) AddFiltering(
|
||||||
ip net.IP,
|
ip net.IP,
|
||||||
protocol fw.Protocol,
|
port *fw.Port,
|
||||||
sPort *fw.Port,
|
direction fw.Direction,
|
||||||
dPort *fw.Port,
|
|
||||||
direction fw.RuleDirection,
|
|
||||||
action fw.Action,
|
action fw.Action,
|
||||||
ipsetName string,
|
|
||||||
comment string,
|
comment string,
|
||||||
) (fw.Rule, error) {
|
) (fw.Rule, error) {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
client := m.client(ip)
|
||||||
client, err := m.client(ip)
|
ok, err := client.ChainExists("filter", ChainFilterName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to check if chain exists: %s", err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
if err := client.NewChain("filter", ChainFilterName); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create chain: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if port == nil || port.Values == nil || (port.IsRange && len(port.Values) != 2) {
|
||||||
|
return nil, fmt.Errorf("invalid port definition")
|
||||||
|
}
|
||||||
|
pv := strconv.Itoa(port.Values[0])
|
||||||
|
if port.IsRange {
|
||||||
|
pv += ":" + strconv.Itoa(port.Values[1])
|
||||||
|
}
|
||||||
|
specs := m.filterRuleSpecs("filter", ChainFilterName, ip, pv, direction, action, comment)
|
||||||
|
if err := client.AppendUnique("filter", ChainFilterName, specs...); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var dPortVal, sPortVal string
|
|
||||||
if dPort != nil && dPort.Values != nil {
|
|
||||||
// TODO: we support only one port per rule in current implementation of ACLs
|
|
||||||
dPortVal = strconv.Itoa(dPort.Values[0])
|
|
||||||
}
|
|
||||||
if sPort != nil && sPort.Values != nil {
|
|
||||||
sPortVal = strconv.Itoa(sPort.Values[0])
|
|
||||||
}
|
|
||||||
ipsetName = m.transformIPsetName(ipsetName, sPortVal, dPortVal)
|
|
||||||
|
|
||||||
ruleID := uuid.New().String()
|
|
||||||
if comment == "" {
|
|
||||||
comment = ruleID
|
|
||||||
}
|
|
||||||
|
|
||||||
if ipsetName != "" {
|
|
||||||
rs, rsExists := m.rulesets[ipsetName]
|
|
||||||
if !rsExists {
|
|
||||||
if err := ipset.Flush(ipsetName); err != nil {
|
|
||||||
log.Errorf("flush ipset %q before use it: %v", ipsetName, err)
|
|
||||||
}
|
|
||||||
if err := ipset.Create(ipsetName); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create ipset: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := ipset.Add(ipsetName, ip.String()); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to add IP to ipset: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if rsExists {
|
|
||||||
// if ruleset already exists it means we already have the firewall rule
|
|
||||||
// so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager.
|
|
||||||
rs.ips[ip.String()] = ruleID
|
|
||||||
return &Rule{
|
|
||||||
ruleID: ruleID,
|
|
||||||
ipsetName: ipsetName,
|
|
||||||
ip: ip.String(),
|
|
||||||
dst: direction == fw.RuleDirectionOUT,
|
|
||||||
v6: ip.To4() == nil,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
// this is new ipset so we need to create firewall rule for it
|
|
||||||
}
|
|
||||||
|
|
||||||
specs := m.filterRuleSpecs("filter", ip, string(protocol), sPortVal, dPortVal,
|
|
||||||
direction, action, comment, ipsetName)
|
|
||||||
|
|
||||||
if direction == fw.RuleDirectionOUT {
|
|
||||||
ok, err := client.Exists("filter", ChainOutputFilterName, specs...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("check is output rule already exists: %w", err)
|
|
||||||
}
|
|
||||||
if ok {
|
|
||||||
return nil, fmt.Errorf("input rule already exists")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := client.Insert("filter", ChainOutputFilterName, 1, specs...); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
ok, err := client.Exists("filter", ChainInputFilterName, specs...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("check is input rule already exists: %w", err)
|
|
||||||
}
|
|
||||||
if ok {
|
|
||||||
return nil, fmt.Errorf("input rule already exists")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := client.Insert("filter", ChainInputFilterName, 1, specs...); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
rule := &Rule{
|
rule := &Rule{
|
||||||
ruleID: ruleID,
|
id: uuid.New().String(),
|
||||||
specs: specs,
|
specs: specs,
|
||||||
ipsetName: ipsetName,
|
|
||||||
ip: ip.String(),
|
|
||||||
dst: direction == fw.RuleDirectionOUT,
|
|
||||||
v6: ip.To4() == nil,
|
v6: ip.To4() == nil,
|
||||||
}
|
}
|
||||||
if ipsetName != "" {
|
|
||||||
// ipset name is defined and it means that this rule was created
|
|
||||||
// for it, need to assosiate it with ruleset
|
|
||||||
m.rulesets[ipsetName] = ruleset{
|
|
||||||
rule: rule,
|
|
||||||
ips: map[string]string{rule.ip: ruleID},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return rule, nil
|
return rule, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -214,224 +92,64 @@ func (m *Manager) AddFiltering(
|
|||||||
func (m *Manager) DeleteRule(rule fw.Rule) error {
|
func (m *Manager) DeleteRule(rule fw.Rule) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
r, ok := rule.(*Rule)
|
r, ok := rule.(*Rule)
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("invalid rule type")
|
return fmt.Errorf("invalid rule type")
|
||||||
}
|
}
|
||||||
|
|
||||||
client := m.ipv4Client
|
client := m.ipv4Client
|
||||||
if r.v6 {
|
if r.v6 {
|
||||||
if m.ipv6Client == nil {
|
|
||||||
return fmt.Errorf("ipv6 is not supported")
|
|
||||||
}
|
|
||||||
client = m.ipv6Client
|
client = m.ipv6Client
|
||||||
}
|
}
|
||||||
|
return client.Delete("filter", ChainFilterName, r.specs...)
|
||||||
if rs, ok := m.rulesets[r.ipsetName]; ok {
|
|
||||||
// delete IP from ruleset IPs list and ipset
|
|
||||||
if _, ok := rs.ips[r.ip]; ok {
|
|
||||||
if err := ipset.Del(r.ipsetName, r.ip); err != nil {
|
|
||||||
return fmt.Errorf("failed to delete ip from ipset: %w", err)
|
|
||||||
}
|
|
||||||
delete(rs.ips, r.ip)
|
|
||||||
}
|
|
||||||
|
|
||||||
// if after delete, set still contains other IPs,
|
|
||||||
// no need to delete firewall rule and we should exit here
|
|
||||||
if len(rs.ips) != 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// we delete last IP from the set, that means we need to delete
|
|
||||||
// set itself and assosiated firewall rule too
|
|
||||||
delete(m.rulesets, r.ipsetName)
|
|
||||||
|
|
||||||
if err := ipset.Destroy(r.ipsetName); err != nil {
|
|
||||||
log.Errorf("delete empty ipset: %v", err)
|
|
||||||
}
|
|
||||||
r = rs.rule
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.dst {
|
|
||||||
return client.Delete("filter", ChainOutputFilterName, r.specs...)
|
|
||||||
}
|
|
||||||
return client.Delete("filter", ChainInputFilterName, r.specs...)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reset firewall to the default state
|
// Reset firewall to the default state
|
||||||
func (m *Manager) Reset() error {
|
func (m *Manager) Reset() error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
if err := m.reset(m.ipv4Client, "filter", ChainFilterName); err != nil {
|
||||||
if err := m.reset(m.ipv4Client, "filter"); err != nil {
|
return fmt.Errorf("clean ipv4 firewall ACL chain: %w", err)
|
||||||
return fmt.Errorf("clean ipv4 firewall ACL input chain: %w", err)
|
|
||||||
}
|
}
|
||||||
if m.ipv6Client != nil {
|
if err := m.reset(m.ipv6Client, "filter", ChainFilterName); err != nil {
|
||||||
if err := m.reset(m.ipv6Client, "filter"); err != nil {
|
return fmt.Errorf("clean ipv6 firewall ACL chain: %w", err)
|
||||||
return fmt.Errorf("clean ipv6 firewall ACL input chain: %w", err)
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Flush doesn't need to be implemented for this manager
|
|
||||||
func (m *Manager) Flush() error { return nil }
|
|
||||||
|
|
||||||
// reset firewall chain, clear it and drop it
|
// reset firewall chain, clear it and drop it
|
||||||
func (m *Manager) reset(client *iptables.IPTables, table string) error {
|
func (m *Manager) reset(client *iptables.IPTables, table, chain string) error {
|
||||||
ok, err := client.ChainExists(table, ChainInputFilterName)
|
ok, err := client.ChainExists(table, chain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to check if input chain exists: %w", err)
|
return fmt.Errorf("failed to check if chain exists: %w", err)
|
||||||
}
|
}
|
||||||
if ok {
|
if !ok {
|
||||||
if ok, err := client.Exists("filter", "INPUT", m.inputDefaultRuleSpecs...); err != nil {
|
|
||||||
return err
|
|
||||||
} else if ok {
|
|
||||||
if err := client.Delete("filter", "INPUT", m.inputDefaultRuleSpecs...); err != nil {
|
|
||||||
log.WithError(err).Errorf("failed to delete default input rule: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ok, err = client.ChainExists(table, ChainOutputFilterName)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to check if output chain exists: %w", err)
|
|
||||||
}
|
|
||||||
if ok {
|
|
||||||
if ok, err := client.Exists("filter", "OUTPUT", m.outputDefaultRuleSpecs...); err != nil {
|
|
||||||
return err
|
|
||||||
} else if ok {
|
|
||||||
if err := client.Delete("filter", "OUTPUT", m.outputDefaultRuleSpecs...); err != nil {
|
|
||||||
log.WithError(err).Errorf("failed to delete default output rule: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := client.ClearAndDeleteChain(table, ChainInputFilterName); err != nil {
|
|
||||||
log.Errorf("failed to clear and delete input chain: %v", err)
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
if err := client.ClearChain(table, ChainFilterName); err != nil {
|
||||||
if err := client.ClearAndDeleteChain(table, ChainOutputFilterName); err != nil {
|
return fmt.Errorf("failed to clear chain: %w", err)
|
||||||
log.Errorf("failed to clear and delete input chain: %v", err)
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
return client.DeleteChain(table, ChainFilterName)
|
||||||
for ipsetName := range m.rulesets {
|
|
||||||
if err := ipset.Flush(ipsetName); err != nil {
|
|
||||||
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
|
|
||||||
}
|
|
||||||
if err := ipset.Destroy(ipsetName); err != nil {
|
|
||||||
log.Errorf("delete ipset %q during reset: %v", ipsetName, err)
|
|
||||||
}
|
|
||||||
delete(m.rulesets, ipsetName)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// filterRuleSpecs returns the specs of a filtering rule
|
// filterRuleSpecs returns the specs of a filtering rule
|
||||||
func (m *Manager) filterRuleSpecs(
|
func (m *Manager) filterRuleSpecs(
|
||||||
table string, ip net.IP, protocol string, sPort, dPort string,
|
table string, chain string, ip net.IP, port string,
|
||||||
direction fw.RuleDirection, action fw.Action, comment string,
|
direction fw.Direction, action fw.Action, comment string,
|
||||||
ipsetName string,
|
|
||||||
) (specs []string) {
|
) (specs []string) {
|
||||||
matchByIP := true
|
if direction == fw.DirectionSrc {
|
||||||
// don't use IP matching if IP is ip 0.0.0.0
|
|
||||||
if s := ip.String(); s == "0.0.0.0" || s == "::" {
|
|
||||||
matchByIP = false
|
|
||||||
}
|
|
||||||
switch direction {
|
|
||||||
case fw.RuleDirectionIN:
|
|
||||||
if matchByIP {
|
|
||||||
if ipsetName != "" {
|
|
||||||
specs = append(specs, "-m", "set", "--set", ipsetName, "src")
|
|
||||||
} else {
|
|
||||||
specs = append(specs, "-s", ip.String())
|
specs = append(specs, "-s", ip.String())
|
||||||
}
|
}
|
||||||
}
|
specs = append(specs, "-p", "tcp", "--dport", port)
|
||||||
case fw.RuleDirectionOUT:
|
|
||||||
if matchByIP {
|
|
||||||
if ipsetName != "" {
|
|
||||||
specs = append(specs, "-m", "set", "--set", ipsetName, "dst")
|
|
||||||
} else {
|
|
||||||
specs = append(specs, "-d", ip.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if protocol != "all" {
|
|
||||||
specs = append(specs, "-p", protocol)
|
|
||||||
}
|
|
||||||
if sPort != "" {
|
|
||||||
specs = append(specs, "--sport", sPort)
|
|
||||||
}
|
|
||||||
if dPort != "" {
|
|
||||||
specs = append(specs, "--dport", dPort)
|
|
||||||
}
|
|
||||||
specs = append(specs, "-j", m.actionToStr(action))
|
specs = append(specs, "-j", m.actionToStr(action))
|
||||||
return append(specs, "-m", "comment", "--comment", comment)
|
return append(specs, "-m", "comment", "--comment", comment)
|
||||||
}
|
}
|
||||||
|
|
||||||
// rawClient returns corresponding iptables client for the given ip
|
// client returns corresponding iptables client for the given ip
|
||||||
func (m *Manager) rawClient(ip net.IP) (*iptables.IPTables, error) {
|
func (m *Manager) client(ip net.IP) *iptables.IPTables {
|
||||||
if ip.To4() != nil {
|
if ip.To4() != nil {
|
||||||
return m.ipv4Client, nil
|
return m.ipv4Client
|
||||||
}
|
}
|
||||||
if m.ipv6Client == nil {
|
return m.ipv6Client
|
||||||
return nil, fmt.Errorf("ipv6 is not supported")
|
|
||||||
}
|
|
||||||
return m.ipv6Client, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// client returns client with initialized chain and default rules
|
|
||||||
func (m *Manager) client(ip net.IP) (*iptables.IPTables, error) {
|
|
||||||
client, err := m.rawClient(ip)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
ok, err := client.ChainExists("filter", ChainInputFilterName)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to check if chain exists: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !ok {
|
|
||||||
if err := client.NewChain("filter", ChainInputFilterName); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create input chain: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := client.AppendUnique("filter", ChainInputFilterName, dropAllDefaultRule...); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create default drop all in netbird input chain: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := client.AppendUnique("filter", "INPUT", m.inputDefaultRuleSpecs...); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create input chain jump rule: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
ok, err = client.ChainExists("filter", ChainOutputFilterName)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to check if chain exists: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !ok {
|
|
||||||
if err := client.NewChain("filter", ChainOutputFilterName); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create output chain: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := client.AppendUnique("filter", ChainOutputFilterName, dropAllDefaultRule...); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create default drop all in netbird output chain: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := client.AppendUnique("filter", "OUTPUT", m.outputDefaultRuleSpecs...); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create output chain jump rule: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return client, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) actionToStr(action fw.Action) string {
|
func (m *Manager) actionToStr(action fw.Action) string {
|
||||||
@@ -440,16 +158,3 @@ func (m *Manager) actionToStr(action fw.Action) string {
|
|||||||
}
|
}
|
||||||
return "DROP"
|
return "DROP"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) transformIPsetName(ipsetName string, sPort, dPort string) string {
|
|
||||||
if ipsetName == "" {
|
|
||||||
return ""
|
|
||||||
} else if sPort != "" && dPort != "" {
|
|
||||||
return ipsetName + "-sport-dport"
|
|
||||||
} else if sPort != "" {
|
|
||||||
return ipsetName + "-sport"
|
|
||||||
} else if dPort != "" {
|
|
||||||
return ipsetName + "-dport"
|
|
||||||
}
|
|
||||||
return ipsetName
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,261 +1,105 @@
|
|||||||
package iptables
|
package iptables
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/coreos/go-iptables/iptables"
|
"github.com/coreos/go-iptables/iptables"
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
fw "github.com/netbirdio/netbird/client/firewall"
|
fw "github.com/netbirdio/netbird/client/firewall"
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// iFaceMapper defines subset methods of interface required for manager
|
func TestNewManager(t *testing.T) {
|
||||||
type iFaceMock struct {
|
|
||||||
NameFunc func() string
|
|
||||||
AddressFunc func() iface.WGAddress
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *iFaceMock) Name() string {
|
|
||||||
if i.NameFunc != nil {
|
|
||||||
return i.NameFunc()
|
|
||||||
}
|
|
||||||
panic("NameFunc is not set")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *iFaceMock) Address() iface.WGAddress {
|
|
||||||
if i.AddressFunc != nil {
|
|
||||||
return i.AddressFunc()
|
|
||||||
}
|
|
||||||
panic("AddressFunc is not set")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIptablesManager(t *testing.T) {
|
|
||||||
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
require.NoError(t, err)
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
mock := &iFaceMock{
|
|
||||||
NameFunc: func() string {
|
|
||||||
return "lo"
|
|
||||||
},
|
|
||||||
AddressFunc: func() iface.WGAddress {
|
|
||||||
return iface.WGAddress{
|
|
||||||
IP: net.ParseIP("10.20.0.1"),
|
|
||||||
Network: &net.IPNet{
|
|
||||||
IP: net.ParseIP("10.20.0.0"),
|
|
||||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// just check on the local interface
|
manager, err := Create()
|
||||||
manager, err := Create(mock)
|
if err != nil {
|
||||||
require.NoError(t, err)
|
t.Fatal(err)
|
||||||
|
}
|
||||||
time.Sleep(time.Second)
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
err := manager.Reset()
|
|
||||||
require.NoError(t, err, "clear the manager state")
|
|
||||||
|
|
||||||
time.Sleep(time.Second)
|
|
||||||
}()
|
|
||||||
|
|
||||||
var rule1 fw.Rule
|
var rule1 fw.Rule
|
||||||
t.Run("add first rule", func(t *testing.T) {
|
t.Run("add first rule", func(t *testing.T) {
|
||||||
ip := net.ParseIP("10.20.0.2")
|
ip := net.ParseIP("10.20.0.2")
|
||||||
port := &fw.Port{Values: []int{8080}}
|
port := &fw.Port{Proto: fw.PortProtocolTCP, Values: []int{8080}}
|
||||||
rule1, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
rule1, err = manager.AddFiltering(ip, port, fw.DirectionDst, fw.ActionAccept, "accept HTTP traffic")
|
||||||
require.NoError(t, err, "failed to add rule")
|
if err != nil {
|
||||||
|
t.Errorf("failed to add rule: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, true, rule1.(*Rule).specs...)
|
checkRuleSpecs(t, ipv4Client, true, rule1.(*Rule).specs...)
|
||||||
})
|
})
|
||||||
|
|
||||||
var rule2 fw.Rule
|
var rule2 fw.Rule
|
||||||
t.Run("add second rule", func(t *testing.T) {
|
t.Run("add second rule", func(t *testing.T) {
|
||||||
ip := net.ParseIP("10.20.0.3")
|
ip := net.ParseIP("10.20.0.3")
|
||||||
port := &fw.Port{
|
port := &fw.Port{
|
||||||
|
Proto: fw.PortProtocolTCP,
|
||||||
Values: []int{8043: 8046},
|
Values: []int{8043: 8046},
|
||||||
}
|
}
|
||||||
rule2, err = manager.AddFiltering(
|
rule2, err = manager.AddFiltering(
|
||||||
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
|
ip, port, fw.DirectionDst, fw.ActionAccept, "accept HTTPS traffic from ports range")
|
||||||
require.NoError(t, err, "failed to add rule")
|
if err != nil {
|
||||||
|
t.Errorf("failed to add rule: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
checkRuleSpecs(t, ipv4Client, ChainInputFilterName, true, rule2.(*Rule).specs...)
|
checkRuleSpecs(t, ipv4Client, true, rule2.(*Rule).specs...)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("delete first rule", func(t *testing.T) {
|
t.Run("delete first rule", func(t *testing.T) {
|
||||||
err := manager.DeleteRule(rule1)
|
if err := manager.DeleteRule(rule1); err != nil {
|
||||||
require.NoError(t, err, "failed to delete rule")
|
t.Errorf("failed to delete rule: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, false, rule1.(*Rule).specs...)
|
checkRuleSpecs(t, ipv4Client, false, rule1.(*Rule).specs...)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("delete second rule", func(t *testing.T) {
|
t.Run("delete second rule", func(t *testing.T) {
|
||||||
err := manager.DeleteRule(rule2)
|
if err := manager.DeleteRule(rule2); err != nil {
|
||||||
require.NoError(t, err, "failed to delete rule")
|
t.Errorf("failed to delete rule: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
require.Empty(t, manager.rulesets, "rulesets index after removed second rule must be empty")
|
checkRuleSpecs(t, ipv4Client, false, rule2.(*Rule).specs...)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("reset check", func(t *testing.T) {
|
t.Run("reset check", func(t *testing.T) {
|
||||||
// add second rule
|
// add second rule
|
||||||
ip := net.ParseIP("10.20.0.3")
|
ip := net.ParseIP("10.20.0.3")
|
||||||
port := &fw.Port{Values: []int{5353}}
|
port := &fw.Port{Proto: fw.PortProtocolUDP, Values: []int{5353}}
|
||||||
_, err = manager.AddFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic")
|
_, err = manager.AddFiltering(ip, port, fw.DirectionDst, fw.ActionAccept, "accept Fake DNS traffic")
|
||||||
require.NoError(t, err, "failed to add rule")
|
if err != nil {
|
||||||
|
t.Errorf("failed to add rule: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
err = manager.Reset()
|
if err := manager.Reset(); err != nil {
|
||||||
require.NoError(t, err, "failed to reset")
|
t.Errorf("failed to reset: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
ok, err := ipv4Client.ChainExists("filter", ChainInputFilterName)
|
ok, err := ipv4Client.ChainExists("filter", ChainFilterName)
|
||||||
require.NoError(t, err, "failed check chain exists")
|
if err != nil {
|
||||||
|
t.Errorf("failed to drop chain: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
if ok {
|
if ok {
|
||||||
require.NoErrorf(t, err, "chain '%v' still exists after Reset", ChainInputFilterName)
|
t.Errorf("chain '%v' still exists after Reset", ChainFilterName)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestIptablesManagerIPSet(t *testing.T) {
|
func checkRuleSpecs(t *testing.T, ipv4Client *iptables.IPTables, mustExists bool, rulespec ...string) {
|
||||||
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
exists, err := ipv4Client.Exists("filter", ChainFilterName, rulespec...)
|
||||||
require.NoError(t, err)
|
if err != nil {
|
||||||
|
t.Errorf("failed to check rule: %v", err)
|
||||||
mock := &iFaceMock{
|
return
|
||||||
NameFunc: func() string {
|
|
||||||
return "lo"
|
|
||||||
},
|
|
||||||
AddressFunc: func() iface.WGAddress {
|
|
||||||
return iface.WGAddress{
|
|
||||||
IP: net.ParseIP("10.20.0.1"),
|
|
||||||
Network: &net.IPNet{
|
|
||||||
IP: net.ParseIP("10.20.0.0"),
|
|
||||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// just check on the local interface
|
if !exists && mustExists {
|
||||||
manager, err := Create(mock)
|
t.Errorf("rule '%v' does not exist", rulespec)
|
||||||
require.NoError(t, err)
|
return
|
||||||
|
|
||||||
time.Sleep(time.Second)
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
err := manager.Reset()
|
|
||||||
require.NoError(t, err, "clear the manager state")
|
|
||||||
|
|
||||||
time.Sleep(time.Second)
|
|
||||||
}()
|
|
||||||
|
|
||||||
var rule1 fw.Rule
|
|
||||||
t.Run("add first rule with set", func(t *testing.T) {
|
|
||||||
ip := net.ParseIP("10.20.0.2")
|
|
||||||
port := &fw.Port{Values: []int{8080}}
|
|
||||||
rule1, err = manager.AddFiltering(
|
|
||||||
ip, "tcp", nil, port, fw.RuleDirectionOUT,
|
|
||||||
fw.ActionAccept, "default", "accept HTTP traffic",
|
|
||||||
)
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
|
||||||
|
|
||||||
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, true, rule1.(*Rule).specs...)
|
|
||||||
require.Equal(t, rule1.(*Rule).ipsetName, "default-dport", "ipset name must be set")
|
|
||||||
require.Equal(t, rule1.(*Rule).ip, "10.20.0.2", "ipset IP must be set")
|
|
||||||
})
|
|
||||||
|
|
||||||
var rule2 fw.Rule
|
|
||||||
t.Run("add second rule", func(t *testing.T) {
|
|
||||||
ip := net.ParseIP("10.20.0.3")
|
|
||||||
port := &fw.Port{
|
|
||||||
Values: []int{443},
|
|
||||||
}
|
}
|
||||||
rule2, err = manager.AddFiltering(
|
if exists && !mustExists {
|
||||||
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept,
|
t.Errorf("rule '%v' exist", rulespec)
|
||||||
"default", "accept HTTPS traffic from ports range",
|
return
|
||||||
)
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
|
||||||
require.Equal(t, rule2.(*Rule).ipsetName, "default-sport", "ipset name must be set")
|
|
||||||
require.Equal(t, rule2.(*Rule).ip, "10.20.0.3", "ipset IP must be set")
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("delete first rule", func(t *testing.T) {
|
|
||||||
err := manager.DeleteRule(rule1)
|
|
||||||
require.NoError(t, err, "failed to delete rule")
|
|
||||||
|
|
||||||
require.NotContains(t, manager.rulesets, rule1.(*Rule).ruleID, "rule must be removed form the ruleset index")
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("delete second rule", func(t *testing.T) {
|
|
||||||
err := manager.DeleteRule(rule2)
|
|
||||||
require.NoError(t, err, "failed to delete rule")
|
|
||||||
|
|
||||||
require.Empty(t, manager.rulesets, "rulesets index after removed second rule must be empty")
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("reset check", func(t *testing.T) {
|
|
||||||
err = manager.Reset()
|
|
||||||
require.NoError(t, err, "failed to reset")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func checkRuleSpecs(t *testing.T, ipv4Client *iptables.IPTables, chainName string, mustExists bool, rulespec ...string) {
|
|
||||||
exists, err := ipv4Client.Exists("filter", chainName, rulespec...)
|
|
||||||
require.NoError(t, err, "failed to check rule")
|
|
||||||
require.Falsef(t, !exists && mustExists, "rule '%v' does not exist", rulespec)
|
|
||||||
require.Falsef(t, exists && !mustExists, "rule '%v' exist", rulespec)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIptablesCreatePerformance(t *testing.T) {
|
|
||||||
mock := &iFaceMock{
|
|
||||||
NameFunc: func() string {
|
|
||||||
return "lo"
|
|
||||||
},
|
|
||||||
AddressFunc: func() iface.WGAddress {
|
|
||||||
return iface.WGAddress{
|
|
||||||
IP: net.ParseIP("10.20.0.1"),
|
|
||||||
Network: &net.IPNet{
|
|
||||||
IP: net.ParseIP("10.20.0.0"),
|
|
||||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
|
||||||
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
|
||||||
// just check on the local interface
|
|
||||||
manager, err := Create(mock)
|
|
||||||
require.NoError(t, err)
|
|
||||||
time.Sleep(time.Second)
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
err := manager.Reset()
|
|
||||||
require.NoError(t, err, "clear the manager state")
|
|
||||||
|
|
||||||
time.Sleep(time.Second)
|
|
||||||
}()
|
|
||||||
|
|
||||||
_, err = manager.client(net.ParseIP("10.20.0.100"))
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
ip := net.ParseIP("10.20.0.100")
|
|
||||||
start := time.Now()
|
|
||||||
for i := 0; i < testMax; i++ {
|
|
||||||
port := &fw.Port{Values: []int{1000 + i}}
|
|
||||||
if i%2 == 0 {
|
|
||||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
|
||||||
} else {
|
|
||||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
|
||||||
}
|
|
||||||
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
|
||||||
}
|
|
||||||
t.Logf("execution avg per rule: %s", time.Since(start)/time.Duration(testMax))
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,16 +2,12 @@ package iptables
|
|||||||
|
|
||||||
// Rule to handle management of rules
|
// Rule to handle management of rules
|
||||||
type Rule struct {
|
type Rule struct {
|
||||||
ruleID string
|
id string
|
||||||
ipsetName string
|
|
||||||
|
|
||||||
specs []string
|
specs []string
|
||||||
ip string
|
|
||||||
dst bool
|
|
||||||
v6 bool
|
v6 bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRuleID returns the rule id
|
// GetRuleID returns the rule id
|
||||||
func (r *Rule) GetRuleID() string {
|
func (r *Rule) GetRuleID() string {
|
||||||
return r.ruleID
|
return r.id
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,758 +0,0 @@
|
|||||||
package nftables
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/binary"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/google/nftables"
|
|
||||||
"github.com/google/nftables/expr"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
|
|
||||||
fw "github.com/netbirdio/netbird/client/firewall"
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
// FilterTableName is the name of the table that is used for filtering by the Netbird client
|
|
||||||
FilterTableName = "netbird-acl"
|
|
||||||
|
|
||||||
// FilterInputChainName is the name of the chain that is used for filtering incoming packets
|
|
||||||
FilterInputChainName = "netbird-acl-input-filter"
|
|
||||||
|
|
||||||
// FilterOutputChainName is the name of the chain that is used for filtering outgoing packets
|
|
||||||
FilterOutputChainName = "netbird-acl-output-filter"
|
|
||||||
)
|
|
||||||
|
|
||||||
var anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
|
|
||||||
|
|
||||||
// Manager of iptables firewall
|
|
||||||
type Manager struct {
|
|
||||||
mutex sync.Mutex
|
|
||||||
|
|
||||||
rConn *nftables.Conn
|
|
||||||
sConn *nftables.Conn
|
|
||||||
tableIPv4 *nftables.Table
|
|
||||||
tableIPv6 *nftables.Table
|
|
||||||
|
|
||||||
filterInputChainIPv4 *nftables.Chain
|
|
||||||
filterOutputChainIPv4 *nftables.Chain
|
|
||||||
|
|
||||||
filterInputChainIPv6 *nftables.Chain
|
|
||||||
filterOutputChainIPv6 *nftables.Chain
|
|
||||||
|
|
||||||
rulesetManager *rulesetManager
|
|
||||||
setRemovedIPs map[string]struct{}
|
|
||||||
setRemoved map[string]*nftables.Set
|
|
||||||
|
|
||||||
wgIface iFaceMapper
|
|
||||||
}
|
|
||||||
|
|
||||||
// iFaceMapper defines subset methods of interface required for manager
|
|
||||||
type iFaceMapper interface {
|
|
||||||
Name() string
|
|
||||||
Address() iface.WGAddress
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create nftables firewall manager
|
|
||||||
func Create(wgIface iFaceMapper) (*Manager, error) {
|
|
||||||
// sConn is used for creating sets and adding/removing elements from them
|
|
||||||
// it's differ then rConn (which does create new conn for each flush operation)
|
|
||||||
// and is permanent. Using same connection for booth type of operations
|
|
||||||
// overloads netlink with high amount of rules ( > 10000)
|
|
||||||
sConn, err := nftables.New(nftables.AsLasting())
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
m := &Manager{
|
|
||||||
rConn: &nftables.Conn{},
|
|
||||||
sConn: sConn,
|
|
||||||
|
|
||||||
rulesetManager: newRuleManager(),
|
|
||||||
setRemovedIPs: map[string]struct{}{},
|
|
||||||
setRemoved: map[string]*nftables.Set{},
|
|
||||||
|
|
||||||
wgIface: wgIface,
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.Reset(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return m, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddFiltering rule to the firewall
|
|
||||||
//
|
|
||||||
// If comment argument is empty firewall manager should set
|
|
||||||
// rule ID as comment for the rule
|
|
||||||
func (m *Manager) AddFiltering(
|
|
||||||
ip net.IP,
|
|
||||||
proto fw.Protocol,
|
|
||||||
sPort *fw.Port,
|
|
||||||
dPort *fw.Port,
|
|
||||||
direction fw.RuleDirection,
|
|
||||||
action fw.Action,
|
|
||||||
ipsetName string,
|
|
||||||
comment string,
|
|
||||||
) (fw.Rule, error) {
|
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
|
|
||||||
var (
|
|
||||||
err error
|
|
||||||
ipset *nftables.Set
|
|
||||||
table *nftables.Table
|
|
||||||
chain *nftables.Chain
|
|
||||||
)
|
|
||||||
|
|
||||||
if direction == fw.RuleDirectionOUT {
|
|
||||||
table, chain, err = m.chain(
|
|
||||||
ip,
|
|
||||||
FilterOutputChainName,
|
|
||||||
nftables.ChainHookOutput,
|
|
||||||
nftables.ChainPriorityFilter,
|
|
||||||
nftables.ChainTypeFilter)
|
|
||||||
} else {
|
|
||||||
table, chain, err = m.chain(
|
|
||||||
ip,
|
|
||||||
FilterInputChainName,
|
|
||||||
nftables.ChainHookInput,
|
|
||||||
nftables.ChainPriorityFilter,
|
|
||||||
nftables.ChainTypeFilter)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
rawIP := ip.To4()
|
|
||||||
if rawIP == nil {
|
|
||||||
rawIP = ip.To16()
|
|
||||||
}
|
|
||||||
|
|
||||||
rulesetID := m.getRulesetID(ip, proto, sPort, dPort, direction, action, ipsetName)
|
|
||||||
|
|
||||||
if ipsetName != "" {
|
|
||||||
// if we already have set with given name, just add ip to the set
|
|
||||||
// and return rule with new ID in other case let's create rule
|
|
||||||
// with fresh created set and set element
|
|
||||||
|
|
||||||
var isSetNew bool
|
|
||||||
ipset, err = m.rConn.GetSetByName(table, ipsetName)
|
|
||||||
if err != nil {
|
|
||||||
if ipset, err = m.createSet(table, rawIP, ipsetName); err != nil {
|
|
||||||
return nil, fmt.Errorf("get set name: %v", err)
|
|
||||||
}
|
|
||||||
isSetNew = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.sConn.SetAddElements(ipset, []nftables.SetElement{{Key: rawIP}}); err != nil {
|
|
||||||
return nil, fmt.Errorf("add set element for the first time: %v", err)
|
|
||||||
}
|
|
||||||
if err := m.sConn.Flush(); err != nil {
|
|
||||||
return nil, fmt.Errorf("flush add elements: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !isSetNew {
|
|
||||||
// if we already have nftables rules with set for given direction
|
|
||||||
// just add new rule to the ruleset and return new fw.Rule object
|
|
||||||
|
|
||||||
if ruleset, ok := m.rulesetManager.getRuleset(rulesetID); ok {
|
|
||||||
return m.rulesetManager.addRule(ruleset, rawIP)
|
|
||||||
}
|
|
||||||
// if ipset exists but it is not linked to rule for given direction
|
|
||||||
// create new rule for direction and bind ipset to it later
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ifaceKey := expr.MetaKeyIIFNAME
|
|
||||||
if direction == fw.RuleDirectionOUT {
|
|
||||||
ifaceKey = expr.MetaKeyOIFNAME
|
|
||||||
}
|
|
||||||
expressions := []expr.Any{
|
|
||||||
&expr.Meta{Key: ifaceKey, Register: 1},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: ifname(m.wgIface.Name()),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
if proto != "all" {
|
|
||||||
expressions = append(expressions, &expr.Payload{
|
|
||||||
DestRegister: 1,
|
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
|
||||||
Offset: uint32(9),
|
|
||||||
Len: uint32(1),
|
|
||||||
})
|
|
||||||
|
|
||||||
var protoData []byte
|
|
||||||
switch proto {
|
|
||||||
case fw.ProtocolTCP:
|
|
||||||
protoData = []byte{unix.IPPROTO_TCP}
|
|
||||||
case fw.ProtocolUDP:
|
|
||||||
protoData = []byte{unix.IPPROTO_UDP}
|
|
||||||
case fw.ProtocolICMP:
|
|
||||||
protoData = []byte{unix.IPPROTO_ICMP}
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("unsupported protocol: %s", proto)
|
|
||||||
}
|
|
||||||
expressions = append(expressions, &expr.Cmp{
|
|
||||||
Register: 1,
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Data: protoData,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// check if rawIP contains zeroed IPv4 0.0.0.0 or same IPv6 value
|
|
||||||
// in that case not add IP match expression into the rule definition
|
|
||||||
if !bytes.HasPrefix(anyIP, rawIP) {
|
|
||||||
// source address position
|
|
||||||
addrLen := uint32(len(rawIP))
|
|
||||||
addrOffset := uint32(12)
|
|
||||||
if addrLen == 16 {
|
|
||||||
addrOffset = 8
|
|
||||||
}
|
|
||||||
|
|
||||||
// change to destination address position if need
|
|
||||||
if direction == fw.RuleDirectionOUT {
|
|
||||||
addrOffset += addrLen
|
|
||||||
}
|
|
||||||
|
|
||||||
expressions = append(expressions,
|
|
||||||
&expr.Payload{
|
|
||||||
DestRegister: 1,
|
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
|
||||||
Offset: addrOffset,
|
|
||||||
Len: addrLen,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
// add individual IP for match if no ipset defined
|
|
||||||
if ipset == nil {
|
|
||||||
expressions = append(expressions,
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: rawIP,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
expressions = append(expressions,
|
|
||||||
&expr.Lookup{
|
|
||||||
SourceRegister: 1,
|
|
||||||
SetName: ipsetName,
|
|
||||||
SetID: ipset.ID,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if sPort != nil && len(sPort.Values) != 0 {
|
|
||||||
expressions = append(expressions,
|
|
||||||
&expr.Payload{
|
|
||||||
DestRegister: 1,
|
|
||||||
Base: expr.PayloadBaseTransportHeader,
|
|
||||||
Offset: 0,
|
|
||||||
Len: 2,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: encodePort(*sPort),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
if dPort != nil && len(dPort.Values) != 0 {
|
|
||||||
expressions = append(expressions,
|
|
||||||
&expr.Payload{
|
|
||||||
DestRegister: 1,
|
|
||||||
Base: expr.PayloadBaseTransportHeader,
|
|
||||||
Offset: 2,
|
|
||||||
Len: 2,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: encodePort(*dPort),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
if action == fw.ActionAccept {
|
|
||||||
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictAccept})
|
|
||||||
} else {
|
|
||||||
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictDrop})
|
|
||||||
}
|
|
||||||
|
|
||||||
userData := []byte(strings.Join([]string{rulesetID, comment}, " "))
|
|
||||||
|
|
||||||
rule := m.rConn.InsertRule(&nftables.Rule{
|
|
||||||
Table: table,
|
|
||||||
Chain: chain,
|
|
||||||
Position: 0,
|
|
||||||
Exprs: expressions,
|
|
||||||
UserData: userData,
|
|
||||||
})
|
|
||||||
if err := m.rConn.Flush(); err != nil {
|
|
||||||
return nil, fmt.Errorf("flush insert rule: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ruleset := m.rulesetManager.createRuleset(rulesetID, rule, ipset)
|
|
||||||
return m.rulesetManager.addRule(ruleset, rawIP)
|
|
||||||
}
|
|
||||||
|
|
||||||
// getRulesetID returns ruleset ID based on given parameters
|
|
||||||
func (m *Manager) getRulesetID(
|
|
||||||
ip net.IP,
|
|
||||||
proto fw.Protocol,
|
|
||||||
sPort *fw.Port,
|
|
||||||
dPort *fw.Port,
|
|
||||||
direction fw.RuleDirection,
|
|
||||||
action fw.Action,
|
|
||||||
ipsetName string,
|
|
||||||
) string {
|
|
||||||
rulesetID := ":" + strconv.Itoa(int(direction)) + ":"
|
|
||||||
if sPort != nil {
|
|
||||||
rulesetID += sPort.String()
|
|
||||||
}
|
|
||||||
rulesetID += ":"
|
|
||||||
if dPort != nil {
|
|
||||||
rulesetID += dPort.String()
|
|
||||||
}
|
|
||||||
rulesetID += ":"
|
|
||||||
rulesetID += strconv.Itoa(int(action))
|
|
||||||
if ipsetName == "" {
|
|
||||||
return "ip:" + ip.String() + rulesetID
|
|
||||||
}
|
|
||||||
return "set:" + ipsetName + rulesetID
|
|
||||||
}
|
|
||||||
|
|
||||||
// createSet in given table by name
|
|
||||||
func (m *Manager) createSet(
|
|
||||||
table *nftables.Table,
|
|
||||||
rawIP []byte,
|
|
||||||
name string,
|
|
||||||
) (*nftables.Set, error) {
|
|
||||||
keyType := nftables.TypeIPAddr
|
|
||||||
if len(rawIP) == 16 {
|
|
||||||
keyType = nftables.TypeIP6Addr
|
|
||||||
}
|
|
||||||
// else we create new ipset and continue creating rule
|
|
||||||
ipset := &nftables.Set{
|
|
||||||
Name: name,
|
|
||||||
Table: table,
|
|
||||||
Dynamic: true,
|
|
||||||
KeyType: keyType,
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.rConn.AddSet(ipset, nil); err != nil {
|
|
||||||
return nil, fmt.Errorf("create set: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.rConn.Flush(); err != nil {
|
|
||||||
return nil, fmt.Errorf("flush created set: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return ipset, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// chain returns the chain for the given IP address with specific settings
|
|
||||||
func (m *Manager) chain(
|
|
||||||
ip net.IP,
|
|
||||||
name string,
|
|
||||||
hook nftables.ChainHook,
|
|
||||||
priority nftables.ChainPriority,
|
|
||||||
cType nftables.ChainType,
|
|
||||||
) (*nftables.Table, *nftables.Chain, error) {
|
|
||||||
var err error
|
|
||||||
|
|
||||||
getChain := func(c *nftables.Chain, tf nftables.TableFamily) (*nftables.Chain, error) {
|
|
||||||
if c != nil {
|
|
||||||
return c, nil
|
|
||||||
}
|
|
||||||
return m.createChainIfNotExists(tf, name, hook, priority, cType)
|
|
||||||
}
|
|
||||||
|
|
||||||
if ip.To4() != nil {
|
|
||||||
if name == FilterInputChainName {
|
|
||||||
m.filterInputChainIPv4, err = getChain(m.filterInputChainIPv4, nftables.TableFamilyIPv4)
|
|
||||||
return m.tableIPv4, m.filterInputChainIPv4, err
|
|
||||||
}
|
|
||||||
m.filterOutputChainIPv4, err = getChain(m.filterOutputChainIPv4, nftables.TableFamilyIPv4)
|
|
||||||
return m.tableIPv4, m.filterOutputChainIPv4, err
|
|
||||||
}
|
|
||||||
if name == FilterInputChainName {
|
|
||||||
m.filterInputChainIPv6, err = getChain(m.filterInputChainIPv6, nftables.TableFamilyIPv6)
|
|
||||||
return m.tableIPv4, m.filterInputChainIPv6, err
|
|
||||||
}
|
|
||||||
m.filterOutputChainIPv6, err = getChain(m.filterOutputChainIPv6, nftables.TableFamilyIPv6)
|
|
||||||
return m.tableIPv4, m.filterOutputChainIPv6, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// table returns the table for the given family of the IP address
|
|
||||||
func (m *Manager) table(family nftables.TableFamily) (*nftables.Table, error) {
|
|
||||||
if family == nftables.TableFamilyIPv4 {
|
|
||||||
if m.tableIPv4 != nil {
|
|
||||||
return m.tableIPv4, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
table, err := m.createTableIfNotExists(nftables.TableFamilyIPv4)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
m.tableIPv4 = table
|
|
||||||
return m.tableIPv4, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.tableIPv6 != nil {
|
|
||||||
return m.tableIPv6, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
table, err := m.createTableIfNotExists(nftables.TableFamilyIPv6)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
m.tableIPv6 = table
|
|
||||||
return m.tableIPv6, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) createTableIfNotExists(family nftables.TableFamily) (*nftables.Table, error) {
|
|
||||||
tables, err := m.rConn.ListTablesOfFamily(family)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("list of tables: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, t := range tables {
|
|
||||||
if t.Name == FilterTableName {
|
|
||||||
return t, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
table := m.rConn.AddTable(&nftables.Table{Name: FilterTableName, Family: nftables.TableFamilyIPv4})
|
|
||||||
if err := m.rConn.Flush(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return table, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) createChainIfNotExists(
|
|
||||||
family nftables.TableFamily,
|
|
||||||
name string,
|
|
||||||
hooknum nftables.ChainHook,
|
|
||||||
priority nftables.ChainPriority,
|
|
||||||
chainType nftables.ChainType,
|
|
||||||
) (*nftables.Chain, error) {
|
|
||||||
table, err := m.table(family)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
chains, err := m.rConn.ListChainsOfTableFamily(family)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("list of chains: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, c := range chains {
|
|
||||||
if c.Name == name && c.Table.Name == table.Name {
|
|
||||||
return c, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
polAccept := nftables.ChainPolicyAccept
|
|
||||||
chain := &nftables.Chain{
|
|
||||||
Name: name,
|
|
||||||
Table: table,
|
|
||||||
Hooknum: hooknum,
|
|
||||||
Priority: priority,
|
|
||||||
Type: chainType,
|
|
||||||
Policy: &polAccept,
|
|
||||||
}
|
|
||||||
|
|
||||||
chain = m.rConn.AddChain(chain)
|
|
||||||
|
|
||||||
ifaceKey := expr.MetaKeyIIFNAME
|
|
||||||
shiftDSTAddr := 0
|
|
||||||
if name == FilterOutputChainName {
|
|
||||||
ifaceKey = expr.MetaKeyOIFNAME
|
|
||||||
shiftDSTAddr = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
expressions := []expr.Any{
|
|
||||||
&expr.Meta{Key: ifaceKey, Register: 1},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: ifname(m.wgIface.Name()),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
mask, _ := netip.AddrFromSlice(m.wgIface.Address().Network.Mask)
|
|
||||||
if m.wgIface.Address().IP.To4() == nil {
|
|
||||||
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To16())
|
|
||||||
expressions = append(expressions,
|
|
||||||
&expr.Payload{
|
|
||||||
DestRegister: 2,
|
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
|
||||||
Offset: uint32(8 + (16 * shiftDSTAddr)),
|
|
||||||
Len: 16,
|
|
||||||
},
|
|
||||||
&expr.Bitwise{
|
|
||||||
SourceRegister: 2,
|
|
||||||
DestRegister: 2,
|
|
||||||
Len: 16,
|
|
||||||
Xor: []byte{0x0, 0x0, 0x0, 0x0},
|
|
||||||
Mask: mask.Unmap().AsSlice(),
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpNeq,
|
|
||||||
Register: 2,
|
|
||||||
Data: ip.Unmap().AsSlice(),
|
|
||||||
},
|
|
||||||
&expr.Verdict{Kind: expr.VerdictAccept},
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
|
|
||||||
expressions = append(expressions,
|
|
||||||
&expr.Payload{
|
|
||||||
DestRegister: 2,
|
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
|
||||||
Offset: uint32(12 + (4 * shiftDSTAddr)),
|
|
||||||
Len: 4,
|
|
||||||
},
|
|
||||||
&expr.Bitwise{
|
|
||||||
SourceRegister: 2,
|
|
||||||
DestRegister: 2,
|
|
||||||
Len: 4,
|
|
||||||
Xor: []byte{0x0, 0x0, 0x0, 0x0},
|
|
||||||
Mask: m.wgIface.Address().Network.Mask,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpNeq,
|
|
||||||
Register: 2,
|
|
||||||
Data: ip.Unmap().AsSlice(),
|
|
||||||
},
|
|
||||||
&expr.Verdict{Kind: expr.VerdictAccept},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
_ = m.rConn.AddRule(&nftables.Rule{
|
|
||||||
Table: table,
|
|
||||||
Chain: chain,
|
|
||||||
Exprs: expressions,
|
|
||||||
})
|
|
||||||
|
|
||||||
expressions = []expr.Any{
|
|
||||||
&expr.Meta{Key: ifaceKey, Register: 1},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: ifname(m.wgIface.Name()),
|
|
||||||
},
|
|
||||||
&expr.Verdict{Kind: expr.VerdictDrop},
|
|
||||||
}
|
|
||||||
_ = m.rConn.AddRule(&nftables.Rule{
|
|
||||||
Table: table,
|
|
||||||
Chain: chain,
|
|
||||||
Exprs: expressions,
|
|
||||||
})
|
|
||||||
|
|
||||||
if err := m.rConn.Flush(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return chain, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteRule from the firewall by rule definition
|
|
||||||
func (m *Manager) DeleteRule(rule fw.Rule) error {
|
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
|
|
||||||
nativeRule, ok := rule.(*Rule)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("invalid rule type")
|
|
||||||
}
|
|
||||||
|
|
||||||
if nativeRule.nftRule == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if nativeRule.nftSet != nil {
|
|
||||||
// call twice of delete set element raises error
|
|
||||||
// so we need to check if element is already removed
|
|
||||||
key := fmt.Sprintf("%s:%v", nativeRule.nftSet.Name, nativeRule.ip)
|
|
||||||
if _, ok := m.setRemovedIPs[key]; !ok {
|
|
||||||
err := m.sConn.SetDeleteElements(nativeRule.nftSet, []nftables.SetElement{{Key: nativeRule.ip}})
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("delete elements for set %q: %v", nativeRule.nftSet.Name, err)
|
|
||||||
}
|
|
||||||
if err := m.sConn.Flush(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
m.setRemovedIPs[key] = struct{}{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.rulesetManager.deleteRule(nativeRule) {
|
|
||||||
// deleteRule indicates that we still have IP in the ruleset
|
|
||||||
// it means we should not remove the nftables rule but need to update set
|
|
||||||
// so we prepare IP to be removed from set on the next flush call
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ruleset doesn't contain IP anymore (or contains only one), remove nft rule
|
|
||||||
if err := m.rConn.DelRule(nativeRule.nftRule); err != nil {
|
|
||||||
log.Errorf("failed to delete rule: %v", err)
|
|
||||||
}
|
|
||||||
if err := m.rConn.Flush(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
nativeRule.nftRule = nil
|
|
||||||
|
|
||||||
if nativeRule.nftSet != nil {
|
|
||||||
if _, ok := m.setRemoved[nativeRule.nftSet.Name]; !ok {
|
|
||||||
m.setRemoved[nativeRule.nftSet.Name] = nativeRule.nftSet
|
|
||||||
}
|
|
||||||
nativeRule.nftSet = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reset firewall to the default state
|
|
||||||
func (m *Manager) Reset() error {
|
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
|
|
||||||
chains, err := m.rConn.ListChains()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("list of chains: %w", err)
|
|
||||||
}
|
|
||||||
for _, c := range chains {
|
|
||||||
if c.Name == FilterInputChainName || c.Name == FilterOutputChainName {
|
|
||||||
m.rConn.DelChain(c)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
tables, err := m.rConn.ListTables()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("list of tables: %w", err)
|
|
||||||
}
|
|
||||||
for _, t := range tables {
|
|
||||||
if t.Name == FilterTableName {
|
|
||||||
m.rConn.DelTable(t)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return m.rConn.Flush()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Flush rule/chain/set operations from the buffer
|
|
||||||
//
|
|
||||||
// Method also get all rules after flush and refreshes handle values in the rulesets
|
|
||||||
func (m *Manager) Flush() error {
|
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
|
|
||||||
if err := m.flushWithBackoff(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// set must be removed after flush rule changes
|
|
||||||
// otherwise we will get error
|
|
||||||
for _, s := range m.setRemoved {
|
|
||||||
m.rConn.FlushSet(s)
|
|
||||||
m.rConn.DelSet(s)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(m.setRemoved) > 0 {
|
|
||||||
if err := m.flushWithBackoff(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
m.setRemovedIPs = map[string]struct{}{}
|
|
||||||
m.setRemoved = map[string]*nftables.Set{}
|
|
||||||
|
|
||||||
if err := m.refreshRuleHandles(m.tableIPv4, m.filterInputChainIPv4); err != nil {
|
|
||||||
log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.refreshRuleHandles(m.tableIPv4, m.filterOutputChainIPv4); err != nil {
|
|
||||||
log.Errorf("failed to refresh rule handles IPv4 output chain: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.refreshRuleHandles(m.tableIPv6, m.filterInputChainIPv6); err != nil {
|
|
||||||
log.Errorf("failed to refresh rule handles IPv6 input chain: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.refreshRuleHandles(m.tableIPv6, m.filterOutputChainIPv6); err != nil {
|
|
||||||
log.Errorf("failed to refresh rule handles IPv6 output chain: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) flushWithBackoff() (err error) {
|
|
||||||
backoff := 4
|
|
||||||
backoffTime := 1000 * time.Millisecond
|
|
||||||
for i := 0; ; i++ {
|
|
||||||
err = m.rConn.Flush()
|
|
||||||
if err != nil {
|
|
||||||
if !strings.Contains(err.Error(), "busy") {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log.Error("failed to flush nftables, retrying...")
|
|
||||||
if i == backoff-1 {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
time.Sleep(backoffTime)
|
|
||||||
backoffTime = backoffTime * 2
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) refreshRuleHandles(table *nftables.Table, chain *nftables.Chain) error {
|
|
||||||
if table == nil || chain == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
list, err := m.rConn.GetRules(table, chain)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, rule := range list {
|
|
||||||
if len(rule.UserData) != 0 {
|
|
||||||
if err := m.rulesetManager.setNftRuleHandle(rule); err != nil {
|
|
||||||
log.Errorf("failed to set rule handle: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func encodePort(port fw.Port) []byte {
|
|
||||||
bs := make([]byte, 2)
|
|
||||||
binary.BigEndian.PutUint16(bs, uint16(port.Values[0]))
|
|
||||||
return bs
|
|
||||||
}
|
|
||||||
|
|
||||||
func ifname(n string) []byte {
|
|
||||||
b := make([]byte, 16)
|
|
||||||
copy(b, []byte(n+"\x00"))
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
@@ -1,207 +0,0 @@
|
|||||||
package nftables
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/google/nftables"
|
|
||||||
"github.com/google/nftables/expr"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
|
|
||||||
fw "github.com/netbirdio/netbird/client/firewall"
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
)
|
|
||||||
|
|
||||||
// iFaceMapper defines subset methods of interface required for manager
|
|
||||||
type iFaceMock struct {
|
|
||||||
NameFunc func() string
|
|
||||||
AddressFunc func() iface.WGAddress
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *iFaceMock) Name() string {
|
|
||||||
if i.NameFunc != nil {
|
|
||||||
return i.NameFunc()
|
|
||||||
}
|
|
||||||
panic("NameFunc is not set")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *iFaceMock) Address() iface.WGAddress {
|
|
||||||
if i.AddressFunc != nil {
|
|
||||||
return i.AddressFunc()
|
|
||||||
}
|
|
||||||
panic("AddressFunc is not set")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNftablesManager(t *testing.T) {
|
|
||||||
mock := &iFaceMock{
|
|
||||||
NameFunc: func() string {
|
|
||||||
return "lo"
|
|
||||||
},
|
|
||||||
AddressFunc: func() iface.WGAddress {
|
|
||||||
return iface.WGAddress{
|
|
||||||
IP: net.ParseIP("100.96.0.1"),
|
|
||||||
Network: &net.IPNet{
|
|
||||||
IP: net.ParseIP("100.96.0.0"),
|
|
||||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// just check on the local interface
|
|
||||||
manager, err := Create(mock)
|
|
||||||
require.NoError(t, err)
|
|
||||||
time.Sleep(time.Second * 3)
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
err = manager.Reset()
|
|
||||||
require.NoError(t, err, "failed to reset")
|
|
||||||
time.Sleep(time.Second)
|
|
||||||
}()
|
|
||||||
|
|
||||||
ip := net.ParseIP("100.96.0.1")
|
|
||||||
|
|
||||||
testClient := &nftables.Conn{}
|
|
||||||
|
|
||||||
rule, err := manager.AddFiltering(
|
|
||||||
ip,
|
|
||||||
fw.ProtocolTCP,
|
|
||||||
nil,
|
|
||||||
&fw.Port{Values: []int{53}},
|
|
||||||
fw.RuleDirectionIN,
|
|
||||||
fw.ActionDrop,
|
|
||||||
"",
|
|
||||||
"",
|
|
||||||
)
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
|
||||||
|
|
||||||
err = manager.Flush()
|
|
||||||
require.NoError(t, err, "failed to flush")
|
|
||||||
|
|
||||||
rules, err := testClient.GetRules(manager.tableIPv4, manager.filterInputChainIPv4)
|
|
||||||
require.NoError(t, err, "failed to get rules")
|
|
||||||
|
|
||||||
// test expectations:
|
|
||||||
// 1) regular rule
|
|
||||||
// 2) "accept extra routed traffic rule" for the interface
|
|
||||||
// 3) "drop all rule" for the interface
|
|
||||||
require.Len(t, rules, 3, "expected 3 rules")
|
|
||||||
|
|
||||||
ipToAdd, _ := netip.AddrFromSlice(ip)
|
|
||||||
add := ipToAdd.Unmap()
|
|
||||||
expectedExprs := []expr.Any{
|
|
||||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: ifname("lo"),
|
|
||||||
},
|
|
||||||
&expr.Payload{
|
|
||||||
DestRegister: 1,
|
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
|
||||||
Offset: uint32(9),
|
|
||||||
Len: uint32(1),
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Register: 1,
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Data: []byte{unix.IPPROTO_TCP},
|
|
||||||
},
|
|
||||||
&expr.Payload{
|
|
||||||
DestRegister: 1,
|
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
|
||||||
Offset: 12,
|
|
||||||
Len: 4,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: add.AsSlice(),
|
|
||||||
},
|
|
||||||
&expr.Payload{
|
|
||||||
DestRegister: 1,
|
|
||||||
Base: expr.PayloadBaseTransportHeader,
|
|
||||||
Offset: 2,
|
|
||||||
Len: 2,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: []byte{0, 53},
|
|
||||||
},
|
|
||||||
&expr.Verdict{Kind: expr.VerdictDrop},
|
|
||||||
}
|
|
||||||
require.ElementsMatch(t, rules[0].Exprs, expectedExprs, "expected the same expressions")
|
|
||||||
|
|
||||||
err = manager.DeleteRule(rule)
|
|
||||||
require.NoError(t, err, "failed to delete rule")
|
|
||||||
|
|
||||||
err = manager.Flush()
|
|
||||||
require.NoError(t, err, "failed to flush")
|
|
||||||
|
|
||||||
rules, err = testClient.GetRules(manager.tableIPv4, manager.filterInputChainIPv4)
|
|
||||||
require.NoError(t, err, "failed to get rules")
|
|
||||||
// test expectations:
|
|
||||||
// 1) "accept extra routed traffic rule" for the interface
|
|
||||||
// 2) "drop all rule" for the interface
|
|
||||||
require.Len(t, rules, 2, "expected 2 rules after deleteion")
|
|
||||||
|
|
||||||
err = manager.Reset()
|
|
||||||
require.NoError(t, err, "failed to reset")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNFtablesCreatePerformance(t *testing.T) {
|
|
||||||
mock := &iFaceMock{
|
|
||||||
NameFunc: func() string {
|
|
||||||
return "lo"
|
|
||||||
},
|
|
||||||
AddressFunc: func() iface.WGAddress {
|
|
||||||
return iface.WGAddress{
|
|
||||||
IP: net.ParseIP("100.96.0.1"),
|
|
||||||
Network: &net.IPNet{
|
|
||||||
IP: net.ParseIP("100.96.0.0"),
|
|
||||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
|
||||||
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
|
||||||
// just check on the local interface
|
|
||||||
manager, err := Create(mock)
|
|
||||||
require.NoError(t, err)
|
|
||||||
time.Sleep(time.Second * 3)
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
if err := manager.Reset(); err != nil {
|
|
||||||
t.Errorf("clear the manager state: %v", err)
|
|
||||||
}
|
|
||||||
time.Sleep(time.Second)
|
|
||||||
}()
|
|
||||||
|
|
||||||
ip := net.ParseIP("10.20.0.100")
|
|
||||||
start := time.Now()
|
|
||||||
for i := 0; i < testMax; i++ {
|
|
||||||
port := &fw.Port{Values: []int{1000 + i}}
|
|
||||||
if i%2 == 0 {
|
|
||||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
|
||||||
} else {
|
|
||||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
|
||||||
}
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
|
||||||
|
|
||||||
if i%100 == 0 {
|
|
||||||
err = manager.Flush()
|
|
||||||
require.NoError(t, err, "failed to flush")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Logf("execution avg per rule: %s", time.Since(start)/time.Duration(testMax))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
package nftables
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/google/nftables"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Rule to handle management of rules
|
|
||||||
type Rule struct {
|
|
||||||
nftRule *nftables.Rule
|
|
||||||
nftSet *nftables.Set
|
|
||||||
|
|
||||||
ruleID string
|
|
||||||
ip []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetRuleID returns the rule id
|
|
||||||
func (r *Rule) GetRuleID() string {
|
|
||||||
return r.ruleID
|
|
||||||
}
|
|
||||||
@@ -1,115 +0,0 @@
|
|||||||
package nftables
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/google/nftables"
|
|
||||||
"github.com/rs/xid"
|
|
||||||
)
|
|
||||||
|
|
||||||
// nftRuleset links native firewall rule and ipset to ACL generated rules
|
|
||||||
type nftRuleset struct {
|
|
||||||
nftRule *nftables.Rule
|
|
||||||
nftSet *nftables.Set
|
|
||||||
issuedRules map[string]*Rule
|
|
||||||
rulesetID string
|
|
||||||
}
|
|
||||||
|
|
||||||
type rulesetManager struct {
|
|
||||||
rulesets map[string]*nftRuleset
|
|
||||||
|
|
||||||
nftSetName2rulesetID map[string]string
|
|
||||||
issuedRuleID2rulesetID map[string]string
|
|
||||||
}
|
|
||||||
|
|
||||||
func newRuleManager() *rulesetManager {
|
|
||||||
return &rulesetManager{
|
|
||||||
rulesets: map[string]*nftRuleset{},
|
|
||||||
|
|
||||||
nftSetName2rulesetID: map[string]string{},
|
|
||||||
issuedRuleID2rulesetID: map[string]string{},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *rulesetManager) getRuleset(rulesetID string) (*nftRuleset, bool) {
|
|
||||||
ruleset, ok := r.rulesets[rulesetID]
|
|
||||||
return ruleset, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *rulesetManager) createRuleset(
|
|
||||||
rulesetID string,
|
|
||||||
nftRule *nftables.Rule,
|
|
||||||
nftSet *nftables.Set,
|
|
||||||
) *nftRuleset {
|
|
||||||
ruleset := nftRuleset{
|
|
||||||
rulesetID: rulesetID,
|
|
||||||
nftRule: nftRule,
|
|
||||||
nftSet: nftSet,
|
|
||||||
issuedRules: map[string]*Rule{},
|
|
||||||
}
|
|
||||||
r.rulesets[ruleset.rulesetID] = &ruleset
|
|
||||||
if nftSet != nil {
|
|
||||||
r.nftSetName2rulesetID[nftSet.Name] = ruleset.rulesetID
|
|
||||||
}
|
|
||||||
return &ruleset
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *rulesetManager) addRule(
|
|
||||||
ruleset *nftRuleset,
|
|
||||||
ip []byte,
|
|
||||||
) (*Rule, error) {
|
|
||||||
if _, ok := r.rulesets[ruleset.rulesetID]; !ok {
|
|
||||||
return nil, fmt.Errorf("ruleset not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
rule := Rule{
|
|
||||||
nftRule: ruleset.nftRule,
|
|
||||||
nftSet: ruleset.nftSet,
|
|
||||||
ruleID: xid.New().String(),
|
|
||||||
ip: ip,
|
|
||||||
}
|
|
||||||
|
|
||||||
ruleset.issuedRules[rule.ruleID] = &rule
|
|
||||||
r.issuedRuleID2rulesetID[rule.ruleID] = ruleset.rulesetID
|
|
||||||
|
|
||||||
return &rule, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// deleteRule from ruleset and returns true if contains other rules
|
|
||||||
func (r *rulesetManager) deleteRule(rule *Rule) bool {
|
|
||||||
rulesetID, ok := r.issuedRuleID2rulesetID[rule.ruleID]
|
|
||||||
if !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
ruleset := r.rulesets[rulesetID]
|
|
||||||
if ruleset.nftRule == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
delete(r.issuedRuleID2rulesetID, rule.ruleID)
|
|
||||||
delete(ruleset.issuedRules, rule.ruleID)
|
|
||||||
|
|
||||||
if len(ruleset.issuedRules) == 0 {
|
|
||||||
delete(r.rulesets, ruleset.rulesetID)
|
|
||||||
if rule.nftSet != nil {
|
|
||||||
delete(r.nftSetName2rulesetID, rule.nftSet.Name)
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// setNftRuleHandle finds rule by userdata which contains rulesetID and updates it's handle number
|
|
||||||
//
|
|
||||||
// This is important to do, because after we add rule to the nftables we can't update it until
|
|
||||||
// we set correct handle value to it.
|
|
||||||
func (r *rulesetManager) setNftRuleHandle(nftRule *nftables.Rule) error {
|
|
||||||
split := bytes.Split(nftRule.UserData, []byte(" "))
|
|
||||||
ruleset, ok := r.rulesets[string(split[0])]
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("ruleset not found")
|
|
||||||
}
|
|
||||||
*ruleset.nftRule = *nftRule
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,122 +0,0 @@
|
|||||||
package nftables
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/google/nftables"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestRulesetManager_createRuleset(t *testing.T) {
|
|
||||||
// Create a ruleset manager.
|
|
||||||
rulesetManager := newRuleManager()
|
|
||||||
|
|
||||||
// Create a ruleset.
|
|
||||||
rulesetID := "ruleset-1"
|
|
||||||
nftRule := nftables.Rule{
|
|
||||||
UserData: []byte(rulesetID),
|
|
||||||
}
|
|
||||||
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
|
|
||||||
require.NotNil(t, ruleset, "createRuleset() failed")
|
|
||||||
require.Equal(t, ruleset.rulesetID, rulesetID, "rulesetID is incorrect")
|
|
||||||
require.Equal(t, ruleset.nftRule, &nftRule, "nftRule is incorrect")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRulesetManager_addRule(t *testing.T) {
|
|
||||||
// Create a ruleset manager.
|
|
||||||
rulesetManager := newRuleManager()
|
|
||||||
|
|
||||||
// Create a ruleset.
|
|
||||||
rulesetID := "ruleset-1"
|
|
||||||
nftRule := nftables.Rule{}
|
|
||||||
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
|
|
||||||
|
|
||||||
// Add a rule to the ruleset.
|
|
||||||
ip := []byte("192.168.1.1")
|
|
||||||
rule, err := rulesetManager.addRule(ruleset, ip)
|
|
||||||
require.NoError(t, err, "addRule() failed")
|
|
||||||
require.NotNil(t, rule, "rule should not be nil")
|
|
||||||
require.NotEqual(t, rule.ruleID, "ruleID is empty")
|
|
||||||
require.EqualValues(t, rule.ip, ip, "ip is incorrect")
|
|
||||||
require.Contains(t, ruleset.issuedRules, rule.ruleID, "ruleID already exists in ruleset")
|
|
||||||
require.Contains(t, rulesetManager.issuedRuleID2rulesetID, rule.ruleID, "ruleID already exists in ruleset manager")
|
|
||||||
|
|
||||||
ruleset2 := &nftRuleset{
|
|
||||||
rulesetID: "ruleset-2",
|
|
||||||
}
|
|
||||||
_, err = rulesetManager.addRule(ruleset2, ip)
|
|
||||||
require.Error(t, err, "addRule() should have failed")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRulesetManager_deleteRule(t *testing.T) {
|
|
||||||
// Create a ruleset manager.
|
|
||||||
rulesetManager := newRuleManager()
|
|
||||||
|
|
||||||
// Create a ruleset.
|
|
||||||
rulesetID := "ruleset-1"
|
|
||||||
nftRule := nftables.Rule{}
|
|
||||||
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
|
|
||||||
|
|
||||||
// Add a rule to the ruleset.
|
|
||||||
ip := []byte("192.168.1.1")
|
|
||||||
rule, err := rulesetManager.addRule(ruleset, ip)
|
|
||||||
require.NoError(t, err, "addRule() failed")
|
|
||||||
require.NotNil(t, rule, "rule should not be nil")
|
|
||||||
|
|
||||||
ip2 := []byte("192.168.1.1")
|
|
||||||
rule2, err := rulesetManager.addRule(ruleset, ip2)
|
|
||||||
require.NoError(t, err, "addRule() failed")
|
|
||||||
require.NotNil(t, rule2, "rule should not be nil")
|
|
||||||
|
|
||||||
hasNext := rulesetManager.deleteRule(rule)
|
|
||||||
require.True(t, hasNext, "deleteRule() should have returned true")
|
|
||||||
|
|
||||||
// Check that the rule is no longer in the manager.
|
|
||||||
require.NotContains(t, rulesetManager.issuedRuleID2rulesetID, rule.ruleID, "rule should have been deleted")
|
|
||||||
|
|
||||||
hasNext = rulesetManager.deleteRule(rule2)
|
|
||||||
require.False(t, hasNext, "deleteRule() should have returned false")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRulesetManager_setNftRuleHandle(t *testing.T) {
|
|
||||||
// Create a ruleset manager.
|
|
||||||
rulesetManager := newRuleManager()
|
|
||||||
// Create a ruleset.
|
|
||||||
rulesetID := "ruleset-1"
|
|
||||||
nftRule := nftables.Rule{}
|
|
||||||
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
|
|
||||||
// Add a rule to the ruleset.
|
|
||||||
ip := []byte("192.168.0.1")
|
|
||||||
|
|
||||||
rule, err := rulesetManager.addRule(ruleset, ip)
|
|
||||||
require.NoError(t, err, "addRule() failed")
|
|
||||||
require.NotNil(t, rule, "rule should not be nil")
|
|
||||||
|
|
||||||
nftRuleCopy := nftRule
|
|
||||||
nftRuleCopy.Handle = 2
|
|
||||||
nftRuleCopy.UserData = []byte(rulesetID)
|
|
||||||
err = rulesetManager.setNftRuleHandle(&nftRuleCopy)
|
|
||||||
require.NoError(t, err, "setNftRuleHandle() failed")
|
|
||||||
// check correct work with references
|
|
||||||
require.Equal(t, nftRule.Handle, uint64(2), "nftRule.Handle is incorrect")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRulesetManager_getRuleset(t *testing.T) {
|
|
||||||
// Create a ruleset manager.
|
|
||||||
rulesetManager := newRuleManager()
|
|
||||||
// Create a ruleset.
|
|
||||||
rulesetID := "ruleset-1"
|
|
||||||
nftRule := nftables.Rule{}
|
|
||||||
nftSet := nftables.Set{
|
|
||||||
ID: 2,
|
|
||||||
}
|
|
||||||
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, &nftSet)
|
|
||||||
require.NotNil(t, ruleset, "createRuleset() failed")
|
|
||||||
|
|
||||||
find, ok := rulesetManager.getRuleset(rulesetID)
|
|
||||||
require.True(t, ok, "getRuleset() failed")
|
|
||||||
require.Equal(t, ruleset, find, "getRulesetBySetID() failed")
|
|
||||||
|
|
||||||
_, ok = rulesetManager.getRuleset("does-not-exist")
|
|
||||||
require.False(t, ok, "getRuleset() failed")
|
|
||||||
}
|
|
||||||
@@ -1,27 +1,14 @@
|
|||||||
package firewall
|
package firewall
|
||||||
|
|
||||||
import (
|
// PortProtocol is the protocol of the port
|
||||||
"strconv"
|
type PortProtocol string
|
||||||
)
|
|
||||||
|
|
||||||
// Protocol is the protocol of the port
|
|
||||||
type Protocol string
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// ProtocolTCP is the TCP protocol
|
// PortProtocolTCP is the TCP protocol
|
||||||
ProtocolTCP Protocol = "tcp"
|
PortProtocolTCP PortProtocol = "tcp"
|
||||||
|
|
||||||
// ProtocolUDP is the UDP protocol
|
// PortProtocolUDP is the UDP protocol
|
||||||
ProtocolUDP Protocol = "udp"
|
PortProtocolUDP PortProtocol = "udp"
|
||||||
|
|
||||||
// ProtocolICMP is the ICMP protocol
|
|
||||||
ProtocolICMP Protocol = "icmp"
|
|
||||||
|
|
||||||
// ProtocolALL cover all supported protocols
|
|
||||||
ProtocolALL Protocol = "all"
|
|
||||||
|
|
||||||
// ProtocolUnknown unknown protocol
|
|
||||||
ProtocolUnknown Protocol = "unknown"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Port of the address for firewall rule
|
// Port of the address for firewall rule
|
||||||
@@ -31,16 +18,7 @@ type Port struct {
|
|||||||
|
|
||||||
// Values contains one value for single port, multiple values for the list of ports, or two values for the range of ports
|
// Values contains one value for single port, multiple values for the list of ports, or two values for the range of ports
|
||||||
Values []int
|
Values []int
|
||||||
}
|
|
||||||
|
|
||||||
// String interface implementation
|
// Proto is the protocol of the port
|
||||||
func (p *Port) String() string {
|
Proto PortProtocol
|
||||||
var ports string
|
|
||||||
for _, port := range p.Values {
|
|
||||||
if ports != "" {
|
|
||||||
ports += ","
|
|
||||||
}
|
|
||||||
ports += strconv.Itoa(port)
|
|
||||||
}
|
|
||||||
return ports
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,30 +0,0 @@
|
|||||||
package uspfilter
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
|
||||||
|
|
||||||
fw "github.com/netbirdio/netbird/client/firewall"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Rule to handle management of rules
|
|
||||||
type Rule struct {
|
|
||||||
id string
|
|
||||||
ip net.IP
|
|
||||||
ipLayer gopacket.LayerType
|
|
||||||
matchByIP bool
|
|
||||||
protoLayer gopacket.LayerType
|
|
||||||
direction fw.RuleDirection
|
|
||||||
sPort uint16
|
|
||||||
dPort uint16
|
|
||||||
drop bool
|
|
||||||
comment string
|
|
||||||
|
|
||||||
udpHook func([]byte) bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetRuleID returns the rule id
|
|
||||||
func (r *Rule) GetRuleID() string {
|
|
||||||
return r.id
|
|
||||||
}
|
|
||||||
@@ -1,377 +0,0 @@
|
|||||||
package uspfilter
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
|
||||||
"github.com/google/gopacket/layers"
|
|
||||||
"github.com/google/uuid"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
fw "github.com/netbirdio/netbird/client/firewall"
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
)
|
|
||||||
|
|
||||||
const layerTypeAll = 0
|
|
||||||
|
|
||||||
// IFaceMapper defines subset methods of interface required for manager
|
|
||||||
type IFaceMapper interface {
|
|
||||||
SetFilter(iface.PacketFilter) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// RuleSet is a set of rules grouped by a string key
|
|
||||||
type RuleSet map[string]Rule
|
|
||||||
|
|
||||||
// Manager userspace firewall manager
|
|
||||||
type Manager struct {
|
|
||||||
outgoingRules map[string]RuleSet
|
|
||||||
incomingRules map[string]RuleSet
|
|
||||||
wgNetwork *net.IPNet
|
|
||||||
decoders sync.Pool
|
|
||||||
|
|
||||||
mutex sync.RWMutex
|
|
||||||
}
|
|
||||||
|
|
||||||
// decoder for packages
|
|
||||||
type decoder struct {
|
|
||||||
eth layers.Ethernet
|
|
||||||
ip4 layers.IPv4
|
|
||||||
ip6 layers.IPv6
|
|
||||||
tcp layers.TCP
|
|
||||||
udp layers.UDP
|
|
||||||
icmp4 layers.ICMPv4
|
|
||||||
icmp6 layers.ICMPv6
|
|
||||||
decoded []gopacket.LayerType
|
|
||||||
parser *gopacket.DecodingLayerParser
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create userspace firewall manager constructor
|
|
||||||
func Create(iface IFaceMapper) (*Manager, error) {
|
|
||||||
m := &Manager{
|
|
||||||
decoders: sync.Pool{
|
|
||||||
New: func() any {
|
|
||||||
d := &decoder{
|
|
||||||
decoded: []gopacket.LayerType{},
|
|
||||||
}
|
|
||||||
d.parser = gopacket.NewDecodingLayerParser(
|
|
||||||
layers.LayerTypeIPv4,
|
|
||||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
|
||||||
)
|
|
||||||
d.parser.IgnoreUnsupported = true
|
|
||||||
return d
|
|
||||||
},
|
|
||||||
},
|
|
||||||
outgoingRules: make(map[string]RuleSet),
|
|
||||||
incomingRules: make(map[string]RuleSet),
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := iface.SetFilter(m); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return m, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddFiltering rule to the firewall
|
|
||||||
//
|
|
||||||
// If comment argument is empty firewall manager should set
|
|
||||||
// rule ID as comment for the rule
|
|
||||||
func (m *Manager) AddFiltering(
|
|
||||||
ip net.IP,
|
|
||||||
proto fw.Protocol,
|
|
||||||
sPort *fw.Port,
|
|
||||||
dPort *fw.Port,
|
|
||||||
direction fw.RuleDirection,
|
|
||||||
action fw.Action,
|
|
||||||
ipsetName string,
|
|
||||||
comment string,
|
|
||||||
) (fw.Rule, error) {
|
|
||||||
r := Rule{
|
|
||||||
id: uuid.New().String(),
|
|
||||||
ip: ip,
|
|
||||||
ipLayer: layers.LayerTypeIPv6,
|
|
||||||
matchByIP: true,
|
|
||||||
direction: direction,
|
|
||||||
drop: action == fw.ActionDrop,
|
|
||||||
comment: comment,
|
|
||||||
}
|
|
||||||
if ipNormalized := ip.To4(); ipNormalized != nil {
|
|
||||||
r.ipLayer = layers.LayerTypeIPv4
|
|
||||||
r.ip = ipNormalized
|
|
||||||
}
|
|
||||||
|
|
||||||
if s := r.ip.String(); s == "0.0.0.0" || s == "::" {
|
|
||||||
r.matchByIP = false
|
|
||||||
}
|
|
||||||
|
|
||||||
if sPort != nil && len(sPort.Values) == 1 {
|
|
||||||
r.sPort = uint16(sPort.Values[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
if dPort != nil && len(dPort.Values) == 1 {
|
|
||||||
r.dPort = uint16(dPort.Values[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
switch proto {
|
|
||||||
case fw.ProtocolTCP:
|
|
||||||
r.protoLayer = layers.LayerTypeTCP
|
|
||||||
case fw.ProtocolUDP:
|
|
||||||
r.protoLayer = layers.LayerTypeUDP
|
|
||||||
case fw.ProtocolICMP:
|
|
||||||
r.protoLayer = layers.LayerTypeICMPv4
|
|
||||||
if r.ipLayer == layers.LayerTypeIPv6 {
|
|
||||||
r.protoLayer = layers.LayerTypeICMPv6
|
|
||||||
}
|
|
||||||
case fw.ProtocolALL:
|
|
||||||
r.protoLayer = layerTypeAll
|
|
||||||
}
|
|
||||||
|
|
||||||
m.mutex.Lock()
|
|
||||||
if direction == fw.RuleDirectionIN {
|
|
||||||
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
|
||||||
m.incomingRules[r.ip.String()] = make(RuleSet)
|
|
||||||
}
|
|
||||||
m.incomingRules[r.ip.String()][r.id] = r
|
|
||||||
} else {
|
|
||||||
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
|
|
||||||
m.outgoingRules[r.ip.String()] = make(RuleSet)
|
|
||||||
}
|
|
||||||
m.outgoingRules[r.ip.String()][r.id] = r
|
|
||||||
}
|
|
||||||
m.mutex.Unlock()
|
|
||||||
|
|
||||||
return &r, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteRule from the firewall by rule definition
|
|
||||||
func (m *Manager) DeleteRule(rule fw.Rule) error {
|
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
|
|
||||||
r, ok := rule.(*Rule)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.direction == fw.RuleDirectionIN {
|
|
||||||
_, ok := m.incomingRules[r.ip.String()][r.id]
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
|
||||||
}
|
|
||||||
delete(m.incomingRules[r.ip.String()], r.id)
|
|
||||||
} else {
|
|
||||||
_, ok := m.outgoingRules[r.ip.String()][r.id]
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
|
||||||
}
|
|
||||||
delete(m.outgoingRules[r.ip.String()], r.id)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reset firewall to the default state
|
|
||||||
func (m *Manager) Reset() error {
|
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
|
|
||||||
m.outgoingRules = make(map[string]RuleSet)
|
|
||||||
m.incomingRules = make(map[string]RuleSet)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Flush doesn't need to be implemented for this manager
|
|
||||||
func (m *Manager) Flush() error { return nil }
|
|
||||||
|
|
||||||
// DropOutgoing filter outgoing packets
|
|
||||||
func (m *Manager) DropOutgoing(packetData []byte) bool {
|
|
||||||
return m.dropFilter(packetData, m.outgoingRules, false)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DropIncoming filter incoming packets
|
|
||||||
func (m *Manager) DropIncoming(packetData []byte) bool {
|
|
||||||
return m.dropFilter(packetData, m.incomingRules, true)
|
|
||||||
}
|
|
||||||
|
|
||||||
// dropFilter imlements same logic for booth direction of the traffic
|
|
||||||
func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet, isIncomingPacket bool) bool {
|
|
||||||
m.mutex.RLock()
|
|
||||||
defer m.mutex.RUnlock()
|
|
||||||
|
|
||||||
d := m.decoders.Get().(*decoder)
|
|
||||||
defer m.decoders.Put(d)
|
|
||||||
|
|
||||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
|
||||||
log.Tracef("couldn't decode layer, err: %s", err)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(d.decoded) < 2 {
|
|
||||||
log.Tracef("not enough levels in network packet")
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
ipLayer := d.decoded[0]
|
|
||||||
|
|
||||||
switch ipLayer {
|
|
||||||
case layers.LayerTypeIPv4:
|
|
||||||
if !m.wgNetwork.Contains(d.ip4.SrcIP) || !m.wgNetwork.Contains(d.ip4.DstIP) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
case layers.LayerTypeIPv6:
|
|
||||||
if !m.wgNetwork.Contains(d.ip6.SrcIP) || !m.wgNetwork.Contains(d.ip6.DstIP) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
log.Errorf("unknown layer: %v", d.decoded[0])
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
var ip net.IP
|
|
||||||
switch ipLayer {
|
|
||||||
case layers.LayerTypeIPv4:
|
|
||||||
if isIncomingPacket {
|
|
||||||
ip = d.ip4.SrcIP
|
|
||||||
} else {
|
|
||||||
ip = d.ip4.DstIP
|
|
||||||
}
|
|
||||||
case layers.LayerTypeIPv6:
|
|
||||||
if isIncomingPacket {
|
|
||||||
ip = d.ip6.SrcIP
|
|
||||||
} else {
|
|
||||||
ip = d.ip6.DstIP
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
filter, ok := validateRule(ip, packetData, rules[ip.String()], d)
|
|
||||||
if ok {
|
|
||||||
return filter
|
|
||||||
}
|
|
||||||
filter, ok = validateRule(ip, packetData, rules["0.0.0.0"], d)
|
|
||||||
if ok {
|
|
||||||
return filter
|
|
||||||
}
|
|
||||||
filter, ok = validateRule(ip, packetData, rules["::"], d)
|
|
||||||
if ok {
|
|
||||||
return filter
|
|
||||||
}
|
|
||||||
|
|
||||||
// default policy is DROP ALL
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decoder) (bool, bool) {
|
|
||||||
payloadLayer := d.decoded[1]
|
|
||||||
for _, rule := range rules {
|
|
||||||
if rule.matchByIP && !ip.Equal(rule.ip) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if rule.protoLayer == layerTypeAll {
|
|
||||||
return rule.drop, true
|
|
||||||
}
|
|
||||||
|
|
||||||
if payloadLayer != rule.protoLayer {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
switch payloadLayer {
|
|
||||||
case layers.LayerTypeTCP:
|
|
||||||
if rule.sPort == 0 && rule.dPort == 0 {
|
|
||||||
return rule.drop, true
|
|
||||||
}
|
|
||||||
if rule.sPort != 0 && rule.sPort == uint16(d.tcp.SrcPort) {
|
|
||||||
return rule.drop, true
|
|
||||||
}
|
|
||||||
if rule.dPort != 0 && rule.dPort == uint16(d.tcp.DstPort) {
|
|
||||||
return rule.drop, true
|
|
||||||
}
|
|
||||||
case layers.LayerTypeUDP:
|
|
||||||
// if rule has UDP hook (and if we are here we match this rule)
|
|
||||||
// we ignore rule.drop and call this hook
|
|
||||||
if rule.udpHook != nil {
|
|
||||||
return rule.udpHook(packetData), true
|
|
||||||
}
|
|
||||||
|
|
||||||
if rule.sPort == 0 && rule.dPort == 0 {
|
|
||||||
return rule.drop, true
|
|
||||||
}
|
|
||||||
if rule.sPort != 0 && rule.sPort == uint16(d.udp.SrcPort) {
|
|
||||||
return rule.drop, true
|
|
||||||
}
|
|
||||||
if rule.dPort != 0 && rule.dPort == uint16(d.udp.DstPort) {
|
|
||||||
return rule.drop, true
|
|
||||||
}
|
|
||||||
return rule.drop, true
|
|
||||||
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
|
||||||
return rule.drop, true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false, false
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNetwork of the wireguard interface to which filtering applied
|
|
||||||
func (m *Manager) SetNetwork(network *net.IPNet) {
|
|
||||||
m.wgNetwork = network
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
|
||||||
//
|
|
||||||
// Hook function returns flag which indicates should be the matched package dropped or not
|
|
||||||
func (m *Manager) AddUDPPacketHook(
|
|
||||||
in bool, ip net.IP, dPort uint16, hook func([]byte) bool,
|
|
||||||
) string {
|
|
||||||
r := Rule{
|
|
||||||
id: uuid.New().String(),
|
|
||||||
ip: ip,
|
|
||||||
protoLayer: layers.LayerTypeUDP,
|
|
||||||
dPort: dPort,
|
|
||||||
ipLayer: layers.LayerTypeIPv6,
|
|
||||||
direction: fw.RuleDirectionOUT,
|
|
||||||
comment: fmt.Sprintf("UDP Hook direction: %v, ip:%v, dport:%d", in, ip, dPort),
|
|
||||||
udpHook: hook,
|
|
||||||
}
|
|
||||||
|
|
||||||
if ip.To4() != nil {
|
|
||||||
r.ipLayer = layers.LayerTypeIPv4
|
|
||||||
}
|
|
||||||
|
|
||||||
m.mutex.Lock()
|
|
||||||
if in {
|
|
||||||
r.direction = fw.RuleDirectionIN
|
|
||||||
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
|
||||||
m.incomingRules[r.ip.String()] = make(map[string]Rule)
|
|
||||||
}
|
|
||||||
m.incomingRules[r.ip.String()][r.id] = r
|
|
||||||
} else {
|
|
||||||
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
|
|
||||||
m.outgoingRules[r.ip.String()] = make(map[string]Rule)
|
|
||||||
}
|
|
||||||
m.outgoingRules[r.ip.String()][r.id] = r
|
|
||||||
}
|
|
||||||
|
|
||||||
m.mutex.Unlock()
|
|
||||||
|
|
||||||
return r.id
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemovePacketHook removes packet hook by given ID
|
|
||||||
func (m *Manager) RemovePacketHook(hookID string) error {
|
|
||||||
for _, arr := range m.incomingRules {
|
|
||||||
for _, r := range arr {
|
|
||||||
if r.id == hookID {
|
|
||||||
return m.DeleteRule(&r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, arr := range m.outgoingRules {
|
|
||||||
for _, r := range arr {
|
|
||||||
if r.id == hookID {
|
|
||||||
return m.DeleteRule(&r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return fmt.Errorf("hook with given id not found")
|
|
||||||
}
|
|
||||||
@@ -1,403 +0,0 @@
|
|||||||
package uspfilter
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
|
||||||
"github.com/google/gopacket/layers"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
fw "github.com/netbirdio/netbird/client/firewall"
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
)
|
|
||||||
|
|
||||||
type IFaceMock struct {
|
|
||||||
SetFilterFunc func(iface.PacketFilter) error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *IFaceMock) SetFilter(iface iface.PacketFilter) error {
|
|
||||||
if i.SetFilterFunc == nil {
|
|
||||||
return fmt.Errorf("not implemented")
|
|
||||||
}
|
|
||||||
return i.SetFilterFunc(iface)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestManagerCreate(t *testing.T) {
|
|
||||||
ifaceMock := &IFaceMock{
|
|
||||||
SetFilterFunc: func(iface.PacketFilter) error { return nil },
|
|
||||||
}
|
|
||||||
|
|
||||||
m, err := Create(ifaceMock)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if m == nil {
|
|
||||||
t.Error("Manager is nil")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestManagerAddFiltering(t *testing.T) {
|
|
||||||
isSetFilterCalled := false
|
|
||||||
ifaceMock := &IFaceMock{
|
|
||||||
SetFilterFunc: func(iface.PacketFilter) error {
|
|
||||||
isSetFilterCalled = true
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
m, err := Create(ifaceMock)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ip := net.ParseIP("192.168.1.1")
|
|
||||||
proto := fw.ProtocolTCP
|
|
||||||
port := &fw.Port{Values: []int{80}}
|
|
||||||
direction := fw.RuleDirectionOUT
|
|
||||||
action := fw.ActionDrop
|
|
||||||
comment := "Test rule"
|
|
||||||
|
|
||||||
rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if rule == nil {
|
|
||||||
t.Error("Rule is nil")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if !isSetFilterCalled {
|
|
||||||
t.Error("SetFilter was not called")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestManagerDeleteRule(t *testing.T) {
|
|
||||||
ifaceMock := &IFaceMock{
|
|
||||||
SetFilterFunc: func(iface.PacketFilter) error { return nil },
|
|
||||||
}
|
|
||||||
|
|
||||||
m, err := Create(ifaceMock)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ip := net.ParseIP("192.168.1.1")
|
|
||||||
proto := fw.ProtocolTCP
|
|
||||||
port := &fw.Port{Values: []int{80}}
|
|
||||||
direction := fw.RuleDirectionOUT
|
|
||||||
action := fw.ActionDrop
|
|
||||||
comment := "Test rule"
|
|
||||||
|
|
||||||
rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ip = net.ParseIP("192.168.1.1")
|
|
||||||
proto = fw.ProtocolTCP
|
|
||||||
port = &fw.Port{Values: []int{80}}
|
|
||||||
direction = fw.RuleDirectionIN
|
|
||||||
action = fw.ActionDrop
|
|
||||||
comment = "Test rule 2"
|
|
||||||
|
|
||||||
rule2, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err = m.DeleteRule(rule)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to delete rule: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, ok := m.incomingRules[ip.String()][rule2.GetRuleID()]; !ok {
|
|
||||||
t.Errorf("rule2 is not in the incomingRules")
|
|
||||||
}
|
|
||||||
|
|
||||||
err = m.DeleteRule(rule2)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to delete rule: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, ok := m.incomingRules[ip.String()][rule2.GetRuleID()]; ok {
|
|
||||||
t.Errorf("rule2 is not in the incomingRules")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAddUDPPacketHook(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
in bool
|
|
||||||
expDir fw.RuleDirection
|
|
||||||
ip net.IP
|
|
||||||
dPort uint16
|
|
||||||
hook func([]byte) bool
|
|
||||||
expectedID string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Test Outgoing UDP Packet Hook",
|
|
||||||
in: false,
|
|
||||||
expDir: fw.RuleDirectionOUT,
|
|
||||||
ip: net.IPv4(10, 168, 0, 1),
|
|
||||||
dPort: 8000,
|
|
||||||
hook: func([]byte) bool { return true },
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Test Incoming UDP Packet Hook",
|
|
||||||
in: true,
|
|
||||||
expDir: fw.RuleDirectionIN,
|
|
||||||
ip: net.IPv6loopback,
|
|
||||||
dPort: 9000,
|
|
||||||
hook: func([]byte) bool { return false },
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
manager := &Manager{
|
|
||||||
incomingRules: map[string]RuleSet{},
|
|
||||||
outgoingRules: map[string]RuleSet{},
|
|
||||||
}
|
|
||||||
|
|
||||||
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
|
||||||
|
|
||||||
var addedRule Rule
|
|
||||||
if tt.in {
|
|
||||||
if len(manager.incomingRules[tt.ip.String()]) != 1 {
|
|
||||||
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
for _, rule := range manager.incomingRules[tt.ip.String()] {
|
|
||||||
addedRule = rule
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if len(manager.outgoingRules) != 1 {
|
|
||||||
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
for _, rule := range manager.outgoingRules[tt.ip.String()] {
|
|
||||||
addedRule = rule
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !tt.ip.Equal(addedRule.ip) {
|
|
||||||
t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if tt.dPort != addedRule.dPort {
|
|
||||||
t.Errorf("expected dPort %d, got %d", tt.dPort, addedRule.dPort)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if layers.LayerTypeUDP != addedRule.protoLayer {
|
|
||||||
t.Errorf("expected protoLayer %s, got %s", layers.LayerTypeUDP, addedRule.protoLayer)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if tt.expDir != addedRule.direction {
|
|
||||||
t.Errorf("expected direction %d, got %d", tt.expDir, addedRule.direction)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if addedRule.udpHook == nil {
|
|
||||||
t.Errorf("expected udpHook to be set")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestManagerReset(t *testing.T) {
|
|
||||||
ifaceMock := &IFaceMock{
|
|
||||||
SetFilterFunc: func(iface.PacketFilter) error { return nil },
|
|
||||||
}
|
|
||||||
|
|
||||||
m, err := Create(ifaceMock)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ip := net.ParseIP("192.168.1.1")
|
|
||||||
proto := fw.ProtocolTCP
|
|
||||||
port := &fw.Port{Values: []int{80}}
|
|
||||||
direction := fw.RuleDirectionOUT
|
|
||||||
action := fw.ActionDrop
|
|
||||||
comment := "Test rule"
|
|
||||||
|
|
||||||
_, err = m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err = m.Reset()
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to reset Manager: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(m.outgoingRules) != 0 || len(m.incomingRules) != 0 {
|
|
||||||
t.Errorf("rules is not empty")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNotMatchByIP(t *testing.T) {
|
|
||||||
ifaceMock := &IFaceMock{
|
|
||||||
SetFilterFunc: func(iface.PacketFilter) error { return nil },
|
|
||||||
}
|
|
||||||
|
|
||||||
m, err := Create(ifaceMock)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
m.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("100.10.0.0"),
|
|
||||||
Mask: net.CIDRMask(16, 32),
|
|
||||||
}
|
|
||||||
|
|
||||||
ip := net.ParseIP("0.0.0.0")
|
|
||||||
proto := fw.ProtocolUDP
|
|
||||||
direction := fw.RuleDirectionOUT
|
|
||||||
action := fw.ActionAccept
|
|
||||||
comment := "Test rule"
|
|
||||||
|
|
||||||
_, err = m.AddFiltering(ip, proto, nil, nil, direction, action, "", comment)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ipv4 := &layers.IPv4{
|
|
||||||
TTL: 64,
|
|
||||||
Version: 4,
|
|
||||||
SrcIP: net.ParseIP("100.10.0.1"),
|
|
||||||
DstIP: net.ParseIP("100.10.0.100"),
|
|
||||||
Protocol: layers.IPProtocolUDP,
|
|
||||||
}
|
|
||||||
udp := &layers.UDP{
|
|
||||||
SrcPort: 51334,
|
|
||||||
DstPort: 53,
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := udp.SetNetworkLayerForChecksum(ipv4); err != nil {
|
|
||||||
t.Errorf("failed to set network layer for checksum: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
payload := gopacket.Payload([]byte("test"))
|
|
||||||
|
|
||||||
buf := gopacket.NewSerializeBuffer()
|
|
||||||
opts := gopacket.SerializeOptions{
|
|
||||||
ComputeChecksums: true,
|
|
||||||
FixLengths: true,
|
|
||||||
}
|
|
||||||
if err = gopacket.SerializeLayers(buf, opts, ipv4, udp, payload); err != nil {
|
|
||||||
t.Errorf("failed to serialize packet: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.dropFilter(buf.Bytes(), m.outgoingRules, false) {
|
|
||||||
t.Errorf("expected packet to be accepted")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = m.Reset(); err != nil {
|
|
||||||
t.Errorf("failed to reset Manager: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestRemovePacketHook tests the functionality of the RemovePacketHook method
|
|
||||||
func TestRemovePacketHook(t *testing.T) {
|
|
||||||
// creating mock iface
|
|
||||||
iface := &IFaceMock{
|
|
||||||
SetFilterFunc: func(iface.PacketFilter) error { return nil },
|
|
||||||
}
|
|
||||||
|
|
||||||
// creating manager instance
|
|
||||||
manager, err := Create(iface)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create Manager: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add a UDP packet hook
|
|
||||||
hookFunc := func(data []byte) bool { return true }
|
|
||||||
hookID := manager.AddUDPPacketHook(false, net.IPv4(192, 168, 0, 1), 8080, hookFunc)
|
|
||||||
|
|
||||||
// Assert the hook is added by finding it in the manager's outgoing rules
|
|
||||||
found := false
|
|
||||||
for _, arr := range manager.outgoingRules {
|
|
||||||
for _, rule := range arr {
|
|
||||||
if rule.id == hookID {
|
|
||||||
found = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !found {
|
|
||||||
t.Fatalf("The hook was not added properly.")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now remove the packet hook
|
|
||||||
err = manager.RemovePacketHook(hookID)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to remove hook: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Assert the hook is removed by checking it in the manager's outgoing rules
|
|
||||||
for _, arr := range manager.outgoingRules {
|
|
||||||
for _, rule := range arr {
|
|
||||||
if rule.id == hookID {
|
|
||||||
t.Fatalf("The hook was not removed properly.")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUSPFilterCreatePerformance(t *testing.T) {
|
|
||||||
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
|
||||||
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
|
||||||
// just check on the local interface
|
|
||||||
ifaceMock := &IFaceMock{
|
|
||||||
SetFilterFunc: func(iface.PacketFilter) error { return nil },
|
|
||||||
}
|
|
||||||
manager, err := Create(ifaceMock)
|
|
||||||
require.NoError(t, err)
|
|
||||||
time.Sleep(time.Second)
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
if err := manager.Reset(); err != nil {
|
|
||||||
t.Errorf("clear the manager state: %v", err)
|
|
||||||
}
|
|
||||||
time.Sleep(time.Second)
|
|
||||||
}()
|
|
||||||
|
|
||||||
ip := net.ParseIP("10.20.0.100")
|
|
||||||
start := time.Now()
|
|
||||||
for i := 0; i < testMax; i++ {
|
|
||||||
port := &fw.Port{Values: []int{1000 + i}}
|
|
||||||
if i%2 == 0 {
|
|
||||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
|
||||||
} else {
|
|
||||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
|
||||||
}
|
|
||||||
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
|
||||||
}
|
|
||||||
t.Logf("execution avg per rule: %s", time.Since(start)/time.Duration(testMax))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,493 +0,0 @@
|
|||||||
package acl
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/md5"
|
|
||||||
"encoding/hex"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"strconv"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall"
|
|
||||||
"github.com/netbirdio/netbird/client/ssh"
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
|
||||||
)
|
|
||||||
|
|
||||||
// IFaceMapper defines subset methods of interface required for manager
|
|
||||||
type IFaceMapper interface {
|
|
||||||
Name() string
|
|
||||||
Address() iface.WGAddress
|
|
||||||
IsUserspaceBind() bool
|
|
||||||
SetFilter(iface.PacketFilter) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// Manager is a ACL rules manager
|
|
||||||
type Manager interface {
|
|
||||||
ApplyFiltering(networkMap *mgmProto.NetworkMap)
|
|
||||||
Stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
// DefaultManager uses firewall manager to handle
|
|
||||||
type DefaultManager struct {
|
|
||||||
manager firewall.Manager
|
|
||||||
ipsetCounter int
|
|
||||||
rulesPairs map[string][]firewall.Rule
|
|
||||||
mutex sync.Mutex
|
|
||||||
}
|
|
||||||
|
|
||||||
type ipsetInfo struct {
|
|
||||||
name string
|
|
||||||
ipCount int
|
|
||||||
}
|
|
||||||
|
|
||||||
func newDefaultManager(fm firewall.Manager) *DefaultManager {
|
|
||||||
return &DefaultManager{
|
|
||||||
manager: fm,
|
|
||||||
rulesPairs: make(map[string][]firewall.Rule),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ApplyFiltering firewall rules to the local firewall manager processed by ACL policy.
|
|
||||||
//
|
|
||||||
// If allowByDefault is ture it appends allow ALL traffic rules to input and output chains.
|
|
||||||
func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
|
|
||||||
d.mutex.Lock()
|
|
||||||
defer d.mutex.Unlock()
|
|
||||||
|
|
||||||
start := time.Now()
|
|
||||||
defer func() {
|
|
||||||
total := 0
|
|
||||||
for _, pairs := range d.rulesPairs {
|
|
||||||
total += len(pairs)
|
|
||||||
}
|
|
||||||
log.Infof(
|
|
||||||
"ACL rules processed in: %v, total rules count: %d",
|
|
||||||
time.Since(start), total)
|
|
||||||
}()
|
|
||||||
|
|
||||||
if d.manager == nil {
|
|
||||||
log.Debug("firewall manager is not supported, skipping firewall rules")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
if err := d.manager.Flush(); err != nil {
|
|
||||||
log.Error("failed to flush firewall rules: ", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
rules, squashedProtocols := d.squashAcceptRules(networkMap)
|
|
||||||
|
|
||||||
enableSSH := (networkMap.PeerConfig != nil &&
|
|
||||||
networkMap.PeerConfig.SshConfig != nil &&
|
|
||||||
networkMap.PeerConfig.SshConfig.SshEnabled)
|
|
||||||
if _, ok := squashedProtocols[mgmProto.FirewallRule_ALL]; ok {
|
|
||||||
enableSSH = enableSSH && !ok
|
|
||||||
}
|
|
||||||
if _, ok := squashedProtocols[mgmProto.FirewallRule_TCP]; ok {
|
|
||||||
enableSSH = enableSSH && !ok
|
|
||||||
}
|
|
||||||
|
|
||||||
// if TCP protocol rules not squashed and SSH enabled
|
|
||||||
// we add default firewall rule which accepts connection to any peer
|
|
||||||
// in the network by SSH (TCP 22 port).
|
|
||||||
if enableSSH {
|
|
||||||
rules = append(rules, &mgmProto.FirewallRule{
|
|
||||||
PeerIP: "0.0.0.0",
|
|
||||||
Direction: mgmProto.FirewallRule_IN,
|
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
|
||||||
Protocol: mgmProto.FirewallRule_TCP,
|
|
||||||
Port: strconv.Itoa(ssh.DefaultSSHPort),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// if we got empty rules list but management not set networkMap.FirewallRulesIsEmpty flag
|
|
||||||
// we have old version of management without rules handling, we should allow all traffic
|
|
||||||
if len(networkMap.FirewallRules) == 0 && !networkMap.FirewallRulesIsEmpty {
|
|
||||||
log.Warn("this peer is connected to a NetBird Management service with an older version. Allowing all traffic from connected peers")
|
|
||||||
rules = append(rules,
|
|
||||||
&mgmProto.FirewallRule{
|
|
||||||
PeerIP: "0.0.0.0",
|
|
||||||
Direction: mgmProto.FirewallRule_IN,
|
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
|
||||||
Protocol: mgmProto.FirewallRule_ALL,
|
|
||||||
},
|
|
||||||
&mgmProto.FirewallRule{
|
|
||||||
PeerIP: "0.0.0.0",
|
|
||||||
Direction: mgmProto.FirewallRule_OUT,
|
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
|
||||||
Protocol: mgmProto.FirewallRule_ALL,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
applyFailed := false
|
|
||||||
newRulePairs := make(map[string][]firewall.Rule)
|
|
||||||
ipsetByRuleSelectors := make(map[string]*ipsetInfo)
|
|
||||||
|
|
||||||
// calculate which IP's can be grouped in by which ipset
|
|
||||||
// to do that we use rule selector (which is just rule properties without IP's)
|
|
||||||
for _, r := range rules {
|
|
||||||
selector := d.getRuleGroupingSelector(r)
|
|
||||||
ipset, ok := ipsetByRuleSelectors[selector]
|
|
||||||
if !ok {
|
|
||||||
ipset = &ipsetInfo{}
|
|
||||||
}
|
|
||||||
|
|
||||||
ipset.ipCount++
|
|
||||||
ipsetByRuleSelectors[selector] = ipset
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, r := range rules {
|
|
||||||
// if this rule is member of rule selection with more than DefaultIPsCountForSet
|
|
||||||
// it's IP address can be used in the ipset for firewall manager which supports it
|
|
||||||
ipset := ipsetByRuleSelectors[d.getRuleGroupingSelector(r)]
|
|
||||||
ipsetName := ""
|
|
||||||
if ipset.name == "" {
|
|
||||||
d.ipsetCounter++
|
|
||||||
ipset.name = fmt.Sprintf("nb%07d", d.ipsetCounter)
|
|
||||||
}
|
|
||||||
ipsetName = ipset.name
|
|
||||||
pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to apply firewall rule: %+v, %v", r, err)
|
|
||||||
applyFailed = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
newRulePairs[pairID] = rulePair
|
|
||||||
}
|
|
||||||
if applyFailed {
|
|
||||||
log.Error("failed to apply firewall rules, rollback ACL to previous state")
|
|
||||||
for _, rules := range newRulePairs {
|
|
||||||
for _, rule := range rules {
|
|
||||||
if err := d.manager.DeleteRule(rule); err != nil {
|
|
||||||
log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.GetRuleID(), err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for pairID, rules := range d.rulesPairs {
|
|
||||||
if _, ok := newRulePairs[pairID]; !ok {
|
|
||||||
for _, rule := range rules {
|
|
||||||
if err := d.manager.DeleteRule(rule); err != nil {
|
|
||||||
log.Errorf("failed to delete firewall rule: %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
delete(d.rulesPairs, pairID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
d.rulesPairs = newRulePairs
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stop ACL controller and clear firewall state
|
|
||||||
func (d *DefaultManager) Stop() {
|
|
||||||
d.mutex.Lock()
|
|
||||||
defer d.mutex.Unlock()
|
|
||||||
|
|
||||||
if err := d.manager.Reset(); err != nil {
|
|
||||||
log.WithError(err).Error("reset firewall state")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *DefaultManager) protoRuleToFirewallRule(
|
|
||||||
r *mgmProto.FirewallRule,
|
|
||||||
ipsetName string,
|
|
||||||
) (string, []firewall.Rule, error) {
|
|
||||||
ip := net.ParseIP(r.PeerIP)
|
|
||||||
if ip == nil {
|
|
||||||
return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule")
|
|
||||||
}
|
|
||||||
|
|
||||||
protocol := convertToFirewallProtocol(r.Protocol)
|
|
||||||
if protocol == firewall.ProtocolUnknown {
|
|
||||||
return "", nil, fmt.Errorf("invalid protocol type: %d, skipping firewall rule", r.Protocol)
|
|
||||||
}
|
|
||||||
|
|
||||||
action := convertFirewallAction(r.Action)
|
|
||||||
if action == firewall.ActionUnknown {
|
|
||||||
return "", nil, fmt.Errorf("invalid action type: %d, skipping firewall rule", r.Action)
|
|
||||||
}
|
|
||||||
|
|
||||||
var port *firewall.Port
|
|
||||||
if r.Port != "" {
|
|
||||||
value, err := strconv.Atoi(r.Port)
|
|
||||||
if err != nil {
|
|
||||||
return "", nil, fmt.Errorf("invalid port, skipping firewall rule")
|
|
||||||
}
|
|
||||||
port = &firewall.Port{
|
|
||||||
Values: []int{value},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ruleID := d.getRuleID(ip, protocol, int(r.Direction), port, action, "")
|
|
||||||
if rulesPair, ok := d.rulesPairs[ruleID]; ok {
|
|
||||||
return ruleID, rulesPair, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var rules []firewall.Rule
|
|
||||||
var err error
|
|
||||||
switch r.Direction {
|
|
||||||
case mgmProto.FirewallRule_IN:
|
|
||||||
rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "")
|
|
||||||
case mgmProto.FirewallRule_OUT:
|
|
||||||
rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "")
|
|
||||||
default:
|
|
||||||
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return "", nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
d.rulesPairs[ruleID] = rules
|
|
||||||
return ruleID, rules, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *DefaultManager) addInRules(
|
|
||||||
ip net.IP,
|
|
||||||
protocol firewall.Protocol,
|
|
||||||
port *firewall.Port,
|
|
||||||
action firewall.Action,
|
|
||||||
ipsetName string,
|
|
||||||
comment string,
|
|
||||||
) ([]firewall.Rule, error) {
|
|
||||||
var rules []firewall.Rule
|
|
||||||
rule, err := d.manager.AddFiltering(
|
|
||||||
ip, protocol, nil, port, firewall.RuleDirectionIN, action, ipsetName, comment)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
|
||||||
}
|
|
||||||
rules = append(rules, rule)
|
|
||||||
|
|
||||||
if shouldSkipInvertedRule(protocol, port) {
|
|
||||||
return rules, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
rule, err = d.manager.AddFiltering(
|
|
||||||
ip, protocol, port, nil, firewall.RuleDirectionOUT, action, ipsetName, comment)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return append(rules, rule), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *DefaultManager) addOutRules(
|
|
||||||
ip net.IP,
|
|
||||||
protocol firewall.Protocol,
|
|
||||||
port *firewall.Port,
|
|
||||||
action firewall.Action,
|
|
||||||
ipsetName string,
|
|
||||||
comment string,
|
|
||||||
) ([]firewall.Rule, error) {
|
|
||||||
var rules []firewall.Rule
|
|
||||||
rule, err := d.manager.AddFiltering(
|
|
||||||
ip, protocol, nil, port, firewall.RuleDirectionOUT, action, ipsetName, comment)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
|
||||||
}
|
|
||||||
rules = append(rules, rule)
|
|
||||||
|
|
||||||
if shouldSkipInvertedRule(protocol, port) {
|
|
||||||
return rules, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
rule, err = d.manager.AddFiltering(
|
|
||||||
ip, protocol, port, nil, firewall.RuleDirectionIN, action, ipsetName, comment)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return append(rules, rule), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// getRuleID() returns unique ID for the rule based on its parameters.
|
|
||||||
func (d *DefaultManager) getRuleID(
|
|
||||||
ip net.IP,
|
|
||||||
proto firewall.Protocol,
|
|
||||||
direction int,
|
|
||||||
port *firewall.Port,
|
|
||||||
action firewall.Action,
|
|
||||||
comment string,
|
|
||||||
) string {
|
|
||||||
idStr := ip.String() + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action)) + comment
|
|
||||||
if port != nil {
|
|
||||||
idStr += port.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
return hex.EncodeToString(md5.New().Sum([]byte(idStr)))
|
|
||||||
}
|
|
||||||
|
|
||||||
// squashAcceptRules does complex logic to convert many rules which allows connection by traffic type
|
|
||||||
// to all peers in the network map to one rule which just accepts that type of the traffic.
|
|
||||||
//
|
|
||||||
// NOTE: It will not squash two rules for same protocol if one covers all peers in the network,
|
|
||||||
// but other has port definitions or has drop policy.
|
|
||||||
func (d *DefaultManager) squashAcceptRules(
|
|
||||||
networkMap *mgmProto.NetworkMap,
|
|
||||||
) ([]*mgmProto.FirewallRule, map[mgmProto.FirewallRuleProtocol]struct{}) {
|
|
||||||
totalIPs := 0
|
|
||||||
for _, p := range append(networkMap.RemotePeers, networkMap.OfflinePeers...) {
|
|
||||||
for range p.AllowedIps {
|
|
||||||
totalIPs++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type protoMatch map[mgmProto.FirewallRuleProtocol]map[string]int
|
|
||||||
|
|
||||||
in := protoMatch{}
|
|
||||||
out := protoMatch{}
|
|
||||||
|
|
||||||
// trace which type of protocols was squashed
|
|
||||||
squashedRules := []*mgmProto.FirewallRule{}
|
|
||||||
squashedProtocols := map[mgmProto.FirewallRuleProtocol]struct{}{}
|
|
||||||
|
|
||||||
// this function we use to do calculation, can we squash the rules by protocol or not.
|
|
||||||
// We summ amount of Peers IP for given protocol we found in original rules list.
|
|
||||||
// But we zeroed the IP's for protocol if:
|
|
||||||
// 1. Any of the rule has DROP action type.
|
|
||||||
// 2. Any of rule contains Port.
|
|
||||||
//
|
|
||||||
// We zeroed this to notify squash function that this protocol can't be squashed.
|
|
||||||
addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols protoMatch) {
|
|
||||||
drop := r.Action == mgmProto.FirewallRule_DROP || r.Port != ""
|
|
||||||
if drop {
|
|
||||||
protocols[r.Protocol] = map[string]int{}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if _, ok := protocols[r.Protocol]; !ok {
|
|
||||||
protocols[r.Protocol] = map[string]int{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// special case, when we recieve this all network IP address
|
|
||||||
// it means that rules for that protocol was already optimized on the
|
|
||||||
// management side
|
|
||||||
if r.PeerIP == "0.0.0.0" {
|
|
||||||
squashedRules = append(squashedRules, r)
|
|
||||||
squashedProtocols[r.Protocol] = struct{}{}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ipset := protocols[r.Protocol]
|
|
||||||
|
|
||||||
if _, ok := ipset[r.PeerIP]; ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ipset[r.PeerIP] = i
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, r := range networkMap.FirewallRules {
|
|
||||||
// calculate squash for different directions
|
|
||||||
if r.Direction == mgmProto.FirewallRule_IN {
|
|
||||||
addRuleToCalculationMap(i, r, in)
|
|
||||||
} else {
|
|
||||||
addRuleToCalculationMap(i, r, out)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// order of squashing by protocol is important
|
|
||||||
// only for ther first element ALL, it must be done first
|
|
||||||
protocolOrders := []mgmProto.FirewallRuleProtocol{
|
|
||||||
mgmProto.FirewallRule_ALL,
|
|
||||||
mgmProto.FirewallRule_ICMP,
|
|
||||||
mgmProto.FirewallRule_TCP,
|
|
||||||
mgmProto.FirewallRule_UDP,
|
|
||||||
}
|
|
||||||
|
|
||||||
squash := func(matches protoMatch, direction mgmProto.FirewallRuleDirection) {
|
|
||||||
for _, protocol := range protocolOrders {
|
|
||||||
if ipset, ok := matches[protocol]; !ok || len(ipset) != totalIPs || len(ipset) < 2 {
|
|
||||||
// don't squash if :
|
|
||||||
// 1. Rules not cover all peers in the network
|
|
||||||
// 2. Rules cover only one peer in the network.
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// add special rule 0.0.0.0 which allows all IP's in our firewall implementations
|
|
||||||
squashedRules = append(squashedRules, &mgmProto.FirewallRule{
|
|
||||||
PeerIP: "0.0.0.0",
|
|
||||||
Direction: direction,
|
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
|
||||||
Protocol: protocol,
|
|
||||||
})
|
|
||||||
squashedProtocols[protocol] = struct{}{}
|
|
||||||
|
|
||||||
if protocol == mgmProto.FirewallRule_ALL {
|
|
||||||
// if we have ALL traffic type squashed rule
|
|
||||||
// it allows all other type of traffic, so we can stop processing
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
squash(in, mgmProto.FirewallRule_IN)
|
|
||||||
squash(out, mgmProto.FirewallRule_OUT)
|
|
||||||
|
|
||||||
// if all protocol was squashed everything is allow and we can ignore all other rules
|
|
||||||
if _, ok := squashedProtocols[mgmProto.FirewallRule_ALL]; ok {
|
|
||||||
return squashedRules, squashedProtocols
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(squashedRules) == 0 {
|
|
||||||
return networkMap.FirewallRules, squashedProtocols
|
|
||||||
}
|
|
||||||
|
|
||||||
var rules []*mgmProto.FirewallRule
|
|
||||||
// filter out rules which was squashed from final list
|
|
||||||
// if we also have other not squashed rules.
|
|
||||||
for i, r := range networkMap.FirewallRules {
|
|
||||||
if _, ok := squashedProtocols[r.Protocol]; ok {
|
|
||||||
if m, ok := in[r.Protocol]; ok && m[r.PeerIP] == i {
|
|
||||||
continue
|
|
||||||
} else if m, ok := out[r.Protocol]; ok && m[r.PeerIP] == i {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
rules = append(rules, r)
|
|
||||||
}
|
|
||||||
|
|
||||||
return append(rules, squashedRules...), squashedProtocols
|
|
||||||
}
|
|
||||||
|
|
||||||
// getRuleGroupingSelector takes all rule properties except IP address to build selector
|
|
||||||
func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) string {
|
|
||||||
return fmt.Sprintf("%v:%v:%v:%s", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port)
|
|
||||||
}
|
|
||||||
|
|
||||||
func convertToFirewallProtocol(protocol mgmProto.FirewallRuleProtocol) firewall.Protocol {
|
|
||||||
switch protocol {
|
|
||||||
case mgmProto.FirewallRule_TCP:
|
|
||||||
return firewall.ProtocolTCP
|
|
||||||
case mgmProto.FirewallRule_UDP:
|
|
||||||
return firewall.ProtocolUDP
|
|
||||||
case mgmProto.FirewallRule_ICMP:
|
|
||||||
return firewall.ProtocolICMP
|
|
||||||
case mgmProto.FirewallRule_ALL:
|
|
||||||
return firewall.ProtocolALL
|
|
||||||
default:
|
|
||||||
return firewall.ProtocolUnknown
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func shouldSkipInvertedRule(protocol firewall.Protocol, port *firewall.Port) bool {
|
|
||||||
return protocol == firewall.ProtocolALL || protocol == firewall.ProtocolICMP || port == nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func convertFirewallAction(action mgmProto.FirewallRuleAction) firewall.Action {
|
|
||||||
switch action {
|
|
||||||
case mgmProto.FirewallRule_ACCEPT:
|
|
||||||
return firewall.ActionAccept
|
|
||||||
case mgmProto.FirewallRule_DROP:
|
|
||||||
return firewall.ActionDrop
|
|
||||||
default:
|
|
||||||
return firewall.ActionUnknown
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,23 +0,0 @@
|
|||||||
//go:build !linux
|
|
||||||
|
|
||||||
package acl
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"runtime"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Create creates a firewall manager instance
|
|
||||||
func Create(iface IFaceMapper) (manager *DefaultManager, err error) {
|
|
||||||
if iface.IsUserspaceBind() {
|
|
||||||
// use userspace packet filtering firewall
|
|
||||||
fm, err := uspfilter.Create(iface)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return newDefaultManager(fm), nil
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
|
||||||
}
|
|
||||||
@@ -1,33 +0,0 @@
|
|||||||
package acl
|
|
||||||
|
|
||||||
import (
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall"
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/iptables"
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/nftables"
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Create creates a firewall manager instance for the Linux
|
|
||||||
func Create(iface IFaceMapper) (manager *DefaultManager, err error) {
|
|
||||||
var fm firewall.Manager
|
|
||||||
if iface.IsUserspaceBind() {
|
|
||||||
// use userspace packet filtering firewall
|
|
||||||
if fm, err = uspfilter.Create(iface); err != nil {
|
|
||||||
log.Debugf("failed to create userspace filtering firewall: %s", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if fm, err = nftables.Create(iface); err != nil {
|
|
||||||
log.Debugf("failed to create nftables manager: %s", err)
|
|
||||||
// fallback to iptables
|
|
||||||
if fm, err = iptables.Create(iface); err != nil {
|
|
||||||
log.Errorf("failed to create iptables manager: %s", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return newDefaultManager(fm), nil
|
|
||||||
}
|
|
||||||
@@ -1,333 +0,0 @@
|
|||||||
package acl
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
|
||||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestDefaultManager(t *testing.T) {
|
|
||||||
networkMap := &mgmProto.NetworkMap{
|
|
||||||
FirewallRules: []*mgmProto.FirewallRule{
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.1",
|
|
||||||
Direction: mgmProto.FirewallRule_OUT,
|
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
|
||||||
Protocol: mgmProto.FirewallRule_TCP,
|
|
||||||
Port: "80",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.2",
|
|
||||||
Direction: mgmProto.FirewallRule_OUT,
|
|
||||||
Action: mgmProto.FirewallRule_DROP,
|
|
||||||
Protocol: mgmProto.FirewallRule_UDP,
|
|
||||||
Port: "53",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
ctrl := gomock.NewController(t)
|
|
||||||
defer ctrl.Finish()
|
|
||||||
|
|
||||||
iface := mocks.NewMockIFaceMapper(ctrl)
|
|
||||||
iface.EXPECT().IsUserspaceBind().Return(true)
|
|
||||||
// iface.EXPECT().Name().Return("lo")
|
|
||||||
iface.EXPECT().SetFilter(gomock.Any())
|
|
||||||
|
|
||||||
// we receive one rule from the management so for testing purposes ignore it
|
|
||||||
acl, err := Create(iface)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("create ACL manager: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer acl.Stop()
|
|
||||||
|
|
||||||
t.Run("apply firewall rules", func(t *testing.T) {
|
|
||||||
acl.ApplyFiltering(networkMap)
|
|
||||||
|
|
||||||
if len(acl.rulesPairs) != 2 {
|
|
||||||
t.Errorf("firewall rules not applied: %v", acl.rulesPairs)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("add extra rules", func(t *testing.T) {
|
|
||||||
existedPairs := map[string]struct{}{}
|
|
||||||
for id := range acl.rulesPairs {
|
|
||||||
existedPairs[id] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// remove first rule
|
|
||||||
networkMap.FirewallRules = networkMap.FirewallRules[1:]
|
|
||||||
networkMap.FirewallRules = append(
|
|
||||||
networkMap.FirewallRules,
|
|
||||||
&mgmProto.FirewallRule{
|
|
||||||
PeerIP: "10.93.0.3",
|
|
||||||
Direction: mgmProto.FirewallRule_IN,
|
|
||||||
Action: mgmProto.FirewallRule_DROP,
|
|
||||||
Protocol: mgmProto.FirewallRule_ICMP,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
acl.ApplyFiltering(networkMap)
|
|
||||||
|
|
||||||
// we should have one old and one new rule in the existed rules
|
|
||||||
if len(acl.rulesPairs) != 2 {
|
|
||||||
t.Errorf("firewall rules not applied")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// check that old rule was removed
|
|
||||||
previousCount := 0
|
|
||||||
for id := range acl.rulesPairs {
|
|
||||||
if _, ok := existedPairs[id]; ok {
|
|
||||||
previousCount++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if previousCount != 1 {
|
|
||||||
t.Errorf("old rule was not removed")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("handle default rules", func(t *testing.T) {
|
|
||||||
networkMap.FirewallRules = networkMap.FirewallRules[:0]
|
|
||||||
|
|
||||||
networkMap.FirewallRulesIsEmpty = true
|
|
||||||
if acl.ApplyFiltering(networkMap); len(acl.rulesPairs) != 0 {
|
|
||||||
t.Errorf("rules should be empty if FirewallRulesIsEmpty is set, got: %v", len(acl.rulesPairs))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
networkMap.FirewallRulesIsEmpty = false
|
|
||||||
acl.ApplyFiltering(networkMap)
|
|
||||||
if len(acl.rulesPairs) != 2 {
|
|
||||||
t.Errorf("rules should contain 2 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.rulesPairs))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDefaultManagerSquashRules(t *testing.T) {
|
|
||||||
networkMap := &mgmProto.NetworkMap{
|
|
||||||
RemotePeers: []*mgmProto.RemotePeerConfig{
|
|
||||||
{AllowedIps: []string{"10.93.0.1"}},
|
|
||||||
{AllowedIps: []string{"10.93.0.2"}},
|
|
||||||
{AllowedIps: []string{"10.93.0.3"}},
|
|
||||||
{AllowedIps: []string{"10.93.0.4"}},
|
|
||||||
},
|
|
||||||
FirewallRules: []*mgmProto.FirewallRule{
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.1",
|
|
||||||
Direction: mgmProto.FirewallRule_IN,
|
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
|
||||||
Protocol: mgmProto.FirewallRule_ALL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.2",
|
|
||||||
Direction: mgmProto.FirewallRule_IN,
|
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
|
||||||
Protocol: mgmProto.FirewallRule_ALL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.3",
|
|
||||||
Direction: mgmProto.FirewallRule_IN,
|
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
|
||||||
Protocol: mgmProto.FirewallRule_ALL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.4",
|
|
||||||
Direction: mgmProto.FirewallRule_IN,
|
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
|
||||||
Protocol: mgmProto.FirewallRule_ALL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.1",
|
|
||||||
Direction: mgmProto.FirewallRule_OUT,
|
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
|
||||||
Protocol: mgmProto.FirewallRule_ALL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.2",
|
|
||||||
Direction: mgmProto.FirewallRule_OUT,
|
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
|
||||||
Protocol: mgmProto.FirewallRule_ALL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.3",
|
|
||||||
Direction: mgmProto.FirewallRule_OUT,
|
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
|
||||||
Protocol: mgmProto.FirewallRule_ALL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.4",
|
|
||||||
Direction: mgmProto.FirewallRule_OUT,
|
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
|
||||||
Protocol: mgmProto.FirewallRule_ALL,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
manager := &DefaultManager{}
|
|
||||||
rules, _ := manager.squashAcceptRules(networkMap)
|
|
||||||
if len(rules) != 2 {
|
|
||||||
t.Errorf("rules should contain 2, got: %v", rules)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
r := rules[0]
|
|
||||||
if r.PeerIP != "0.0.0.0" {
|
|
||||||
t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP)
|
|
||||||
return
|
|
||||||
} else if r.Direction != mgmProto.FirewallRule_IN {
|
|
||||||
t.Errorf("direction should be IN, got: %v", r.Direction)
|
|
||||||
return
|
|
||||||
} else if r.Protocol != mgmProto.FirewallRule_ALL {
|
|
||||||
t.Errorf("protocol should be ALL, got: %v", r.Protocol)
|
|
||||||
return
|
|
||||||
} else if r.Action != mgmProto.FirewallRule_ACCEPT {
|
|
||||||
t.Errorf("action should be ACCEPT, got: %v", r.Action)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
r = rules[1]
|
|
||||||
if r.PeerIP != "0.0.0.0" {
|
|
||||||
t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP)
|
|
||||||
return
|
|
||||||
} else if r.Direction != mgmProto.FirewallRule_OUT {
|
|
||||||
t.Errorf("direction should be OUT, got: %v", r.Direction)
|
|
||||||
return
|
|
||||||
} else if r.Protocol != mgmProto.FirewallRule_ALL {
|
|
||||||
t.Errorf("protocol should be ALL, got: %v", r.Protocol)
|
|
||||||
return
|
|
||||||
} else if r.Action != mgmProto.FirewallRule_ACCEPT {
|
|
||||||
t.Errorf("action should be ACCEPT, got: %v", r.Action)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
|
|
||||||
networkMap := &mgmProto.NetworkMap{
|
|
||||||
RemotePeers: []*mgmProto.RemotePeerConfig{
|
|
||||||
{AllowedIps: []string{"10.93.0.1"}},
|
|
||||||
{AllowedIps: []string{"10.93.0.2"}},
|
|
||||||
{AllowedIps: []string{"10.93.0.3"}},
|
|
||||||
{AllowedIps: []string{"10.93.0.4"}},
|
|
||||||
},
|
|
||||||
FirewallRules: []*mgmProto.FirewallRule{
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.1",
|
|
||||||
Direction: mgmProto.FirewallRule_IN,
|
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
|
||||||
Protocol: mgmProto.FirewallRule_ALL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.2",
|
|
||||||
Direction: mgmProto.FirewallRule_IN,
|
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
|
||||||
Protocol: mgmProto.FirewallRule_ALL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.3",
|
|
||||||
Direction: mgmProto.FirewallRule_IN,
|
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
|
||||||
Protocol: mgmProto.FirewallRule_ALL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.4",
|
|
||||||
Direction: mgmProto.FirewallRule_IN,
|
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
|
||||||
Protocol: mgmProto.FirewallRule_TCP,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.1",
|
|
||||||
Direction: mgmProto.FirewallRule_OUT,
|
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
|
||||||
Protocol: mgmProto.FirewallRule_ALL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.2",
|
|
||||||
Direction: mgmProto.FirewallRule_OUT,
|
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
|
||||||
Protocol: mgmProto.FirewallRule_ALL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.3",
|
|
||||||
Direction: mgmProto.FirewallRule_OUT,
|
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
|
||||||
Protocol: mgmProto.FirewallRule_ALL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.4",
|
|
||||||
Direction: mgmProto.FirewallRule_OUT,
|
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
|
||||||
Protocol: mgmProto.FirewallRule_UDP,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
manager := &DefaultManager{}
|
|
||||||
if rules, _ := manager.squashAcceptRules(networkMap); len(rules) != len(networkMap.FirewallRules) {
|
|
||||||
t.Errorf("we should got same amount of rules as intput, got %v", len(rules))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
|
||||||
networkMap := &mgmProto.NetworkMap{
|
|
||||||
PeerConfig: &mgmProto.PeerConfig{
|
|
||||||
SshConfig: &mgmProto.SSHConfig{
|
|
||||||
SshEnabled: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
RemotePeers: []*mgmProto.RemotePeerConfig{
|
|
||||||
{AllowedIps: []string{"10.93.0.1"}},
|
|
||||||
{AllowedIps: []string{"10.93.0.2"}},
|
|
||||||
{AllowedIps: []string{"10.93.0.3"}},
|
|
||||||
},
|
|
||||||
FirewallRules: []*mgmProto.FirewallRule{
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.1",
|
|
||||||
Direction: mgmProto.FirewallRule_IN,
|
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
|
||||||
Protocol: mgmProto.FirewallRule_TCP,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.2",
|
|
||||||
Direction: mgmProto.FirewallRule_IN,
|
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
|
||||||
Protocol: mgmProto.FirewallRule_TCP,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.3",
|
|
||||||
Direction: mgmProto.FirewallRule_OUT,
|
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
|
||||||
Protocol: mgmProto.FirewallRule_UDP,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
ctrl := gomock.NewController(t)
|
|
||||||
defer ctrl.Finish()
|
|
||||||
|
|
||||||
iface := mocks.NewMockIFaceMapper(ctrl)
|
|
||||||
iface.EXPECT().IsUserspaceBind().Return(true)
|
|
||||||
// iface.EXPECT().Name().Return("lo")
|
|
||||||
iface.EXPECT().SetFilter(gomock.Any())
|
|
||||||
|
|
||||||
// we receive one rule from the management so for testing purposes ignore it
|
|
||||||
acl, err := Create(iface)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("create ACL manager: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer acl.Stop()
|
|
||||||
|
|
||||||
acl.ApplyFiltering(networkMap)
|
|
||||||
|
|
||||||
if len(acl.rulesPairs) != 4 {
|
|
||||||
t.Errorf("expect 4 rules (last must be SSH), got: %d", len(acl.rulesPairs))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
## Mocks
|
|
||||||
|
|
||||||
To generate (or refresh) mocks from acl package please install [mockgen](https://github.com/golang/mock).
|
|
||||||
Run this command from the `./client/internal/acl` folder to update iface mapper interface mock:
|
|
||||||
```bash
|
|
||||||
mockgen -destination mocks/iface_mapper.go -package mocks . IFaceMapper
|
|
||||||
```
|
|
||||||
@@ -1,91 +0,0 @@
|
|||||||
// Code generated by MockGen. DO NOT EDIT.
|
|
||||||
// Source: github.com/netbirdio/netbird/client/internal/acl (interfaces: IFaceMapper)
|
|
||||||
|
|
||||||
// Package mocks is a generated GoMock package.
|
|
||||||
package mocks
|
|
||||||
|
|
||||||
import (
|
|
||||||
reflect "reflect"
|
|
||||||
|
|
||||||
gomock "github.com/golang/mock/gomock"
|
|
||||||
iface "github.com/netbirdio/netbird/iface"
|
|
||||||
)
|
|
||||||
|
|
||||||
// MockIFaceMapper is a mock of IFaceMapper interface.
|
|
||||||
type MockIFaceMapper struct {
|
|
||||||
ctrl *gomock.Controller
|
|
||||||
recorder *MockIFaceMapperMockRecorder
|
|
||||||
}
|
|
||||||
|
|
||||||
// MockIFaceMapperMockRecorder is the mock recorder for MockIFaceMapper.
|
|
||||||
type MockIFaceMapperMockRecorder struct {
|
|
||||||
mock *MockIFaceMapper
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewMockIFaceMapper creates a new mock instance.
|
|
||||||
func NewMockIFaceMapper(ctrl *gomock.Controller) *MockIFaceMapper {
|
|
||||||
mock := &MockIFaceMapper{ctrl: ctrl}
|
|
||||||
mock.recorder = &MockIFaceMapperMockRecorder{mock}
|
|
||||||
return mock
|
|
||||||
}
|
|
||||||
|
|
||||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
|
||||||
func (m *MockIFaceMapper) EXPECT() *MockIFaceMapperMockRecorder {
|
|
||||||
return m.recorder
|
|
||||||
}
|
|
||||||
|
|
||||||
// Address mocks base method.
|
|
||||||
func (m *MockIFaceMapper) Address() iface.WGAddress {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
ret := m.ctrl.Call(m, "Address")
|
|
||||||
ret0, _ := ret[0].(iface.WGAddress)
|
|
||||||
return ret0
|
|
||||||
}
|
|
||||||
|
|
||||||
// Address indicates an expected call of Address.
|
|
||||||
func (mr *MockIFaceMapperMockRecorder) Address() *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Address", reflect.TypeOf((*MockIFaceMapper)(nil).Address))
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsUserspaceBind mocks base method.
|
|
||||||
func (m *MockIFaceMapper) IsUserspaceBind() bool {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
ret := m.ctrl.Call(m, "IsUserspaceBind")
|
|
||||||
ret0, _ := ret[0].(bool)
|
|
||||||
return ret0
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsUserspaceBind indicates an expected call of IsUserspaceBind.
|
|
||||||
func (mr *MockIFaceMapperMockRecorder) IsUserspaceBind() *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsUserspaceBind", reflect.TypeOf((*MockIFaceMapper)(nil).IsUserspaceBind))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Name mocks base method.
|
|
||||||
func (m *MockIFaceMapper) Name() string {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
ret := m.ctrl.Call(m, "Name")
|
|
||||||
ret0, _ := ret[0].(string)
|
|
||||||
return ret0
|
|
||||||
}
|
|
||||||
|
|
||||||
// Name indicates an expected call of Name.
|
|
||||||
func (mr *MockIFaceMapperMockRecorder) Name() *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockIFaceMapper)(nil).Name))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetFilter mocks base method.
|
|
||||||
func (m *MockIFaceMapper) SetFilter(arg0 iface.PacketFilter) error {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
ret := m.ctrl.Call(m, "SetFilter", arg0)
|
|
||||||
ret0, _ := ret[0].(error)
|
|
||||||
return ret0
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetFilter indicates an expected call of SetFilter.
|
|
||||||
func (mr *MockIFaceMapperMockRecorder) SetFilter(arg0 interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetFilter", reflect.TypeOf((*MockIFaceMapper)(nil).SetFilter), arg0)
|
|
||||||
}
|
|
||||||
@@ -1,202 +0,0 @@
|
|||||||
package auth
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// HostedGrantType grant type for device flow on Hosted
|
|
||||||
const (
|
|
||||||
HostedGrantType = "urn:ietf:params:oauth:grant-type:device_code"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ OAuthFlow = &DeviceAuthorizationFlow{}
|
|
||||||
|
|
||||||
// DeviceAuthorizationFlow implements the OAuthFlow interface,
|
|
||||||
// for the Device Authorization Flow.
|
|
||||||
type DeviceAuthorizationFlow struct {
|
|
||||||
providerConfig internal.DeviceAuthProviderConfig
|
|
||||||
|
|
||||||
HTTPClient HTTPClient
|
|
||||||
}
|
|
||||||
|
|
||||||
// RequestDeviceCodePayload used for request device code payload for auth0
|
|
||||||
type RequestDeviceCodePayload struct {
|
|
||||||
Audience string `json:"audience"`
|
|
||||||
ClientID string `json:"client_id"`
|
|
||||||
Scope string `json:"scope"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// TokenRequestPayload used for requesting the auth0 token
|
|
||||||
type TokenRequestPayload struct {
|
|
||||||
GrantType string `json:"grant_type"`
|
|
||||||
DeviceCode string `json:"device_code,omitempty"`
|
|
||||||
ClientID string `json:"client_id"`
|
|
||||||
RefreshToken string `json:"refresh_token,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// TokenRequestResponse used for parsing Hosted token's response
|
|
||||||
type TokenRequestResponse struct {
|
|
||||||
Error string `json:"error"`
|
|
||||||
ErrorDescription string `json:"error_description"`
|
|
||||||
TokenInfo
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewDeviceAuthorizationFlow returns device authorization flow client
|
|
||||||
func NewDeviceAuthorizationFlow(config internal.DeviceAuthProviderConfig) (*DeviceAuthorizationFlow, error) {
|
|
||||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
|
||||||
httpTransport.MaxIdleConns = 5
|
|
||||||
|
|
||||||
httpClient := &http.Client{
|
|
||||||
Timeout: 10 * time.Second,
|
|
||||||
Transport: httpTransport,
|
|
||||||
}
|
|
||||||
|
|
||||||
return &DeviceAuthorizationFlow{
|
|
||||||
providerConfig: config,
|
|
||||||
HTTPClient: httpClient,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetClientID returns the provider client id
|
|
||||||
func (d *DeviceAuthorizationFlow) GetClientID(ctx context.Context) string {
|
|
||||||
return d.providerConfig.ClientID
|
|
||||||
}
|
|
||||||
|
|
||||||
// RequestAuthInfo requests a device code login flow information from Hosted
|
|
||||||
func (d *DeviceAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error) {
|
|
||||||
form := url.Values{}
|
|
||||||
form.Add("client_id", d.providerConfig.ClientID)
|
|
||||||
form.Add("audience", d.providerConfig.Audience)
|
|
||||||
form.Add("scope", d.providerConfig.Scope)
|
|
||||||
req, err := http.NewRequest("POST", d.providerConfig.DeviceAuthEndpoint,
|
|
||||||
strings.NewReader(form.Encode()))
|
|
||||||
if err != nil {
|
|
||||||
return AuthFlowInfo{}, fmt.Errorf("creating request failed with error: %v", err)
|
|
||||||
}
|
|
||||||
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
|
||||||
|
|
||||||
res, err := d.HTTPClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return AuthFlowInfo{}, fmt.Errorf("doing request failed with error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
defer res.Body.Close()
|
|
||||||
body, err := io.ReadAll(res.Body)
|
|
||||||
if err != nil {
|
|
||||||
return AuthFlowInfo{}, fmt.Errorf("reading body failed with error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if res.StatusCode != 200 {
|
|
||||||
return AuthFlowInfo{}, fmt.Errorf("request device code returned status %d error: %s", res.StatusCode, string(body))
|
|
||||||
}
|
|
||||||
|
|
||||||
deviceCode := AuthFlowInfo{}
|
|
||||||
err = json.Unmarshal(body, &deviceCode)
|
|
||||||
if err != nil {
|
|
||||||
return AuthFlowInfo{}, fmt.Errorf("unmarshaling response failed with error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fallback to the verification_uri if the IdP doesn't support verification_uri_complete
|
|
||||||
if deviceCode.VerificationURIComplete == "" {
|
|
||||||
deviceCode.VerificationURIComplete = deviceCode.VerificationURI
|
|
||||||
}
|
|
||||||
|
|
||||||
return deviceCode, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *DeviceAuthorizationFlow) requestToken(info AuthFlowInfo) (TokenRequestResponse, error) {
|
|
||||||
form := url.Values{}
|
|
||||||
form.Add("client_id", d.providerConfig.ClientID)
|
|
||||||
form.Add("grant_type", HostedGrantType)
|
|
||||||
form.Add("device_code", info.DeviceCode)
|
|
||||||
|
|
||||||
req, err := http.NewRequest("POST", d.providerConfig.TokenEndpoint, strings.NewReader(form.Encode()))
|
|
||||||
if err != nil {
|
|
||||||
return TokenRequestResponse{}, fmt.Errorf("failed to create request access token: %v", err)
|
|
||||||
}
|
|
||||||
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
|
||||||
|
|
||||||
res, err := d.HTTPClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return TokenRequestResponse{}, fmt.Errorf("failed to request access token with error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
err := res.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
body, err := io.ReadAll(res.Body)
|
|
||||||
if err != nil {
|
|
||||||
return TokenRequestResponse{}, fmt.Errorf("failed reading access token response body with error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if res.StatusCode > 499 {
|
|
||||||
return TokenRequestResponse{}, fmt.Errorf("access token response returned code: %s", string(body))
|
|
||||||
}
|
|
||||||
|
|
||||||
tokenResponse := TokenRequestResponse{}
|
|
||||||
err = json.Unmarshal(body, &tokenResponse)
|
|
||||||
if err != nil {
|
|
||||||
return TokenRequestResponse{}, fmt.Errorf("parsing token response failed with error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return tokenResponse, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// WaitToken waits user's login and authorize the app. Once the user's authorize
|
|
||||||
// it retrieves the access token from Hosted's endpoint and validates it before returning
|
|
||||||
func (d *DeviceAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error) {
|
|
||||||
interval := time.Duration(info.Interval) * time.Second
|
|
||||||
ticker := time.NewTicker(interval)
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return TokenInfo{}, ctx.Err()
|
|
||||||
case <-ticker.C:
|
|
||||||
|
|
||||||
tokenResponse, err := d.requestToken(info)
|
|
||||||
if err != nil {
|
|
||||||
return TokenInfo{}, fmt.Errorf("parsing token response failed with error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if tokenResponse.Error != "" {
|
|
||||||
if tokenResponse.Error == "authorization_pending" {
|
|
||||||
continue
|
|
||||||
} else if tokenResponse.Error == "slow_down" {
|
|
||||||
interval = interval + (3 * time.Second)
|
|
||||||
ticker.Reset(interval)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
return TokenInfo{}, fmt.Errorf(tokenResponse.ErrorDescription)
|
|
||||||
}
|
|
||||||
|
|
||||||
tokenInfo := TokenInfo{
|
|
||||||
AccessToken: tokenResponse.AccessToken,
|
|
||||||
TokenType: tokenResponse.TokenType,
|
|
||||||
RefreshToken: tokenResponse.RefreshToken,
|
|
||||||
IDToken: tokenResponse.IDToken,
|
|
||||||
ExpiresIn: tokenResponse.ExpiresIn,
|
|
||||||
UseIDToken: d.providerConfig.UseIDToken,
|
|
||||||
}
|
|
||||||
|
|
||||||
err = isValidAccessToken(tokenInfo.GetTokenToUse(), d.providerConfig.Audience)
|
|
||||||
if err != nil {
|
|
||||||
return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return tokenInfo, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,90 +0,0 @@
|
|||||||
package auth
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
gstatus "google.golang.org/grpc/status"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
|
||||||
)
|
|
||||||
|
|
||||||
// OAuthFlow represents an interface for authorization using different OAuth 2.0 flows
|
|
||||||
type OAuthFlow interface {
|
|
||||||
RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error)
|
|
||||||
WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error)
|
|
||||||
GetClientID(ctx context.Context) string
|
|
||||||
}
|
|
||||||
|
|
||||||
// HTTPClient http client interface for API calls
|
|
||||||
type HTTPClient interface {
|
|
||||||
Do(req *http.Request) (*http.Response, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AuthFlowInfo holds information for the OAuth 2.0 authorization flow
|
|
||||||
type AuthFlowInfo struct {
|
|
||||||
DeviceCode string `json:"device_code"`
|
|
||||||
UserCode string `json:"user_code"`
|
|
||||||
VerificationURI string `json:"verification_uri"`
|
|
||||||
VerificationURIComplete string `json:"verification_uri_complete"`
|
|
||||||
ExpiresIn int `json:"expires_in"`
|
|
||||||
Interval int `json:"interval"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Claims used when validating the access token
|
|
||||||
type Claims struct {
|
|
||||||
Audience interface{} `json:"aud"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// TokenInfo holds information of issued access token
|
|
||||||
type TokenInfo struct {
|
|
||||||
AccessToken string `json:"access_token"`
|
|
||||||
RefreshToken string `json:"refresh_token"`
|
|
||||||
IDToken string `json:"id_token"`
|
|
||||||
TokenType string `json:"token_type"`
|
|
||||||
ExpiresIn int `json:"expires_in"`
|
|
||||||
UseIDToken bool `json:"-"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetTokenToUse returns either the access or id token based on UseIDToken field
|
|
||||||
func (t TokenInfo) GetTokenToUse() string {
|
|
||||||
if t.UseIDToken {
|
|
||||||
return t.IDToken
|
|
||||||
}
|
|
||||||
return t.AccessToken
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewOAuthFlow initializes and returns the appropriate OAuth flow based on the management configuration.
|
|
||||||
func NewOAuthFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) {
|
|
||||||
log.Debug("getting device authorization flow info")
|
|
||||||
|
|
||||||
// Try to initialize the Device Authorization Flow
|
|
||||||
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
|
|
||||||
if err == nil {
|
|
||||||
return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("getting device authorization flow info failed with error: %v", err)
|
|
||||||
log.Debugf("falling back to pkce authorization flow info")
|
|
||||||
|
|
||||||
// If Device Authorization Flow failed, try the PKCE Authorization Flow
|
|
||||||
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
|
|
||||||
if err != nil {
|
|
||||||
s, ok := gstatus.FromError(err)
|
|
||||||
if ok && s.Code() == codes.NotFound {
|
|
||||||
return nil, fmt.Errorf("no SSO provider returned from management. " +
|
|
||||||
"If you are using hosting Netbird see documentation at " +
|
|
||||||
"https://github.com/netbirdio/netbird/tree/main/management for details")
|
|
||||||
} else if ok && s.Code() == codes.Unimplemented {
|
|
||||||
return nil, fmt.Errorf("the management server, %s, does not support SSO providers, "+
|
|
||||||
"please update your server or use Setup Keys to login", config.ManagementURL)
|
|
||||||
} else {
|
|
||||||
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
|
|
||||||
}
|
|
||||||
@@ -1,238 +0,0 @@
|
|||||||
package auth
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto/sha256"
|
|
||||||
"crypto/subtle"
|
|
||||||
"encoding/base64"
|
|
||||||
"fmt"
|
|
||||||
"html/template"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"golang.org/x/oauth2"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/templates"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ OAuthFlow = &PKCEAuthorizationFlow{}
|
|
||||||
|
|
||||||
const (
|
|
||||||
queryState = "state"
|
|
||||||
queryCode = "code"
|
|
||||||
defaultPKCETimeoutSeconds = 300
|
|
||||||
)
|
|
||||||
|
|
||||||
// PKCEAuthorizationFlow implements the OAuthFlow interface for
|
|
||||||
// the Authorization Code Flow with PKCE.
|
|
||||||
type PKCEAuthorizationFlow struct {
|
|
||||||
providerConfig internal.PKCEAuthProviderConfig
|
|
||||||
state string
|
|
||||||
codeVerifier string
|
|
||||||
oAuthConfig *oauth2.Config
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewPKCEAuthorizationFlow returns new PKCE authorization code flow.
|
|
||||||
func NewPKCEAuthorizationFlow(config internal.PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
|
|
||||||
var availableRedirectURL string
|
|
||||||
|
|
||||||
// find the first available redirect URL
|
|
||||||
for _, redirectURL := range config.RedirectURLs {
|
|
||||||
if !isRedirectURLPortUsed(redirectURL) {
|
|
||||||
availableRedirectURL = redirectURL
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if availableRedirectURL == "" {
|
|
||||||
return nil, fmt.Errorf("no available port found from configured redirect URLs: %q", config.RedirectURLs)
|
|
||||||
}
|
|
||||||
|
|
||||||
cfg := &oauth2.Config{
|
|
||||||
ClientID: config.ClientID,
|
|
||||||
ClientSecret: config.ClientSecret,
|
|
||||||
Endpoint: oauth2.Endpoint{
|
|
||||||
AuthURL: config.AuthorizationEndpoint,
|
|
||||||
TokenURL: config.TokenEndpoint,
|
|
||||||
},
|
|
||||||
RedirectURL: availableRedirectURL,
|
|
||||||
Scopes: strings.Split(config.Scope, " "),
|
|
||||||
}
|
|
||||||
|
|
||||||
return &PKCEAuthorizationFlow{
|
|
||||||
providerConfig: config,
|
|
||||||
oAuthConfig: cfg,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetClientID returns the provider client id
|
|
||||||
func (p *PKCEAuthorizationFlow) GetClientID(_ context.Context) string {
|
|
||||||
return p.providerConfig.ClientID
|
|
||||||
}
|
|
||||||
|
|
||||||
// RequestAuthInfo requests a authorization code login flow information.
|
|
||||||
func (p *PKCEAuthorizationFlow) RequestAuthInfo(_ context.Context) (AuthFlowInfo, error) {
|
|
||||||
state, err := randomBytesInHex(24)
|
|
||||||
if err != nil {
|
|
||||||
return AuthFlowInfo{}, fmt.Errorf("could not generate random state: %v", err)
|
|
||||||
}
|
|
||||||
p.state = state
|
|
||||||
|
|
||||||
codeVerifier, err := randomBytesInHex(64)
|
|
||||||
if err != nil {
|
|
||||||
return AuthFlowInfo{}, fmt.Errorf("could not create a code verifier: %v", err)
|
|
||||||
}
|
|
||||||
p.codeVerifier = codeVerifier
|
|
||||||
|
|
||||||
codeChallenge := createCodeChallenge(codeVerifier)
|
|
||||||
authURL := p.oAuthConfig.AuthCodeURL(
|
|
||||||
state,
|
|
||||||
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
|
|
||||||
oauth2.SetAuthURLParam("code_challenge", codeChallenge),
|
|
||||||
oauth2.SetAuthURLParam("audience", p.providerConfig.Audience),
|
|
||||||
)
|
|
||||||
|
|
||||||
return AuthFlowInfo{
|
|
||||||
VerificationURIComplete: authURL,
|
|
||||||
ExpiresIn: defaultPKCETimeoutSeconds,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// WaitToken waits for the OAuth token in the PKCE Authorization Flow.
|
|
||||||
// It starts an HTTP server to receive the OAuth token callback and waits for the token or an error.
|
|
||||||
// Once the token is received, it is converted to TokenInfo and validated before returning.
|
|
||||||
func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (TokenInfo, error) {
|
|
||||||
tokenChan := make(chan *oauth2.Token, 1)
|
|
||||||
errChan := make(chan error, 1)
|
|
||||||
|
|
||||||
go p.startServer(tokenChan, errChan)
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return TokenInfo{}, ctx.Err()
|
|
||||||
case token := <-tokenChan:
|
|
||||||
return p.handleOAuthToken(token)
|
|
||||||
case err := <-errChan:
|
|
||||||
return TokenInfo{}, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *PKCEAuthorizationFlow) startServer(tokenChan chan<- *oauth2.Token, errChan chan<- error) {
|
|
||||||
parsedURL, err := url.Parse(p.oAuthConfig.RedirectURL)
|
|
||||||
if err != nil {
|
|
||||||
errChan <- fmt.Errorf("failed to parse redirect URL: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
port := parsedURL.Port()
|
|
||||||
|
|
||||||
server := http.Server{Addr: fmt.Sprintf(":%s", port)}
|
|
||||||
defer func() {
|
|
||||||
if err := server.Shutdown(context.Background()); err != nil {
|
|
||||||
log.Errorf("error while shutting down pkce flow server: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
http.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
|
|
||||||
tokenValidatorFunc := func() (*oauth2.Token, error) {
|
|
||||||
query := req.URL.Query()
|
|
||||||
|
|
||||||
state := query.Get(queryState)
|
|
||||||
// Prevent timing attacks on state
|
|
||||||
if subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 {
|
|
||||||
return nil, fmt.Errorf("invalid state")
|
|
||||||
}
|
|
||||||
|
|
||||||
code := query.Get(queryCode)
|
|
||||||
if code == "" {
|
|
||||||
return nil, fmt.Errorf("missing code")
|
|
||||||
}
|
|
||||||
|
|
||||||
return p.oAuthConfig.Exchange(
|
|
||||||
req.Context(),
|
|
||||||
code,
|
|
||||||
oauth2.SetAuthURLParam("code_verifier", p.codeVerifier),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
token, err := tokenValidatorFunc()
|
|
||||||
if err != nil {
|
|
||||||
errChan <- fmt.Errorf("PKCE authorization flow failed: %v", err)
|
|
||||||
renderPKCEFlowTmpl(w, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
tokenChan <- token
|
|
||||||
renderPKCEFlowTmpl(w, nil)
|
|
||||||
})
|
|
||||||
|
|
||||||
if err := server.ListenAndServe(); err != nil {
|
|
||||||
errChan <- err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *PKCEAuthorizationFlow) handleOAuthToken(token *oauth2.Token) (TokenInfo, error) {
|
|
||||||
tokenInfo := TokenInfo{
|
|
||||||
AccessToken: token.AccessToken,
|
|
||||||
RefreshToken: token.RefreshToken,
|
|
||||||
TokenType: token.TokenType,
|
|
||||||
ExpiresIn: token.Expiry.Second(),
|
|
||||||
UseIDToken: p.providerConfig.UseIDToken,
|
|
||||||
}
|
|
||||||
if idToken, ok := token.Extra("id_token").(string); ok {
|
|
||||||
tokenInfo.IDToken = idToken
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := isValidAccessToken(tokenInfo.GetTokenToUse(), p.providerConfig.Audience); err != nil {
|
|
||||||
return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return tokenInfo, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func createCodeChallenge(codeVerifier string) string {
|
|
||||||
sha2 := sha256.Sum256([]byte(codeVerifier))
|
|
||||||
return base64.RawURLEncoding.EncodeToString(sha2[:])
|
|
||||||
}
|
|
||||||
|
|
||||||
// isRedirectURLPortUsed checks if the port used in the redirect URL is in use.
|
|
||||||
func isRedirectURLPortUsed(redirectURL string) bool {
|
|
||||||
parsedURL, err := url.Parse(redirectURL)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to parse redirect URL: %v", err)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
addr := fmt.Sprintf(":%s", parsedURL.Port())
|
|
||||||
conn, err := net.DialTimeout("tcp", addr, 3*time.Second)
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if err := conn.Close(); err != nil {
|
|
||||||
log.Errorf("error while closing the connection: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func renderPKCEFlowTmpl(w http.ResponseWriter, authError error) {
|
|
||||||
tmpl, err := template.New("pkce-auth-flow").Parse(templates.PKCEAuthMsgTmpl)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
data := make(map[string]string)
|
|
||||||
if authError != nil {
|
|
||||||
data["Error"] = authError.Error()
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := tmpl.Execute(w, data); err != nil {
|
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,62 +0,0 @@
|
|||||||
package auth
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/rand"
|
|
||||||
"encoding/base64"
|
|
||||||
"encoding/hex"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"reflect"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
func randomBytesInHex(count int) (string, error) {
|
|
||||||
buf := make([]byte, count)
|
|
||||||
_, err := io.ReadFull(rand.Reader, buf)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("could not generate %d random bytes: %v", count, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return hex.EncodeToString(buf), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// isValidAccessToken is a simple validation of the access token
|
|
||||||
func isValidAccessToken(token string, audience string) error {
|
|
||||||
if token == "" {
|
|
||||||
return fmt.Errorf("token received is empty")
|
|
||||||
}
|
|
||||||
|
|
||||||
encodedClaims := strings.Split(token, ".")[1]
|
|
||||||
claimsString, err := base64.RawURLEncoding.DecodeString(encodedClaims)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
claims := Claims{}
|
|
||||||
err = json.Unmarshal(claimsString, &claims)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if claims.Audience == nil {
|
|
||||||
return fmt.Errorf("required token field audience is absent")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Audience claim of JWT can be a string or an array of strings
|
|
||||||
typ := reflect.TypeOf(claims.Audience)
|
|
||||||
switch typ.Kind() {
|
|
||||||
case reflect.String:
|
|
||||||
if claims.Audience == audience {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
case reflect.Slice:
|
|
||||||
for _, aud := range claims.Audience.([]interface{}) {
|
|
||||||
if audience == aud {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Errorf("invalid JWT token audience field")
|
|
||||||
}
|
|
||||||
@@ -215,13 +215,11 @@ func update(input ConfigInput) (*Config, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if input.PreSharedKey != nil && config.PreSharedKey != *input.PreSharedKey {
|
if input.PreSharedKey != nil && config.PreSharedKey != *input.PreSharedKey {
|
||||||
if *input.PreSharedKey != "" {
|
log.Infof("new pre-shared key provided, updated to %s (old value %s)",
|
||||||
log.Infof("new pre-shared key provides, updated to %s (old value %s)",
|
|
||||||
*input.PreSharedKey, config.PreSharedKey)
|
*input.PreSharedKey, config.PreSharedKey)
|
||||||
config.PreSharedKey = *input.PreSharedKey
|
config.PreSharedKey = *input.PreSharedKey
|
||||||
refresh = true
|
refresh = true
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if config.SSHKey == "" {
|
if config.SSHKey == "" {
|
||||||
pem, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
pem, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||||
|
|||||||
@@ -63,22 +63,7 @@ func TestGetConfig(t *testing.T) {
|
|||||||
assert.Equal(t, config.ManagementURL.String(), managementURL)
|
assert.Equal(t, config.ManagementURL.String(), managementURL)
|
||||||
assert.Equal(t, config.PreSharedKey, preSharedKey)
|
assert.Equal(t, config.PreSharedKey, preSharedKey)
|
||||||
|
|
||||||
// case 4: new empty pre-shared key config -> fetch it
|
// case 4: existing config, but new managementURL has been provided -> update config
|
||||||
newPreSharedKey := ""
|
|
||||||
config, err = UpdateOrCreateConfig(ConfigInput{
|
|
||||||
ManagementURL: managementURL,
|
|
||||||
AdminURL: adminURL,
|
|
||||||
ConfigPath: path,
|
|
||||||
PreSharedKey: &newPreSharedKey,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.Equal(t, config.ManagementURL.String(), managementURL)
|
|
||||||
assert.Equal(t, config.PreSharedKey, preSharedKey)
|
|
||||||
|
|
||||||
// case 5: existing config, but new managementURL has been provided -> update config
|
|
||||||
newManagementURL := "https://test.newManagement.url:33071"
|
newManagementURL := "https://test.newManagement.url:33071"
|
||||||
config, err = UpdateOrCreateConfig(ConfigInput{
|
config, err = UpdateOrCreateConfig(ConfigInput{
|
||||||
ManagementURL: newManagementURL,
|
ManagementURL: newManagementURL,
|
||||||
|
|||||||
@@ -12,9 +12,7 @@ import (
|
|||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
gstatus "google.golang.org/grpc/status"
|
gstatus "google.golang.org/grpc/status"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
"github.com/netbirdio/netbird/client/ssh"
|
"github.com/netbirdio/netbird/client/ssh"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
@@ -25,24 +23,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// RunClient with main logic.
|
// RunClient with main logic.
|
||||||
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status) error {
|
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover) error {
|
||||||
return runClient(ctx, config, statusRecorder, MobileDependency{})
|
|
||||||
}
|
|
||||||
|
|
||||||
// RunClientMobile with main logic on mobile system
|
|
||||||
func RunClientMobile(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover, routeListener routemanager.RouteListener, dnsAddresses []string, dnsReadyListener dns.ReadyListener) error {
|
|
||||||
// in case of non Android os these variables will be nil
|
|
||||||
mobileDependency := MobileDependency{
|
|
||||||
TunAdapter: tunAdapter,
|
|
||||||
IFaceDiscover: iFaceDiscover,
|
|
||||||
RouteListener: routeListener,
|
|
||||||
HostDNSAddresses: dnsAddresses,
|
|
||||||
DnsReadyListener: dnsReadyListener,
|
|
||||||
}
|
|
||||||
return runClient(ctx, config, statusRecorder, mobileDependency)
|
|
||||||
}
|
|
||||||
|
|
||||||
func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status, mobileDependency MobileDependency) error {
|
|
||||||
backOff := &backoff.ExponentialBackOff{
|
backOff := &backoff.ExponentialBackOff{
|
||||||
InitialInterval: time.Second,
|
InitialInterval: time.Second,
|
||||||
RandomizationFactor: 1,
|
RandomizationFactor: 1,
|
||||||
@@ -169,7 +150,13 @@ func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
|
|||||||
return wrapErr(err)
|
return wrapErr(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, engineConfig, mobileDependency, statusRecorder)
|
md, err := newMobileDependency(tunAdapter, iFaceDiscover, mgmClient)
|
||||||
|
if err != nil {
|
||||||
|
log.Error(err)
|
||||||
|
return wrapErr(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, engineConfig, md, statusRecorder)
|
||||||
err = engine.Start()
|
err = engine.Start()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("error while starting Netbird Connection Engine: %s", err)
|
log.Errorf("error while starting Netbird Connection Engine: %s", err)
|
||||||
|
|||||||
@@ -16,11 +16,11 @@ import (
|
|||||||
// DeviceAuthorizationFlow represents Device Authorization Flow information
|
// DeviceAuthorizationFlow represents Device Authorization Flow information
|
||||||
type DeviceAuthorizationFlow struct {
|
type DeviceAuthorizationFlow struct {
|
||||||
Provider string
|
Provider string
|
||||||
ProviderConfig DeviceAuthProviderConfig
|
ProviderConfig ProviderConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeviceAuthProviderConfig has all attributes needed to initiate a device authorization flow
|
// ProviderConfig has all attributes needed to initiate a device authorization flow
|
||||||
type DeviceAuthProviderConfig struct {
|
type ProviderConfig struct {
|
||||||
// ClientID An IDP application client id
|
// ClientID An IDP application client id
|
||||||
ClientID string
|
ClientID string
|
||||||
// ClientSecret An IDP application client secret
|
// ClientSecret An IDP application client secret
|
||||||
@@ -88,7 +88,7 @@ func GetDeviceAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmU
|
|||||||
deviceAuthorizationFlow := DeviceAuthorizationFlow{
|
deviceAuthorizationFlow := DeviceAuthorizationFlow{
|
||||||
Provider: protoDeviceAuthorizationFlow.Provider.String(),
|
Provider: protoDeviceAuthorizationFlow.Provider.String(),
|
||||||
|
|
||||||
ProviderConfig: DeviceAuthProviderConfig{
|
ProviderConfig: ProviderConfig{
|
||||||
Audience: protoDeviceAuthorizationFlow.GetProviderConfig().GetAudience(),
|
Audience: protoDeviceAuthorizationFlow.GetProviderConfig().GetAudience(),
|
||||||
ClientID: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientID(),
|
ClientID: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientID(),
|
||||||
ClientSecret: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientSecret(),
|
ClientSecret: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientSecret(),
|
||||||
@@ -105,7 +105,7 @@ func GetDeviceAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmU
|
|||||||
deviceAuthorizationFlow.ProviderConfig.Scope = "openid"
|
deviceAuthorizationFlow.ProviderConfig.Scope = "openid"
|
||||||
}
|
}
|
||||||
|
|
||||||
err = isDeviceAuthProviderConfigValid(deviceAuthorizationFlow.ProviderConfig)
|
err = isProviderConfigValid(deviceAuthorizationFlow.ProviderConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return DeviceAuthorizationFlow{}, err
|
return DeviceAuthorizationFlow{}, err
|
||||||
}
|
}
|
||||||
@@ -113,7 +113,7 @@ func GetDeviceAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmU
|
|||||||
return deviceAuthorizationFlow, nil
|
return deviceAuthorizationFlow, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func isDeviceAuthProviderConfigValid(config DeviceAuthProviderConfig) error {
|
func isProviderConfigValid(config ProviderConfig) error {
|
||||||
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
||||||
if config.Audience == "" {
|
if config.Audience == "" {
|
||||||
return fmt.Errorf(errorMSGFormat, "Audience")
|
return fmt.Errorf(errorMSGFormat, "Audience")
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
//go:build !android
|
|
||||||
|
|
||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
//go:build !android
|
|
||||||
|
|
||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -34,10 +32,6 @@ func newFileConfigurator() (hostManager, error) {
|
|||||||
return &fileConfigurator{}, nil
|
return &fileConfigurator{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *fileConfigurator) supportCustomPort() bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fileConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
func (f *fileConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
||||||
backupFileExist := false
|
backupFileExist := false
|
||||||
_, err := os.Stat(fileDefaultResolvConfBackupLocation)
|
_, err := os.Stat(fileDefaultResolvConfBackupLocation)
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ import (
|
|||||||
type hostManager interface {
|
type hostManager interface {
|
||||||
applyDNSConfig(config hostDNSConfig) error
|
applyDNSConfig(config hostDNSConfig) error
|
||||||
restoreHostDNS() error
|
restoreHostDNS() error
|
||||||
supportCustomPort() bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type hostDNSConfig struct {
|
type hostDNSConfig struct {
|
||||||
@@ -29,7 +28,6 @@ type domainConfig struct {
|
|||||||
type mockHostConfigurator struct {
|
type mockHostConfigurator struct {
|
||||||
applyDNSConfigFunc func(config hostDNSConfig) error
|
applyDNSConfigFunc func(config hostDNSConfig) error
|
||||||
restoreHostDNSFunc func() error
|
restoreHostDNSFunc func() error
|
||||||
supportCustomPortFunc func() bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockHostConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
func (m *mockHostConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
||||||
@@ -46,18 +44,10 @@ func (m *mockHostConfigurator) restoreHostDNS() error {
|
|||||||
return fmt.Errorf("method restoreHostDNS is not implemented")
|
return fmt.Errorf("method restoreHostDNS is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockHostConfigurator) supportCustomPort() bool {
|
|
||||||
if m.supportCustomPortFunc != nil {
|
|
||||||
return m.supportCustomPortFunc()
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func newNoopHostMocker() hostManager {
|
func newNoopHostMocker() hostManager {
|
||||||
return &mockHostConfigurator{
|
return &mockHostConfigurator{
|
||||||
applyDNSConfigFunc: func(config hostDNSConfig) error { return nil },
|
applyDNSConfigFunc: func(config hostDNSConfig) error { return nil },
|
||||||
restoreHostDNSFunc: func() error { return nil },
|
restoreHostDNSFunc: func() error { return nil },
|
||||||
supportCustomPortFunc: func() bool { return true },
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,20 +0,0 @@
|
|||||||
package dns
|
|
||||||
|
|
||||||
type androidHostManager struct {
|
|
||||||
}
|
|
||||||
|
|
||||||
func newHostManager(wgInterface WGIface) (hostManager, error) {
|
|
||||||
return &androidHostManager{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a androidHostManager) applyDNSConfig(config hostDNSConfig) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a androidHostManager) restoreHostDNS() error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a androidHostManager) supportCustomPort() bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -32,16 +33,12 @@ type systemConfigurator struct {
|
|||||||
createdKeys map[string]struct{}
|
createdKeys map[string]struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHostManager(_ WGIface) (hostManager, error) {
|
func newHostManager(_ *iface.WGIface) (hostManager, error) {
|
||||||
return &systemConfigurator{
|
return &systemConfigurator{
|
||||||
createdKeys: make(map[string]struct{}),
|
createdKeys: make(map[string]struct{}),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *systemConfigurator) supportCustomPort() bool {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *systemConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
func (s *systemConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
|
|||||||
@@ -1,14 +1,12 @@
|
|||||||
//go:build !android
|
|
||||||
|
|
||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -25,7 +23,7 @@ const (
|
|||||||
|
|
||||||
type osManagerType int
|
type osManagerType int
|
||||||
|
|
||||||
func newHostManager(wgInterface WGIface) (hostManager, error) {
|
func newHostManager(wgInterface *iface.WGIface) (hostManager, error) {
|
||||||
osManager, err := getOSDNSManagerType()
|
osManager, err := getOSDNSManagerType()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/sys/windows/registry"
|
"golang.org/x/sys/windows/registry"
|
||||||
)
|
)
|
||||||
@@ -31,7 +32,7 @@ type registryConfigurator struct {
|
|||||||
existingSearchDomains []string
|
existingSearchDomains []string
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHostManager(wgInterface WGIface) (hostManager, error) {
|
func newHostManager(wgInterface *iface.WGIface) (hostManager, error) {
|
||||||
guid, err := wgInterface.GetInterfaceGUIDString()
|
guid, err := wgInterface.GetInterfaceGUIDString()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -41,10 +42,6 @@ func newHostManager(wgInterface WGIface) (hostManager, error) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *registryConfigurator) supportCustomPort() bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *registryConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
func (r *registryConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
||||||
var err error
|
var err error
|
||||||
if config.routeAll {
|
if config.routeAll {
|
||||||
|
|||||||
@@ -2,12 +2,10 @@ package dns
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
type registrationMap map[string]struct{}
|
type registrationMap map[string]struct{}
|
||||||
@@ -17,12 +15,9 @@ type localResolver struct {
|
|||||||
records sync.Map
|
records sync.Map
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *localResolver) stop() {
|
|
||||||
}
|
|
||||||
|
|
||||||
// ServeDNS handles a DNS request
|
// ServeDNS handles a DNS request
|
||||||
func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
log.Tracef("received question: %#v", r.Question[0])
|
log.Tracef("received question: %#v\n", r.Question[0])
|
||||||
replyMessage := &dns.Msg{}
|
replyMessage := &dns.Msg{}
|
||||||
replyMessage.SetReply(r)
|
replyMessage.SetReply(r)
|
||||||
replyMessage.RecursionAvailable = true
|
replyMessage.RecursionAvailable = true
|
||||||
|
|||||||
@@ -7,17 +7,16 @@ import (
|
|||||||
|
|
||||||
// MockServer is the mock instance of a dns server
|
// MockServer is the mock instance of a dns server
|
||||||
type MockServer struct {
|
type MockServer struct {
|
||||||
InitializeFunc func() error
|
StartFunc func()
|
||||||
StopFunc func()
|
StopFunc func()
|
||||||
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
|
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize mock implementation of Initialize from Server interface
|
// Start mock implementation of Start from Server interface
|
||||||
func (m *MockServer) Initialize() error {
|
func (m *MockServer) Start() {
|
||||||
if m.InitializeFunc != nil {
|
if m.StartFunc != nil {
|
||||||
return m.InitializeFunc()
|
m.StartFunc()
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop mock implementation of Stop from Server interface
|
// Stop mock implementation of Stop from Server interface
|
||||||
@@ -27,15 +26,6 @@ func (m *MockServer) Stop() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockServer) DnsIP() string {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockServer) OnUpdatedHostDNSServer(strings []string) {
|
|
||||||
//TODO implement me
|
|
||||||
panic("implement me")
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateDNSServer mock implementation of UpdateDNSServer from Server interface
|
// UpdateDNSServer mock implementation of UpdateDNSServer from Server interface
|
||||||
func (m *MockServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
|
func (m *MockServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
|
||||||
if m.UpdateDNSServerFunc != nil {
|
if m.UpdateDNSServerFunc != nil {
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
//go:build !android
|
|
||||||
|
|
||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -13,6 +11,7 @@ import (
|
|||||||
"github.com/godbus/dbus/v5"
|
"github.com/godbus/dbus/v5"
|
||||||
"github.com/hashicorp/go-version"
|
"github.com/hashicorp/go-version"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -70,7 +69,7 @@ func (s networkManagerConnSettings) cleanDeprecatedSettings() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newNetworkManagerDbusConfigurator(wgInterface WGIface) (hostManager, error) {
|
func newNetworkManagerDbusConfigurator(wgInterface *iface.WGIface) (hostManager, error) {
|
||||||
obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode)
|
obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -89,10 +88,6 @@ func newNetworkManagerDbusConfigurator(wgInterface WGIface) (hostManager, error)
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *networkManagerDbusConfigurator) supportCustomPort() bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *networkManagerDbusConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
func (n *networkManagerDbusConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
||||||
connSettings, configVersion, err := n.getAppliedConnectionSettings()
|
connSettings, configVersion, err := n.getAppliedConnectionSettings()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
//go:build !android
|
|
||||||
|
|
||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -7,6 +5,7 @@ import (
|
|||||||
"os/exec"
|
"os/exec"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -16,16 +15,12 @@ type resolvconf struct {
|
|||||||
ifaceName string
|
ifaceName string
|
||||||
}
|
}
|
||||||
|
|
||||||
func newResolvConfConfigurator(wgInterface WGIface) (hostManager, error) {
|
func newResolvConfConfigurator(wgInterface *iface.WGIface) (hostManager, error) {
|
||||||
return &resolvconf{
|
return &resolvconf{
|
||||||
ifaceName: wgInterface.Name(),
|
ifaceName: wgInterface.Name(),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *resolvconf) supportCustomPort() bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *resolvconf) applyDNSConfig(config hostDNSConfig) error {
|
func (r *resolvconf) applyDNSConfig(config hostDNSConfig) error {
|
||||||
var err error
|
var err error
|
||||||
if !config.routeAll {
|
if !config.routeAll {
|
||||||
|
|||||||
@@ -1,103 +0,0 @@
|
|||||||
package dns
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
|
||||||
"github.com/google/gopacket/layers"
|
|
||||||
"github.com/miekg/dns"
|
|
||||||
"golang.zx2c4.com/wireguard/tun"
|
|
||||||
)
|
|
||||||
|
|
||||||
type responseWriter struct {
|
|
||||||
local net.Addr
|
|
||||||
remote net.Addr
|
|
||||||
packet gopacket.Packet
|
|
||||||
device tun.Device
|
|
||||||
}
|
|
||||||
|
|
||||||
// LocalAddr returns the net.Addr of the server
|
|
||||||
func (r *responseWriter) LocalAddr() net.Addr {
|
|
||||||
return r.local
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoteAddr returns the net.Addr of the client that sent the current request.
|
|
||||||
func (r *responseWriter) RemoteAddr() net.Addr {
|
|
||||||
return r.remote
|
|
||||||
}
|
|
||||||
|
|
||||||
// WriteMsg writes a reply back to the client.
|
|
||||||
func (r *responseWriter) WriteMsg(msg *dns.Msg) error {
|
|
||||||
buff, err := msg.Pack()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
_, err = r.Write(buff)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write writes a raw buffer back to the client.
|
|
||||||
func (r *responseWriter) Write(data []byte) (int, error) {
|
|
||||||
var ip gopacket.SerializableLayer
|
|
||||||
|
|
||||||
// Get the UDP layer
|
|
||||||
udpLayer := r.packet.Layer(layers.LayerTypeUDP)
|
|
||||||
udp := udpLayer.(*layers.UDP)
|
|
||||||
|
|
||||||
// Swap the source and destination addresses for the response
|
|
||||||
udp.SrcPort, udp.DstPort = udp.DstPort, udp.SrcPort
|
|
||||||
|
|
||||||
// Check if it's an IPv4 packet
|
|
||||||
if ipv4Layer := r.packet.Layer(layers.LayerTypeIPv4); ipv4Layer != nil {
|
|
||||||
ipv4 := ipv4Layer.(*layers.IPv4)
|
|
||||||
ipv4.SrcIP, ipv4.DstIP = ipv4.DstIP, ipv4.SrcIP
|
|
||||||
ip = ipv4
|
|
||||||
} else if ipv6Layer := r.packet.Layer(layers.LayerTypeIPv6); ipv6Layer != nil {
|
|
||||||
ipv6 := ipv6Layer.(*layers.IPv6)
|
|
||||||
ipv6.SrcIP, ipv6.DstIP = ipv6.DstIP, ipv6.SrcIP
|
|
||||||
ip = ipv6
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := udp.SetNetworkLayerForChecksum(ip.(gopacket.NetworkLayer)); err != nil {
|
|
||||||
return 0, fmt.Errorf("failed to set network layer for checksum: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Serialize the packet
|
|
||||||
buffer := gopacket.NewSerializeBuffer()
|
|
||||||
options := gopacket.SerializeOptions{
|
|
||||||
ComputeChecksums: true,
|
|
||||||
FixLengths: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
payload := gopacket.Payload(data)
|
|
||||||
err := gopacket.SerializeLayers(buffer, options, ip, udp, payload)
|
|
||||||
if err != nil {
|
|
||||||
return 0, fmt.Errorf("failed to serialize packet: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
send := buffer.Bytes()
|
|
||||||
sendBuffer := make([]byte, 40, len(send)+40)
|
|
||||||
sendBuffer = append(sendBuffer, send...)
|
|
||||||
|
|
||||||
return r.device.Write([][]byte{sendBuffer}, 40)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close closes the connection.
|
|
||||||
func (r *responseWriter) Close() error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// TsigStatus returns the status of the Tsig.
|
|
||||||
func (r *responseWriter) TsigStatus() error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// TsigTimersOnly sets the tsig timers only boolean.
|
|
||||||
func (r *responseWriter) TsigTimersOnly(bool) {
|
|
||||||
}
|
|
||||||
|
|
||||||
// Hijack lets the caller take over the connection.
|
|
||||||
// After a call to Hijack(), the DNS package will not do anything with the connection.
|
|
||||||
func (r *responseWriter) Hijack() {
|
|
||||||
}
|
|
||||||
@@ -1,93 +0,0 @@
|
|||||||
package dns
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
|
||||||
"github.com/google/gopacket"
|
|
||||||
"github.com/google/gopacket/layers"
|
|
||||||
"github.com/miekg/dns"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/iface/mocks"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestResponseWriterLocalAddr(t *testing.T) {
|
|
||||||
ctrl := gomock.NewController(t)
|
|
||||||
defer ctrl.Finish()
|
|
||||||
|
|
||||||
device := mocks.NewMockDevice(ctrl)
|
|
||||||
device.EXPECT().Write(gomock.Any(), gomock.Any())
|
|
||||||
|
|
||||||
request := &dns.Msg{
|
|
||||||
Question: []dns.Question{{
|
|
||||||
Name: "google.com.",
|
|
||||||
Qtype: dns.TypeA,
|
|
||||||
Qclass: dns.TypeA,
|
|
||||||
}},
|
|
||||||
}
|
|
||||||
|
|
||||||
replyMessage := &dns.Msg{}
|
|
||||||
replyMessage.SetReply(request)
|
|
||||||
replyMessage.RecursionAvailable = true
|
|
||||||
replyMessage.Rcode = dns.RcodeSuccess
|
|
||||||
replyMessage.Answer = []dns.RR{
|
|
||||||
&dns.A{
|
|
||||||
A: net.IPv4(8, 8, 8, 8),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
ipv4 := &layers.IPv4{
|
|
||||||
Protocol: layers.IPProtocolUDP,
|
|
||||||
SrcIP: net.IPv4(127, 0, 0, 1),
|
|
||||||
DstIP: net.IPv4(127, 0, 0, 2),
|
|
||||||
}
|
|
||||||
udp := &layers.UDP{
|
|
||||||
DstPort: 53,
|
|
||||||
SrcPort: 45223,
|
|
||||||
}
|
|
||||||
if err := udp.SetNetworkLayerForChecksum(ipv4); err != nil {
|
|
||||||
t.Error("failed to set network layer for checksum")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Serialize the packet
|
|
||||||
buffer := gopacket.NewSerializeBuffer()
|
|
||||||
options := gopacket.SerializeOptions{
|
|
||||||
ComputeChecksums: true,
|
|
||||||
FixLengths: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
requestData, err := request.Pack()
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("got an error while packing the request message, error: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
payload := gopacket.Payload(requestData)
|
|
||||||
|
|
||||||
if err := gopacket.SerializeLayers(buffer, options, ipv4, udp, payload); err != nil {
|
|
||||||
t.Errorf("failed to serialize packet: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
rw := &responseWriter{
|
|
||||||
local: &net.UDPAddr{
|
|
||||||
IP: net.IPv4(127, 0, 0, 1),
|
|
||||||
Port: 55223,
|
|
||||||
},
|
|
||||||
remote: &net.UDPAddr{
|
|
||||||
IP: net.IPv4(127, 0, 0, 1),
|
|
||||||
Port: 53,
|
|
||||||
},
|
|
||||||
packet: gopacket.NewPacket(
|
|
||||||
buffer.Bytes(),
|
|
||||||
layers.LayerTypeIPv4,
|
|
||||||
gopacket.Default,
|
|
||||||
),
|
|
||||||
device: device,
|
|
||||||
}
|
|
||||||
if err := rw.WriteMsg(replyMessage); err != nil {
|
|
||||||
t.Errorf("got an error while writing the local resolver response, error: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,469 +1,10 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"net/netip"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
|
||||||
"github.com/mitchellh/hashstructure/v2"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
|
|
||||||
type ReadyListener interface {
|
|
||||||
OnReady()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Server is a dns server interface
|
// Server is a dns server interface
|
||||||
type Server interface {
|
type Server interface {
|
||||||
Initialize() error
|
Start()
|
||||||
Stop()
|
Stop()
|
||||||
DnsIP() string
|
|
||||||
UpdateDNSServer(serial uint64, update nbdns.Config) error
|
UpdateDNSServer(serial uint64, update nbdns.Config) error
|
||||||
OnUpdatedHostDNSServer(strings []string)
|
|
||||||
}
|
|
||||||
|
|
||||||
type registeredHandlerMap map[string]handlerWithStop
|
|
||||||
|
|
||||||
// DefaultServer dns server object
|
|
||||||
type DefaultServer struct {
|
|
||||||
ctx context.Context
|
|
||||||
ctxCancel context.CancelFunc
|
|
||||||
mux sync.Mutex
|
|
||||||
service service
|
|
||||||
dnsMuxMap registeredHandlerMap
|
|
||||||
localResolver *localResolver
|
|
||||||
wgInterface WGIface
|
|
||||||
hostManager hostManager
|
|
||||||
updateSerial uint64
|
|
||||||
previousConfigHash uint64
|
|
||||||
currentConfig hostDNSConfig
|
|
||||||
|
|
||||||
// permanent related properties
|
|
||||||
permanent bool
|
|
||||||
hostsDnsList []string
|
|
||||||
hostsDnsListLock sync.Mutex
|
|
||||||
}
|
|
||||||
|
|
||||||
type handlerWithStop interface {
|
|
||||||
dns.Handler
|
|
||||||
stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
type muxUpdate struct {
|
|
||||||
domain string
|
|
||||||
handler handlerWithStop
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewDefaultServer returns a new dns server
|
|
||||||
func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress string) (*DefaultServer, error) {
|
|
||||||
var addrPort *netip.AddrPort
|
|
||||||
if customAddress != "" {
|
|
||||||
parsedAddrPort, err := netip.ParseAddrPort(customAddress)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to parse the custom dns address, got error: %s", err)
|
|
||||||
}
|
|
||||||
addrPort = &parsedAddrPort
|
|
||||||
}
|
|
||||||
|
|
||||||
var dnsService service
|
|
||||||
if wgInterface.IsUserspaceBind() {
|
|
||||||
dnsService = newServiceViaMemory(wgInterface)
|
|
||||||
} else {
|
|
||||||
dnsService = newServiceViaListener(wgInterface, addrPort)
|
|
||||||
}
|
|
||||||
|
|
||||||
return newDefaultServer(ctx, wgInterface, dnsService), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems
|
|
||||||
func NewDefaultServerPermanentUpstream(ctx context.Context, wgInterface WGIface, hostsDnsList []string) *DefaultServer {
|
|
||||||
log.Debugf("host dns address list is: %v", hostsDnsList)
|
|
||||||
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface))
|
|
||||||
ds.permanent = true
|
|
||||||
ds.hostsDnsList = hostsDnsList
|
|
||||||
ds.addHostRootZone()
|
|
||||||
setServerDns(ds)
|
|
||||||
return ds
|
|
||||||
}
|
|
||||||
|
|
||||||
func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service) *DefaultServer {
|
|
||||||
ctx, stop := context.WithCancel(ctx)
|
|
||||||
defaultServer := &DefaultServer{
|
|
||||||
ctx: ctx,
|
|
||||||
ctxCancel: stop,
|
|
||||||
service: dnsService,
|
|
||||||
dnsMuxMap: make(registeredHandlerMap),
|
|
||||||
localResolver: &localResolver{
|
|
||||||
registeredMap: make(registrationMap),
|
|
||||||
},
|
|
||||||
wgInterface: wgInterface,
|
|
||||||
}
|
|
||||||
|
|
||||||
return defaultServer
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize instantiate host manager and the dns service
|
|
||||||
func (s *DefaultServer) Initialize() (err error) {
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
if s.hostManager != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.permanent {
|
|
||||||
err = s.service.Listen()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
s.hostManager, err = newHostManager(s.wgInterface)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// DnsIP returns the DNS resolver server IP address
|
|
||||||
//
|
|
||||||
// When kernel space interface used it return real DNS server listener IP address
|
|
||||||
// For bind interface, fake DNS resolver address returned (second last IP address from Nebird network)
|
|
||||||
func (s *DefaultServer) DnsIP() string {
|
|
||||||
return s.service.RuntimeIP()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stop stops the server
|
|
||||||
func (s *DefaultServer) Stop() {
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
s.ctxCancel()
|
|
||||||
|
|
||||||
if s.hostManager != nil {
|
|
||||||
err := s.hostManager.restoreHostDNS()
|
|
||||||
if err != nil {
|
|
||||||
log.Error(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
s.service.Stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones
|
|
||||||
// It will be applied if the mgm server do not enforce DNS settings for root zone
|
|
||||||
func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) {
|
|
||||||
s.hostsDnsListLock.Lock()
|
|
||||||
defer s.hostsDnsListLock.Unlock()
|
|
||||||
|
|
||||||
s.hostsDnsList = hostsDnsList
|
|
||||||
_, ok := s.dnsMuxMap[nbdns.RootZone]
|
|
||||||
if ok {
|
|
||||||
log.Debugf("on new host DNS config but skip to apply it")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log.Debugf("update host DNS settings: %+v", hostsDnsList)
|
|
||||||
s.addHostRootZone()
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateDNSServer processes an update received from the management service
|
|
||||||
func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
|
|
||||||
select {
|
|
||||||
case <-s.ctx.Done():
|
|
||||||
log.Infof("not updating DNS server as context is closed")
|
|
||||||
return s.ctx.Err()
|
|
||||||
default:
|
|
||||||
if serial < s.updateSerial {
|
|
||||||
return fmt.Errorf("not applying dns update, error: "+
|
|
||||||
"network update is %d behind the last applied update", s.updateSerial-serial)
|
|
||||||
}
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
if s.hostManager == nil {
|
|
||||||
return fmt.Errorf("dns service is not initialized yet")
|
|
||||||
}
|
|
||||||
|
|
||||||
hash, err := hashstructure.Hash(update, hashstructure.FormatV2, &hashstructure.HashOptions{
|
|
||||||
ZeroNil: true,
|
|
||||||
IgnoreZeroValue: true,
|
|
||||||
SlicesAsSets: true,
|
|
||||||
UseStringer: true,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("unable to hash the dns configuration update, got error: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.previousConfigHash == hash {
|
|
||||||
log.Debugf("not applying the dns configuration update as there is nothing new")
|
|
||||||
s.updateSerial = serial
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := s.applyConfiguration(update); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
s.updateSerial = serial
|
|
||||||
s.previousConfigHash = hash
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
|
||||||
// is the service should be disabled, we stop the listener or fake resolver
|
|
||||||
// and proceed with a regular update to clean up the handlers and records
|
|
||||||
if update.ServiceEnable {
|
|
||||||
_ = s.service.Listen()
|
|
||||||
} else if !s.permanent {
|
|
||||||
s.service.Stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("not applying dns update, error: %v", err)
|
|
||||||
}
|
|
||||||
upstreamMuxUpdates, err := s.buildUpstreamHandlerUpdate(update.NameServerGroups)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("not applying dns update, error: %v", err)
|
|
||||||
}
|
|
||||||
muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...)
|
|
||||||
|
|
||||||
s.updateMux(muxUpdates)
|
|
||||||
s.updateLocalResolver(localRecords)
|
|
||||||
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
|
|
||||||
|
|
||||||
hostUpdate := s.currentConfig
|
|
||||||
if s.service.RuntimePort() != defaultPort && !s.hostManager.supportCustomPort() {
|
|
||||||
log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " +
|
|
||||||
"Learn more at: https://netbird.io/docs/how-to-guides/nameservers#local-resolver")
|
|
||||||
hostUpdate.routeAll = false
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = s.hostManager.applyDNSConfig(hostUpdate); err != nil {
|
|
||||||
log.Error(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]muxUpdate, map[string]nbdns.SimpleRecord, error) {
|
|
||||||
var muxUpdates []muxUpdate
|
|
||||||
localRecords := make(map[string]nbdns.SimpleRecord, 0)
|
|
||||||
|
|
||||||
for _, customZone := range customZones {
|
|
||||||
|
|
||||||
if len(customZone.Records) == 0 {
|
|
||||||
return nil, nil, fmt.Errorf("received an empty list of records")
|
|
||||||
}
|
|
||||||
|
|
||||||
muxUpdates = append(muxUpdates, muxUpdate{
|
|
||||||
domain: customZone.Domain,
|
|
||||||
handler: s.localResolver,
|
|
||||||
})
|
|
||||||
|
|
||||||
for _, record := range customZone.Records {
|
|
||||||
var class uint16 = dns.ClassINET
|
|
||||||
if record.Class != nbdns.DefaultClass {
|
|
||||||
return nil, nil, fmt.Errorf("received an invalid class type: %s", record.Class)
|
|
||||||
}
|
|
||||||
key := buildRecordKey(record.Name, class, uint16(record.Type))
|
|
||||||
localRecords[key] = record
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return muxUpdates, localRecords, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]muxUpdate, error) {
|
|
||||||
|
|
||||||
var muxUpdates []muxUpdate
|
|
||||||
for _, nsGroup := range nameServerGroups {
|
|
||||||
if len(nsGroup.NameServers) == 0 {
|
|
||||||
log.Warn("received a nameserver group with empty nameserver list")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
handler := newUpstreamResolver(s.ctx)
|
|
||||||
for _, ns := range nsGroup.NameServers {
|
|
||||||
if ns.NSType != nbdns.UDPNameServerType {
|
|
||||||
log.Warnf("skiping nameserver %s with type %s, this peer supports only %s",
|
|
||||||
ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String())
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
handler.upstreamServers = append(handler.upstreamServers, getNSHostPort(ns))
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(handler.upstreamServers) == 0 {
|
|
||||||
handler.stop()
|
|
||||||
log.Errorf("received a nameserver group with an invalid nameserver list")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// when upstream fails to resolve domain several times over all it servers
|
|
||||||
// it will calls this hook to exclude self from the configuration and
|
|
||||||
// reapply DNS settings, but it not touch the original configuration and serial number
|
|
||||||
// because it is temporal deactivation until next try
|
|
||||||
//
|
|
||||||
// after some period defined by upstream it trys to reactivate self by calling this hook
|
|
||||||
// everything we need here is just to re-apply current configuration because it already
|
|
||||||
// contains this upstream settings (temporal deactivation not removed it)
|
|
||||||
handler.deactivate, handler.reactivate = s.upstreamCallbacks(nsGroup, handler)
|
|
||||||
|
|
||||||
if nsGroup.Primary {
|
|
||||||
muxUpdates = append(muxUpdates, muxUpdate{
|
|
||||||
domain: nbdns.RootZone,
|
|
||||||
handler: handler,
|
|
||||||
})
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(nsGroup.Domains) == 0 {
|
|
||||||
handler.stop()
|
|
||||||
return nil, fmt.Errorf("received a non primary nameserver group with an empty domain list")
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, domain := range nsGroup.Domains {
|
|
||||||
if domain == "" {
|
|
||||||
handler.stop()
|
|
||||||
return nil, fmt.Errorf("received a nameserver group with an empty domain element")
|
|
||||||
}
|
|
||||||
muxUpdates = append(muxUpdates, muxUpdate{
|
|
||||||
domain: domain,
|
|
||||||
handler: handler,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return muxUpdates, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
|
|
||||||
muxUpdateMap := make(registeredHandlerMap)
|
|
||||||
|
|
||||||
var isContainRootUpdate bool
|
|
||||||
|
|
||||||
for _, update := range muxUpdates {
|
|
||||||
s.service.RegisterMux(update.domain, update.handler)
|
|
||||||
muxUpdateMap[update.domain] = update.handler
|
|
||||||
if existingHandler, ok := s.dnsMuxMap[update.domain]; ok {
|
|
||||||
existingHandler.stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
if update.domain == nbdns.RootZone {
|
|
||||||
isContainRootUpdate = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for key, existingHandler := range s.dnsMuxMap {
|
|
||||||
_, found := muxUpdateMap[key]
|
|
||||||
if !found {
|
|
||||||
if !isContainRootUpdate && key == nbdns.RootZone {
|
|
||||||
s.hostsDnsListLock.Lock()
|
|
||||||
s.addHostRootZone()
|
|
||||||
s.hostsDnsListLock.Unlock()
|
|
||||||
existingHandler.stop()
|
|
||||||
} else {
|
|
||||||
existingHandler.stop()
|
|
||||||
s.service.DeregisterMux(key)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
s.dnsMuxMap = muxUpdateMap
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord) {
|
|
||||||
for key := range s.localResolver.registeredMap {
|
|
||||||
_, found := update[key]
|
|
||||||
if !found {
|
|
||||||
s.localResolver.deleteRecord(key)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
updatedMap := make(registrationMap)
|
|
||||||
for key, record := range update {
|
|
||||||
err := s.localResolver.registerRecord(record)
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("got an error while registering the record (%s), error: %v", record.String(), err)
|
|
||||||
}
|
|
||||||
updatedMap[key] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
s.localResolver.registeredMap = updatedMap
|
|
||||||
}
|
|
||||||
|
|
||||||
func getNSHostPort(ns nbdns.NameServer) string {
|
|
||||||
return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port)
|
|
||||||
}
|
|
||||||
|
|
||||||
// upstreamCallbacks returns two functions, the first one is used to deactivate
|
|
||||||
// the upstream resolver from the configuration, the second one is used to
|
|
||||||
// reactivate it. Not allowed to call reactivate before deactivate.
|
|
||||||
func (s *DefaultServer) upstreamCallbacks(
|
|
||||||
nsGroup *nbdns.NameServerGroup,
|
|
||||||
handler dns.Handler,
|
|
||||||
) (deactivate func(), reactivate func()) {
|
|
||||||
var removeIndex map[string]int
|
|
||||||
deactivate = func() {
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
l := log.WithField("nameservers", nsGroup.NameServers)
|
|
||||||
l.Info("temporary deactivate nameservers group due timeout")
|
|
||||||
|
|
||||||
removeIndex = make(map[string]int)
|
|
||||||
for _, domain := range nsGroup.Domains {
|
|
||||||
removeIndex[domain] = -1
|
|
||||||
}
|
|
||||||
if nsGroup.Primary {
|
|
||||||
removeIndex[nbdns.RootZone] = -1
|
|
||||||
s.currentConfig.routeAll = false
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, item := range s.currentConfig.domains {
|
|
||||||
if _, found := removeIndex[item.domain]; found {
|
|
||||||
s.currentConfig.domains[i].disabled = true
|
|
||||||
s.service.DeregisterMux(item.domain)
|
|
||||||
removeIndex[item.domain] = i
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
|
|
||||||
l.WithError(err).Error("fail to apply nameserver deactivation on the host")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
reactivate = func() {
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
for domain, i := range removeIndex {
|
|
||||||
if i == -1 || i >= len(s.currentConfig.domains) || s.currentConfig.domains[i].domain != domain {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
s.currentConfig.domains[i].disabled = false
|
|
||||||
s.service.RegisterMux(domain, handler)
|
|
||||||
}
|
|
||||||
|
|
||||||
l := log.WithField("nameservers", nsGroup.NameServers)
|
|
||||||
l.Debug("reactivate temporary disabled nameserver group")
|
|
||||||
|
|
||||||
if nsGroup.Primary {
|
|
||||||
s.currentConfig.routeAll = true
|
|
||||||
}
|
|
||||||
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
|
|
||||||
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) addHostRootZone() {
|
|
||||||
handler := newUpstreamResolver(s.ctx)
|
|
||||||
handler.upstreamServers = make([]string, len(s.hostsDnsList))
|
|
||||||
for n, ua := range s.hostsDnsList {
|
|
||||||
handler.upstreamServers[n] = fmt.Sprintf("%s:53", ua)
|
|
||||||
}
|
|
||||||
handler.deactivate = func() {}
|
|
||||||
handler.reactivate = func() {}
|
|
||||||
s.service.RegisterMux(nbdns.RootZone, handler)
|
|
||||||
}
|
}
|
||||||
|
|||||||
32
client/internal/dns/server_android.go
Normal file
32
client/internal/dns/server_android.go
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DefaultServer dummy dns server
|
||||||
|
type DefaultServer struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDefaultServer On Android the DNS feature is not supported yet
|
||||||
|
func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAddress string) (*DefaultServer, error) {
|
||||||
|
return &DefaultServer{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start dummy implementation
|
||||||
|
func (s DefaultServer) Start() {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop dummy implementation
|
||||||
|
func (s DefaultServer) Stop() {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateDNSServer dummy implementation
|
||||||
|
func (s DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -1,29 +0,0 @@
|
|||||||
package dns
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
mutex sync.Mutex
|
|
||||||
server Server
|
|
||||||
)
|
|
||||||
|
|
||||||
// GetServerDns export the DNS server instance in static way. It used by the Mobile client
|
|
||||||
func GetServerDns() (Server, error) {
|
|
||||||
mutex.Lock()
|
|
||||||
if server == nil {
|
|
||||||
mutex.Unlock()
|
|
||||||
return nil, fmt.Errorf("DNS server not instantiated yet")
|
|
||||||
}
|
|
||||||
s := server
|
|
||||||
mutex.Unlock()
|
|
||||||
return s, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func setServerDns(newServerServer Server) {
|
|
||||||
mutex.Lock()
|
|
||||||
server = newServerServer
|
|
||||||
defer mutex.Unlock()
|
|
||||||
}
|
|
||||||
@@ -1,24 +0,0 @@
|
|||||||
package dns
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestGetServerDns(t *testing.T) {
|
|
||||||
_, err := GetServerDns()
|
|
||||||
if err == nil {
|
|
||||||
t.Errorf("invalid dns server instance")
|
|
||||||
}
|
|
||||||
|
|
||||||
srv := &MockServer{}
|
|
||||||
setServerDns(srv)
|
|
||||||
|
|
||||||
srvB, err := GetServerDns()
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("invalid dns server instance: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if srvB != srv {
|
|
||||||
t.Errorf("missmatch dns instances")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
465
client/internal/dns/server_nonandroid.go
Normal file
465
client/internal/dns/server_nonandroid.go
Normal file
@@ -0,0 +1,465 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"runtime"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"github.com/mitchellh/hashstructure/v2"
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultPort = 53
|
||||||
|
customPort = 5053
|
||||||
|
defaultIP = "127.0.0.1"
|
||||||
|
customIP = "127.0.0.153"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DefaultServer dns server object
|
||||||
|
type DefaultServer struct {
|
||||||
|
ctx context.Context
|
||||||
|
ctxCancel context.CancelFunc
|
||||||
|
upstreamCtxCancel context.CancelFunc
|
||||||
|
mux sync.Mutex
|
||||||
|
server *dns.Server
|
||||||
|
dnsMux *dns.ServeMux
|
||||||
|
dnsMuxMap registrationMap
|
||||||
|
localResolver *localResolver
|
||||||
|
wgInterface *iface.WGIface
|
||||||
|
hostManager hostManager
|
||||||
|
updateSerial uint64
|
||||||
|
listenerIsRunning bool
|
||||||
|
runtimePort int
|
||||||
|
runtimeIP string
|
||||||
|
previousConfigHash uint64
|
||||||
|
currentConfig hostDNSConfig
|
||||||
|
customAddress *netip.AddrPort
|
||||||
|
}
|
||||||
|
|
||||||
|
type muxUpdate struct {
|
||||||
|
domain string
|
||||||
|
handler dns.Handler
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDefaultServer returns a new dns server
|
||||||
|
func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAddress string) (*DefaultServer, error) {
|
||||||
|
mux := dns.NewServeMux()
|
||||||
|
|
||||||
|
dnsServer := &dns.Server{
|
||||||
|
Net: "udp",
|
||||||
|
Handler: mux,
|
||||||
|
UDPSize: 65535,
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, stop := context.WithCancel(ctx)
|
||||||
|
|
||||||
|
var addrPort *netip.AddrPort
|
||||||
|
if customAddress != "" {
|
||||||
|
parsedAddrPort, err := netip.ParseAddrPort(customAddress)
|
||||||
|
if err != nil {
|
||||||
|
stop()
|
||||||
|
return nil, fmt.Errorf("unable to parse the custom dns address, got error: %s", err)
|
||||||
|
}
|
||||||
|
addrPort = &parsedAddrPort
|
||||||
|
}
|
||||||
|
|
||||||
|
defaultServer := &DefaultServer{
|
||||||
|
ctx: ctx,
|
||||||
|
ctxCancel: stop,
|
||||||
|
server: dnsServer,
|
||||||
|
dnsMux: mux,
|
||||||
|
dnsMuxMap: make(registrationMap),
|
||||||
|
localResolver: &localResolver{
|
||||||
|
registeredMap: make(registrationMap),
|
||||||
|
},
|
||||||
|
wgInterface: wgInterface,
|
||||||
|
runtimePort: defaultPort,
|
||||||
|
customAddress: addrPort,
|
||||||
|
}
|
||||||
|
|
||||||
|
hostmanager, err := newHostManager(wgInterface)
|
||||||
|
if err != nil {
|
||||||
|
stop()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defaultServer.hostManager = hostmanager
|
||||||
|
return defaultServer, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start runs the listener in a go routine
|
||||||
|
func (s *DefaultServer) Start() {
|
||||||
|
if s.customAddress != nil {
|
||||||
|
s.runtimeIP = s.customAddress.Addr().String()
|
||||||
|
s.runtimePort = int(s.customAddress.Port())
|
||||||
|
} else {
|
||||||
|
ip, port, err := s.getFirstListenerAvailable()
|
||||||
|
if err != nil {
|
||||||
|
log.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.runtimeIP = ip
|
||||||
|
s.runtimePort = port
|
||||||
|
}
|
||||||
|
|
||||||
|
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
|
||||||
|
|
||||||
|
log.Debugf("starting dns on %s", s.server.Addr)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
s.setListenerStatus(true)
|
||||||
|
defer s.setListenerStatus(false)
|
||||||
|
|
||||||
|
err := s.server.ListenAndServe()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.runtimePort, err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) getFirstListenerAvailable() (string, int, error) {
|
||||||
|
ips := []string{defaultIP, customIP}
|
||||||
|
if runtime.GOOS != "darwin" && s.wgInterface != nil {
|
||||||
|
ips = append([]string{s.wgInterface.Address().IP.String()}, ips...)
|
||||||
|
}
|
||||||
|
ports := []int{defaultPort, customPort}
|
||||||
|
for _, port := range ports {
|
||||||
|
for _, ip := range ips {
|
||||||
|
addrString := fmt.Sprintf("%s:%d", ip, port)
|
||||||
|
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
|
||||||
|
probeListener, err := net.ListenUDP("udp", udpAddr)
|
||||||
|
if err == nil {
|
||||||
|
err = probeListener.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("got an error closing the probe listener, error: %s", err)
|
||||||
|
}
|
||||||
|
return ip, port, nil
|
||||||
|
}
|
||||||
|
log.Warnf("binding dns on %s is not available, error: %s", addrString, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", 0, fmt.Errorf("unable to find an unused ip and port combination. IPs tested: %v and ports %v", ips, ports)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) setListenerStatus(running bool) {
|
||||||
|
s.listenerIsRunning = running
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop stops the server
|
||||||
|
func (s *DefaultServer) Stop() {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
s.ctxCancel()
|
||||||
|
|
||||||
|
err := s.hostManager.restoreHostDNS()
|
||||||
|
if err != nil {
|
||||||
|
log.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.stopListener()
|
||||||
|
if err != nil {
|
||||||
|
log.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) stopListener() error {
|
||||||
|
if !s.listenerIsRunning {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
err := s.server.ShutdownContext(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("stopping dns server listener returned an error: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateDNSServer processes an update received from the management service
|
||||||
|
func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
|
||||||
|
select {
|
||||||
|
case <-s.ctx.Done():
|
||||||
|
log.Infof("not updating DNS server as context is closed")
|
||||||
|
return s.ctx.Err()
|
||||||
|
default:
|
||||||
|
if serial < s.updateSerial {
|
||||||
|
return fmt.Errorf("not applying dns update, error: "+
|
||||||
|
"network update is %d behind the last applied update", s.updateSerial-serial)
|
||||||
|
}
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
hash, err := hashstructure.Hash(update, hashstructure.FormatV2, &hashstructure.HashOptions{
|
||||||
|
ZeroNil: true,
|
||||||
|
IgnoreZeroValue: true,
|
||||||
|
SlicesAsSets: true,
|
||||||
|
UseStringer: true,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("unable to hash the dns configuration update, got error: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.previousConfigHash == hash {
|
||||||
|
log.Debugf("not applying the dns configuration update as there is nothing new")
|
||||||
|
s.updateSerial = serial
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.applyConfiguration(update); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
s.updateSerial = serial
|
||||||
|
s.previousConfigHash = hash
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||||
|
// is the service should be disabled, we stop the listener
|
||||||
|
// and proceed with a regular update to clean up the handlers and records
|
||||||
|
if !update.ServiceEnable {
|
||||||
|
err := s.stopListener()
|
||||||
|
if err != nil {
|
||||||
|
log.Error(err)
|
||||||
|
}
|
||||||
|
} else if !s.listenerIsRunning {
|
||||||
|
s.Start()
|
||||||
|
}
|
||||||
|
|
||||||
|
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("not applying dns update, error: %v", err)
|
||||||
|
}
|
||||||
|
upstreamMuxUpdates, err := s.buildUpstreamHandlerUpdate(update.NameServerGroups)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("not applying dns update, error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...)
|
||||||
|
|
||||||
|
s.updateMux(muxUpdates)
|
||||||
|
s.updateLocalResolver(localRecords)
|
||||||
|
s.currentConfig = dnsConfigToHostDNSConfig(update, s.runtimeIP, s.runtimePort)
|
||||||
|
|
||||||
|
if err = s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
|
||||||
|
log.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]muxUpdate, map[string]nbdns.SimpleRecord, error) {
|
||||||
|
var muxUpdates []muxUpdate
|
||||||
|
localRecords := make(map[string]nbdns.SimpleRecord, 0)
|
||||||
|
|
||||||
|
for _, customZone := range customZones {
|
||||||
|
|
||||||
|
if len(customZone.Records) == 0 {
|
||||||
|
return nil, nil, fmt.Errorf("received an empty list of records")
|
||||||
|
}
|
||||||
|
|
||||||
|
muxUpdates = append(muxUpdates, muxUpdate{
|
||||||
|
domain: customZone.Domain,
|
||||||
|
handler: s.localResolver,
|
||||||
|
})
|
||||||
|
|
||||||
|
for _, record := range customZone.Records {
|
||||||
|
var class uint16 = dns.ClassINET
|
||||||
|
if record.Class != nbdns.DefaultClass {
|
||||||
|
return nil, nil, fmt.Errorf("received an invalid class type: %s", record.Class)
|
||||||
|
}
|
||||||
|
key := buildRecordKey(record.Name, class, uint16(record.Type))
|
||||||
|
localRecords[key] = record
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return muxUpdates, localRecords, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]muxUpdate, error) {
|
||||||
|
// clean up the previous upstream resolver
|
||||||
|
if s.upstreamCtxCancel != nil {
|
||||||
|
s.upstreamCtxCancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
var muxUpdates []muxUpdate
|
||||||
|
for _, nsGroup := range nameServerGroups {
|
||||||
|
if len(nsGroup.NameServers) == 0 {
|
||||||
|
log.Warn("received a nameserver group with empty nameserver list")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var ctx context.Context
|
||||||
|
ctx, s.upstreamCtxCancel = context.WithCancel(s.ctx)
|
||||||
|
|
||||||
|
handler := newUpstreamResolver(ctx)
|
||||||
|
for _, ns := range nsGroup.NameServers {
|
||||||
|
if ns.NSType != nbdns.UDPNameServerType {
|
||||||
|
log.Warnf("skiping nameserver %s with type %s, this peer supports only %s",
|
||||||
|
ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
handler.upstreamServers = append(handler.upstreamServers, getNSHostPort(ns))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(handler.upstreamServers) == 0 {
|
||||||
|
log.Errorf("received a nameserver group with an invalid nameserver list")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// when upstream fails to resolve domain several times over all it servers
|
||||||
|
// it will calls this hook to exclude self from the configuration and
|
||||||
|
// reapply DNS settings, but it not touch the original configuration and serial number
|
||||||
|
// because it is temporal deactivation until next try
|
||||||
|
//
|
||||||
|
// after some period defined by upstream it trys to reactivate self by calling this hook
|
||||||
|
// everything we need here is just to re-apply current configuration because it already
|
||||||
|
// contains this upstream settings (temporal deactivation not removed it)
|
||||||
|
handler.deactivate, handler.reactivate = s.upstreamCallbacks(nsGroup, handler)
|
||||||
|
|
||||||
|
if nsGroup.Primary {
|
||||||
|
muxUpdates = append(muxUpdates, muxUpdate{
|
||||||
|
domain: nbdns.RootZone,
|
||||||
|
handler: handler,
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(nsGroup.Domains) == 0 {
|
||||||
|
return nil, fmt.Errorf("received a non primary nameserver group with an empty domain list")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, domain := range nsGroup.Domains {
|
||||||
|
if domain == "" {
|
||||||
|
return nil, fmt.Errorf("received a nameserver group with an empty domain element")
|
||||||
|
}
|
||||||
|
muxUpdates = append(muxUpdates, muxUpdate{
|
||||||
|
domain: domain,
|
||||||
|
handler: handler,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return muxUpdates, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
|
||||||
|
muxUpdateMap := make(registrationMap)
|
||||||
|
|
||||||
|
for _, update := range muxUpdates {
|
||||||
|
s.registerMux(update.domain, update.handler)
|
||||||
|
muxUpdateMap[update.domain] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
for key := range s.dnsMuxMap {
|
||||||
|
_, found := muxUpdateMap[key]
|
||||||
|
if !found {
|
||||||
|
s.deregisterMux(key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.dnsMuxMap = muxUpdateMap
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord) {
|
||||||
|
for key := range s.localResolver.registeredMap {
|
||||||
|
_, found := update[key]
|
||||||
|
if !found {
|
||||||
|
s.localResolver.deleteRecord(key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
updatedMap := make(registrationMap)
|
||||||
|
for key, record := range update {
|
||||||
|
err := s.localResolver.registerRecord(record)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("got an error while registering the record (%s), error: %v", record.String(), err)
|
||||||
|
}
|
||||||
|
updatedMap[key] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.localResolver.registeredMap = updatedMap
|
||||||
|
}
|
||||||
|
|
||||||
|
func getNSHostPort(ns nbdns.NameServer) string {
|
||||||
|
return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) registerMux(pattern string, handler dns.Handler) {
|
||||||
|
s.dnsMux.Handle(pattern, handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) deregisterMux(pattern string) {
|
||||||
|
s.dnsMux.HandleRemove(pattern)
|
||||||
|
}
|
||||||
|
|
||||||
|
// upstreamCallbacks returns two functions, the first one is used to deactivate
|
||||||
|
// the upstream resolver from the configuration, the second one is used to
|
||||||
|
// reactivate it. Not allowed to call reactivate before deactivate.
|
||||||
|
func (s *DefaultServer) upstreamCallbacks(
|
||||||
|
nsGroup *nbdns.NameServerGroup,
|
||||||
|
handler dns.Handler,
|
||||||
|
) (deactivate func(), reactivate func()) {
|
||||||
|
var removeIndex map[string]int
|
||||||
|
deactivate = func() {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
l := log.WithField("nameservers", nsGroup.NameServers)
|
||||||
|
l.Info("temporary deactivate nameservers group due timeout")
|
||||||
|
|
||||||
|
removeIndex = make(map[string]int)
|
||||||
|
for _, domain := range nsGroup.Domains {
|
||||||
|
removeIndex[domain] = -1
|
||||||
|
}
|
||||||
|
if nsGroup.Primary {
|
||||||
|
removeIndex[nbdns.RootZone] = -1
|
||||||
|
s.currentConfig.routeAll = false
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, item := range s.currentConfig.domains {
|
||||||
|
if _, found := removeIndex[item.domain]; found {
|
||||||
|
s.currentConfig.domains[i].disabled = true
|
||||||
|
s.deregisterMux(item.domain)
|
||||||
|
removeIndex[item.domain] = i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
|
||||||
|
l.WithError(err).Error("fail to apply nameserver deactivation on the host")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
reactivate = func() {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
for domain, i := range removeIndex {
|
||||||
|
if i == -1 || i >= len(s.currentConfig.domains) || s.currentConfig.domains[i].domain != domain {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
s.currentConfig.domains[i].disabled = false
|
||||||
|
s.registerMux(domain, handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
l := log.WithField("nameservers", nsGroup.NameServers)
|
||||||
|
l.Debug("reactivate temporary disabled nameserver group")
|
||||||
|
|
||||||
|
if nsGroup.Primary {
|
||||||
|
s.currentConfig.routeAll = true
|
||||||
|
}
|
||||||
|
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
|
||||||
|
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
@@ -5,59 +5,18 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/formatter"
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
pfmock "github.com/netbirdio/netbird/iface/mocks"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type mocWGIface struct {
|
|
||||||
filter iface.PacketFilter
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *mocWGIface) Name() string {
|
|
||||||
panic("implement me")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *mocWGIface) Address() iface.WGAddress {
|
|
||||||
ip, network, _ := net.ParseCIDR("100.66.100.0/24")
|
|
||||||
return iface.WGAddress{
|
|
||||||
IP: ip,
|
|
||||||
Network: network,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *mocWGIface) GetFilter() iface.PacketFilter {
|
|
||||||
return w.filter
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *mocWGIface) GetDevice() *iface.DeviceWrapper {
|
|
||||||
panic("implement me")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *mocWGIface) GetInterfaceGUIDString() (string, error) {
|
|
||||||
panic("implement me")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *mocWGIface) IsUserspaceBind() bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *mocWGIface) SetFilter(filter iface.PacketFilter) error {
|
|
||||||
w.filter = filter
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var zoneRecords = []nbdns.SimpleRecord{
|
var zoneRecords = []nbdns.SimpleRecord{
|
||||||
{
|
{
|
||||||
Name: "peera.netbird.cloud",
|
Name: "peera.netbird.cloud",
|
||||||
@@ -68,11 +27,6 @@ var zoneRecords = []nbdns.SimpleRecord{
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
|
||||||
log.SetLevel(log.TraceLevel)
|
|
||||||
formatter.SetTextFormatter(log.StandardLogger())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdateDNSServer(t *testing.T) {
|
func TestUpdateDNSServer(t *testing.T) {
|
||||||
nameServers := []nbdns.NameServer{
|
nameServers := []nbdns.NameServer{
|
||||||
{
|
{
|
||||||
@@ -87,23 +41,21 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
dummyHandler := &localResolver{}
|
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
initUpstreamMap registeredHandlerMap
|
initUpstreamMap registrationMap
|
||||||
initLocalMap registrationMap
|
initLocalMap registrationMap
|
||||||
initSerial uint64
|
initSerial uint64
|
||||||
inputSerial uint64
|
inputSerial uint64
|
||||||
inputUpdate nbdns.Config
|
inputUpdate nbdns.Config
|
||||||
shouldFail bool
|
shouldFail bool
|
||||||
expectedUpstreamMap registeredHandlerMap
|
expectedUpstreamMap registrationMap
|
||||||
expectedLocalMap registrationMap
|
expectedLocalMap registrationMap
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Initial Config Should Succeed",
|
name: "Initial Config Should Succeed",
|
||||||
initLocalMap: make(registrationMap),
|
initLocalMap: make(registrationMap),
|
||||||
initUpstreamMap: make(registeredHandlerMap),
|
initUpstreamMap: make(registrationMap),
|
||||||
initSerial: 0,
|
initSerial: 0,
|
||||||
inputSerial: 1,
|
inputSerial: 1,
|
||||||
inputUpdate: nbdns.Config{
|
inputUpdate: nbdns.Config{
|
||||||
@@ -125,13 +77,13 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectedUpstreamMap: registeredHandlerMap{"netbird.io": dummyHandler, "netbird.cloud": dummyHandler, nbdns.RootZone: dummyHandler},
|
expectedUpstreamMap: registrationMap{"netbird.io": struct{}{}, "netbird.cloud": struct{}{}, nbdns.RootZone: struct{}{}},
|
||||||
expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}},
|
expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "New Config Should Succeed",
|
name: "New Config Should Succeed",
|
||||||
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
|
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
|
||||||
initUpstreamMap: registeredHandlerMap{buildRecordKey(zoneRecords[0].Name, 1, 1): dummyHandler},
|
initUpstreamMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}},
|
||||||
initSerial: 0,
|
initSerial: 0,
|
||||||
inputSerial: 1,
|
inputSerial: 1,
|
||||||
inputUpdate: nbdns.Config{
|
inputUpdate: nbdns.Config{
|
||||||
@@ -149,13 +101,13 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectedUpstreamMap: registeredHandlerMap{"netbird.io": dummyHandler, "netbird.cloud": dummyHandler},
|
expectedUpstreamMap: registrationMap{"netbird.io": struct{}{}, "netbird.cloud": struct{}{}},
|
||||||
expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}},
|
expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Smaller Config Serial Should Be Skipped",
|
name: "Smaller Config Serial Should Be Skipped",
|
||||||
initLocalMap: make(registrationMap),
|
initLocalMap: make(registrationMap),
|
||||||
initUpstreamMap: make(registeredHandlerMap),
|
initUpstreamMap: make(registrationMap),
|
||||||
initSerial: 2,
|
initSerial: 2,
|
||||||
inputSerial: 1,
|
inputSerial: 1,
|
||||||
shouldFail: true,
|
shouldFail: true,
|
||||||
@@ -163,7 +115,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
|
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
|
||||||
initLocalMap: make(registrationMap),
|
initLocalMap: make(registrationMap),
|
||||||
initUpstreamMap: make(registeredHandlerMap),
|
initUpstreamMap: make(registrationMap),
|
||||||
initSerial: 0,
|
initSerial: 0,
|
||||||
inputSerial: 1,
|
inputSerial: 1,
|
||||||
inputUpdate: nbdns.Config{
|
inputUpdate: nbdns.Config{
|
||||||
@@ -185,7 +137,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Invalid NS Group Nameservers list Should Fail",
|
name: "Invalid NS Group Nameservers list Should Fail",
|
||||||
initLocalMap: make(registrationMap),
|
initLocalMap: make(registrationMap),
|
||||||
initUpstreamMap: make(registeredHandlerMap),
|
initUpstreamMap: make(registrationMap),
|
||||||
initSerial: 0,
|
initSerial: 0,
|
||||||
inputSerial: 1,
|
inputSerial: 1,
|
||||||
inputUpdate: nbdns.Config{
|
inputUpdate: nbdns.Config{
|
||||||
@@ -207,7 +159,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Invalid Custom Zone Records list Should Fail",
|
name: "Invalid Custom Zone Records list Should Fail",
|
||||||
initLocalMap: make(registrationMap),
|
initLocalMap: make(registrationMap),
|
||||||
initUpstreamMap: make(registeredHandlerMap),
|
initUpstreamMap: make(registrationMap),
|
||||||
initSerial: 0,
|
initSerial: 0,
|
||||||
inputSerial: 1,
|
inputSerial: 1,
|
||||||
inputUpdate: nbdns.Config{
|
inputUpdate: nbdns.Config{
|
||||||
@@ -229,21 +181,21 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Empty Config Should Succeed and Clean Maps",
|
name: "Empty Config Should Succeed and Clean Maps",
|
||||||
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
|
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
|
||||||
initUpstreamMap: registeredHandlerMap{zoneRecords[0].Name: dummyHandler},
|
initUpstreamMap: registrationMap{zoneRecords[0].Name: struct{}{}},
|
||||||
initSerial: 0,
|
initSerial: 0,
|
||||||
inputSerial: 1,
|
inputSerial: 1,
|
||||||
inputUpdate: nbdns.Config{ServiceEnable: true},
|
inputUpdate: nbdns.Config{ServiceEnable: true},
|
||||||
expectedUpstreamMap: make(registeredHandlerMap),
|
expectedUpstreamMap: make(registrationMap),
|
||||||
expectedLocalMap: make(registrationMap),
|
expectedLocalMap: make(registrationMap),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Disabled Service Should clean map",
|
name: "Disabled Service Should clean map",
|
||||||
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
|
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
|
||||||
initUpstreamMap: registeredHandlerMap{zoneRecords[0].Name: dummyHandler},
|
initUpstreamMap: registrationMap{zoneRecords[0].Name: struct{}{}},
|
||||||
initSerial: 0,
|
initSerial: 0,
|
||||||
inputSerial: 1,
|
inputSerial: 1,
|
||||||
inputUpdate: nbdns.Config{ServiceEnable: false},
|
inputUpdate: nbdns.Config{ServiceEnable: false},
|
||||||
expectedUpstreamMap: make(registeredHandlerMap),
|
expectedUpstreamMap: make(registrationMap),
|
||||||
expectedLocalMap: make(registrationMap),
|
expectedLocalMap: make(registrationMap),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -254,7 +206,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), iface.DefaultMTU, nil, newNet)
|
wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), iface.DefaultMTU, nil, nil, newNet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -272,10 +224,6 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
err = dnsServer.Initialize()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer func() {
|
defer func() {
|
||||||
err = dnsServer.hostManager.restoreHostDNS()
|
err = dnsServer.hostManager.restoreHostDNS()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -286,6 +234,8 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
dnsServer.dnsMuxMap = testCase.initUpstreamMap
|
dnsServer.dnsMuxMap = testCase.initUpstreamMap
|
||||||
dnsServer.localResolver.registeredMap = testCase.initLocalMap
|
dnsServer.localResolver.registeredMap = testCase.initLocalMap
|
||||||
dnsServer.updateSerial = testCase.initSerial
|
dnsServer.updateSerial = testCase.initSerial
|
||||||
|
// pretend we are running
|
||||||
|
dnsServer.listenerIsRunning = true
|
||||||
|
|
||||||
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
|
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -320,133 +270,6 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
|
||||||
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
|
|
||||||
defer os.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
|
||||||
|
|
||||||
_ = os.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
|
||||||
newNet, err := stdnet.NewNet(nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("create stdnet: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.1/32", iface.DefaultMTU, nil, newNet)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("build interface wireguard: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err = wgIface.Create()
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("crate and init wireguard interface: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if err = wgIface.Close(); err != nil {
|
|
||||||
t.Logf("close wireguard interface: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
ctrl := gomock.NewController(t)
|
|
||||||
defer ctrl.Finish()
|
|
||||||
|
|
||||||
_, ipNet, err := net.ParseCIDR("100.66.100.1/32")
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("parse CIDR: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
|
||||||
packetfilter.EXPECT().DropOutgoing(gomock.Any()).AnyTimes()
|
|
||||||
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
|
||||||
packetfilter.EXPECT().RemovePacketHook(gomock.Any())
|
|
||||||
packetfilter.EXPECT().SetNetwork(ipNet)
|
|
||||||
|
|
||||||
if err := wgIface.SetFilter(packetfilter); err != nil {
|
|
||||||
t.Errorf("set packet filter: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "")
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("create DNS server: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err = dnsServer.Initialize()
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("run DNS server: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if err = dnsServer.hostManager.restoreHostDNS(); err != nil {
|
|
||||||
t.Logf("restore DNS settings on the host: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
dnsServer.dnsMuxMap = registeredHandlerMap{zoneRecords[0].Name: &localResolver{}}
|
|
||||||
dnsServer.localResolver.registeredMap = registrationMap{"netbird.cloud": struct{}{}}
|
|
||||||
dnsServer.updateSerial = 0
|
|
||||||
|
|
||||||
nameServers := []nbdns.NameServer{
|
|
||||||
{
|
|
||||||
IP: netip.MustParseAddr("8.8.8.8"),
|
|
||||||
NSType: nbdns.UDPNameServerType,
|
|
||||||
Port: 53,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
IP: netip.MustParseAddr("8.8.4.4"),
|
|
||||||
NSType: nbdns.UDPNameServerType,
|
|
||||||
Port: 53,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
update := nbdns.Config{
|
|
||||||
ServiceEnable: true,
|
|
||||||
CustomZones: []nbdns.CustomZone{
|
|
||||||
{
|
|
||||||
Domain: "netbird.cloud",
|
|
||||||
Records: zoneRecords,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
NameServerGroups: []*nbdns.NameServerGroup{
|
|
||||||
{
|
|
||||||
Domains: []string{"netbird.io"},
|
|
||||||
NameServers: nameServers,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
NameServers: nameServers,
|
|
||||||
Primary: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start the server with regular configuration
|
|
||||||
if err := dnsServer.UpdateDNSServer(1, update); err != nil {
|
|
||||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
update2 := update
|
|
||||||
update2.ServiceEnable = false
|
|
||||||
// Disable the server, stop the listener
|
|
||||||
if err := dnsServer.UpdateDNSServer(2, update2); err != nil {
|
|
||||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
update3 := update2
|
|
||||||
update3.NameServerGroups = update3.NameServerGroups[:1]
|
|
||||||
// But service still get updates and we checking that we handle
|
|
||||||
// internal state in the right way
|
|
||||||
if err := dnsServer.UpdateDNSServer(3, update3); err != nil {
|
|
||||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDNSServerStartStop(t *testing.T) {
|
func TestDNSServerStartStop(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -463,23 +286,21 @@ func TestDNSServerStartStop(t *testing.T) {
|
|||||||
|
|
||||||
for _, testCase := range testCases {
|
for _, testCase := range testCases {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort)
|
dnsServer := getDefaultServerWithNoHostManager(t, testCase.addrPort)
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("%v", err)
|
|
||||||
}
|
|
||||||
dnsServer.hostManager = newNoopHostMocker()
|
dnsServer.hostManager = newNoopHostMocker()
|
||||||
err = dnsServer.service.Listen()
|
dnsServer.Start()
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("dns server is not running: %s", err)
|
|
||||||
}
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
if !dnsServer.listenerIsRunning {
|
||||||
|
t.Fatal("dns server listener is not running")
|
||||||
|
}
|
||||||
defer dnsServer.Stop()
|
defer dnsServer.Stop()
|
||||||
err = dnsServer.localResolver.registerRecord(zoneRecords[0])
|
err := dnsServer.localResolver.registerRecord(zoneRecords[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
dnsServer.service.RegisterMux("netbird.cloud", dnsServer.localResolver)
|
dnsServer.dnsMux.Handle("netbird.cloud", dnsServer.localResolver)
|
||||||
|
|
||||||
resolver := &net.Resolver{
|
resolver := &net.Resolver{
|
||||||
PreferGo: true,
|
PreferGo: true,
|
||||||
@@ -487,7 +308,7 @@ func TestDNSServerStartStop(t *testing.T) {
|
|||||||
d := net.Dialer{
|
d := net.Dialer{
|
||||||
Timeout: time.Second * 5,
|
Timeout: time.Second * 5,
|
||||||
}
|
}
|
||||||
addr := fmt.Sprintf("%s:%d", dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
|
addr := fmt.Sprintf("%s:%d", dnsServer.runtimeIP, dnsServer.runtimePort)
|
||||||
conn, err := d.DialContext(ctx, network, addr)
|
conn, err := d.DialContext(ctx, network, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Log(err)
|
t.Log(err)
|
||||||
@@ -522,7 +343,7 @@ func TestDNSServerStartStop(t *testing.T) {
|
|||||||
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
||||||
hostManager := &mockHostConfigurator{}
|
hostManager := &mockHostConfigurator{}
|
||||||
server := DefaultServer{
|
server := DefaultServer{
|
||||||
service: newServiceViaMemory(&mocWGIface{}),
|
dnsMux: dns.DefaultServeMux,
|
||||||
localResolver: &localResolver{
|
localResolver: &localResolver{
|
||||||
registeredMap: make(registrationMap),
|
registeredMap: make(registrationMap),
|
||||||
},
|
},
|
||||||
@@ -585,237 +406,35 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
|
func getDefaultServerWithNoHostManager(t *testing.T, addrPort string) *DefaultServer {
|
||||||
wgIFace, err := createWgInterfaceWithBind(t)
|
mux := dns.NewServeMux()
|
||||||
|
|
||||||
|
var parsedAddrPort *netip.AddrPort
|
||||||
|
if addrPort != "" {
|
||||||
|
parsed, err := netip.ParseAddrPort(addrPort)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("failed to initialize wg interface")
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
defer wgIFace.Close()
|
parsedAddrPort = &parsed
|
||||||
|
|
||||||
var dnsList []string
|
|
||||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList)
|
|
||||||
err = dnsServer.Initialize()
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to initialize DNS server: %v", err)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
defer dnsServer.Stop()
|
|
||||||
|
|
||||||
dnsServer.OnUpdatedHostDNSServer([]string{"8.8.8.8"})
|
dnsServer := &dns.Server{
|
||||||
|
Net: "udp",
|
||||||
|
Handler: mux,
|
||||||
|
UDPSize: 65535,
|
||||||
|
}
|
||||||
|
|
||||||
resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
|
ctx, cancel := context.WithCancel(context.TODO())
|
||||||
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
|
||||||
if err != nil {
|
return &DefaultServer{
|
||||||
t.Errorf("failed to resolve: %s", err)
|
ctx: ctx,
|
||||||
}
|
ctxCancel: cancel,
|
||||||
}
|
server: dnsServer,
|
||||||
|
dnsMux: mux,
|
||||||
func TestDNSPermanent_updateUpstream(t *testing.T) {
|
dnsMuxMap: make(registrationMap),
|
||||||
wgIFace, err := createWgInterfaceWithBind(t)
|
localResolver: &localResolver{
|
||||||
if err != nil {
|
registeredMap: make(registrationMap),
|
||||||
t.Fatal("failed to initialize wg interface")
|
},
|
||||||
}
|
customAddress: parsedAddrPort,
|
||||||
defer wgIFace.Close()
|
|
||||||
|
|
||||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"})
|
|
||||||
err = dnsServer.Initialize()
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to initialize DNS server: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer dnsServer.Stop()
|
|
||||||
|
|
||||||
// check initial state
|
|
||||||
resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
|
|
||||||
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to resolve: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
update := nbdns.Config{
|
|
||||||
ServiceEnable: true,
|
|
||||||
CustomZones: []nbdns.CustomZone{
|
|
||||||
{
|
|
||||||
Domain: "netbird.cloud",
|
|
||||||
Records: zoneRecords,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
NameServerGroups: []*nbdns.NameServerGroup{
|
|
||||||
{
|
|
||||||
NameServers: []nbdns.NameServer{
|
|
||||||
{
|
|
||||||
IP: netip.MustParseAddr("8.8.4.4"),
|
|
||||||
NSType: nbdns.UDPNameServerType,
|
|
||||||
Port: 53,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Enabled: true,
|
|
||||||
Primary: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
err = dnsServer.UpdateDNSServer(1, update)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to update dns server: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to resolve: %s", err)
|
|
||||||
}
|
|
||||||
ips, err := resolver.LookupHost(context.Background(), zoneRecords[0].Name)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed resolve zone record: %v", err)
|
|
||||||
}
|
|
||||||
if ips[0] != zoneRecords[0].RData {
|
|
||||||
t.Fatalf("invalid zone record: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
update2 := nbdns.Config{
|
|
||||||
ServiceEnable: true,
|
|
||||||
CustomZones: []nbdns.CustomZone{
|
|
||||||
{
|
|
||||||
Domain: "netbird.cloud",
|
|
||||||
Records: zoneRecords,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
NameServerGroups: []*nbdns.NameServerGroup{},
|
|
||||||
}
|
|
||||||
|
|
||||||
err = dnsServer.UpdateDNSServer(2, update2)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to update dns server: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to resolve: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ips, err = resolver.LookupHost(context.Background(), zoneRecords[0].Name)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed resolve zone record: %v", err)
|
|
||||||
}
|
|
||||||
if ips[0] != zoneRecords[0].RData {
|
|
||||||
t.Fatalf("invalid zone record: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDNSPermanent_matchOnly(t *testing.T) {
|
|
||||||
wgIFace, err := createWgInterfaceWithBind(t)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal("failed to initialize wg interface")
|
|
||||||
}
|
|
||||||
defer wgIFace.Close()
|
|
||||||
|
|
||||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"})
|
|
||||||
err = dnsServer.Initialize()
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to initialize DNS server: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer dnsServer.Stop()
|
|
||||||
|
|
||||||
// check initial state
|
|
||||||
resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
|
|
||||||
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to resolve: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
update := nbdns.Config{
|
|
||||||
ServiceEnable: true,
|
|
||||||
CustomZones: []nbdns.CustomZone{
|
|
||||||
{
|
|
||||||
Domain: "netbird.cloud",
|
|
||||||
Records: zoneRecords,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
NameServerGroups: []*nbdns.NameServerGroup{
|
|
||||||
{
|
|
||||||
NameServers: []nbdns.NameServer{
|
|
||||||
{
|
|
||||||
IP: netip.MustParseAddr("8.8.4.4"),
|
|
||||||
NSType: nbdns.UDPNameServerType,
|
|
||||||
Port: 53,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Domains: []string{"customdomain.com"},
|
|
||||||
Primary: false,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
err = dnsServer.UpdateDNSServer(1, update)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to update dns server: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to resolve: %s", err)
|
|
||||||
}
|
|
||||||
ips, err := resolver.LookupHost(context.Background(), zoneRecords[0].Name)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed resolve zone record: %v", err)
|
|
||||||
}
|
|
||||||
if ips[0] != zoneRecords[0].RData {
|
|
||||||
t.Fatalf("invalid zone record: %v", err)
|
|
||||||
}
|
|
||||||
_, err = resolver.LookupHost(context.Background(), "customdomain.com")
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to resolve: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
|
|
||||||
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
|
|
||||||
defer os.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
|
||||||
|
|
||||||
_ = os.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
|
||||||
newNet, err := stdnet.NewNet(nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("create stdnet: %v", err)
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", iface.DefaultMTU, nil, newNet)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("build interface wireguard: %v", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = wgIface.Create()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("crate and init wireguard interface: %v", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
pf, err := uspfilter.Create(wgIface)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to create uspfilter: %v", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = wgIface.SetFilter(pf)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("set packet filter: %v", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return wgIface, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func newDnsResolver(ip string, port int) *net.Resolver {
|
|
||||||
return &net.Resolver{
|
|
||||||
PreferGo: true,
|
|
||||||
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
|
||||||
d := net.Dialer{
|
|
||||||
Timeout: time.Second * 3,
|
|
||||||
}
|
|
||||||
addr := fmt.Sprintf("%s:%d", ip, port)
|
|
||||||
return d.DialContext(ctx, network, addr)
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,18 +0,0 @@
|
|||||||
package dns
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/miekg/dns"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
defaultPort = 53
|
|
||||||
)
|
|
||||||
|
|
||||||
type service interface {
|
|
||||||
Listen() error
|
|
||||||
Stop()
|
|
||||||
RegisterMux(domain string, handler dns.Handler)
|
|
||||||
DeregisterMux(key string)
|
|
||||||
RuntimePort() int
|
|
||||||
RuntimeIP() string
|
|
||||||
}
|
|
||||||
@@ -1,145 +0,0 @@
|
|||||||
package dns
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"runtime"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
customPort = 5053
|
|
||||||
defaultIP = "127.0.0.1"
|
|
||||||
customIP = "127.0.0.153"
|
|
||||||
)
|
|
||||||
|
|
||||||
type serviceViaListener struct {
|
|
||||||
wgInterface WGIface
|
|
||||||
dnsMux *dns.ServeMux
|
|
||||||
customAddr *netip.AddrPort
|
|
||||||
server *dns.Server
|
|
||||||
runtimeIP string
|
|
||||||
runtimePort int
|
|
||||||
listenerIsRunning bool
|
|
||||||
listenerFlagLock sync.Mutex
|
|
||||||
}
|
|
||||||
|
|
||||||
func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort) *serviceViaListener {
|
|
||||||
mux := dns.NewServeMux()
|
|
||||||
|
|
||||||
s := &serviceViaListener{
|
|
||||||
wgInterface: wgIface,
|
|
||||||
dnsMux: mux,
|
|
||||||
customAddr: customAddr,
|
|
||||||
server: &dns.Server{
|
|
||||||
Net: "udp",
|
|
||||||
Handler: mux,
|
|
||||||
UDPSize: 65535,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *serviceViaListener) Listen() error {
|
|
||||||
s.listenerFlagLock.Lock()
|
|
||||||
defer s.listenerFlagLock.Unlock()
|
|
||||||
|
|
||||||
if s.listenerIsRunning {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var err error
|
|
||||||
s.runtimeIP, s.runtimePort, err = s.evalRuntimeAddress()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to eval runtime address: %s", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
|
|
||||||
|
|
||||||
log.Debugf("starting dns on %s", s.server.Addr)
|
|
||||||
go func() {
|
|
||||||
s.setListenerStatus(true)
|
|
||||||
defer s.setListenerStatus(false)
|
|
||||||
|
|
||||||
err := s.server.ListenAndServe()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.runtimePort, err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *serviceViaListener) Stop() {
|
|
||||||
s.listenerFlagLock.Lock()
|
|
||||||
defer s.listenerFlagLock.Unlock()
|
|
||||||
|
|
||||||
if !s.listenerIsRunning {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
err := s.server.ShutdownContext(ctx)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("stopping dns server listener returned an error: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *serviceViaListener) RegisterMux(pattern string, handler dns.Handler) {
|
|
||||||
s.dnsMux.Handle(pattern, handler)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *serviceViaListener) DeregisterMux(pattern string) {
|
|
||||||
s.dnsMux.HandleRemove(pattern)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *serviceViaListener) RuntimePort() int {
|
|
||||||
return s.runtimePort
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *serviceViaListener) RuntimeIP() string {
|
|
||||||
return s.runtimeIP
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *serviceViaListener) setListenerStatus(running bool) {
|
|
||||||
s.listenerIsRunning = running
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *serviceViaListener) getFirstListenerAvailable() (string, int, error) {
|
|
||||||
ips := []string{defaultIP, customIP}
|
|
||||||
if runtime.GOOS != "darwin" {
|
|
||||||
ips = append([]string{s.wgInterface.Address().IP.String()}, ips...)
|
|
||||||
}
|
|
||||||
ports := []int{defaultPort, customPort}
|
|
||||||
for _, port := range ports {
|
|
||||||
for _, ip := range ips {
|
|
||||||
addrString := fmt.Sprintf("%s:%d", ip, port)
|
|
||||||
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
|
|
||||||
probeListener, err := net.ListenUDP("udp", udpAddr)
|
|
||||||
if err == nil {
|
|
||||||
err = probeListener.Close()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("got an error closing the probe listener, error: %s", err)
|
|
||||||
}
|
|
||||||
return ip, port, nil
|
|
||||||
}
|
|
||||||
log.Warnf("binding dns on %s is not available, error: %s", addrString, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return "", 0, fmt.Errorf("unable to find an unused ip and port combination. IPs tested: %v and ports %v", ips, ports)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *serviceViaListener) evalRuntimeAddress() (string, int, error) {
|
|
||||||
if s.customAddr != nil {
|
|
||||||
return s.customAddr.Addr().String(), int(s.customAddr.Port()), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return s.getFirstListenerAvailable()
|
|
||||||
}
|
|
||||||
@@ -1,139 +0,0 @@
|
|||||||
package dns
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"math/big"
|
|
||||||
"net"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
|
||||||
"github.com/google/gopacket/layers"
|
|
||||||
"github.com/miekg/dns"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
type serviceViaMemory struct {
|
|
||||||
wgInterface WGIface
|
|
||||||
dnsMux *dns.ServeMux
|
|
||||||
runtimeIP string
|
|
||||||
runtimePort int
|
|
||||||
udpFilterHookID string
|
|
||||||
listenerIsRunning bool
|
|
||||||
listenerFlagLock sync.Mutex
|
|
||||||
}
|
|
||||||
|
|
||||||
func newServiceViaMemory(wgIface WGIface) *serviceViaMemory {
|
|
||||||
s := &serviceViaMemory{
|
|
||||||
wgInterface: wgIface,
|
|
||||||
dnsMux: dns.NewServeMux(),
|
|
||||||
|
|
||||||
runtimeIP: getLastIPFromNetwork(wgIface.Address().Network, 1),
|
|
||||||
runtimePort: defaultPort,
|
|
||||||
}
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *serviceViaMemory) Listen() error {
|
|
||||||
s.listenerFlagLock.Lock()
|
|
||||||
defer s.listenerFlagLock.Unlock()
|
|
||||||
|
|
||||||
if s.listenerIsRunning {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var err error
|
|
||||||
s.udpFilterHookID, err = s.filterDNSTraffic()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
s.listenerIsRunning = true
|
|
||||||
|
|
||||||
log.Debugf("dns service listening on: %s", s.RuntimeIP())
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *serviceViaMemory) Stop() {
|
|
||||||
s.listenerFlagLock.Lock()
|
|
||||||
defer s.listenerFlagLock.Unlock()
|
|
||||||
|
|
||||||
if !s.listenerIsRunning {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := s.wgInterface.GetFilter().RemovePacketHook(s.udpFilterHookID); err != nil {
|
|
||||||
log.Errorf("unable to remove DNS packet hook: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
s.listenerIsRunning = false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *serviceViaMemory) RegisterMux(pattern string, handler dns.Handler) {
|
|
||||||
s.dnsMux.Handle(pattern, handler)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *serviceViaMemory) DeregisterMux(pattern string) {
|
|
||||||
s.dnsMux.HandleRemove(pattern)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *serviceViaMemory) RuntimePort() int {
|
|
||||||
return s.runtimePort
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *serviceViaMemory) RuntimeIP() string {
|
|
||||||
return s.runtimeIP
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *serviceViaMemory) filterDNSTraffic() (string, error) {
|
|
||||||
filter := s.wgInterface.GetFilter()
|
|
||||||
if filter == nil {
|
|
||||||
return "", fmt.Errorf("can't set DNS filter, filter not initialized")
|
|
||||||
}
|
|
||||||
|
|
||||||
firstLayerDecoder := layers.LayerTypeIPv4
|
|
||||||
if s.wgInterface.Address().Network.IP.To4() == nil {
|
|
||||||
firstLayerDecoder = layers.LayerTypeIPv6
|
|
||||||
}
|
|
||||||
|
|
||||||
hook := func(packetData []byte) bool {
|
|
||||||
// Decode the packet
|
|
||||||
packet := gopacket.NewPacket(packetData, firstLayerDecoder, gopacket.Default)
|
|
||||||
|
|
||||||
// Get the UDP layer
|
|
||||||
udpLayer := packet.Layer(layers.LayerTypeUDP)
|
|
||||||
udp := udpLayer.(*layers.UDP)
|
|
||||||
|
|
||||||
msg := new(dns.Msg)
|
|
||||||
if err := msg.Unpack(udp.Payload); err != nil {
|
|
||||||
log.Tracef("parse DNS request: %v", err)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
writer := responseWriter{
|
|
||||||
packet: packet,
|
|
||||||
device: s.wgInterface.GetDevice().Device,
|
|
||||||
}
|
|
||||||
go s.dnsMux.ServeDNS(&writer, msg)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getLastIPFromNetwork(network *net.IPNet, fromEnd int) string {
|
|
||||||
// Calculate the last IP in the CIDR range
|
|
||||||
var endIP net.IP
|
|
||||||
for i := 0; i < len(network.IP); i++ {
|
|
||||||
endIP = append(endIP, network.IP[i]|^network.Mask[i])
|
|
||||||
}
|
|
||||||
|
|
||||||
// convert to big.Int
|
|
||||||
endInt := big.NewInt(0)
|
|
||||||
endInt.SetBytes(endIP)
|
|
||||||
|
|
||||||
// subtract fromEnd from the last ip
|
|
||||||
fromEndBig := big.NewInt(int64(fromEnd))
|
|
||||||
resultInt := big.NewInt(0)
|
|
||||||
resultInt.Sub(endInt, fromEndBig)
|
|
||||||
|
|
||||||
return net.IP(resultInt.Bytes()).String()
|
|
||||||
}
|
|
||||||
@@ -1,31 +0,0 @@
|
|||||||
package dns
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestGetLastIPFromNetwork(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
addr string
|
|
||||||
ip string
|
|
||||||
}{
|
|
||||||
{"2001:db8::/32", "2001:db8:ffff:ffff:ffff:ffff:ffff:fffe"},
|
|
||||||
{"192.168.0.0/30", "192.168.0.2"},
|
|
||||||
{"192.168.0.0/16", "192.168.255.254"},
|
|
||||||
{"192.168.0.0/24", "192.168.0.254"},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
_, ipnet, err := net.ParseCIDR(tt.addr)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Error parsing CIDR: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
lastIP := getLastIPFromNetwork(ipnet, 1)
|
|
||||||
if lastIP != tt.ip {
|
|
||||||
t.Errorf("wrong IP address, expected %s: got %s", tt.ip, lastIP)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,5 +1,3 @@
|
|||||||
//go:build !android
|
|
||||||
|
|
||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -11,10 +9,10 @@ import (
|
|||||||
|
|
||||||
"github.com/godbus/dbus/v5"
|
"github.com/godbus/dbus/v5"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -52,7 +50,7 @@ type systemdDbusLinkDomainsInput struct {
|
|||||||
MatchOnly bool
|
MatchOnly bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSystemdDbusConfigurator(wgInterface WGIface) (hostManager, error) {
|
func newSystemdDbusConfigurator(wgInterface *iface.WGIface) (hostManager, error) {
|
||||||
iface, err := net.InterfaceByName(wgInterface.Name())
|
iface, err := net.InterfaceByName(wgInterface.Name())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -77,10 +75,6 @@ func newSystemdDbusConfigurator(wgInterface WGIface) (hostManager, error) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *systemdDbusConfigurator) supportCustomPort() bool {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *systemdDbusConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
func (s *systemdDbusConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
||||||
parsedIP, err := netip.ParseAddr(config.serverIP)
|
parsedIP, err := netip.ParseAddr(config.serverIP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -3,31 +3,24 @@ package dns
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/cenkalti/backoff/v4"
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
failsTillDeact = int32(5)
|
failsTillDeact = int32(3)
|
||||||
reactivatePeriod = 30 * time.Second
|
reactivatePeriod = time.Minute
|
||||||
upstreamTimeout = 15 * time.Second
|
upstreamTimeout = 15 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
type upstreamClient interface {
|
|
||||||
ExchangeContext(ctx context.Context, m *dns.Msg, a string) (r *dns.Msg, rtt time.Duration, err error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type upstreamResolver struct {
|
type upstreamResolver struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
upstreamClient *dns.Client
|
||||||
upstreamClient upstreamClient
|
|
||||||
upstreamServers []string
|
upstreamServers []string
|
||||||
disabled bool
|
disabled bool
|
||||||
failsCount atomic.Int32
|
failsCount atomic.Int32
|
||||||
@@ -40,11 +33,9 @@ type upstreamResolver struct {
|
|||||||
reactivate func()
|
reactivate func()
|
||||||
}
|
}
|
||||||
|
|
||||||
func newUpstreamResolver(parentCTX context.Context) *upstreamResolver {
|
func newUpstreamResolver(ctx context.Context) *upstreamResolver {
|
||||||
ctx, cancel := context.WithCancel(parentCTX)
|
|
||||||
return &upstreamResolver{
|
return &upstreamResolver{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
cancel: cancel,
|
|
||||||
upstreamClient: &dns.Client{},
|
upstreamClient: &dns.Client{},
|
||||||
upstreamTimeout: upstreamTimeout,
|
upstreamTimeout: upstreamTimeout,
|
||||||
reactivatePeriod: reactivatePeriod,
|
reactivatePeriod: reactivatePeriod,
|
||||||
@@ -52,11 +43,6 @@ func newUpstreamResolver(parentCTX context.Context) *upstreamResolver {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *upstreamResolver) stop() {
|
|
||||||
log.Debugf("stoping serving DNS for upstreams %s", u.upstreamServers)
|
|
||||||
u.cancel()
|
|
||||||
}
|
|
||||||
|
|
||||||
// ServeDNS handles a DNS request
|
// ServeDNS handles a DNS request
|
||||||
func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
defer u.checkUpstreamFails()
|
defer u.checkUpstreamFails()
|
||||||
@@ -121,57 +107,28 @@ func (u *upstreamResolver) checkUpstreamFails() {
|
|||||||
log.Warnf("upstream resolving is disabled for %v", reactivatePeriod)
|
log.Warnf("upstream resolving is disabled for %v", reactivatePeriod)
|
||||||
u.deactivate()
|
u.deactivate()
|
||||||
u.disabled = true
|
u.disabled = true
|
||||||
go u.waitUntilResponse()
|
go u.waitUntilReactivation()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// waitUntilResponse retries, in an exponential interval, querying the upstream servers until it gets a positive response
|
// waitUntilReactivation reset fails counter and activates upstream resolving
|
||||||
func (u *upstreamResolver) waitUntilResponse() {
|
func (u *upstreamResolver) waitUntilReactivation() {
|
||||||
exponentialBackOff := &backoff.ExponentialBackOff{
|
timer := time.NewTimer(u.reactivatePeriod)
|
||||||
InitialInterval: 500 * time.Millisecond,
|
defer func() {
|
||||||
RandomizationFactor: 0.5,
|
if !timer.Stop() {
|
||||||
Multiplier: 1.1,
|
<-timer.C
|
||||||
MaxInterval: u.reactivatePeriod,
|
|
||||||
MaxElapsedTime: 0,
|
|
||||||
Stop: backoff.Stop,
|
|
||||||
Clock: backoff.SystemClock,
|
|
||||||
}
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
r := new(dns.Msg).SetQuestion("netbird.io.", dns.TypeA)
|
|
||||||
|
|
||||||
operation := func() error {
|
|
||||||
select {
|
select {
|
||||||
case <-u.ctx.Done():
|
case <-u.ctx.Done():
|
||||||
return backoff.Permanent(fmt.Errorf("exiting upstream retry loop for upstreams %s: parent context has been canceled", u.upstreamServers))
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
var err error
|
|
||||||
for _, upstream := range u.upstreamServers {
|
|
||||||
ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout)
|
|
||||||
_, _, err = u.upstreamClient.ExchangeContext(ctx, r, upstream)
|
|
||||||
|
|
||||||
cancel()
|
|
||||||
|
|
||||||
if err == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Tracef("checking connectivity with upstreams %s failed with error: %s. Retrying in %s", err, u.upstreamServers, exponentialBackOff.NextBackOff())
|
|
||||||
return fmt.Errorf("got an error from upstream check call")
|
|
||||||
}
|
|
||||||
|
|
||||||
err := backoff.Retry(operation, exponentialBackOff)
|
|
||||||
if err != nil {
|
|
||||||
log.Warn(err)
|
|
||||||
return
|
return
|
||||||
}
|
case <-timer.C:
|
||||||
|
log.Info("upstream resolving is reactivated")
|
||||||
log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServers)
|
|
||||||
u.failsCount.Store(0)
|
u.failsCount.Store(0)
|
||||||
u.reactivate()
|
u.reactivate()
|
||||||
u.disabled = false
|
u.disabled = false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// isTimeout returns true if the given error is a network timeout error.
|
// isTimeout returns true if the given error is a network timeout error.
|
||||||
|
|||||||
@@ -2,11 +2,10 @@ package dns
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"github.com/miekg/dns"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
||||||
@@ -107,29 +106,8 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockUpstreamResolver struct {
|
|
||||||
r *dns.Msg
|
|
||||||
rtt time.Duration
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
|
|
||||||
// ExchangeContext mock implementation of ExchangeContext from upstreamResolver
|
|
||||||
func (c mockUpstreamResolver) ExchangeContext(_ context.Context, _ *dns.Msg, _ string) (r *dns.Msg, rtt time.Duration, err error) {
|
|
||||||
return c.r, c.rtt, c.err
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
|
func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
|
||||||
resolver := &upstreamResolver{
|
resolver := newUpstreamResolver(context.TODO())
|
||||||
ctx: context.TODO(),
|
|
||||||
upstreamClient: &mockUpstreamResolver{
|
|
||||||
err: nil,
|
|
||||||
r: new(dns.Msg),
|
|
||||||
rtt: time.Millisecond,
|
|
||||||
},
|
|
||||||
upstreamTimeout: upstreamTimeout,
|
|
||||||
reactivatePeriod: reactivatePeriod,
|
|
||||||
failsTillDeact: failsTillDeact,
|
|
||||||
}
|
|
||||||
resolver.upstreamServers = []string{"0.0.0.0:-1"}
|
resolver.upstreamServers = []string{"0.0.0.0:-1"}
|
||||||
resolver.failsTillDeact = 0
|
resolver.failsTillDeact = 0
|
||||||
resolver.reactivatePeriod = time.Microsecond * 100
|
resolver.reactivatePeriod = time.Microsecond * 100
|
||||||
|
|||||||
@@ -1,14 +0,0 @@
|
|||||||
//go:build !windows
|
|
||||||
|
|
||||||
package dns
|
|
||||||
|
|
||||||
import "github.com/netbirdio/netbird/iface"
|
|
||||||
|
|
||||||
// WGIface defines subset methods of interface required for manager
|
|
||||||
type WGIface interface {
|
|
||||||
Name() string
|
|
||||||
Address() iface.WGAddress
|
|
||||||
IsUserspaceBind() bool
|
|
||||||
GetFilter() iface.PacketFilter
|
|
||||||
GetDevice() *iface.DeviceWrapper
|
|
||||||
}
|
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
package dns
|
|
||||||
|
|
||||||
import "github.com/netbirdio/netbird/iface"
|
|
||||||
|
|
||||||
// WGIface defines subset methods of interface required for manager
|
|
||||||
type WGIface interface {
|
|
||||||
Name() string
|
|
||||||
Address() iface.WGAddress
|
|
||||||
IsUserspaceBind() bool
|
|
||||||
GetFilter() iface.PacketFilter
|
|
||||||
GetDevice() *iface.DeviceWrapper
|
|
||||||
GetInterfaceGUIDString() (string, error)
|
|
||||||
}
|
|
||||||
@@ -17,11 +17,9 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/acl"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/wgproxy"
|
|
||||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
@@ -102,7 +100,6 @@ type Engine struct {
|
|||||||
ctx context.Context
|
ctx context.Context
|
||||||
|
|
||||||
wgInterface *iface.WGIface
|
wgInterface *iface.WGIface
|
||||||
wgProxyFactory *wgproxy.Factory
|
|
||||||
|
|
||||||
udpMux *bind.UniversalUDPMuxDefault
|
udpMux *bind.UniversalUDPMuxDefault
|
||||||
udpMuxConn io.Closer
|
udpMuxConn io.Closer
|
||||||
@@ -116,7 +113,6 @@ type Engine struct {
|
|||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
|
|
||||||
routeManager routemanager.Manager
|
routeManager routemanager.Manager
|
||||||
acl acl.Manager
|
|
||||||
|
|
||||||
dnsServer dns.Server
|
dnsServer dns.Server
|
||||||
}
|
}
|
||||||
@@ -133,7 +129,6 @@ func NewEngine(
|
|||||||
signalClient signal.Client, mgmClient mgm.Client,
|
signalClient signal.Client, mgmClient mgm.Client,
|
||||||
config *EngineConfig, mobileDep MobileDependency, statusRecorder *peer.Status,
|
config *EngineConfig, mobileDep MobileDependency, statusRecorder *peer.Status,
|
||||||
) *Engine {
|
) *Engine {
|
||||||
|
|
||||||
return &Engine{
|
return &Engine{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
@@ -148,7 +143,6 @@ func NewEngine(
|
|||||||
networkSerial: 0,
|
networkSerial: 0,
|
||||||
sshServerFunc: nbssh.DefaultSSHServer,
|
sshServerFunc: nbssh.DefaultSSHServer,
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
wgProxyFactory: wgproxy.NewFactory(config.WgPort),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -185,46 +179,13 @@ func (e *Engine) Start() error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to create pion's stdnet: %s", err)
|
log.Errorf("failed to create pion's stdnet: %s", err)
|
||||||
}
|
}
|
||||||
|
e.wgInterface, err = iface.NewWGIFace(wgIFaceName, wgAddr, iface.DefaultMTU, e.mobileDep.Routes, e.mobileDep.TunAdapter, transportNet)
|
||||||
e.wgInterface, err = iface.NewWGIFace(wgIFaceName, wgAddr, iface.DefaultMTU, e.mobileDep.TunAdapter, transportNet)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed creating wireguard interface instance %s: [%s]", wgIFaceName, err.Error())
|
log.Errorf("failed creating wireguard interface instance %s: [%s]", wgIFaceName, err.Error())
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
var routes []*route.Route
|
|
||||||
|
|
||||||
if runtime.GOOS == "android" {
|
|
||||||
routes, err = e.readInitialSettings()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if e.dnsServer == nil {
|
|
||||||
e.dnsServer = dns.NewDefaultServerPermanentUpstream(e.ctx, e.wgInterface, e.mobileDep.HostDNSAddresses)
|
|
||||||
go e.mobileDep.DnsReadyListener.OnReady()
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// todo fix custom address
|
|
||||||
if e.dnsServer == nil {
|
|
||||||
e.dnsServer, err = dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress)
|
|
||||||
if err != nil {
|
|
||||||
e.close()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, routes)
|
|
||||||
e.routeManager.SetRouteChangeListener(e.mobileDep.RouteListener)
|
|
||||||
|
|
||||||
if runtime.GOOS != "android" {
|
|
||||||
err = e.wgInterface.Create()
|
err = e.wgInterface.Create()
|
||||||
} else {
|
|
||||||
err = e.wgInterface.CreateOnMobile(iface.MobileIFaceArguments{
|
|
||||||
Routes: e.routeManager.InitialRouteRange(),
|
|
||||||
Dns: e.dnsServer.DnsIP(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed creating tunnel interface %s: [%s]", wgIFaceName, err.Error())
|
log.Errorf("failed creating tunnel interface %s: [%s]", wgIFaceName, err.Error())
|
||||||
e.close()
|
e.close()
|
||||||
@@ -258,17 +219,17 @@ func (e *Engine) Start() error {
|
|||||||
e.udpMux = mux
|
e.udpMux = mux
|
||||||
}
|
}
|
||||||
|
|
||||||
if acl, err := acl.Create(e.wgInterface); err != nil {
|
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder)
|
||||||
log.Errorf("failed to create ACL manager, policy will not work: %s", err.Error())
|
|
||||||
} else {
|
|
||||||
e.acl = acl
|
|
||||||
}
|
|
||||||
|
|
||||||
err = e.dnsServer.Initialize()
|
if e.dnsServer == nil {
|
||||||
|
// todo fix custom address
|
||||||
|
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
e.close()
|
e.close()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
e.dnsServer = dnsServer
|
||||||
|
}
|
||||||
|
|
||||||
e.receiveSignalEvents()
|
e.receiveSignalEvents()
|
||||||
e.receiveManagementEvents()
|
e.receiveManagementEvents()
|
||||||
@@ -610,7 +571,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
// cleanup request, most likely our peer has been deleted
|
// cleanup request, most likely our peer has been deleted
|
||||||
if networkMap.GetRemotePeersIsEmpty() {
|
if networkMap.GetRemotePeersIsEmpty() {
|
||||||
err := e.removeAllPeers()
|
err := e.removeAllPeers()
|
||||||
e.statusRecorder.FinishPeerListModifications()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -630,8 +590,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
e.statusRecorder.FinishPeerListModifications()
|
|
||||||
|
|
||||||
// update SSHServer by adding remote peer SSH keys
|
// update SSHServer by adding remote peer SSH keys
|
||||||
if !isNil(e.sshServer) {
|
if !isNil(e.sshServer) {
|
||||||
for _, config := range networkMap.GetRemotePeers() {
|
for _, config := range networkMap.GetRemotePeers() {
|
||||||
@@ -663,9 +621,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
log.Errorf("failed to update dns server, err: %v", err)
|
log.Errorf("failed to update dns server, err: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if e.acl != nil {
|
|
||||||
e.acl.ApplyFiltering(networkMap)
|
|
||||||
}
|
|
||||||
e.networkSerial = serial
|
e.networkSerial = serial
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -767,13 +722,17 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
|
|||||||
}
|
}
|
||||||
e.peerConns[peerKey] = conn
|
e.peerConns[peerKey] = conn
|
||||||
|
|
||||||
err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn)
|
err = e.statusRecorder.AddPeer(peerKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err)
|
log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
go e.connWorker(conn, peerKey)
|
go e.connWorker(conn, peerKey)
|
||||||
}
|
}
|
||||||
|
err := e.statusRecorder.UpdatePeerFQDN(peerKey, peerConfig.Fqdn)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("error updating peer's %s fqdn in the status recorder, got error: %v", peerKey, err)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -814,14 +773,14 @@ func (e *Engine) connWorker(conn *peer.Conn, peerKey string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) peerExists(peerKey string) bool {
|
func (e Engine) peerExists(peerKey string) bool {
|
||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
defer e.syncMsgMux.Unlock()
|
||||||
_, ok := e.peerConns[peerKey]
|
_, ok := e.peerConns[peerKey]
|
||||||
return ok
|
return ok
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, error) {
|
func (e Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, error) {
|
||||||
log.Debugf("creating peer connection %s", pubKey)
|
log.Debugf("creating peer connection %s", pubKey)
|
||||||
var stunTurn []*ice.URL
|
var stunTurn []*ice.URL
|
||||||
stunTurn = append(stunTurn, e.STUNs...)
|
stunTurn = append(stunTurn, e.STUNs...)
|
||||||
@@ -852,7 +811,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e
|
|||||||
UserspaceBind: e.wgInterface.IsUserspaceBind(),
|
UserspaceBind: e.wgInterface.IsUserspaceBind(),
|
||||||
}
|
}
|
||||||
|
|
||||||
peerConn, err := peer.NewConn(config, e.statusRecorder, e.wgProxyFactory, e.mobileDep.TunAdapter, e.mobileDep.IFaceDiscover)
|
peerConn, err := peer.NewConn(config, e.statusRecorder, e.mobileDep.TunAdapter, e.mobileDep.IFaceDiscover)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -1009,10 +968,6 @@ func (e *Engine) parseNATExternalIPMappings() []string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) close() {
|
func (e *Engine) close() {
|
||||||
if err := e.wgProxyFactory.Free(); err != nil {
|
|
||||||
log.Errorf("failed closing ebpf proxy: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
|
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
|
||||||
if e.wgInterface != nil {
|
if e.wgInterface != nil {
|
||||||
if err := e.wgInterface.Close(); err != nil {
|
if err := e.wgInterface.Close(); err != nil {
|
||||||
@@ -1047,18 +1002,6 @@ func (e *Engine) close() {
|
|||||||
e.dnsServer.Stop()
|
e.dnsServer.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
if e.acl != nil {
|
|
||||||
e.acl.Stop()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *Engine) readInitialSettings() ([]*route.Route, error) {
|
|
||||||
netMap, err := e.mgmClient.GetNetworkMap()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
routes := toRoutes(netMap.GetRoutes())
|
|
||||||
return routes, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
|
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
|
||||||
|
|||||||
@@ -3,6 +3,8 @@ package internal
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/netbirdio/netbird/iface/bind"
|
||||||
|
"github.com/pion/transport/v2/stdnet"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
@@ -13,7 +15,6 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pion/transport/v2/stdnet"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@@ -28,7 +29,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
"github.com/netbirdio/netbird/iface/bind"
|
|
||||||
mgmt "github.com/netbirdio/netbird/management/client"
|
mgmt "github.com/netbirdio/netbird/management/client"
|
||||||
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
@@ -213,11 +213,11 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", iface.DefaultMTU, nil, newNet)
|
engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", iface.DefaultMTU, nil, nil, newNet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), engine.wgInterface, engine.statusRecorder, nil)
|
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), engine.wgInterface, engine.statusRecorder)
|
||||||
engine.dnsServer = &dns.MockServer{
|
engine.dnsServer = &dns.MockServer{
|
||||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||||
}
|
}
|
||||||
@@ -367,9 +367,9 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
|||||||
t.Errorf("expecting Engine.peerConns to contain peer %s", p)
|
t.Errorf("expecting Engine.peerConns to contain peer %s", p)
|
||||||
}
|
}
|
||||||
expectedAllowedIPs := strings.Join(p.AllowedIps, ",")
|
expectedAllowedIPs := strings.Join(p.AllowedIps, ",")
|
||||||
if conn.WgConfig().AllowedIps != expectedAllowedIPs {
|
if conn.GetConf().ProxyConfig.AllowedIps != expectedAllowedIPs {
|
||||||
t.Errorf("expecting peer %s to have AllowedIPs= %s, got %s", p.GetWgPubKey(),
|
t.Errorf("expecting peer %s to have AllowedIPs= %s, got %s", p.GetWgPubKey(),
|
||||||
expectedAllowedIPs, conn.WgConfig().AllowedIps)
|
expectedAllowedIPs, conn.GetConf().ProxyConfig.AllowedIps)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -567,7 +567,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, nil, newNet)
|
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, nil, nil, newNet)
|
||||||
assert.NoError(t, err, "shouldn't return error")
|
assert.NoError(t, err, "shouldn't return error")
|
||||||
input := struct {
|
input := struct {
|
||||||
inputSerial uint64
|
inputSerial uint64
|
||||||
@@ -736,7 +736,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, nil, newNet)
|
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, nil, nil, newNet)
|
||||||
assert.NoError(t, err, "shouldn't return error")
|
assert.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
mockRouteManager := &routemanager.MockManager{
|
mockRouteManager := &routemanager.MockManager{
|
||||||
@@ -1039,7 +1039,7 @@ func startManagement(dataDir string) (*grpc.Server, string, error) {
|
|||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||||
store, err := server.NewFileStore(config.Datadir, nil)
|
store, err := server.NewFileStore(config.Datadir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed creating a store: %s: %v", config.Datadir, err)
|
log.Fatalf("failed creating a store: %s: %v", config.Datadir, err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
package internal
|
package internal
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
)
|
)
|
||||||
@@ -11,7 +9,5 @@ import (
|
|||||||
type MobileDependency struct {
|
type MobileDependency struct {
|
||||||
TunAdapter iface.TunAdapter
|
TunAdapter iface.TunAdapter
|
||||||
IFaceDiscover stdnet.ExternalIFaceDiscover
|
IFaceDiscover stdnet.ExternalIFaceDiscover
|
||||||
RouteListener routemanager.RouteListener
|
Routes []string
|
||||||
HostDNSAddresses []string
|
|
||||||
DnsReadyListener dns.ReadyListener
|
|
||||||
}
|
}
|
||||||
|
|||||||
29
client/internal/mobile_dependency_android.go
Normal file
29
client/internal/mobile_dependency_android.go
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
|
mgm "github.com/netbirdio/netbird/management/client"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newMobileDependency(tunAdapter iface.TunAdapter, ifaceDiscover stdnet.ExternalIFaceDiscover, mgmClient *mgm.GrpcClient) (MobileDependency, error) {
|
||||||
|
md := MobileDependency{
|
||||||
|
TunAdapter: tunAdapter,
|
||||||
|
IFaceDiscover: ifaceDiscover,
|
||||||
|
}
|
||||||
|
err := md.readMap(mgmClient)
|
||||||
|
return md, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *MobileDependency) readMap(mgmClient *mgm.GrpcClient) error {
|
||||||
|
routes, err := mgmClient.GetRoutes()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
d.Routes = make([]string, len(routes))
|
||||||
|
for i, r := range routes {
|
||||||
|
d.Routes[i] = r.GetNetwork()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
13
client/internal/mobile_dependency_nonandroid.go
Normal file
13
client/internal/mobile_dependency_nonandroid.go
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
|
mgm "github.com/netbirdio/netbird/management/client"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newMobileDependency(tunAdapter iface.TunAdapter, ifaceDiscover stdnet.ExternalIFaceDiscover, mgmClient *mgm.GrpcClient) (MobileDependency, error) {
|
||||||
|
return MobileDependency{}, nil
|
||||||
|
}
|
||||||
286
client/internal/oauth.go
Normal file
286
client/internal/oauth.go
Normal file
@@ -0,0 +1,286 @@
|
|||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// OAuthClient is a OAuth client interface for various idp providers
|
||||||
|
type OAuthClient interface {
|
||||||
|
RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error)
|
||||||
|
WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo, error)
|
||||||
|
GetClientID(ctx context.Context) string
|
||||||
|
}
|
||||||
|
|
||||||
|
// HTTPClient http client interface for API calls
|
||||||
|
type HTTPClient interface {
|
||||||
|
Do(req *http.Request) (*http.Response, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeviceAuthInfo holds information for the OAuth device login flow
|
||||||
|
type DeviceAuthInfo struct {
|
||||||
|
DeviceCode string `json:"device_code"`
|
||||||
|
UserCode string `json:"user_code"`
|
||||||
|
VerificationURI string `json:"verification_uri"`
|
||||||
|
VerificationURIComplete string `json:"verification_uri_complete"`
|
||||||
|
ExpiresIn int `json:"expires_in"`
|
||||||
|
Interval int `json:"interval"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// HostedGrantType grant type for device flow on Hosted
|
||||||
|
const (
|
||||||
|
HostedGrantType = "urn:ietf:params:oauth:grant-type:device_code"
|
||||||
|
HostedRefreshGrant = "refresh_token"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Hosted client
|
||||||
|
type Hosted struct {
|
||||||
|
providerConfig ProviderConfig
|
||||||
|
|
||||||
|
HTTPClient HTTPClient
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestDeviceCodePayload used for request device code payload for auth0
|
||||||
|
type RequestDeviceCodePayload struct {
|
||||||
|
Audience string `json:"audience"`
|
||||||
|
ClientID string `json:"client_id"`
|
||||||
|
Scope string `json:"scope"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TokenRequestPayload used for requesting the auth0 token
|
||||||
|
type TokenRequestPayload struct {
|
||||||
|
GrantType string `json:"grant_type"`
|
||||||
|
DeviceCode string `json:"device_code,omitempty"`
|
||||||
|
ClientID string `json:"client_id"`
|
||||||
|
RefreshToken string `json:"refresh_token,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TokenRequestResponse used for parsing Hosted token's response
|
||||||
|
type TokenRequestResponse struct {
|
||||||
|
Error string `json:"error"`
|
||||||
|
ErrorDescription string `json:"error_description"`
|
||||||
|
TokenInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
// Claims used when validating the access token
|
||||||
|
type Claims struct {
|
||||||
|
Audience interface{} `json:"aud"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TokenInfo holds information of issued access token
|
||||||
|
type TokenInfo struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
IDToken string `json:"id_token"`
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
ExpiresIn int `json:"expires_in"`
|
||||||
|
UseIDToken bool `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTokenToUse returns either the access or id token based on UseIDToken field
|
||||||
|
func (t TokenInfo) GetTokenToUse() string {
|
||||||
|
if t.UseIDToken {
|
||||||
|
return t.IDToken
|
||||||
|
}
|
||||||
|
return t.AccessToken
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHostedDeviceFlow returns an Hosted OAuth client
|
||||||
|
func NewHostedDeviceFlow(config ProviderConfig) *Hosted {
|
||||||
|
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||||
|
httpTransport.MaxIdleConns = 5
|
||||||
|
|
||||||
|
httpClient := &http.Client{
|
||||||
|
Timeout: 10 * time.Second,
|
||||||
|
Transport: httpTransport,
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Hosted{
|
||||||
|
providerConfig: config,
|
||||||
|
HTTPClient: httpClient,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetClientID returns the provider client id
|
||||||
|
func (h *Hosted) GetClientID(ctx context.Context) string {
|
||||||
|
return h.providerConfig.ClientID
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestDeviceCode requests a device code login flow information from Hosted
|
||||||
|
func (h *Hosted) RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error) {
|
||||||
|
form := url.Values{}
|
||||||
|
form.Add("client_id", h.providerConfig.ClientID)
|
||||||
|
form.Add("audience", h.providerConfig.Audience)
|
||||||
|
form.Add("scope", h.providerConfig.Scope)
|
||||||
|
req, err := http.NewRequest("POST", h.providerConfig.DeviceAuthEndpoint,
|
||||||
|
strings.NewReader(form.Encode()))
|
||||||
|
if err != nil {
|
||||||
|
return DeviceAuthInfo{}, fmt.Errorf("creating request failed with error: %v", err)
|
||||||
|
}
|
||||||
|
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
|
||||||
|
res, err := h.HTTPClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return DeviceAuthInfo{}, fmt.Errorf("doing request failed with error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer res.Body.Close()
|
||||||
|
body, err := io.ReadAll(res.Body)
|
||||||
|
if err != nil {
|
||||||
|
return DeviceAuthInfo{}, fmt.Errorf("reading body failed with error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if res.StatusCode != 200 {
|
||||||
|
return DeviceAuthInfo{}, fmt.Errorf("request device code returned status %d error: %s", res.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
deviceCode := DeviceAuthInfo{}
|
||||||
|
err = json.Unmarshal(body, &deviceCode)
|
||||||
|
if err != nil {
|
||||||
|
return DeviceAuthInfo{}, fmt.Errorf("unmarshaling response failed with error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to the verification_uri if the IdP doesn't support verification_uri_complete
|
||||||
|
if deviceCode.VerificationURIComplete == "" {
|
||||||
|
deviceCode.VerificationURIComplete = deviceCode.VerificationURI
|
||||||
|
}
|
||||||
|
|
||||||
|
return deviceCode, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hosted) requestToken(info DeviceAuthInfo) (TokenRequestResponse, error) {
|
||||||
|
form := url.Values{}
|
||||||
|
form.Add("client_id", h.providerConfig.ClientID)
|
||||||
|
form.Add("grant_type", HostedGrantType)
|
||||||
|
form.Add("device_code", info.DeviceCode)
|
||||||
|
req, err := http.NewRequest("POST", h.providerConfig.TokenEndpoint, strings.NewReader(form.Encode()))
|
||||||
|
if err != nil {
|
||||||
|
return TokenRequestResponse{}, fmt.Errorf("failed to create request access token: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
|
||||||
|
res, err := h.HTTPClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return TokenRequestResponse{}, fmt.Errorf("failed to request access token with error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
err := res.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(res.Body)
|
||||||
|
if err != nil {
|
||||||
|
return TokenRequestResponse{}, fmt.Errorf("failed reading access token response body with error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if res.StatusCode > 499 {
|
||||||
|
return TokenRequestResponse{}, fmt.Errorf("access token response returned code: %s", string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenResponse := TokenRequestResponse{}
|
||||||
|
err = json.Unmarshal(body, &tokenResponse)
|
||||||
|
if err != nil {
|
||||||
|
return TokenRequestResponse{}, fmt.Errorf("parsing token response failed with error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokenResponse, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WaitToken waits user's login and authorize the app. Once the user's authorize
|
||||||
|
// it retrieves the access token from Hosted's endpoint and validates it before returning
|
||||||
|
func (h *Hosted) WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo, error) {
|
||||||
|
interval := time.Duration(info.Interval) * time.Second
|
||||||
|
ticker := time.NewTicker(interval)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return TokenInfo{}, ctx.Err()
|
||||||
|
case <-ticker.C:
|
||||||
|
|
||||||
|
tokenResponse, err := h.requestToken(info)
|
||||||
|
if err != nil {
|
||||||
|
return TokenInfo{}, fmt.Errorf("parsing token response failed with error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tokenResponse.Error != "" {
|
||||||
|
if tokenResponse.Error == "authorization_pending" {
|
||||||
|
continue
|
||||||
|
} else if tokenResponse.Error == "slow_down" {
|
||||||
|
interval = interval + (3 * time.Second)
|
||||||
|
ticker.Reset(interval)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return TokenInfo{}, fmt.Errorf(tokenResponse.ErrorDescription)
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenInfo := TokenInfo{
|
||||||
|
AccessToken: tokenResponse.AccessToken,
|
||||||
|
TokenType: tokenResponse.TokenType,
|
||||||
|
RefreshToken: tokenResponse.RefreshToken,
|
||||||
|
IDToken: tokenResponse.IDToken,
|
||||||
|
ExpiresIn: tokenResponse.ExpiresIn,
|
||||||
|
UseIDToken: h.providerConfig.UseIDToken,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = isValidAccessToken(tokenInfo.GetTokenToUse(), h.providerConfig.Audience)
|
||||||
|
if err != nil {
|
||||||
|
return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokenInfo, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// isValidAccessToken is a simple validation of the access token
|
||||||
|
func isValidAccessToken(token string, audience string) error {
|
||||||
|
if token == "" {
|
||||||
|
return fmt.Errorf("token received is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
encodedClaims := strings.Split(token, ".")[1]
|
||||||
|
claimsString, err := base64.RawURLEncoding.DecodeString(encodedClaims)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
claims := Claims{}
|
||||||
|
err = json.Unmarshal(claimsString, &claims)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if claims.Audience == nil {
|
||||||
|
return fmt.Errorf("required token field audience is absent")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Audience claim of JWT can be a string or an array of strings
|
||||||
|
typ := reflect.TypeOf(claims.Audience)
|
||||||
|
switch typ.Kind() {
|
||||||
|
case reflect.String:
|
||||||
|
if claims.Audience == audience {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
case reflect.Slice:
|
||||||
|
for _, aud := range claims.Audience.([]interface{}) {
|
||||||
|
if audience == aud {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("invalid JWT token audience field")
|
||||||
|
}
|
||||||
@@ -1,17 +1,17 @@
|
|||||||
package auth
|
package internal
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/golang-jwt/jwt"
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
type mockHTTPClient struct {
|
type mockHTTPClient struct {
|
||||||
@@ -53,7 +53,7 @@ func TestHosted_RequestDeviceCode(t *testing.T) {
|
|||||||
testingErrFunc require.ErrorAssertionFunc
|
testingErrFunc require.ErrorAssertionFunc
|
||||||
expectedErrorMSG string
|
expectedErrorMSG string
|
||||||
testingFunc require.ComparisonAssertionFunc
|
testingFunc require.ComparisonAssertionFunc
|
||||||
expectedOut AuthFlowInfo
|
expectedOut DeviceAuthInfo
|
||||||
expectedMSG string
|
expectedMSG string
|
||||||
expectPayload string
|
expectPayload string
|
||||||
}
|
}
|
||||||
@@ -92,7 +92,7 @@ func TestHosted_RequestDeviceCode(t *testing.T) {
|
|||||||
testingFunc: require.EqualValues,
|
testingFunc: require.EqualValues,
|
||||||
expectPayload: expectPayload,
|
expectPayload: expectPayload,
|
||||||
}
|
}
|
||||||
testCase4Out := AuthFlowInfo{ExpiresIn: 10}
|
testCase4Out := DeviceAuthInfo{ExpiresIn: 10}
|
||||||
testCase4 := test{
|
testCase4 := test{
|
||||||
name: "Got Device Code",
|
name: "Got Device Code",
|
||||||
inputResBody: fmt.Sprintf("{\"expires_in\":%d}", testCase4Out.ExpiresIn),
|
inputResBody: fmt.Sprintf("{\"expires_in\":%d}", testCase4Out.ExpiresIn),
|
||||||
@@ -113,8 +113,8 @@ func TestHosted_RequestDeviceCode(t *testing.T) {
|
|||||||
err: testCase.inputReqError,
|
err: testCase.inputReqError,
|
||||||
}
|
}
|
||||||
|
|
||||||
deviceFlow := &DeviceAuthorizationFlow{
|
hosted := Hosted{
|
||||||
providerConfig: internal.DeviceAuthProviderConfig{
|
providerConfig: ProviderConfig{
|
||||||
Audience: expectedAudience,
|
Audience: expectedAudience,
|
||||||
ClientID: expectedClientID,
|
ClientID: expectedClientID,
|
||||||
Scope: expectedScope,
|
Scope: expectedScope,
|
||||||
@@ -125,7 +125,7 @@ func TestHosted_RequestDeviceCode(t *testing.T) {
|
|||||||
HTTPClient: &httpClient,
|
HTTPClient: &httpClient,
|
||||||
}
|
}
|
||||||
|
|
||||||
authInfo, err := deviceFlow.RequestAuthInfo(context.TODO())
|
authInfo, err := hosted.RequestDeviceCode(context.TODO())
|
||||||
testCase.testingErrFunc(t, err, testCase.expectedErrorMSG)
|
testCase.testingErrFunc(t, err, testCase.expectedErrorMSG)
|
||||||
|
|
||||||
require.EqualValues(t, expectPayload, httpClient.reqBody, "payload should match")
|
require.EqualValues(t, expectPayload, httpClient.reqBody, "payload should match")
|
||||||
@@ -145,7 +145,7 @@ func TestHosted_WaitToken(t *testing.T) {
|
|||||||
inputMaxReqs int
|
inputMaxReqs int
|
||||||
inputCountResBody string
|
inputCountResBody string
|
||||||
inputTimeout time.Duration
|
inputTimeout time.Duration
|
||||||
inputInfo AuthFlowInfo
|
inputInfo DeviceAuthInfo
|
||||||
inputAudience string
|
inputAudience string
|
||||||
testingErrFunc require.ErrorAssertionFunc
|
testingErrFunc require.ErrorAssertionFunc
|
||||||
expectedErrorMSG string
|
expectedErrorMSG string
|
||||||
@@ -155,7 +155,7 @@ func TestHosted_WaitToken(t *testing.T) {
|
|||||||
expectPayload string
|
expectPayload string
|
||||||
}
|
}
|
||||||
|
|
||||||
defaultInfo := AuthFlowInfo{
|
defaultInfo := DeviceAuthInfo{
|
||||||
DeviceCode: "test",
|
DeviceCode: "test",
|
||||||
ExpiresIn: 10,
|
ExpiresIn: 10,
|
||||||
Interval: 1,
|
Interval: 1,
|
||||||
@@ -278,8 +278,8 @@ func TestHosted_WaitToken(t *testing.T) {
|
|||||||
countResBody: testCase.inputCountResBody,
|
countResBody: testCase.inputCountResBody,
|
||||||
}
|
}
|
||||||
|
|
||||||
deviceFlow := DeviceAuthorizationFlow{
|
hosted := Hosted{
|
||||||
providerConfig: internal.DeviceAuthProviderConfig{
|
providerConfig: ProviderConfig{
|
||||||
Audience: testCase.inputAudience,
|
Audience: testCase.inputAudience,
|
||||||
ClientID: clientID,
|
ClientID: clientID,
|
||||||
TokenEndpoint: "test.hosted.com/token",
|
TokenEndpoint: "test.hosted.com/token",
|
||||||
@@ -287,12 +287,11 @@ func TestHosted_WaitToken(t *testing.T) {
|
|||||||
Scope: "openid",
|
Scope: "openid",
|
||||||
UseIDToken: false,
|
UseIDToken: false,
|
||||||
},
|
},
|
||||||
HTTPClient: &httpClient,
|
HTTPClient: &httpClient}
|
||||||
}
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.TODO(), testCase.inputTimeout)
|
ctx, cancel := context.WithTimeout(context.TODO(), testCase.inputTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
tokenInfo, err := deviceFlow.WaitToken(ctx, testCase.inputInfo)
|
tokenInfo, err := hosted.WaitToken(ctx, testCase.inputInfo)
|
||||||
testCase.testingErrFunc(t, err, testCase.expectedErrorMSG)
|
testCase.testingErrFunc(t, err, testCase.expectedErrorMSG)
|
||||||
|
|
||||||
require.EqualValues(t, testCase.expectPayload, httpClient.reqBody, "payload should match")
|
require.EqualValues(t, testCase.expectPayload, httpClient.reqBody, "payload should match")
|
||||||
@@ -4,6 +4,8 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@@ -13,7 +15,6 @@ import (
|
|||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
"github.com/netbirdio/netbird/client/internal/wgproxy"
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
"github.com/netbirdio/netbird/iface/bind"
|
"github.com/netbirdio/netbird/iface/bind"
|
||||||
signal "github.com/netbirdio/netbird/signal/client"
|
signal "github.com/netbirdio/netbird/signal/client"
|
||||||
@@ -22,6 +23,9 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
envICEKeepAliveIntervalSec = "NB_ICE_KEEP_ALIVE_INTERVAL_SEC"
|
||||||
|
envICEDisconnectedTimeoutSec = "NB_ICE_DISCONNECTED_TIMEOUT_SEC"
|
||||||
|
|
||||||
iceKeepAliveDefault = 4 * time.Second
|
iceKeepAliveDefault = 4 * time.Second
|
||||||
iceDisconnectedTimeoutDefault = 6 * time.Second
|
iceDisconnectedTimeoutDefault = 6 * time.Second
|
||||||
|
|
||||||
@@ -109,9 +113,7 @@ type Conn struct {
|
|||||||
|
|
||||||
statusRecorder *Status
|
statusRecorder *Status
|
||||||
|
|
||||||
wgProxyFactory *wgproxy.Factory
|
proxy *WireGuardProxy
|
||||||
wgProxy wgproxy.Proxy
|
|
||||||
|
|
||||||
remoteModeCh chan ModeMessage
|
remoteModeCh chan ModeMessage
|
||||||
meta meta
|
meta meta
|
||||||
|
|
||||||
@@ -147,7 +149,7 @@ func (conn *Conn) UpdateStunTurn(turnStun []*ice.URL) {
|
|||||||
|
|
||||||
// NewConn creates a new not opened Conn to the remote peer.
|
// NewConn creates a new not opened Conn to the remote peer.
|
||||||
// To establish a connection run Conn.Open
|
// To establish a connection run Conn.Open
|
||||||
func NewConn(config ConnConfig, statusRecorder *Status, wgProxyFactory *wgproxy.Factory, adapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover) (*Conn, error) {
|
func NewConn(config ConnConfig, statusRecorder *Status, adapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover) (*Conn, error) {
|
||||||
return &Conn{
|
return &Conn{
|
||||||
config: config,
|
config: config,
|
||||||
mu: sync.Mutex{},
|
mu: sync.Mutex{},
|
||||||
@@ -157,7 +159,6 @@ func NewConn(config ConnConfig, statusRecorder *Status, wgProxyFactory *wgproxy.
|
|||||||
remoteAnswerCh: make(chan OfferAnswer),
|
remoteAnswerCh: make(chan OfferAnswer),
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
remoteModeCh: make(chan ModeMessage, 1),
|
remoteModeCh: make(chan ModeMessage, 1),
|
||||||
wgProxyFactory: wgProxyFactory,
|
|
||||||
adapter: adapter,
|
adapter: adapter,
|
||||||
iFaceDiscover: iFaceDiscover,
|
iFaceDiscover: iFaceDiscover,
|
||||||
}, nil
|
}, nil
|
||||||
@@ -175,14 +176,13 @@ func (conn *Conn) reCreateAgent() error {
|
|||||||
log.Errorf("failed to create pion's stdnet: %s", err)
|
log.Errorf("failed to create pion's stdnet: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
iceKeepAlive := iceKeepAlive()
|
iceKeepAlive, iceDisconnectedTimeout := readICEAgentConfigProperties()
|
||||||
iceDisconnectedTimeout := iceDisconnectedTimeout()
|
|
||||||
|
|
||||||
agentConfig := &ice.AgentConfig{
|
agentConfig := &ice.AgentConfig{
|
||||||
MulticastDNSMode: ice.MulticastDNSModeDisabled,
|
MulticastDNSMode: ice.MulticastDNSModeDisabled,
|
||||||
NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6},
|
NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6},
|
||||||
Urls: conn.config.StunTurn,
|
Urls: conn.config.StunTurn,
|
||||||
CandidateTypes: conn.candidateTypes(),
|
CandidateTypes: []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay},
|
||||||
FailedTimeout: &failedTimeout,
|
FailedTimeout: &failedTimeout,
|
||||||
InterfaceFilter: stdnet.InterfaceFilter(conn.config.InterfaceBlackList),
|
InterfaceFilter: stdnet.InterfaceFilter(conn.config.InterfaceBlackList),
|
||||||
UDPMux: conn.config.UDPMux,
|
UDPMux: conn.config.UDPMux,
|
||||||
@@ -221,11 +221,32 @@ func (conn *Conn) reCreateAgent() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) candidateTypes() []ice.CandidateType {
|
func readICEAgentConfigProperties() (time.Duration, time.Duration) {
|
||||||
if hasICEForceRelayConn() {
|
iceKeepAlive := iceKeepAliveDefault
|
||||||
return []ice.CandidateType{ice.CandidateTypeRelay}
|
iceDisconnectedTimeout := iceDisconnectedTimeoutDefault
|
||||||
|
|
||||||
|
keepAliveEnv := os.Getenv(envICEKeepAliveIntervalSec)
|
||||||
|
if keepAliveEnv != "" {
|
||||||
|
log.Debugf("setting ICE keep alive interval to %s seconds", keepAliveEnv)
|
||||||
|
keepAliveEnvSec, err := strconv.Atoi(keepAliveEnv)
|
||||||
|
if err == nil {
|
||||||
|
iceKeepAlive = time.Duration(keepAliveEnvSec) * time.Second
|
||||||
|
} else {
|
||||||
|
log.Warnf("invalid value %s set for %s, using default %v", keepAliveEnv, envICEKeepAliveIntervalSec, iceKeepAlive)
|
||||||
}
|
}
|
||||||
return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay}
|
}
|
||||||
|
|
||||||
|
disconnectedTimeoutEnv := os.Getenv(envICEDisconnectedTimeoutSec)
|
||||||
|
if disconnectedTimeoutEnv != "" {
|
||||||
|
log.Debugf("setting ICE disconnected timeout to %s seconds", disconnectedTimeoutEnv)
|
||||||
|
disconnectedTimeoutSec, err := strconv.Atoi(disconnectedTimeoutEnv)
|
||||||
|
if err == nil {
|
||||||
|
iceDisconnectedTimeout = time.Duration(disconnectedTimeoutSec) * time.Second
|
||||||
|
} else {
|
||||||
|
log.Warnf("invalid value %s set for %s, using default %v", disconnectedTimeoutEnv, envICEDisconnectedTimeoutSec, iceDisconnectedTimeout)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return iceKeepAlive, iceDisconnectedTimeout
|
||||||
}
|
}
|
||||||
|
|
||||||
// Open opens connection to the remote peer starting ICE candidate gathering process.
|
// Open opens connection to the remote peer starting ICE candidate gathering process.
|
||||||
@@ -334,7 +355,7 @@ func (conn *Conn) Open() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("connected to peer %s, endpoint address: %s", conn.config.Key, remoteAddr.String())
|
log.Infof("connected to peer %s, proxy: %v, remote address: %s", conn.config.Key, conn.proxy != nil, remoteAddr.String())
|
||||||
|
|
||||||
// wait until connection disconnected or has been closed externally (upper layer, e.g. engine)
|
// wait until connection disconnected or has been closed externally (upper layer, e.g. engine)
|
||||||
select {
|
select {
|
||||||
@@ -363,10 +384,10 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int) (ne
|
|||||||
|
|
||||||
var endpoint net.Addr
|
var endpoint net.Addr
|
||||||
if isRelayCandidate(pair.Local) {
|
if isRelayCandidate(pair.Local) {
|
||||||
log.Debugf("setup relay connection")
|
conn.proxy = NewWireGuardProxy(conn.config.WgConfig.WgListenPort, conn.config.WgConfig.RemoteKey, remoteConn)
|
||||||
conn.wgProxy = conn.wgProxyFactory.GetProxy()
|
endpoint, err = conn.proxy.Start()
|
||||||
endpoint, err = conn.wgProxy.AddTurnConn(remoteConn)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
conn.proxy = nil
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -375,12 +396,10 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int) (ne
|
|||||||
endpoint = remoteConn.RemoteAddr()
|
endpoint = remoteConn.RemoteAddr()
|
||||||
}
|
}
|
||||||
|
|
||||||
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
|
err = conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, endpoint, conn.config.WgConfig.PreSharedKey)
|
||||||
|
|
||||||
err = conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, endpointUdpAddr, conn.config.WgConfig.PreSharedKey)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if conn.wgProxy != nil {
|
if conn.proxy != nil {
|
||||||
_ = conn.wgProxy.CloseConn()
|
_ = conn.proxy.Close()
|
||||||
}
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -393,7 +412,7 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int) (ne
|
|||||||
ConnStatusUpdate: time.Now(),
|
ConnStatusUpdate: time.Now(),
|
||||||
LocalIceCandidateType: pair.Local.Type().String(),
|
LocalIceCandidateType: pair.Local.Type().String(),
|
||||||
RemoteIceCandidateType: pair.Remote.Type().String(),
|
RemoteIceCandidateType: pair.Remote.Type().String(),
|
||||||
Direct: !isRelayCandidate(pair.Local),
|
Direct: conn.proxy == nil,
|
||||||
}
|
}
|
||||||
if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
|
if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
|
||||||
peerState.Relayed = true
|
peerState.Relayed = true
|
||||||
@@ -441,13 +460,15 @@ func (conn *Conn) cleanup() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if conn.wgProxy != nil {
|
|
||||||
err2 = conn.wgProxy.CloseConn()
|
|
||||||
conn.wgProxy = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// todo: is it problem if we try to remove a peer what is never existed?
|
// todo: is it problem if we try to remove a peer what is never existed?
|
||||||
err3 = conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
|
err2 = conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
|
||||||
|
|
||||||
|
if conn.proxy != nil {
|
||||||
|
err3 = conn.proxy.Close()
|
||||||
|
if err3 != nil {
|
||||||
|
conn.proxy = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if conn.notifyDisconnected != nil {
|
if conn.notifyDisconnected != nil {
|
||||||
conn.notifyDisconnected()
|
conn.notifyDisconnected()
|
||||||
|
|||||||
@@ -5,11 +5,12 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
|
|
||||||
"github.com/magiconair/properties/assert"
|
"github.com/magiconair/properties/assert"
|
||||||
"github.com/pion/ice/v2"
|
"github.com/pion/ice/v2"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/proxy"
|
||||||
"github.com/netbirdio/netbird/client/internal/wgproxy"
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -19,6 +20,7 @@ var connConf = ConnConfig{
|
|||||||
StunTurn: []*ice.URL{},
|
StunTurn: []*ice.URL{},
|
||||||
InterfaceBlackList: nil,
|
InterfaceBlackList: nil,
|
||||||
Timeout: time.Second,
|
Timeout: time.Second,
|
||||||
|
ProxyConfig: proxy.Config{},
|
||||||
LocalWgPort: 51820,
|
LocalWgPort: 51820,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -35,11 +37,7 @@ func TestNewConn_interfaceFilter(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestConn_GetKey(t *testing.T) {
|
func TestConn_GetKey(t *testing.T) {
|
||||||
wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
|
conn, err := NewConn(connConf, nil, nil, nil)
|
||||||
defer func() {
|
|
||||||
_ = wgProxyFactory.Free()
|
|
||||||
}()
|
|
||||||
conn, err := NewConn(connConf, nil, wgProxyFactory, nil, nil)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -50,11 +48,8 @@ func TestConn_GetKey(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestConn_OnRemoteOffer(t *testing.T) {
|
func TestConn_OnRemoteOffer(t *testing.T) {
|
||||||
wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
|
|
||||||
defer func() {
|
conn, err := NewConn(connConf, NewRecorder("https://mgm"), nil, nil)
|
||||||
_ = wgProxyFactory.Free()
|
|
||||||
}()
|
|
||||||
conn, err := NewConn(connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -87,11 +82,8 @@ func TestConn_OnRemoteOffer(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestConn_OnRemoteAnswer(t *testing.T) {
|
func TestConn_OnRemoteAnswer(t *testing.T) {
|
||||||
wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
|
|
||||||
defer func() {
|
conn, err := NewConn(connConf, NewRecorder("https://mgm"), nil, nil)
|
||||||
_ = wgProxyFactory.Free()
|
|
||||||
}()
|
|
||||||
conn, err := NewConn(connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -123,11 +115,8 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
|
|||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
func TestConn_Status(t *testing.T) {
|
func TestConn_Status(t *testing.T) {
|
||||||
wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
|
|
||||||
defer func() {
|
conn, err := NewConn(connConf, NewRecorder("https://mgm"), nil, nil)
|
||||||
_ = wgProxyFactory.Free()
|
|
||||||
}()
|
|
||||||
conn, err := NewConn(connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -153,11 +142,8 @@ func TestConn_Status(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestConn_Close(t *testing.T) {
|
func TestConn_Close(t *testing.T) {
|
||||||
wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
|
|
||||||
defer func() {
|
conn, err := NewConn(connConf, NewRecorder("https://mgm"), nil, nil)
|
||||||
_ = wgProxyFactory.Free()
|
|
||||||
}()
|
|
||||||
conn, err := NewConn(connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,53 +0,0 @@
|
|||||||
package peer
|
|
||||||
|
|
||||||
import (
|
|
||||||
"os"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
envICEKeepAliveIntervalSec = "NB_ICE_KEEP_ALIVE_INTERVAL_SEC"
|
|
||||||
envICEDisconnectedTimeoutSec = "NB_ICE_DISCONNECTED_TIMEOUT_SEC"
|
|
||||||
envICEForceRelayConn = "NB_ICE_FORCE_RELAY_CONN"
|
|
||||||
)
|
|
||||||
|
|
||||||
func iceKeepAlive() time.Duration {
|
|
||||||
keepAliveEnv := os.Getenv(envICEKeepAliveIntervalSec)
|
|
||||||
if keepAliveEnv == "" {
|
|
||||||
return iceKeepAliveDefault
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("setting ICE keep alive interval to %s seconds", keepAliveEnv)
|
|
||||||
keepAliveEnvSec, err := strconv.Atoi(keepAliveEnv)
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("invalid value %s set for %s, using default %v", keepAliveEnv, envICEKeepAliveIntervalSec, iceKeepAliveDefault)
|
|
||||||
return iceKeepAliveDefault
|
|
||||||
}
|
|
||||||
|
|
||||||
return time.Duration(keepAliveEnvSec) * time.Second
|
|
||||||
}
|
|
||||||
|
|
||||||
func iceDisconnectedTimeout() time.Duration {
|
|
||||||
disconnectedTimeoutEnv := os.Getenv(envICEDisconnectedTimeoutSec)
|
|
||||||
if disconnectedTimeoutEnv == "" {
|
|
||||||
return iceDisconnectedTimeoutDefault
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("setting ICE disconnected timeout to %s seconds", disconnectedTimeoutEnv)
|
|
||||||
disconnectedTimeoutSec, err := strconv.Atoi(disconnectedTimeoutEnv)
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("invalid value %s set for %s, using default %v", disconnectedTimeoutEnv, envICEDisconnectedTimeoutSec, iceDisconnectedTimeoutDefault)
|
|
||||||
return iceDisconnectedTimeoutDefault
|
|
||||||
}
|
|
||||||
|
|
||||||
return time.Duration(disconnectedTimeoutSec) * time.Second
|
|
||||||
}
|
|
||||||
|
|
||||||
func hasICEForceRelayConn() bool {
|
|
||||||
disconnectedTimeoutEnv := os.Getenv(envICEForceRelayConn)
|
|
||||||
return strings.ToLower(disconnectedTimeoutEnv) == "true"
|
|
||||||
}
|
|
||||||
@@ -59,11 +59,6 @@ type Status struct {
|
|||||||
mgmAddress string
|
mgmAddress string
|
||||||
signalAddress string
|
signalAddress string
|
||||||
notifier *notifier
|
notifier *notifier
|
||||||
|
|
||||||
// To reduce the number of notification invocation this bool will be true when need to call the notification
|
|
||||||
// Some Peer actions mostly used by in a batch when the network map has been synchronized. In these type of events
|
|
||||||
// set to true this variable and at the end of the processing we will reset it by the FinishPeerListModifications()
|
|
||||||
peerListChangedForNotification bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRecorder returns a new Status instance
|
// NewRecorder returns a new Status instance
|
||||||
@@ -83,13 +78,11 @@ func (d *Status) ReplaceOfflinePeers(replacement []State) {
|
|||||||
defer d.mux.Unlock()
|
defer d.mux.Unlock()
|
||||||
d.offlinePeers = make([]State, len(replacement))
|
d.offlinePeers = make([]State, len(replacement))
|
||||||
copy(d.offlinePeers, replacement)
|
copy(d.offlinePeers, replacement)
|
||||||
|
d.notifyPeerListChanged()
|
||||||
// todo we should set to true in case if the list changed only
|
|
||||||
d.peerListChangedForNotification = true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddPeer adds peer to Daemon status map
|
// AddPeer adds peer to Daemon status map
|
||||||
func (d *Status) AddPeer(peerPubKey string, fqdn string) error {
|
func (d *Status) AddPeer(peerPubKey string) error {
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.Unlock()
|
||||||
|
|
||||||
@@ -97,12 +90,7 @@ func (d *Status) AddPeer(peerPubKey string, fqdn string) error {
|
|||||||
if ok {
|
if ok {
|
||||||
return errors.New("peer already exist")
|
return errors.New("peer already exist")
|
||||||
}
|
}
|
||||||
d.peers[peerPubKey] = State{
|
d.peers[peerPubKey] = State{PubKey: peerPubKey, ConnStatus: StatusDisconnected}
|
||||||
PubKey: peerPubKey,
|
|
||||||
ConnStatus: StatusDisconnected,
|
|
||||||
FQDN: fqdn,
|
|
||||||
}
|
|
||||||
d.peerListChangedForNotification = true
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -124,13 +112,13 @@ func (d *Status) RemovePeer(peerPubKey string) error {
|
|||||||
defer d.mux.Unlock()
|
defer d.mux.Unlock()
|
||||||
|
|
||||||
_, ok := d.peers[peerPubKey]
|
_, ok := d.peers[peerPubKey]
|
||||||
if !ok {
|
if ok {
|
||||||
return errors.New("no peer with to remove")
|
delete(d.peers, peerPubKey)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
delete(d.peers, peerPubKey)
|
d.notifyPeerListChanged()
|
||||||
d.peerListChangedForNotification = true
|
return errors.New("no peer with to remove")
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdatePeerState updates peer status
|
// UpdatePeerState updates peer status
|
||||||
@@ -147,8 +135,6 @@ func (d *Status) UpdatePeerState(receivedState State) error {
|
|||||||
peerState.IP = receivedState.IP
|
peerState.IP = receivedState.IP
|
||||||
}
|
}
|
||||||
|
|
||||||
skipNotification := shouldSkipNotify(receivedState, peerState)
|
|
||||||
|
|
||||||
if receivedState.ConnStatus != peerState.ConnStatus {
|
if receivedState.ConnStatus != peerState.ConnStatus {
|
||||||
peerState.ConnStatus = receivedState.ConnStatus
|
peerState.ConnStatus = receivedState.ConnStatus
|
||||||
peerState.ConnStatusUpdate = receivedState.ConnStatusUpdate
|
peerState.ConnStatusUpdate = receivedState.ConnStatusUpdate
|
||||||
@@ -160,10 +146,6 @@ func (d *Status) UpdatePeerState(receivedState State) error {
|
|||||||
|
|
||||||
d.peers[receivedState.PubKey] = peerState
|
d.peers[receivedState.PubKey] = peerState
|
||||||
|
|
||||||
if skipNotification {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
ch, found := d.changeNotify[receivedState.PubKey]
|
ch, found := d.changeNotify[receivedState.PubKey]
|
||||||
if found && ch != nil {
|
if found && ch != nil {
|
||||||
close(ch)
|
close(ch)
|
||||||
@@ -174,19 +156,6 @@ func (d *Status) UpdatePeerState(receivedState State) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func shouldSkipNotify(new, curr State) bool {
|
|
||||||
switch {
|
|
||||||
case new.ConnStatus == StatusConnecting:
|
|
||||||
return true
|
|
||||||
case new.ConnStatus == StatusDisconnected && curr.ConnStatus == StatusConnecting:
|
|
||||||
return true
|
|
||||||
case new.ConnStatus == StatusDisconnected && curr.ConnStatus == StatusDisconnected:
|
|
||||||
return curr.IP != ""
|
|
||||||
default:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdatePeerFQDN update peer's state fqdn only
|
// UpdatePeerFQDN update peer's state fqdn only
|
||||||
func (d *Status) UpdatePeerFQDN(peerPubKey, fqdn string) error {
|
func (d *Status) UpdatePeerFQDN(peerPubKey, fqdn string) error {
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
@@ -200,21 +169,8 @@ func (d *Status) UpdatePeerFQDN(peerPubKey, fqdn string) error {
|
|||||||
peerState.FQDN = fqdn
|
peerState.FQDN = fqdn
|
||||||
d.peers[peerPubKey] = peerState
|
d.peers[peerPubKey] = peerState
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// FinishPeerListModifications this event invoke the notification
|
|
||||||
func (d *Status) FinishPeerListModifications() {
|
|
||||||
d.mux.Lock()
|
|
||||||
|
|
||||||
if !d.peerListChangedForNotification {
|
|
||||||
d.mux.Unlock()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
d.peerListChangedForNotification = false
|
|
||||||
d.mux.Unlock()
|
|
||||||
|
|
||||||
d.notifyPeerListChanged()
|
d.notifyPeerListChanged()
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPeerStateChangeNotifier returns a change notifier channel for a peer
|
// GetPeerStateChangeNotifier returns a change notifier channel for a peer
|
||||||
|
|||||||
@@ -9,13 +9,13 @@ import (
|
|||||||
func TestAddPeer(t *testing.T) {
|
func TestAddPeer(t *testing.T) {
|
||||||
key := "abc"
|
key := "abc"
|
||||||
status := NewRecorder("https://mgm")
|
status := NewRecorder("https://mgm")
|
||||||
err := status.AddPeer(key, "abc.netbird")
|
err := status.AddPeer(key)
|
||||||
assert.NoError(t, err, "shouldn't return error")
|
assert.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
_, exists := status.peers[key]
|
_, exists := status.peers[key]
|
||||||
assert.True(t, exists, "value was found")
|
assert.True(t, exists, "value was found")
|
||||||
|
|
||||||
err = status.AddPeer(key, "abc.netbird")
|
err = status.AddPeer(key)
|
||||||
|
|
||||||
assert.Error(t, err, "should return error on duplicate")
|
assert.Error(t, err, "should return error on duplicate")
|
||||||
}
|
}
|
||||||
@@ -23,7 +23,7 @@ func TestAddPeer(t *testing.T) {
|
|||||||
func TestGetPeer(t *testing.T) {
|
func TestGetPeer(t *testing.T) {
|
||||||
key := "abc"
|
key := "abc"
|
||||||
status := NewRecorder("https://mgm")
|
status := NewRecorder("https://mgm")
|
||||||
err := status.AddPeer(key, "abc.netbird")
|
err := status.AddPeer(key)
|
||||||
assert.NoError(t, err, "shouldn't return error")
|
assert.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
peerStatus, err := status.GetPeer(key)
|
peerStatus, err := status.GetPeer(key)
|
||||||
|
|||||||
@@ -1,71 +1,68 @@
|
|||||||
package wgproxy
|
package peer
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"net"
|
||||||
)
|
)
|
||||||
|
|
||||||
// WGUserSpaceProxy proxies
|
// WireGuardProxy proxies
|
||||||
type WGUserSpaceProxy struct {
|
type WireGuardProxy struct {
|
||||||
localWGListenPort int
|
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
|
|
||||||
|
wgListenPort int
|
||||||
|
remoteKey string
|
||||||
|
|
||||||
remoteConn net.Conn
|
remoteConn net.Conn
|
||||||
localConn net.Conn
|
localConn net.Conn
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewWGUserSpaceProxy instantiate a user space WireGuard proxy
|
func NewWireGuardProxy(wgListenPort int, remoteKey string, remoteConn net.Conn) *WireGuardProxy {
|
||||||
func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy {
|
p := &WireGuardProxy{
|
||||||
p := &WGUserSpaceProxy{
|
wgListenPort: wgListenPort,
|
||||||
localWGListenPort: wgPort,
|
remoteKey: remoteKey,
|
||||||
|
remoteConn: remoteConn,
|
||||||
}
|
}
|
||||||
p.ctx, p.cancel = context.WithCancel(context.Background())
|
p.ctx, p.cancel = context.WithCancel(context.Background())
|
||||||
return p
|
return p
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddTurnConn start the proxy with the given remote conn
|
func (p *WireGuardProxy) Start() (net.Addr, error) {
|
||||||
func (p *WGUserSpaceProxy) AddTurnConn(remoteConn net.Conn) (net.Addr, error) {
|
lConn, err := net.Dial("udp", fmt.Sprintf("127.0.0.1:%d", p.wgListenPort))
|
||||||
p.remoteConn = remoteConn
|
|
||||||
|
|
||||||
var err error
|
|
||||||
p.localConn, err = net.Dial("udp", fmt.Sprintf(":%d", p.localWGListenPort))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed dialing to local Wireguard port %s", err)
|
log.Errorf("failed dialing to local Wireguard port %s", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
p.localConn = lConn
|
||||||
|
|
||||||
go p.proxyToRemote()
|
go p.proxyToRemote()
|
||||||
go p.proxyToLocal()
|
go p.proxyToLocal()
|
||||||
|
|
||||||
return p.localConn.LocalAddr(), err
|
return lConn.LocalAddr(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CloseConn close the localConn
|
func (p *WireGuardProxy) Close() error {
|
||||||
func (p *WGUserSpaceProxy) CloseConn() error {
|
|
||||||
p.cancel()
|
p.cancel()
|
||||||
if p.localConn == nil {
|
if p.localConn != nil {
|
||||||
return nil
|
err := p.localConn.Close()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return p.localConn.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Free doing nothing because this implementation of proxy does not have global state
|
|
||||||
func (p *WGUserSpaceProxy) Free() error {
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// proxyToRemote proxies everything from Wireguard to the RemoteKey peer
|
// proxyToRemote proxies everything from Wireguard to the RemoteKey peer
|
||||||
// blocks
|
// blocks
|
||||||
func (p *WGUserSpaceProxy) proxyToRemote() {
|
func (p *WireGuardProxy) proxyToRemote() {
|
||||||
|
|
||||||
buf := make([]byte, 1500)
|
buf := make([]byte, 1500)
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-p.ctx.Done():
|
case <-p.ctx.Done():
|
||||||
|
log.Debugf("stopped proxying to remote peer %s due to closed connection", p.remoteKey)
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
n, err := p.localConn.Read(buf)
|
n, err := p.localConn.Read(buf)
|
||||||
@@ -83,12 +80,13 @@ func (p *WGUserSpaceProxy) proxyToRemote() {
|
|||||||
|
|
||||||
// proxyToLocal proxies everything from the RemoteKey peer to local Wireguard
|
// proxyToLocal proxies everything from the RemoteKey peer to local Wireguard
|
||||||
// blocks
|
// blocks
|
||||||
func (p *WGUserSpaceProxy) proxyToLocal() {
|
func (p *WireGuardProxy) proxyToLocal() {
|
||||||
|
|
||||||
buf := make([]byte, 1500)
|
buf := make([]byte, 1500)
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-p.ctx.Done():
|
case <-p.ctx.Done():
|
||||||
|
log.Debugf("stopped proxying from remote peer %s due to closed connection", p.remoteKey)
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
n, err := p.remoteConn.Read(buf)
|
n, err := p.remoteConn.Read(buf)
|
||||||
@@ -1,128 +0,0 @@
|
|||||||
package internal
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"net/url"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
|
|
||||||
mgm "github.com/netbirdio/netbird/management/client"
|
|
||||||
)
|
|
||||||
|
|
||||||
// PKCEAuthorizationFlow represents PKCE Authorization Flow information
|
|
||||||
type PKCEAuthorizationFlow struct {
|
|
||||||
ProviderConfig PKCEAuthProviderConfig
|
|
||||||
}
|
|
||||||
|
|
||||||
// PKCEAuthProviderConfig has all attributes needed to initiate pkce authorization flow
|
|
||||||
type PKCEAuthProviderConfig struct {
|
|
||||||
// ClientID An IDP application client id
|
|
||||||
ClientID string
|
|
||||||
// ClientSecret An IDP application client secret
|
|
||||||
ClientSecret string
|
|
||||||
// Audience An Audience for to authorization validation
|
|
||||||
Audience string
|
|
||||||
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
|
|
||||||
TokenEndpoint string
|
|
||||||
// AuthorizationEndpoint is the endpoint of an IDP manager where clients can obtain authorization code
|
|
||||||
AuthorizationEndpoint string
|
|
||||||
// Scopes provides the scopes to be included in the token request
|
|
||||||
Scope string
|
|
||||||
// RedirectURL handles authorization code from IDP manager
|
|
||||||
RedirectURLs []string
|
|
||||||
// UseIDToken indicates if the id token should be used for authentication
|
|
||||||
UseIDToken bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it
|
|
||||||
func GetPKCEAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL) (PKCEAuthorizationFlow, error) {
|
|
||||||
// validate our peer's Wireguard PRIVATE key
|
|
||||||
myPrivateKey, err := wgtypes.ParseKey(privateKey)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
|
|
||||||
return PKCEAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var mgmTLSEnabled bool
|
|
||||||
if mgmURL.Scheme == "https" {
|
|
||||||
mgmTLSEnabled = true
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("connecting to Management Service %s", mgmURL.String())
|
|
||||||
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed connecting to Management Service %s %v", mgmURL.String(), err)
|
|
||||||
return PKCEAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
log.Debugf("connected to the Management service %s", mgmURL.String())
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
err = mgmClient.Close()
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed to close the Management service client %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
serverKey, err := mgmClient.GetServerPublicKey()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
|
||||||
return PKCEAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
protoPKCEAuthorizationFlow, err := mgmClient.GetPKCEAuthorizationFlow(*serverKey)
|
|
||||||
if err != nil {
|
|
||||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
|
||||||
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
|
|
||||||
return PKCEAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
log.Errorf("failed to retrieve pkce flow: %v", err)
|
|
||||||
return PKCEAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
authFlow := PKCEAuthorizationFlow{
|
|
||||||
ProviderConfig: PKCEAuthProviderConfig{
|
|
||||||
Audience: protoPKCEAuthorizationFlow.GetProviderConfig().GetAudience(),
|
|
||||||
ClientID: protoPKCEAuthorizationFlow.GetProviderConfig().GetClientID(),
|
|
||||||
ClientSecret: protoPKCEAuthorizationFlow.GetProviderConfig().GetClientSecret(),
|
|
||||||
TokenEndpoint: protoPKCEAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(),
|
|
||||||
AuthorizationEndpoint: protoPKCEAuthorizationFlow.GetProviderConfig().GetAuthorizationEndpoint(),
|
|
||||||
Scope: protoPKCEAuthorizationFlow.GetProviderConfig().GetScope(),
|
|
||||||
RedirectURLs: protoPKCEAuthorizationFlow.GetProviderConfig().GetRedirectURLs(),
|
|
||||||
UseIDToken: protoPKCEAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
err = isPKCEProviderConfigValid(authFlow.ProviderConfig)
|
|
||||||
if err != nil {
|
|
||||||
return PKCEAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return authFlow, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func isPKCEProviderConfigValid(config PKCEAuthProviderConfig) error {
|
|
||||||
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
|
||||||
if config.Audience == "" {
|
|
||||||
return fmt.Errorf(errorMSGFormat, "Audience")
|
|
||||||
}
|
|
||||||
if config.ClientID == "" {
|
|
||||||
return fmt.Errorf(errorMSGFormat, "Client ID")
|
|
||||||
}
|
|
||||||
if config.TokenEndpoint == "" {
|
|
||||||
return fmt.Errorf(errorMSGFormat, "Token Endpoint")
|
|
||||||
}
|
|
||||||
if config.AuthorizationEndpoint == "" {
|
|
||||||
return fmt.Errorf(errorMSGFormat, "Authorization Auth Endpoint")
|
|
||||||
}
|
|
||||||
if config.Scope == "" {
|
|
||||||
return fmt.Errorf(errorMSGFormat, "PKCE Auth Scopes")
|
|
||||||
}
|
|
||||||
if config.RedirectURLs == nil {
|
|
||||||
return fmt.Errorf(errorMSGFormat, "PKCE Redirect URLs")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -71,7 +71,7 @@ func (c *clientNetwork) getRouterPeerStatuses() map[string]routerPeerStatus {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]routerPeerStatus) string {
|
func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]routerPeerStatus) string {
|
||||||
chosen := ""
|
var chosen string
|
||||||
chosenScore := 0
|
chosenScore := 0
|
||||||
|
|
||||||
currID := ""
|
currID := ""
|
||||||
@@ -85,26 +85,17 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]ro
|
|||||||
if !found || !peerStatus.connected {
|
if !found || !peerStatus.connected {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.Metric < route.MaxMetric {
|
if r.Metric < route.MaxMetric {
|
||||||
metricDiff := route.MaxMetric - r.Metric
|
metricDiff := route.MaxMetric - r.Metric
|
||||||
tempScore = metricDiff * 10
|
tempScore = metricDiff * 10
|
||||||
}
|
}
|
||||||
|
|
||||||
if !peerStatus.relayed {
|
if !peerStatus.relayed {
|
||||||
tempScore++
|
tempScore++
|
||||||
}
|
}
|
||||||
|
if !peerStatus.direct {
|
||||||
if peerStatus.direct {
|
|
||||||
tempScore++
|
tempScore++
|
||||||
}
|
}
|
||||||
|
if tempScore > chosenScore || (tempScore == chosenScore && currID == r.ID) {
|
||||||
if tempScore > chosenScore || (tempScore == chosenScore && r.ID == currID) {
|
|
||||||
chosen = r.ID
|
|
||||||
chosenScore = tempScore
|
|
||||||
}
|
|
||||||
|
|
||||||
if chosen == "" && currID == "" {
|
|
||||||
chosen = r.ID
|
chosen = r.ID
|
||||||
chosenScore = tempScore
|
chosenScore = tempScore
|
||||||
}
|
}
|
||||||
@@ -115,9 +106,7 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]ro
|
|||||||
for _, r := range c.routes {
|
for _, r := range c.routes {
|
||||||
peers = append(peers, r.Peer)
|
peers = append(peers, r.Peer)
|
||||||
}
|
}
|
||||||
|
log.Warnf("no route was chosen for network %s because no peers from list %s were connected", c.network, peers)
|
||||||
log.Warnf("the network %s has not been assigned a routing peer as no peers from the list %s are currently connected", c.network, peers)
|
|
||||||
|
|
||||||
} else if chosen != currID {
|
} else if chosen != currID {
|
||||||
log.Infof("new chosen route is %s with peer %s with score %d", chosen, c.routes[chosen].Peer, chosenScore)
|
log.Infof("new chosen route is %s with peer %s with score %d", chosen, c.routes[chosen].Peer, chosenScore)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,199 +0,0 @@
|
|||||||
package routemanager
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/netip"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/route"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestGetBestrouteFromStatuses(t *testing.T) {
|
|
||||||
|
|
||||||
testCases := []struct {
|
|
||||||
name string
|
|
||||||
statuses map[string]routerPeerStatus
|
|
||||||
expectedRouteID string
|
|
||||||
currentRoute *route.Route
|
|
||||||
existingRoutes map[string]*route.Route
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "one route",
|
|
||||||
statuses: map[string]routerPeerStatus{
|
|
||||||
"route1": {
|
|
||||||
connected: true,
|
|
||||||
relayed: false,
|
|
||||||
direct: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
existingRoutes: map[string]*route.Route{
|
|
||||||
"route1": {
|
|
||||||
ID: "route1",
|
|
||||||
Metric: route.MaxMetric,
|
|
||||||
Peer: "peer1",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
currentRoute: nil,
|
|
||||||
expectedRouteID: "route1",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "one connected routes with relayed and direct",
|
|
||||||
statuses: map[string]routerPeerStatus{
|
|
||||||
"route1": {
|
|
||||||
connected: true,
|
|
||||||
relayed: true,
|
|
||||||
direct: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
existingRoutes: map[string]*route.Route{
|
|
||||||
"route1": {
|
|
||||||
ID: "route1",
|
|
||||||
Metric: route.MaxMetric,
|
|
||||||
Peer: "peer1",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
currentRoute: nil,
|
|
||||||
expectedRouteID: "route1",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "one connected routes with relayed and no direct",
|
|
||||||
statuses: map[string]routerPeerStatus{
|
|
||||||
"route1": {
|
|
||||||
connected: true,
|
|
||||||
relayed: true,
|
|
||||||
direct: false,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
existingRoutes: map[string]*route.Route{
|
|
||||||
"route1": {
|
|
||||||
ID: "route1",
|
|
||||||
Metric: route.MaxMetric,
|
|
||||||
Peer: "peer1",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
currentRoute: nil,
|
|
||||||
expectedRouteID: "route1",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "no connected peers",
|
|
||||||
statuses: map[string]routerPeerStatus{
|
|
||||||
"route1": {
|
|
||||||
connected: false,
|
|
||||||
relayed: false,
|
|
||||||
direct: false,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
existingRoutes: map[string]*route.Route{
|
|
||||||
"route1": {
|
|
||||||
ID: "route1",
|
|
||||||
Metric: route.MaxMetric,
|
|
||||||
Peer: "peer1",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
currentRoute: nil,
|
|
||||||
expectedRouteID: "",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "multiple connected peers with different metrics",
|
|
||||||
statuses: map[string]routerPeerStatus{
|
|
||||||
"route1": {
|
|
||||||
connected: true,
|
|
||||||
relayed: false,
|
|
||||||
direct: true,
|
|
||||||
},
|
|
||||||
"route2": {
|
|
||||||
connected: true,
|
|
||||||
relayed: false,
|
|
||||||
direct: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
existingRoutes: map[string]*route.Route{
|
|
||||||
"route1": {
|
|
||||||
ID: "route1",
|
|
||||||
Metric: 9000,
|
|
||||||
Peer: "peer1",
|
|
||||||
},
|
|
||||||
"route2": {
|
|
||||||
ID: "route2",
|
|
||||||
Metric: route.MaxMetric,
|
|
||||||
Peer: "peer2",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
currentRoute: nil,
|
|
||||||
expectedRouteID: "route1",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "multiple connected peers with one relayed",
|
|
||||||
statuses: map[string]routerPeerStatus{
|
|
||||||
"route1": {
|
|
||||||
connected: true,
|
|
||||||
relayed: false,
|
|
||||||
direct: true,
|
|
||||||
},
|
|
||||||
"route2": {
|
|
||||||
connected: true,
|
|
||||||
relayed: true,
|
|
||||||
direct: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
existingRoutes: map[string]*route.Route{
|
|
||||||
"route1": {
|
|
||||||
ID: "route1",
|
|
||||||
Metric: route.MaxMetric,
|
|
||||||
Peer: "peer1",
|
|
||||||
},
|
|
||||||
"route2": {
|
|
||||||
ID: "route2",
|
|
||||||
Metric: route.MaxMetric,
|
|
||||||
Peer: "peer2",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
currentRoute: nil,
|
|
||||||
expectedRouteID: "route1",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "multiple connected peers with one direct",
|
|
||||||
statuses: map[string]routerPeerStatus{
|
|
||||||
"route1": {
|
|
||||||
connected: true,
|
|
||||||
relayed: false,
|
|
||||||
direct: true,
|
|
||||||
},
|
|
||||||
"route2": {
|
|
||||||
connected: true,
|
|
||||||
relayed: false,
|
|
||||||
direct: false,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
existingRoutes: map[string]*route.Route{
|
|
||||||
"route1": {
|
|
||||||
ID: "route1",
|
|
||||||
Metric: route.MaxMetric,
|
|
||||||
Peer: "peer1",
|
|
||||||
},
|
|
||||||
"route2": {
|
|
||||||
ID: "route2",
|
|
||||||
Metric: route.MaxMetric,
|
|
||||||
Peer: "peer2",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
currentRoute: nil,
|
|
||||||
expectedRouteID: "route1",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
// create new clientNetwork
|
|
||||||
client := &clientNetwork{
|
|
||||||
network: netip.MustParsePrefix("192.168.0.0/24"),
|
|
||||||
routes: tc.existingRoutes,
|
|
||||||
chosenRoute: tc.currentRoute,
|
|
||||||
}
|
|
||||||
|
|
||||||
chosenRoute := client.getBestRouteFromStatuses(tc.statuses)
|
|
||||||
if chosenRoute != tc.expectedRouteID {
|
|
||||||
t.Errorf("expected routeID %s, got %s", tc.expectedRouteID, chosenRoute)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -6,6 +6,8 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/coreos/go-iptables/iptables"
|
||||||
|
"github.com/google/nftables"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -28,13 +30,33 @@ func genKey(format string, input string) string {
|
|||||||
|
|
||||||
// NewFirewall if supported, returns an iptables manager, otherwise returns a nftables manager
|
// NewFirewall if supported, returns an iptables manager, otherwise returns a nftables manager
|
||||||
func NewFirewall(parentCTX context.Context) firewallManager {
|
func NewFirewall(parentCTX context.Context) firewallManager {
|
||||||
manager, err := newNFTablesManager(parentCTX)
|
ctx, cancel := context.WithCancel(parentCTX)
|
||||||
if err == nil {
|
|
||||||
log.Debugf("nftables firewall manager will be used")
|
if isIptablesSupported() {
|
||||||
return manager
|
log.Debugf("iptables is supported")
|
||||||
|
ipv4Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
|
ipv6Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
||||||
|
|
||||||
|
return &iptablesManager{
|
||||||
|
ctx: ctx,
|
||||||
|
stop: cancel,
|
||||||
|
ipv4Client: ipv4Client,
|
||||||
|
ipv6Client: ipv6Client,
|
||||||
|
rules: make(map[string]map[string][]string),
|
||||||
}
|
}
|
||||||
log.Debugf("fallback to iptables firewall manager: %s", err)
|
}
|
||||||
return newIptablesManager(parentCTX)
|
|
||||||
|
log.Debugf("iptables is not supported, using nftables")
|
||||||
|
|
||||||
|
manager := &nftablesManager{
|
||||||
|
ctx: ctx,
|
||||||
|
stop: cancel,
|
||||||
|
conn: &nftables.Conn{},
|
||||||
|
chains: make(map[string]map[string]*nftables.Chain),
|
||||||
|
rules: make(map[string]*nftables.Rule),
|
||||||
|
}
|
||||||
|
|
||||||
|
return manager
|
||||||
}
|
}
|
||||||
|
|
||||||
func getInPair(pair routerPair) routerPair {
|
func getInPair(pair routerPair) routerPair {
|
||||||
|
|||||||
@@ -49,28 +49,6 @@ type iptablesManager struct {
|
|||||||
mux sync.Mutex
|
mux sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func newIptablesManager(parentCtx context.Context) *iptablesManager {
|
|
||||||
ctx, cancel := context.WithCancel(parentCtx)
|
|
||||||
ipv4Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
|
||||||
if !isIptablesClientAvailable(ipv4Client) {
|
|
||||||
log.Infof("iptables is missing for ipv4")
|
|
||||||
ipv4Client = nil
|
|
||||||
}
|
|
||||||
ipv6Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
|
||||||
if !isIptablesClientAvailable(ipv6Client) {
|
|
||||||
log.Infof("iptables is missing for ipv6")
|
|
||||||
ipv6Client = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return &iptablesManager{
|
|
||||||
ctx: ctx,
|
|
||||||
stop: cancel,
|
|
||||||
ipv4Client: ipv4Client,
|
|
||||||
ipv6Client: ipv6Client,
|
|
||||||
rules: make(map[string]map[string][]string),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// CleanRoutingRules cleans existing iptables resources that we created by the agent
|
// CleanRoutingRules cleans existing iptables resources that we created by the agent
|
||||||
func (i *iptablesManager) CleanRoutingRules() {
|
func (i *iptablesManager) CleanRoutingRules() {
|
||||||
i.mux.Lock()
|
i.mux.Lock()
|
||||||
@@ -83,7 +61,6 @@ func (i *iptablesManager) CleanRoutingRules() {
|
|||||||
|
|
||||||
log.Debug("flushing tables")
|
log.Debug("flushing tables")
|
||||||
errMSGFormat := "iptables: failed cleaning %s chain %s,error: %v"
|
errMSGFormat := "iptables: failed cleaning %s chain %s,error: %v"
|
||||||
if i.ipv4Client != nil {
|
|
||||||
err = i.ipv4Client.ClearAndDeleteChain(iptablesFilterTable, iptablesRoutingForwardingChain)
|
err = i.ipv4Client.ClearAndDeleteChain(iptablesFilterTable, iptablesRoutingForwardingChain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf(errMSGFormat, ipv4, iptablesRoutingForwardingChain, err)
|
log.Errorf(errMSGFormat, ipv4, iptablesRoutingForwardingChain, err)
|
||||||
@@ -93,9 +70,7 @@ func (i *iptablesManager) CleanRoutingRules() {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf(errMSGFormat, ipv4, iptablesRoutingNatChain, err)
|
log.Errorf(errMSGFormat, ipv4, iptablesRoutingNatChain, err)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if i.ipv6Client != nil {
|
|
||||||
err = i.ipv6Client.ClearAndDeleteChain(iptablesFilterTable, iptablesRoutingForwardingChain)
|
err = i.ipv6Client.ClearAndDeleteChain(iptablesFilterTable, iptablesRoutingForwardingChain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf(errMSGFormat, ipv6, iptablesRoutingForwardingChain, err)
|
log.Errorf(errMSGFormat, ipv6, iptablesRoutingForwardingChain, err)
|
||||||
@@ -105,7 +80,6 @@ func (i *iptablesManager) CleanRoutingRules() {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf(errMSGFormat, ipv6, iptablesRoutingNatChain, err)
|
log.Errorf(errMSGFormat, ipv6, iptablesRoutingNatChain, err)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
log.Info("done cleaning up iptables rules")
|
log.Info("done cleaning up iptables rules")
|
||||||
}
|
}
|
||||||
@@ -122,7 +96,6 @@ func (i *iptablesManager) RestoreOrCreateContainers() error {
|
|||||||
|
|
||||||
errMSGFormat := "iptables: failed creating %s chain %s,error: %v"
|
errMSGFormat := "iptables: failed creating %s chain %s,error: %v"
|
||||||
|
|
||||||
if i.ipv4Client != nil {
|
|
||||||
err := createChain(i.ipv4Client, iptablesFilterTable, iptablesRoutingForwardingChain)
|
err := createChain(i.ipv4Client, iptablesFilterTable, iptablesRoutingForwardingChain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf(errMSGFormat, ipv4, iptablesRoutingForwardingChain, err)
|
return fmt.Errorf(errMSGFormat, ipv4, iptablesRoutingForwardingChain, err)
|
||||||
@@ -133,14 +106,7 @@ func (i *iptablesManager) RestoreOrCreateContainers() error {
|
|||||||
return fmt.Errorf(errMSGFormat, ipv4, iptablesRoutingNatChain, err)
|
return fmt.Errorf(errMSGFormat, ipv4, iptablesRoutingNatChain, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = i.restoreRules(i.ipv4Client)
|
err = createChain(i.ipv6Client, iptablesFilterTable, iptablesRoutingForwardingChain)
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("iptables: error while restoring ipv4 rules: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if i.ipv6Client != nil {
|
|
||||||
err := createChain(i.ipv6Client, iptablesFilterTable, iptablesRoutingForwardingChain)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf(errMSGFormat, ipv6, iptablesRoutingForwardingChain, err)
|
return fmt.Errorf(errMSGFormat, ipv6, iptablesRoutingForwardingChain, err)
|
||||||
}
|
}
|
||||||
@@ -150,13 +116,17 @@ func (i *iptablesManager) RestoreOrCreateContainers() error {
|
|||||||
return fmt.Errorf(errMSGFormat, ipv6, iptablesRoutingNatChain, err)
|
return fmt.Errorf(errMSGFormat, ipv6, iptablesRoutingNatChain, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = i.restoreRules(i.ipv4Client)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("iptables: error while restoring ipv4 rules: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
err = i.restoreRules(i.ipv6Client)
|
err = i.restoreRules(i.ipv6Client)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("iptables: error while restoring ipv6 rules: %v", err)
|
return fmt.Errorf("iptables: error while restoring ipv6 rules: %v", err)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
err := i.addJumpRules()
|
err = i.addJumpRules()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("iptables: error while creating jump rules: %v", err)
|
return fmt.Errorf("iptables: error while creating jump rules: %v", err)
|
||||||
}
|
}
|
||||||
@@ -170,13 +140,12 @@ func (i *iptablesManager) addJumpRules() error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if i.ipv4Client != nil {
|
|
||||||
rule := append(iptablesDefaultForwardingRule, ipv4Forwarding)
|
rule := append(iptablesDefaultForwardingRule, ipv4Forwarding)
|
||||||
|
|
||||||
err = i.ipv4Client.Insert(iptablesFilterTable, iptablesForwardChain, 1, rule...)
|
err = i.ipv4Client.Insert(iptablesFilterTable, iptablesForwardChain, 1, rule...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
i.rules[ipv4][ipv4Forwarding] = rule
|
i.rules[ipv4][ipv4Forwarding] = rule
|
||||||
|
|
||||||
rule = append(iptablesDefaultNatRule, ipv4Nat)
|
rule = append(iptablesDefaultNatRule, ipv4Nat)
|
||||||
@@ -185,10 +154,8 @@ func (i *iptablesManager) addJumpRules() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
i.rules[ipv4][ipv4Nat] = rule
|
i.rules[ipv4][ipv4Nat] = rule
|
||||||
}
|
|
||||||
|
|
||||||
if i.ipv6Client != nil {
|
rule = append(iptablesDefaultForwardingRule, ipv6Forwarding)
|
||||||
rule := append(iptablesDefaultForwardingRule, ipv6Forwarding)
|
|
||||||
err = i.ipv6Client.Insert(iptablesFilterTable, iptablesForwardChain, 1, rule...)
|
err = i.ipv6Client.Insert(iptablesFilterTable, iptablesForwardChain, 1, rule...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -201,7 +168,6 @@ func (i *iptablesManager) addJumpRules() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
i.rules[ipv6][ipv6Nat] = rule
|
i.rules[ipv6][ipv6Nat] = rule
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -211,7 +177,6 @@ func (i *iptablesManager) cleanJumpRules() error {
|
|||||||
var err error
|
var err error
|
||||||
errMSGFormat := "iptables: failed cleaning rule from %s chain %s,err: %v"
|
errMSGFormat := "iptables: failed cleaning rule from %s chain %s,err: %v"
|
||||||
rule, found := i.rules[ipv4][ipv4Forwarding]
|
rule, found := i.rules[ipv4][ipv4Forwarding]
|
||||||
if i.ipv4Client != nil {
|
|
||||||
if found {
|
if found {
|
||||||
log.Debugf("iptables: removing %s rule: %s ", ipv4, ipv4Forwarding)
|
log.Debugf("iptables: removing %s rule: %s ", ipv4, ipv4Forwarding)
|
||||||
err = i.ipv4Client.DeleteIfExists(iptablesFilterTable, iptablesForwardChain, rule...)
|
err = i.ipv4Client.DeleteIfExists(iptablesFilterTable, iptablesForwardChain, rule...)
|
||||||
@@ -227,8 +192,6 @@ func (i *iptablesManager) cleanJumpRules() error {
|
|||||||
return fmt.Errorf(errMSGFormat, ipv4, iptablesPostRoutingChain, err)
|
return fmt.Errorf(errMSGFormat, ipv4, iptablesPostRoutingChain, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
if i.ipv6Client == nil {
|
|
||||||
rule, found = i.rules[ipv6][ipv6Forwarding]
|
rule, found = i.rules[ipv6][ipv6Forwarding]
|
||||||
if found {
|
if found {
|
||||||
log.Debugf("iptables: removing %s rule: %s ", ipv6, ipv6Forwarding)
|
log.Debugf("iptables: removing %s rule: %s ", ipv6, ipv6Forwarding)
|
||||||
@@ -245,7 +208,6 @@ func (i *iptablesManager) cleanJumpRules() error {
|
|||||||
return fmt.Errorf(errMSGFormat, ipv6, iptablesPostRoutingChain, err)
|
return fmt.Errorf(errMSGFormat, ipv6, iptablesPostRoutingChain, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -475,8 +437,3 @@ func getIptablesRuleType(table string) string {
|
|||||||
}
|
}
|
||||||
return ruleType
|
return ruleType
|
||||||
}
|
}
|
||||||
|
|
||||||
func isIptablesClientAvailable(client *iptables.IPTables) bool {
|
|
||||||
_, err := client.ListChains("filter")
|
|
||||||
return err == nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -16,7 +16,17 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
|||||||
t.SkipNow()
|
t.SkipNow()
|
||||||
}
|
}
|
||||||
|
|
||||||
manager := newIptablesManager(context.TODO())
|
ctx, cancel := context.WithCancel(context.TODO())
|
||||||
|
ipv4Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
|
ipv6Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
||||||
|
|
||||||
|
manager := &iptablesManager{
|
||||||
|
ctx: ctx,
|
||||||
|
stop: cancel,
|
||||||
|
ipv4Client: ipv4Client,
|
||||||
|
ipv6Client: ipv6Client,
|
||||||
|
rules: make(map[string]map[string][]string),
|
||||||
|
}
|
||||||
|
|
||||||
defer manager.CleanRoutingRules()
|
defer manager.CleanRoutingRules()
|
||||||
|
|
||||||
@@ -27,21 +37,21 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
|||||||
|
|
||||||
require.Len(t, manager.rules[ipv4], 2, "should have created minimal rules for ipv4")
|
require.Len(t, manager.rules[ipv4], 2, "should have created minimal rules for ipv4")
|
||||||
|
|
||||||
exists, err := manager.ipv4Client.Exists(iptablesFilterTable, iptablesForwardChain, manager.rules[ipv4][ipv4Forwarding]...)
|
exists, err := ipv4Client.Exists(iptablesFilterTable, iptablesForwardChain, manager.rules[ipv4][ipv4Forwarding]...)
|
||||||
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv4, iptablesFilterTable, iptablesForwardChain)
|
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv4, iptablesFilterTable, iptablesForwardChain)
|
||||||
require.True(t, exists, "forwarding rule should exist")
|
require.True(t, exists, "forwarding rule should exist")
|
||||||
|
|
||||||
exists, err = manager.ipv4Client.Exists(iptablesNatTable, iptablesPostRoutingChain, manager.rules[ipv4][ipv4Nat]...)
|
exists, err = ipv4Client.Exists(iptablesNatTable, iptablesPostRoutingChain, manager.rules[ipv4][ipv4Nat]...)
|
||||||
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv4, iptablesNatTable, iptablesPostRoutingChain)
|
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv4, iptablesNatTable, iptablesPostRoutingChain)
|
||||||
require.True(t, exists, "postrouting rule should exist")
|
require.True(t, exists, "postrouting rule should exist")
|
||||||
|
|
||||||
require.Len(t, manager.rules[ipv6], 2, "should have created minimal rules for ipv6")
|
require.Len(t, manager.rules[ipv6], 2, "should have created minimal rules for ipv6")
|
||||||
|
|
||||||
exists, err = manager.ipv6Client.Exists(iptablesFilterTable, iptablesForwardChain, manager.rules[ipv6][ipv6Forwarding]...)
|
exists, err = ipv6Client.Exists(iptablesFilterTable, iptablesForwardChain, manager.rules[ipv6][ipv6Forwarding]...)
|
||||||
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv6, iptablesFilterTable, iptablesForwardChain)
|
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv6, iptablesFilterTable, iptablesForwardChain)
|
||||||
require.True(t, exists, "forwarding rule should exist")
|
require.True(t, exists, "forwarding rule should exist")
|
||||||
|
|
||||||
exists, err = manager.ipv6Client.Exists(iptablesNatTable, iptablesPostRoutingChain, manager.rules[ipv6][ipv6Nat]...)
|
exists, err = ipv6Client.Exists(iptablesNatTable, iptablesPostRoutingChain, manager.rules[ipv6][ipv6Nat]...)
|
||||||
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv6, iptablesNatTable, iptablesPostRoutingChain)
|
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv6, iptablesNatTable, iptablesPostRoutingChain)
|
||||||
require.True(t, exists, "postrouting rule should exist")
|
require.True(t, exists, "postrouting rule should exist")
|
||||||
|
|
||||||
@@ -54,13 +64,13 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
|||||||
forward4RuleKey := genKey(forwardingFormat, pair.ID)
|
forward4RuleKey := genKey(forwardingFormat, pair.ID)
|
||||||
forward4Rule := genRuleSpec(routingFinalForwardJump, forward4RuleKey, pair.source, pair.destination)
|
forward4Rule := genRuleSpec(routingFinalForwardJump, forward4RuleKey, pair.source, pair.destination)
|
||||||
|
|
||||||
err = manager.ipv4Client.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forward4Rule...)
|
err = ipv4Client.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forward4Rule...)
|
||||||
require.NoError(t, err, "inserting rule should not return error")
|
require.NoError(t, err, "inserting rule should not return error")
|
||||||
|
|
||||||
nat4RuleKey := genKey(natFormat, pair.ID)
|
nat4RuleKey := genKey(natFormat, pair.ID)
|
||||||
nat4Rule := genRuleSpec(routingFinalNatJump, nat4RuleKey, pair.source, pair.destination)
|
nat4Rule := genRuleSpec(routingFinalNatJump, nat4RuleKey, pair.source, pair.destination)
|
||||||
|
|
||||||
err = manager.ipv4Client.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, nat4Rule...)
|
err = ipv4Client.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, nat4Rule...)
|
||||||
require.NoError(t, err, "inserting rule should not return error")
|
require.NoError(t, err, "inserting rule should not return error")
|
||||||
|
|
||||||
pair = routerPair{
|
pair = routerPair{
|
||||||
@@ -73,13 +83,13 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
|||||||
forward6RuleKey := genKey(forwardingFormat, pair.ID)
|
forward6RuleKey := genKey(forwardingFormat, pair.ID)
|
||||||
forward6Rule := genRuleSpec(routingFinalForwardJump, forward6RuleKey, pair.source, pair.destination)
|
forward6Rule := genRuleSpec(routingFinalForwardJump, forward6RuleKey, pair.source, pair.destination)
|
||||||
|
|
||||||
err = manager.ipv6Client.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forward6Rule...)
|
err = ipv6Client.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forward6Rule...)
|
||||||
require.NoError(t, err, "inserting rule should not return error")
|
require.NoError(t, err, "inserting rule should not return error")
|
||||||
|
|
||||||
nat6RuleKey := genKey(natFormat, pair.ID)
|
nat6RuleKey := genKey(natFormat, pair.ID)
|
||||||
nat6Rule := genRuleSpec(routingFinalNatJump, nat6RuleKey, pair.source, pair.destination)
|
nat6Rule := genRuleSpec(routingFinalNatJump, nat6RuleKey, pair.source, pair.destination)
|
||||||
|
|
||||||
err = manager.ipv6Client.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, nat6Rule...)
|
err = ipv6Client.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, nat6Rule...)
|
||||||
require.NoError(t, err, "inserting rule should not return error")
|
require.NoError(t, err, "inserting rule should not return error")
|
||||||
|
|
||||||
delete(manager.rules, ipv4)
|
delete(manager.rules, ipv4)
|
||||||
|
|||||||
@@ -16,8 +16,6 @@ import (
|
|||||||
// Manager is a route manager interface
|
// Manager is a route manager interface
|
||||||
type Manager interface {
|
type Manager interface {
|
||||||
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error
|
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error
|
||||||
SetRouteChangeListener(listener RouteListener)
|
|
||||||
InitialRouteRange() []string
|
|
||||||
Stop()
|
Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -31,14 +29,12 @@ type DefaultManager struct {
|
|||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
wgInterface *iface.WGIface
|
wgInterface *iface.WGIface
|
||||||
pubKey string
|
pubKey string
|
||||||
notifier *notifier
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewManager returns a new route manager
|
// NewManager returns a new route manager
|
||||||
func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status, initialRoutes []*route.Route) *DefaultManager {
|
func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status) *DefaultManager {
|
||||||
mCTX, cancel := context.WithCancel(ctx)
|
mCTX, cancel := context.WithCancel(ctx)
|
||||||
|
return &DefaultManager{
|
||||||
dm := &DefaultManager{
|
|
||||||
ctx: mCTX,
|
ctx: mCTX,
|
||||||
stop: cancel,
|
stop: cancel,
|
||||||
clientNetworks: make(map[string]*clientNetwork),
|
clientNetworks: make(map[string]*clientNetwork),
|
||||||
@@ -46,21 +42,13 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface,
|
|||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
wgInterface: wgInterface,
|
wgInterface: wgInterface,
|
||||||
pubKey: pubKey,
|
pubKey: pubKey,
|
||||||
notifier: newNotifier(),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if runtime.GOOS == "android" {
|
|
||||||
cr := dm.clientRoutes(initialRoutes)
|
|
||||||
dm.notifier.setInitialClientRoutes(cr)
|
|
||||||
}
|
|
||||||
return dm
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop stops the manager watchers and clean firewall rules
|
// Stop stops the manager watchers and clean firewall rules
|
||||||
func (m *DefaultManager) Stop() {
|
func (m *DefaultManager) Stop() {
|
||||||
m.stop()
|
m.stop()
|
||||||
m.serverRouter.cleanUp()
|
m.serverRouter.cleanUp()
|
||||||
m.ctx = nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateRoutes compares received routes with existing routes and remove, update or add them to the client and server maps
|
// UpdateRoutes compares received routes with existing routes and remove, update or add them to the client and server maps
|
||||||
@@ -73,56 +61,6 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
|
|||||||
m.mux.Lock()
|
m.mux.Lock()
|
||||||
defer m.mux.Unlock()
|
defer m.mux.Unlock()
|
||||||
|
|
||||||
newServerRoutesMap, newClientRoutesIDMap := m.classifiesRoutes(newRoutes)
|
|
||||||
|
|
||||||
m.updateClientNetworks(updateSerial, newClientRoutesIDMap)
|
|
||||||
m.notifier.onNewRoutes(newClientRoutesIDMap)
|
|
||||||
err := m.serverRouter.updateRoutes(newServerRoutesMap)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetRouteChangeListener set RouteListener for route change notifier
|
|
||||||
func (m *DefaultManager) SetRouteChangeListener(listener RouteListener) {
|
|
||||||
m.notifier.setListener(listener)
|
|
||||||
}
|
|
||||||
|
|
||||||
// InitialRouteRange return the list of initial routes. It used by mobile systems
|
|
||||||
func (m *DefaultManager) InitialRouteRange() []string {
|
|
||||||
return m.notifier.initialRouteRanges()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) {
|
|
||||||
// removing routes that do not exist as per the update from the Management service.
|
|
||||||
for id, client := range m.clientNetworks {
|
|
||||||
_, found := networks[id]
|
|
||||||
if !found {
|
|
||||||
log.Debugf("stopping client network watcher, %s", id)
|
|
||||||
client.stop()
|
|
||||||
delete(m.clientNetworks, id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for id, routes := range networks {
|
|
||||||
clientNetworkWatcher, found := m.clientNetworks[id]
|
|
||||||
if !found {
|
|
||||||
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.wgInterface, m.statusRecorder, routes[0].Network)
|
|
||||||
m.clientNetworks[id] = clientNetworkWatcher
|
|
||||||
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
|
||||||
}
|
|
||||||
update := routesUpdate{
|
|
||||||
updateSerial: updateSerial,
|
|
||||||
routes: routes,
|
|
||||||
}
|
|
||||||
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(update)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *DefaultManager) classifiesRoutes(newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route) {
|
|
||||||
newClientRoutesIDMap := make(map[string][]*route.Route)
|
newClientRoutesIDMap := make(map[string][]*route.Route)
|
||||||
newServerRoutesMap := make(map[string]*route.Route)
|
newServerRoutesMap := make(map[string]*route.Route)
|
||||||
ownNetworkIDs := make(map[string]bool)
|
ownNetworkIDs := make(map[string]bool)
|
||||||
@@ -154,14 +92,39 @@ func (m *DefaultManager) classifiesRoutes(newRoutes []*route.Route) (map[string]
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return newServerRoutesMap, newClientRoutesIDMap
|
m.updateClientNetworks(updateSerial, newClientRoutesIDMap)
|
||||||
|
|
||||||
|
err := m.serverRouter.updateRoutes(newServerRoutesMap)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Route {
|
func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) {
|
||||||
_, crMap := m.classifiesRoutes(initialRoutes)
|
// removing routes that do not exist as per the update from the Management service.
|
||||||
rs := make([]*route.Route, 0)
|
for id, client := range m.clientNetworks {
|
||||||
for _, routes := range crMap {
|
_, found := networks[id]
|
||||||
rs = append(rs, routes...)
|
if !found {
|
||||||
|
log.Debugf("stopping client network watcher, %s", id)
|
||||||
|
client.stop()
|
||||||
|
delete(m.clientNetworks, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for id, routes := range networks {
|
||||||
|
clientNetworkWatcher, found := m.clientNetworks[id]
|
||||||
|
if !found {
|
||||||
|
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.wgInterface, m.statusRecorder, routes[0].Network)
|
||||||
|
m.clientNetworks[id] = clientNetworkWatcher
|
||||||
|
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
||||||
|
}
|
||||||
|
update := routesUpdate{
|
||||||
|
updateSerial: updateSerial,
|
||||||
|
routes: routes,
|
||||||
|
}
|
||||||
|
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(update)
|
||||||
}
|
}
|
||||||
return rs
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -397,7 +397,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", iface.DefaultMTU, nil, newNet)
|
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", iface.DefaultMTU, nil, nil, newNet)
|
||||||
require.NoError(t, err, "should create testing WGIface interface")
|
require.NoError(t, err, "should create testing WGIface interface")
|
||||||
defer wgInterface.Close()
|
defer wgInterface.Close()
|
||||||
|
|
||||||
@@ -406,7 +406,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
|||||||
|
|
||||||
statusRecorder := peer.NewRecorder("https://mgm")
|
statusRecorder := peer.NewRecorder("https://mgm")
|
||||||
ctx := context.TODO()
|
ctx := context.TODO()
|
||||||
routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder, nil)
|
routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder)
|
||||||
defer routeManager.Stop()
|
defer routeManager.Stop()
|
||||||
|
|
||||||
if len(testCase.inputInitRoutes) > 0 {
|
if len(testCase.inputInitRoutes) > 0 {
|
||||||
|
|||||||
@@ -1,10 +1,7 @@
|
|||||||
package routemanager
|
package routemanager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -14,11 +11,6 @@ type MockManager struct {
|
|||||||
StopFunc func()
|
StopFunc func()
|
||||||
}
|
}
|
||||||
|
|
||||||
// InitialRouteRange mock implementation of InitialRouteRange from Manager interface
|
|
||||||
func (m *MockManager) InitialRouteRange() []string {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateRoutes mock implementation of UpdateRoutes from Manager interface
|
// UpdateRoutes mock implementation of UpdateRoutes from Manager interface
|
||||||
func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
|
func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
|
||||||
if m.UpdateRoutesFunc != nil {
|
if m.UpdateRoutesFunc != nil {
|
||||||
@@ -27,15 +19,6 @@ func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route
|
|||||||
return fmt.Errorf("method UpdateRoutes is not implemented")
|
return fmt.Errorf("method UpdateRoutes is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start mock implementation of Start from Manager interface
|
|
||||||
func (m *MockManager) Start(ctx context.Context, iface *iface.WGIface) {
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetRouteChangeListener mock implementation of SetRouteChangeListener from Manager interface
|
|
||||||
func (m *MockManager) SetRouteChangeListener(listener RouteListener) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stop mock implementation of Stop from Manager interface
|
// Stop mock implementation of Stop from Manager interface
|
||||||
func (m *MockManager) Stop() {
|
func (m *MockManager) Stop() {
|
||||||
if m.StopFunc != nil {
|
if m.StopFunc != nil {
|
||||||
|
|||||||
@@ -19,9 +19,6 @@ const (
|
|||||||
nftablesTable = "netbird-rt"
|
nftablesTable = "netbird-rt"
|
||||||
nftablesRoutingForwardingChain = "netbird-rt-fwd"
|
nftablesRoutingForwardingChain = "netbird-rt-fwd"
|
||||||
nftablesRoutingNatChain = "netbird-rt-nat"
|
nftablesRoutingNatChain = "netbird-rt-nat"
|
||||||
|
|
||||||
userDataAcceptForwardRuleSrc = "frwacceptsrc"
|
|
||||||
userDataAcceptForwardRuleDst = "frwacceptdst"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// constants needed to create nftable rules
|
// constants needed to create nftable rules
|
||||||
@@ -81,36 +78,9 @@ type nftablesManager struct {
|
|||||||
tableIPv6 *nftables.Table
|
tableIPv6 *nftables.Table
|
||||||
chains map[string]map[string]*nftables.Chain
|
chains map[string]map[string]*nftables.Chain
|
||||||
rules map[string]*nftables.Rule
|
rules map[string]*nftables.Rule
|
||||||
filterTable *nftables.Table
|
|
||||||
defaultForwardRules []*nftables.Rule
|
|
||||||
mux sync.Mutex
|
mux sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func newNFTablesManager(parentCtx context.Context) (*nftablesManager, error) {
|
|
||||||
ctx, cancel := context.WithCancel(parentCtx)
|
|
||||||
|
|
||||||
mgr := &nftablesManager{
|
|
||||||
ctx: ctx,
|
|
||||||
stop: cancel,
|
|
||||||
conn: &nftables.Conn{},
|
|
||||||
chains: make(map[string]map[string]*nftables.Chain),
|
|
||||||
rules: make(map[string]*nftables.Rule),
|
|
||||||
defaultForwardRules: make([]*nftables.Rule, 2),
|
|
||||||
}
|
|
||||||
|
|
||||||
err := mgr.isSupported()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = mgr.readFilterTable()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return mgr, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CleanRoutingRules cleans existing nftables rules from the system
|
// CleanRoutingRules cleans existing nftables rules from the system
|
||||||
func (n *nftablesManager) CleanRoutingRules() {
|
func (n *nftablesManager) CleanRoutingRules() {
|
||||||
n.mux.Lock()
|
n.mux.Lock()
|
||||||
@@ -120,13 +90,6 @@ func (n *nftablesManager) CleanRoutingRules() {
|
|||||||
n.conn.FlushTable(n.tableIPv6)
|
n.conn.FlushTable(n.tableIPv6)
|
||||||
n.conn.FlushTable(n.tableIPv4)
|
n.conn.FlushTable(n.tableIPv4)
|
||||||
}
|
}
|
||||||
|
|
||||||
if n.defaultForwardRules[0] != nil {
|
|
||||||
err := n.eraseDefaultForwardRule()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to delete forward rule: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
log.Debugf("flushing tables result in: %v error", n.conn.Flush())
|
log.Debugf("flushing tables result in: %v error", n.conn.Flush())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -259,112 +222,6 @@ func (n *nftablesManager) refreshRulesMap() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *nftablesManager) readFilterTable() error {
|
|
||||||
tables, err := n.conn.ListTables()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, t := range tables {
|
|
||||||
if t.Name == "filter" {
|
|
||||||
n.filterTable = t
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *nftablesManager) eraseDefaultForwardRule() error {
|
|
||||||
if n.defaultForwardRules[0] == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
err := n.refreshDefaultForwardRule()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, r := range n.defaultForwardRules {
|
|
||||||
err = n.conn.DelRule(r)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to delete forward rule (%d): %s", i, err)
|
|
||||||
}
|
|
||||||
n.defaultForwardRules[i] = nil
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *nftablesManager) refreshDefaultForwardRule() error {
|
|
||||||
rules, err := n.conn.GetRules(n.defaultForwardRules[0].Table, n.defaultForwardRules[0].Chain)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to list rules in forward chain: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
found := false
|
|
||||||
for i, r := range n.defaultForwardRules {
|
|
||||||
for _, rule := range rules {
|
|
||||||
if string(rule.UserData) == string(r.UserData) {
|
|
||||||
n.defaultForwardRules[i] = rule
|
|
||||||
found = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !found {
|
|
||||||
return fmt.Errorf("unable to find forward accept rule")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *nftablesManager) acceptForwardRule(sourceNetwork string) error {
|
|
||||||
src := generateCIDRMatcherExpressions("source", sourceNetwork)
|
|
||||||
dst := generateCIDRMatcherExpressions("destination", "0.0.0.0/0")
|
|
||||||
|
|
||||||
var exprs []expr.Any
|
|
||||||
exprs = append(src, append(dst, &expr.Verdict{
|
|
||||||
Kind: expr.VerdictAccept,
|
|
||||||
})...)
|
|
||||||
|
|
||||||
r := &nftables.Rule{
|
|
||||||
Table: n.filterTable,
|
|
||||||
Chain: &nftables.Chain{
|
|
||||||
Name: "FORWARD",
|
|
||||||
Table: n.filterTable,
|
|
||||||
Type: nftables.ChainTypeFilter,
|
|
||||||
Hooknum: nftables.ChainHookForward,
|
|
||||||
Priority: nftables.ChainPriorityFilter,
|
|
||||||
},
|
|
||||||
Exprs: exprs,
|
|
||||||
UserData: []byte(userDataAcceptForwardRuleSrc),
|
|
||||||
}
|
|
||||||
|
|
||||||
n.defaultForwardRules[0] = n.conn.AddRule(r)
|
|
||||||
|
|
||||||
src = generateCIDRMatcherExpressions("source", "0.0.0.0/0")
|
|
||||||
dst = generateCIDRMatcherExpressions("destination", sourceNetwork)
|
|
||||||
|
|
||||||
exprs = append(src, append(dst, &expr.Verdict{
|
|
||||||
Kind: expr.VerdictAccept,
|
|
||||||
})...)
|
|
||||||
|
|
||||||
r = &nftables.Rule{
|
|
||||||
Table: n.filterTable,
|
|
||||||
Chain: &nftables.Chain{
|
|
||||||
Name: "FORWARD",
|
|
||||||
Table: n.filterTable,
|
|
||||||
Type: nftables.ChainTypeFilter,
|
|
||||||
Hooknum: nftables.ChainHookForward,
|
|
||||||
Priority: nftables.ChainPriorityFilter,
|
|
||||||
},
|
|
||||||
Exprs: exprs,
|
|
||||||
UserData: []byte(userDataAcceptForwardRuleDst),
|
|
||||||
}
|
|
||||||
|
|
||||||
n.defaultForwardRules[1] = n.conn.AddRule(r)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// checkOrCreateDefaultForwardingRules checks if the default forwarding rules are enabled
|
// checkOrCreateDefaultForwardingRules checks if the default forwarding rules are enabled
|
||||||
func (n *nftablesManager) checkOrCreateDefaultForwardingRules() {
|
func (n *nftablesManager) checkOrCreateDefaultForwardingRules() {
|
||||||
_, foundIPv4 := n.rules[ipv4Forwarding]
|
_, foundIPv4 := n.rules[ipv4Forwarding]
|
||||||
@@ -418,14 +275,6 @@ func (n *nftablesManager) InsertRoutingRules(pair routerPair) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if n.defaultForwardRules[0] == nil && n.filterTable != nil {
|
|
||||||
err = n.acceptForwardRule(pair.source)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("unable to create default forward rule: %s", err)
|
|
||||||
}
|
|
||||||
log.Debugf("default accept forward rule added")
|
|
||||||
}
|
|
||||||
|
|
||||||
err = n.conn.Flush()
|
err = n.conn.Flush()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.destination, err)
|
return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.destination, err)
|
||||||
@@ -506,13 +355,6 @@ func (n *nftablesManager) RemoveRoutingRules(pair routerPair) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(n.rules) == 2 && n.defaultForwardRules[0] != nil {
|
|
||||||
err := n.eraseDefaultForwardRule()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to delte default fwd rule: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err = n.conn.Flush()
|
err = n.conn.Flush()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.destination, err)
|
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.destination, err)
|
||||||
@@ -544,14 +386,6 @@ func (n *nftablesManager) removeRoutingRule(format string, pair routerPair) erro
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *nftablesManager) isSupported() error {
|
|
||||||
_, err := n.conn.ListChains()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("nftables is not supported: %s", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// getPayloadDirectives get expression directives based on ip version and direction
|
// getPayloadDirectives get expression directives based on ip version and direction
|
||||||
func getPayloadDirectives(direction string, isIPv4 bool, isIPv6 bool) (uint32, uint32, []byte) {
|
func getPayloadDirectives(direction string, isIPv4 bool, isIPv6 bool) (uint32, uint32, []byte) {
|
||||||
switch {
|
switch {
|
||||||
|
|||||||
@@ -14,16 +14,21 @@ import (
|
|||||||
|
|
||||||
func TestNftablesManager_RestoreOrCreateContainers(t *testing.T) {
|
func TestNftablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||||
|
|
||||||
manager, err := newNFTablesManager(context.TODO())
|
ctx, cancel := context.WithCancel(context.TODO())
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to create nftables manager: %s", err)
|
manager := &nftablesManager{
|
||||||
|
ctx: ctx,
|
||||||
|
stop: cancel,
|
||||||
|
conn: &nftables.Conn{},
|
||||||
|
chains: make(map[string]map[string]*nftables.Chain),
|
||||||
|
rules: make(map[string]*nftables.Rule),
|
||||||
}
|
}
|
||||||
|
|
||||||
nftablesTestingClient := &nftables.Conn{}
|
nftablesTestingClient := &nftables.Conn{}
|
||||||
|
|
||||||
defer manager.CleanRoutingRules()
|
defer manager.CleanRoutingRules()
|
||||||
|
|
||||||
err = manager.RestoreOrCreateContainers()
|
err := manager.RestoreOrCreateContainers()
|
||||||
require.NoError(t, err, "shouldn't return error")
|
require.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
require.Len(t, manager.chains, 2, "should have created chains for ipv4 and ipv6")
|
require.Len(t, manager.chains, 2, "should have created chains for ipv4 and ipv6")
|
||||||
@@ -129,16 +134,21 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) {
|
|||||||
|
|
||||||
for _, testCase := range insertRuleTestCases {
|
for _, testCase := range insertRuleTestCases {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
manager, err := newNFTablesManager(context.TODO())
|
ctx, cancel := context.WithCancel(context.TODO())
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to create nftables manager: %s", err)
|
manager := &nftablesManager{
|
||||||
|
ctx: ctx,
|
||||||
|
stop: cancel,
|
||||||
|
conn: &nftables.Conn{},
|
||||||
|
chains: make(map[string]map[string]*nftables.Chain),
|
||||||
|
rules: make(map[string]*nftables.Rule),
|
||||||
}
|
}
|
||||||
|
|
||||||
nftablesTestingClient := &nftables.Conn{}
|
nftablesTestingClient := &nftables.Conn{}
|
||||||
|
|
||||||
defer manager.CleanRoutingRules()
|
defer manager.CleanRoutingRules()
|
||||||
|
|
||||||
err = manager.RestoreOrCreateContainers()
|
err := manager.RestoreOrCreateContainers()
|
||||||
require.NoError(t, err, "shouldn't return error")
|
require.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
err = manager.InsertRoutingRules(testCase.inputPair)
|
err = manager.InsertRoutingRules(testCase.inputPair)
|
||||||
@@ -229,16 +239,21 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
|
|||||||
|
|
||||||
for _, testCase := range removeRuleTestCases {
|
for _, testCase := range removeRuleTestCases {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
manager, err := newNFTablesManager(context.TODO())
|
ctx, cancel := context.WithCancel(context.TODO())
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to create nftables manager: %s", err)
|
manager := &nftablesManager{
|
||||||
|
ctx: ctx,
|
||||||
|
stop: cancel,
|
||||||
|
conn: &nftables.Conn{},
|
||||||
|
chains: make(map[string]map[string]*nftables.Chain),
|
||||||
|
rules: make(map[string]*nftables.Rule),
|
||||||
}
|
}
|
||||||
|
|
||||||
nftablesTestingClient := &nftables.Conn{}
|
nftablesTestingClient := &nftables.Conn{}
|
||||||
|
|
||||||
defer manager.CleanRoutingRules()
|
defer manager.CleanRoutingRules()
|
||||||
|
|
||||||
err = manager.RestoreOrCreateContainers()
|
err := manager.RestoreOrCreateContainers()
|
||||||
require.NoError(t, err, "shouldn't return error")
|
require.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
table := manager.tableIPv4
|
table := manager.tableIPv4
|
||||||
|
|||||||
@@ -1,90 +0,0 @@
|
|||||||
package routemanager
|
|
||||||
|
|
||||||
import (
|
|
||||||
"sort"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/route"
|
|
||||||
)
|
|
||||||
|
|
||||||
// RouteListener is a callback interface for mobile system
|
|
||||||
type RouteListener interface {
|
|
||||||
// OnNewRouteSetting invoke when new route setting has been arrived
|
|
||||||
OnNewRouteSetting()
|
|
||||||
}
|
|
||||||
|
|
||||||
type notifier struct {
|
|
||||||
initialRouteRangers []string
|
|
||||||
routeRangers []string
|
|
||||||
|
|
||||||
routeListener RouteListener
|
|
||||||
routeListenerMux sync.Mutex
|
|
||||||
}
|
|
||||||
|
|
||||||
func newNotifier() *notifier {
|
|
||||||
return ¬ifier{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *notifier) setListener(listener RouteListener) {
|
|
||||||
n.routeListenerMux.Lock()
|
|
||||||
defer n.routeListenerMux.Unlock()
|
|
||||||
n.routeListener = listener
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *notifier) setInitialClientRoutes(clientRoutes []*route.Route) {
|
|
||||||
nets := make([]string, 0)
|
|
||||||
for _, r := range clientRoutes {
|
|
||||||
nets = append(nets, r.Network.String())
|
|
||||||
}
|
|
||||||
sort.Strings(nets)
|
|
||||||
n.initialRouteRangers = nets
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *notifier) onNewRoutes(idMap map[string][]*route.Route) {
|
|
||||||
newNets := make([]string, 0)
|
|
||||||
for _, routes := range idMap {
|
|
||||||
for _, r := range routes {
|
|
||||||
newNets = append(newNets, r.Network.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
sort.Strings(newNets)
|
|
||||||
if !n.hasDiff(n.routeRangers, newNets) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
n.routeRangers = newNets
|
|
||||||
|
|
||||||
if !n.hasDiff(n.initialRouteRangers, newNets) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
n.notify()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *notifier) notify() {
|
|
||||||
n.routeListenerMux.Lock()
|
|
||||||
defer n.routeListenerMux.Unlock()
|
|
||||||
if n.routeListener == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
go func(l RouteListener) {
|
|
||||||
l.OnNewRouteSetting()
|
|
||||||
}(n.routeListener)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *notifier) hasDiff(a []string, b []string) bool {
|
|
||||||
if len(a) != len(b) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
for i, v := range a {
|
|
||||||
if v != b[i] {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *notifier) initialRouteRanges() []string {
|
|
||||||
return n.initialRouteRangers
|
|
||||||
}
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user