mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-04 23:19:55 +00:00
Compare commits
3 Commits
nb-interfa
...
feature/bu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
43176fb96c | ||
|
|
d1e4bb0fa3 | ||
|
|
12b0c13511 |
15
.github/ISSUE_TEMPLATE/bug-issue-report.md
vendored
15
.github/ISSUE_TEMPLATE/bug-issue-report.md
vendored
@@ -37,21 +37,16 @@ If yes, which one?
|
|||||||
|
|
||||||
**Debug output**
|
**Debug output**
|
||||||
|
|
||||||
To help us resolve the problem, please attach the following anonymized status output
|
To help us resolve the problem, please attach the following debug output
|
||||||
|
|
||||||
netbird status -dA
|
netbird status -dA
|
||||||
|
|
||||||
Create and upload a debug bundle, and share the returned file key:
|
As well as the file created by
|
||||||
|
|
||||||
netbird debug for 1m -AS -U
|
|
||||||
|
|
||||||
*Uploaded files are automatically deleted after 30 days.*
|
|
||||||
|
|
||||||
|
|
||||||
Alternatively, create the file only and attach it here manually:
|
|
||||||
|
|
||||||
netbird debug for 1m -AS
|
netbird debug for 1m -AS
|
||||||
|
|
||||||
|
|
||||||
|
We advise reviewing the anonymized output for any remaining personal information.
|
||||||
|
|
||||||
**Screenshots**
|
**Screenshots**
|
||||||
|
|
||||||
@@ -62,10 +57,8 @@ If applicable, add screenshots to help explain your problem.
|
|||||||
Add any other context about the problem here.
|
Add any other context about the problem here.
|
||||||
|
|
||||||
**Have you tried these troubleshooting steps?**
|
**Have you tried these troubleshooting steps?**
|
||||||
- [ ] Reviewed [client troubleshooting](https://docs.netbird.io/how-to/troubleshooting-client) (if applicable)
|
|
||||||
- [ ] Checked for newer NetBird versions
|
- [ ] Checked for newer NetBird versions
|
||||||
- [ ] Searched for similar issues on GitHub (including closed ones)
|
- [ ] Searched for similar issues on GitHub (including closed ones)
|
||||||
- [ ] Restarted the NetBird client
|
- [ ] Restarted the NetBird client
|
||||||
- [ ] Disabled other VPN software
|
- [ ] Disabled other VPN software
|
||||||
- [ ] Checked firewall settings
|
- [ ] Checked firewall settings
|
||||||
|
|
||||||
|
|||||||
2
.github/pull_request_template.md
vendored
2
.github/pull_request_template.md
vendored
@@ -13,5 +13,3 @@
|
|||||||
- [ ] It is a refactor
|
- [ ] It is a refactor
|
||||||
- [ ] Created tests that fail without the change (if possible)
|
- [ ] Created tests that fail without the change (if possible)
|
||||||
- [ ] Extended the README / documentation, if necessary
|
- [ ] Extended the README / documentation, if necessary
|
||||||
|
|
||||||
> By submitting this pull request, you confirm that you have read and agree to the terms of the [Contributor License Agreement](https://github.com/netbirdio/netbird/blob/main/CONTRIBUTOR_LICENSE_AGREEMENT.md).
|
|
||||||
|
|||||||
4
.github/workflows/git-town.yml
vendored
4
.github/workflows/git-town.yml
vendored
@@ -16,6 +16,6 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- uses: git-town/action@v1.2.1
|
- uses: git-town/action@v1
|
||||||
with:
|
with:
|
||||||
skip-single-stacks: true
|
skip-single-stacks: true
|
||||||
8
.github/workflows/golang-test-linux.yml
vendored
8
.github/workflows/golang-test-linux.yml
vendored
@@ -223,10 +223,6 @@ jobs:
|
|||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
if: steps.cache.outputs.cache-hit != 'true'
|
|
||||||
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
|
|
||||||
|
|
||||||
- name: Get Go environment
|
- name: Get Go environment
|
||||||
run: |
|
run: |
|
||||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
@@ -273,10 +269,6 @@ jobs:
|
|||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
if: steps.cache.outputs.cache-hit != 'true'
|
|
||||||
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
|
|
||||||
|
|
||||||
- name: Get Go environment
|
- name: Get Go environment
|
||||||
run: |
|
run: |
|
||||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
|
|||||||
1
.github/workflows/golangci-lint.yml
vendored
1
.github/workflows/golangci-lint.yml
vendored
@@ -21,6 +21,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe
|
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe
|
||||||
skip: go.mod,go.sum
|
skip: go.mod,go.sum
|
||||||
|
only_warn: 1
|
||||||
golangci:
|
golangci:
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ jobs:
|
|||||||
- name: gomobile init
|
- name: gomobile init
|
||||||
run: gomobile init
|
run: gomobile init
|
||||||
- name: build android netbird lib
|
- name: build android netbird lib
|
||||||
run: PATH=$PATH:$(go env GOPATH) gomobile bind -o $GITHUB_WORKSPACE/netbird.aar -javapkg=io.netbird.gomobile -ldflags="-checklinkname=0 -X golang.zx2c4.com/wireguard/ipc.socketDirectory=/data/data/io.netbird.client/cache/wireguard -X github.com/netbirdio/netbird/version.version=buildtest" $GITHUB_WORKSPACE/client/android
|
run: PATH=$PATH:$(go env GOPATH) gomobile bind -o $GITHUB_WORKSPACE/netbird.aar -javapkg=io.netbird.gomobile -ldflags="-X golang.zx2c4.com/wireguard/ipc.socketDirectory=/data/data/io.netbird.client/cache/wireguard -X github.com/netbirdio/netbird/version.version=buildtest" $GITHUB_WORKSPACE/client/android
|
||||||
env:
|
env:
|
||||||
CGO_ENABLED: 0
|
CGO_ENABLED: 0
|
||||||
ANDROID_NDK_HOME: /usr/local/lib/android/sdk/ndk/23.1.7779620
|
ANDROID_NDK_HOME: /usr/local/lib/android/sdk/ndk/23.1.7779620
|
||||||
|
|||||||
23
.github/workflows/release.yml
vendored
23
.github/workflows/release.yml
vendored
@@ -9,7 +9,7 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
|
|
||||||
env:
|
env:
|
||||||
SIGN_PIPE_VER: "v0.0.21"
|
SIGN_PIPE_VER: "v0.0.18"
|
||||||
GORELEASER_VER: "v2.3.2"
|
GORELEASER_VER: "v2.3.2"
|
||||||
PRODUCT_NAME: "NetBird"
|
PRODUCT_NAME: "NetBird"
|
||||||
COPYRIGHT: "NetBird GmbH"
|
COPYRIGHT: "NetBird GmbH"
|
||||||
@@ -65,13 +65,6 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
username: ${{ secrets.DOCKER_USER }}
|
username: ${{ secrets.DOCKER_USER }}
|
||||||
password: ${{ secrets.DOCKER_TOKEN }}
|
password: ${{ secrets.DOCKER_TOKEN }}
|
||||||
- name: Log in to the GitHub container registry
|
|
||||||
if: github.event_name != 'pull_request'
|
|
||||||
uses: docker/login-action@v3
|
|
||||||
with:
|
|
||||||
registry: ghcr.io
|
|
||||||
username: ${{ github.actor }}
|
|
||||||
password: ${{ secrets.CI_DOCKER_PUSH_GITHUB_TOKEN }}
|
|
||||||
- name: Install OS build dependencies
|
- name: Install OS build dependencies
|
||||||
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu
|
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu
|
||||||
|
|
||||||
@@ -231,17 +224,3 @@ jobs:
|
|||||||
ref: ${{ env.SIGN_PIPE_VER }}
|
ref: ${{ env.SIGN_PIPE_VER }}
|
||||||
token: ${{ secrets.SIGN_GITHUB_TOKEN }}
|
token: ${{ secrets.SIGN_GITHUB_TOKEN }}
|
||||||
inputs: '{ "tag": "${{ github.ref }}", "skipRelease": false }'
|
inputs: '{ "tag": "${{ github.ref }}", "skipRelease": false }'
|
||||||
|
|
||||||
post_on_forum:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
continue-on-error: true
|
|
||||||
needs: [trigger_signer]
|
|
||||||
steps:
|
|
||||||
- uses: Codixer/discourse-topic-github-release-action@v2.0.1
|
|
||||||
with:
|
|
||||||
discourse-api-key: ${{ secrets.DISCOURSE_RELEASES_API_KEY }}
|
|
||||||
discourse-base-url: https://forum.netbird.io
|
|
||||||
discourse-author-username: NetBird
|
|
||||||
discourse-category: 17
|
|
||||||
discourse-tags:
|
|
||||||
releases
|
|
||||||
|
|||||||
@@ -134,7 +134,6 @@ jobs:
|
|||||||
NETBIRD_STORE_ENGINE_MYSQL_DSN: '${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}$'
|
NETBIRD_STORE_ENGINE_MYSQL_DSN: '${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}$'
|
||||||
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
|
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
|
||||||
CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4"
|
CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4"
|
||||||
CI_NETBIRD_MGMT_DISABLE_DEFAULT_POLICY: false
|
|
||||||
|
|
||||||
run: |
|
run: |
|
||||||
set -x
|
set -x
|
||||||
@@ -173,15 +172,13 @@ jobs:
|
|||||||
grep "NETBIRD_STORE_ENGINE_MYSQL_DSN=$NETBIRD_STORE_ENGINE_MYSQL_DSN" docker-compose.yml
|
grep "NETBIRD_STORE_ENGINE_MYSQL_DSN=$NETBIRD_STORE_ENGINE_MYSQL_DSN" docker-compose.yml
|
||||||
grep NETBIRD_STORE_ENGINE_POSTGRES_DSN docker-compose.yml | egrep "$NETBIRD_STORE_ENGINE_POSTGRES_DSN"
|
grep NETBIRD_STORE_ENGINE_POSTGRES_DSN docker-compose.yml | egrep "$NETBIRD_STORE_ENGINE_POSTGRES_DSN"
|
||||||
# check relay values
|
# check relay values
|
||||||
grep "NB_EXPOSED_ADDRESS=rels://$CI_NETBIRD_DOMAIN:33445" docker-compose.yml
|
grep "NB_EXPOSED_ADDRESS=$CI_NETBIRD_DOMAIN:33445" docker-compose.yml
|
||||||
grep "NB_LISTEN_ADDRESS=:33445" docker-compose.yml
|
grep "NB_LISTEN_ADDRESS=:33445" docker-compose.yml
|
||||||
grep '33445:33445' docker-compose.yml
|
grep '33445:33445' docker-compose.yml
|
||||||
grep -A 10 'relay:' docker-compose.yml | egrep 'NB_AUTH_SECRET=.+$'
|
grep -A 10 'relay:' docker-compose.yml | egrep 'NB_AUTH_SECRET=.+$'
|
||||||
grep -A 7 Relay management.json | grep "rels://$CI_NETBIRD_DOMAIN:33445"
|
grep -A 7 Relay management.json | grep "rel://$CI_NETBIRD_DOMAIN:33445"
|
||||||
grep -A 7 Relay management.json | egrep '"Secret": ".+"'
|
grep -A 7 Relay management.json | egrep '"Secret": ".+"'
|
||||||
grep DisablePromptLogin management.json | grep 'true'
|
grep DisablePromptLogin management.json | grep 'true'
|
||||||
grep LoginFlag management.json | grep 0
|
|
||||||
grep DisableDefaultPolicy management.json | grep "$CI_NETBIRD_MGMT_DISABLE_DEFAULT_POLICY"
|
|
||||||
|
|
||||||
- name: Install modules
|
- name: Install modules
|
||||||
run: go mod tidy
|
run: go mod tidy
|
||||||
|
|||||||
136
.goreleaser.yaml
136
.goreleaser.yaml
@@ -149,7 +149,6 @@ nfpms:
|
|||||||
dockers:
|
dockers:
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/netbird:{{ .Version }}-amd64
|
- netbirdio/netbird:{{ .Version }}-amd64
|
||||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-amd64
|
|
||||||
ids:
|
ids:
|
||||||
- netbird
|
- netbird
|
||||||
goarch: amd64
|
goarch: amd64
|
||||||
@@ -165,7 +164,6 @@ dockers:
|
|||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/netbird:{{ .Version }}-arm64v8
|
- netbirdio/netbird:{{ .Version }}-arm64v8
|
||||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
|
|
||||||
ids:
|
ids:
|
||||||
- netbird
|
- netbird
|
||||||
goarch: arm64
|
goarch: arm64
|
||||||
@@ -177,11 +175,10 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/netbird:{{ .Version }}-arm
|
- netbirdio/netbird:{{ .Version }}-arm
|
||||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm
|
|
||||||
ids:
|
ids:
|
||||||
- netbird
|
- netbird
|
||||||
goarch: arm
|
goarch: arm
|
||||||
@@ -194,12 +191,11 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
|
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/netbird:{{ .Version }}-rootless-amd64
|
- netbirdio/netbird:{{ .Version }}-rootless-amd64
|
||||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-amd64
|
|
||||||
ids:
|
ids:
|
||||||
- netbird
|
- netbird
|
||||||
goarch: amd64
|
goarch: amd64
|
||||||
@@ -211,11 +207,9 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/netbird:{{ .Version }}-rootless-arm64v8
|
- netbirdio/netbird:{{ .Version }}-rootless-arm64v8
|
||||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm64v8
|
|
||||||
ids:
|
ids:
|
||||||
- netbird
|
- netbird
|
||||||
goarch: arm64
|
goarch: arm64
|
||||||
@@ -227,11 +221,9 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/netbird:{{ .Version }}-rootless-arm
|
- netbirdio/netbird:{{ .Version }}-rootless-arm
|
||||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm
|
|
||||||
ids:
|
ids:
|
||||||
- netbird
|
- netbird
|
||||||
goarch: arm
|
goarch: arm
|
||||||
@@ -244,12 +236,10 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
|
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/relay:{{ .Version }}-amd64
|
- netbirdio/relay:{{ .Version }}-amd64
|
||||||
- ghcr.io/netbirdio/relay:{{ .Version }}-amd64
|
|
||||||
ids:
|
ids:
|
||||||
- netbird-relay
|
- netbird-relay
|
||||||
goarch: amd64
|
goarch: amd64
|
||||||
@@ -261,11 +251,10 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/relay:{{ .Version }}-arm64v8
|
- netbirdio/relay:{{ .Version }}-arm64v8
|
||||||
- ghcr.io/netbirdio/relay:{{ .Version }}-arm64v8
|
|
||||||
ids:
|
ids:
|
||||||
- netbird-relay
|
- netbird-relay
|
||||||
goarch: arm64
|
goarch: arm64
|
||||||
@@ -277,11 +266,10 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/relay:{{ .Version }}-arm
|
- netbirdio/relay:{{ .Version }}-arm
|
||||||
- ghcr.io/netbirdio/relay:{{ .Version }}-arm
|
|
||||||
ids:
|
ids:
|
||||||
- netbird-relay
|
- netbird-relay
|
||||||
goarch: arm
|
goarch: arm
|
||||||
@@ -294,11 +282,10 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/signal:{{ .Version }}-amd64
|
- netbirdio/signal:{{ .Version }}-amd64
|
||||||
- ghcr.io/netbirdio/signal:{{ .Version }}-amd64
|
|
||||||
ids:
|
ids:
|
||||||
- netbird-signal
|
- netbird-signal
|
||||||
goarch: amd64
|
goarch: amd64
|
||||||
@@ -310,11 +297,10 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/signal:{{ .Version }}-arm64v8
|
- netbirdio/signal:{{ .Version }}-arm64v8
|
||||||
- ghcr.io/netbirdio/signal:{{ .Version }}-arm64v8
|
|
||||||
ids:
|
ids:
|
||||||
- netbird-signal
|
- netbird-signal
|
||||||
goarch: arm64
|
goarch: arm64
|
||||||
@@ -326,11 +312,10 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/signal:{{ .Version }}-arm
|
- netbirdio/signal:{{ .Version }}-arm
|
||||||
- ghcr.io/netbirdio/signal:{{ .Version }}-arm
|
|
||||||
ids:
|
ids:
|
||||||
- netbird-signal
|
- netbird-signal
|
||||||
goarch: arm
|
goarch: arm
|
||||||
@@ -343,11 +328,10 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/management:{{ .Version }}-amd64
|
- netbirdio/management:{{ .Version }}-amd64
|
||||||
- ghcr.io/netbirdio/management:{{ .Version }}-amd64
|
|
||||||
ids:
|
ids:
|
||||||
- netbird-mgmt
|
- netbird-mgmt
|
||||||
goarch: amd64
|
goarch: amd64
|
||||||
@@ -359,11 +343,10 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/management:{{ .Version }}-arm64v8
|
- netbirdio/management:{{ .Version }}-arm64v8
|
||||||
- ghcr.io/netbirdio/management:{{ .Version }}-arm64v8
|
|
||||||
ids:
|
ids:
|
||||||
- netbird-mgmt
|
- netbird-mgmt
|
||||||
goarch: arm64
|
goarch: arm64
|
||||||
@@ -375,11 +358,10 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/management:{{ .Version }}-arm
|
- netbirdio/management:{{ .Version }}-arm
|
||||||
- ghcr.io/netbirdio/management:{{ .Version }}-arm
|
|
||||||
ids:
|
ids:
|
||||||
- netbird-mgmt
|
- netbird-mgmt
|
||||||
goarch: arm
|
goarch: arm
|
||||||
@@ -392,11 +374,10 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/management:{{ .Version }}-debug-amd64
|
- netbirdio/management:{{ .Version }}-debug-amd64
|
||||||
- ghcr.io/netbirdio/management:{{ .Version }}-debug-amd64
|
|
||||||
ids:
|
ids:
|
||||||
- netbird-mgmt
|
- netbird-mgmt
|
||||||
goarch: amd64
|
goarch: amd64
|
||||||
@@ -408,11 +389,10 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/management:{{ .Version }}-debug-arm64v8
|
- netbirdio/management:{{ .Version }}-debug-arm64v8
|
||||||
- ghcr.io/netbirdio/management:{{ .Version }}-debug-arm64v8
|
|
||||||
ids:
|
ids:
|
||||||
- netbird-mgmt
|
- netbird-mgmt
|
||||||
goarch: arm64
|
goarch: arm64
|
||||||
@@ -424,12 +404,11 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
|
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/management:{{ .Version }}-debug-arm
|
- netbirdio/management:{{ .Version }}-debug-arm
|
||||||
- ghcr.io/netbirdio/management:{{ .Version }}-debug-arm
|
|
||||||
ids:
|
ids:
|
||||||
- netbird-mgmt
|
- netbird-mgmt
|
||||||
goarch: arm
|
goarch: arm
|
||||||
@@ -442,11 +421,10 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/upload:{{ .Version }}-amd64
|
- netbirdio/upload:{{ .Version }}-amd64
|
||||||
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
|
|
||||||
ids:
|
ids:
|
||||||
- netbird-upload
|
- netbird-upload
|
||||||
goarch: amd64
|
goarch: amd64
|
||||||
@@ -458,11 +436,10 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/upload:{{ .Version }}-arm64v8
|
- netbirdio/upload:{{ .Version }}-arm64v8
|
||||||
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
|
|
||||||
ids:
|
ids:
|
||||||
- netbird-upload
|
- netbird-upload
|
||||||
goarch: arm64
|
goarch: arm64
|
||||||
@@ -474,11 +451,10 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/upload:{{ .Version }}-arm
|
- netbirdio/upload:{{ .Version }}-arm
|
||||||
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
|
|
||||||
ids:
|
ids:
|
||||||
- netbird-upload
|
- netbird-upload
|
||||||
goarch: arm
|
goarch: arm
|
||||||
@@ -491,7 +467,7 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
docker_manifests:
|
docker_manifests:
|
||||||
- name_template: netbirdio/netbird:{{ .Version }}
|
- name_template: netbirdio/netbird:{{ .Version }}
|
||||||
@@ -570,84 +546,6 @@ docker_manifests:
|
|||||||
- netbirdio/upload:{{ .Version }}-arm64v8
|
- netbirdio/upload:{{ .Version }}-arm64v8
|
||||||
- netbirdio/upload:{{ .Version }}-arm
|
- netbirdio/upload:{{ .Version }}-arm
|
||||||
- netbirdio/upload:{{ .Version }}-amd64
|
- netbirdio/upload:{{ .Version }}-amd64
|
||||||
|
|
||||||
- name_template: ghcr.io/netbirdio/netbird:{{ .Version }}
|
|
||||||
image_templates:
|
|
||||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
|
|
||||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm
|
|
||||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-amd64
|
|
||||||
|
|
||||||
- name_template: ghcr.io/netbirdio/netbird:latest
|
|
||||||
image_templates:
|
|
||||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
|
|
||||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm
|
|
||||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-amd64
|
|
||||||
|
|
||||||
- name_template: ghcr.io/netbirdio/netbird:{{ .Version }}-rootless
|
|
||||||
image_templates:
|
|
||||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm64v8
|
|
||||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm
|
|
||||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-amd64
|
|
||||||
|
|
||||||
- name_template: ghcr.io/netbirdio/netbird:rootless-latest
|
|
||||||
image_templates:
|
|
||||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm64v8
|
|
||||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm
|
|
||||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-amd64
|
|
||||||
|
|
||||||
- name_template: ghcr.io/netbirdio/relay:{{ .Version }}
|
|
||||||
image_templates:
|
|
||||||
- ghcr.io/netbirdio/relay:{{ .Version }}-arm64v8
|
|
||||||
- ghcr.io/netbirdio/relay:{{ .Version }}-arm
|
|
||||||
- ghcr.io/netbirdio/relay:{{ .Version }}-amd64
|
|
||||||
|
|
||||||
- name_template: ghcr.io/netbirdio/relay:latest
|
|
||||||
image_templates:
|
|
||||||
- ghcr.io/netbirdio/relay:{{ .Version }}-arm64v8
|
|
||||||
- ghcr.io/netbirdio/relay:{{ .Version }}-arm
|
|
||||||
- ghcr.io/netbirdio/relay:{{ .Version }}-amd64
|
|
||||||
|
|
||||||
- name_template: ghcr.io/netbirdio/signal:{{ .Version }}
|
|
||||||
image_templates:
|
|
||||||
- ghcr.io/netbirdio/signal:{{ .Version }}-arm64v8
|
|
||||||
- ghcr.io/netbirdio/signal:{{ .Version }}-arm
|
|
||||||
- ghcr.io/netbirdio/signal:{{ .Version }}-amd64
|
|
||||||
|
|
||||||
- name_template: ghcr.io/netbirdio/signal:latest
|
|
||||||
image_templates:
|
|
||||||
- ghcr.io/netbirdio/signal:{{ .Version }}-arm64v8
|
|
||||||
- ghcr.io/netbirdio/signal:{{ .Version }}-arm
|
|
||||||
- ghcr.io/netbirdio/signal:{{ .Version }}-amd64
|
|
||||||
|
|
||||||
- name_template: ghcr.io/netbirdio/management:{{ .Version }}
|
|
||||||
image_templates:
|
|
||||||
- ghcr.io/netbirdio/management:{{ .Version }}-arm64v8
|
|
||||||
- ghcr.io/netbirdio/management:{{ .Version }}-arm
|
|
||||||
- ghcr.io/netbirdio/management:{{ .Version }}-amd64
|
|
||||||
|
|
||||||
- name_template: ghcr.io/netbirdio/management:latest
|
|
||||||
image_templates:
|
|
||||||
- ghcr.io/netbirdio/management:{{ .Version }}-arm64v8
|
|
||||||
- ghcr.io/netbirdio/management:{{ .Version }}-arm
|
|
||||||
- ghcr.io/netbirdio/management:{{ .Version }}-amd64
|
|
||||||
|
|
||||||
- name_template: ghcr.io/netbirdio/management:debug-latest
|
|
||||||
image_templates:
|
|
||||||
- ghcr.io/netbirdio/management:{{ .Version }}-debug-arm64v8
|
|
||||||
- ghcr.io/netbirdio/management:{{ .Version }}-debug-arm
|
|
||||||
- ghcr.io/netbirdio/management:{{ .Version }}-debug-amd64
|
|
||||||
|
|
||||||
- name_template: ghcr.io/netbirdio/upload:{{ .Version }}
|
|
||||||
image_templates:
|
|
||||||
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
|
|
||||||
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
|
|
||||||
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
|
|
||||||
|
|
||||||
- name_template: ghcr.io/netbirdio/upload:latest
|
|
||||||
image_templates:
|
|
||||||
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
|
|
||||||
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
|
|
||||||
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
|
|
||||||
brews:
|
brews:
|
||||||
- ids:
|
- ids:
|
||||||
- default
|
- default
|
||||||
|
|||||||
16
README.md
16
README.md
@@ -12,11 +12,8 @@
|
|||||||
<img src="https://img.shields.io/badge/license-BSD--3-blue" />
|
<img src="https://img.shields.io/badge/license-BSD--3-blue" />
|
||||||
</a>
|
</a>
|
||||||
<br>
|
<br>
|
||||||
<a href="https://docs.netbird.io/slack-url">
|
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-31rofwmxc-27akKd0Le0vyRpBcwXkP0g">
|
||||||
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
|
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
|
||||||
</a>
|
|
||||||
<a href="https://forum.netbird.io">
|
|
||||||
<img src="https://img.shields.io/badge/community forum-@netbird-red.svg?logo=discourse"/>
|
|
||||||
</a>
|
</a>
|
||||||
<br>
|
<br>
|
||||||
<a href="https://gurubase.io/g/netbird">
|
<a href="https://gurubase.io/g/netbird">
|
||||||
@@ -32,13 +29,13 @@
|
|||||||
<br/>
|
<br/>
|
||||||
See <a href="https://netbird.io/docs/">Documentation</a>
|
See <a href="https://netbird.io/docs/">Documentation</a>
|
||||||
<br/>
|
<br/>
|
||||||
Join our <a href="https://docs.netbird.io/slack-url">Slack channel</a> or our <a href="https://forum.netbird.io">Community forum</a>
|
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-31rofwmxc-27akKd0Le0vyRpBcwXkP0g">Slack channel</a>
|
||||||
<br/>
|
<br/>
|
||||||
|
|
||||||
</strong>
|
</strong>
|
||||||
<br>
|
<br>
|
||||||
<a href="https://registry.terraform.io/providers/netbirdio/netbird/latest">
|
<a href="https://github.com/netbirdio/kubernetes-operator">
|
||||||
New: NetBird terraform provider
|
New: NetBird Kubernetes Operator
|
||||||
</a>
|
</a>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
@@ -50,9 +47,10 @@
|
|||||||
|
|
||||||
**Secure.** NetBird enables secure remote access by applying granular access policies while allowing you to manage them intuitively from a single place. Works universally on any infrastructure.
|
**Secure.** NetBird enables secure remote access by applying granular access policies while allowing you to manage them intuitively from a single place. Works universally on any infrastructure.
|
||||||
|
|
||||||
### Open Source Network Security in a Single Platform
|
### Open-Source Network Security in a Single Platform
|
||||||
|
|
||||||
<img width="1188" alt="centralized-network-management 1" src="https://github.com/user-attachments/assets/c28cc8e4-15d2-4d2f-bb97-a6433db39d56" />
|
|
||||||
|

|
||||||
|
|
||||||
### NetBird on Lawrence Systems (Video)
|
### NetBird on Lawrence Systems (Video)
|
||||||
[](https://www.youtube.com/watch?v=Kwrff6h0rEw)
|
[](https://www.youtube.com/watch?v=Kwrff6h0rEw)
|
||||||
|
|||||||
7
buf.gen.yaml
Normal file
7
buf.gen.yaml
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
# For details on buf.gen.yaml configuration, visit https://buf.build/docs/configuration/v2/buf-gen-yaml/
|
||||||
|
version: v2
|
||||||
|
plugins:
|
||||||
|
- remote: buf.build/protocolbuffers/go:v1.35.1
|
||||||
|
out: .
|
||||||
|
- remote: buf.build/grpc/go:v1.5.1
|
||||||
|
out: .
|
||||||
10
buf.yaml
Normal file
10
buf.yaml
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
# For details on buf.yaml configuration, visit https://buf.build/docs/configuration/v2/buf-yaml
|
||||||
|
version: v2
|
||||||
|
modules:
|
||||||
|
- path: proto
|
||||||
|
lint:
|
||||||
|
use:
|
||||||
|
- BASIC
|
||||||
|
breaking:
|
||||||
|
use:
|
||||||
|
- FILE
|
||||||
@@ -1,9 +1,6 @@
|
|||||||
FROM alpine:3.21.3
|
FROM alpine:3.21.3
|
||||||
# iproute2: busybox doesn't display ip rules properly
|
# iproute2: busybox doesn't display ip rules properly
|
||||||
RUN apk add --no-cache ca-certificates ip6tables iproute2 iptables
|
RUN apk add --no-cache ca-certificates ip6tables iproute2 iptables
|
||||||
|
|
||||||
ARG NETBIRD_BINARY=netbird
|
|
||||||
COPY ${NETBIRD_BINARY} /usr/local/bin/netbird
|
|
||||||
|
|
||||||
ENV NB_FOREGROUND_MODE=true
|
ENV NB_FOREGROUND_MODE=true
|
||||||
ENTRYPOINT [ "/usr/local/bin/netbird","up"]
|
ENTRYPOINT [ "/usr/local/bin/netbird","up"]
|
||||||
|
COPY netbird /usr/local/bin/netbird
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
FROM alpine:3.21.0
|
FROM alpine:3.21.0
|
||||||
|
|
||||||
ARG NETBIRD_BINARY=netbird
|
COPY netbird /usr/local/bin/netbird
|
||||||
COPY ${NETBIRD_BINARY} /usr/local/bin/netbird
|
|
||||||
|
|
||||||
RUN apk add --no-cache ca-certificates \
|
RUN apk add --no-cache ca-certificates \
|
||||||
&& adduser -D -h /var/lib/netbird netbird
|
&& adduser -D -h /var/lib/netbird netbird
|
||||||
|
|||||||
@@ -59,14 +59,10 @@ type Client struct {
|
|||||||
deviceName string
|
deviceName string
|
||||||
uiVersion string
|
uiVersion string
|
||||||
networkChangeListener listener.NetworkChangeListener
|
networkChangeListener listener.NetworkChangeListener
|
||||||
|
|
||||||
connectClient *internal.ConnectClient
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClient instantiate a new Client
|
// NewClient instantiate a new Client
|
||||||
func NewClient(cfgFile string, androidSDKVersion int, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
|
func NewClient(cfgFile, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
|
||||||
execWorkaround(androidSDKVersion)
|
|
||||||
|
|
||||||
net.SetAndroidProtectSocketFn(tunAdapter.ProtectSocket)
|
net.SetAndroidProtectSocketFn(tunAdapter.ProtectSocket)
|
||||||
return &Client{
|
return &Client{
|
||||||
cfgFile: cfgFile,
|
cfgFile: cfgFile,
|
||||||
@@ -110,8 +106,8 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
|
|||||||
|
|
||||||
// todo do not throw error in case of cancelled context
|
// todo do not throw error in case of cancelled context
|
||||||
ctx = internal.CtxInitState(ctx)
|
ctx = internal.CtxInitState(ctx)
|
||||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||||
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
|
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
|
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
|
||||||
@@ -136,8 +132,8 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener
|
|||||||
|
|
||||||
// todo do not throw error in case of cancelled context
|
// todo do not throw error in case of cancelled context
|
||||||
ctx = internal.CtxInitState(ctx)
|
ctx = internal.CtxInitState(ctx)
|
||||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||||
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
|
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop the internal client and free the resources
|
// Stop the internal client and free the resources
|
||||||
@@ -178,55 +174,6 @@ func (c *Client) PeersList() *PeerInfoArray {
|
|||||||
return &PeerInfoArray{items: peerInfos}
|
return &PeerInfoArray{items: peerInfos}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) Networks() *NetworkArray {
|
|
||||||
if c.connectClient == nil {
|
|
||||||
log.Error("not connected")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
engine := c.connectClient.Engine()
|
|
||||||
if engine == nil {
|
|
||||||
log.Error("could not get engine")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
routeManager := engine.GetRouteManager()
|
|
||||||
if routeManager == nil {
|
|
||||||
log.Error("could not get route manager")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
networkArray := &NetworkArray{
|
|
||||||
items: make([]Network, 0),
|
|
||||||
}
|
|
||||||
|
|
||||||
for id, routes := range routeManager.GetClientRoutesWithNetID() {
|
|
||||||
if len(routes) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
r := routes[0]
|
|
||||||
netStr := r.Network.String()
|
|
||||||
if r.IsDynamic() {
|
|
||||||
netStr = r.Domains.SafeString()
|
|
||||||
}
|
|
||||||
|
|
||||||
peer, err := c.recorder.GetPeer(routes[0].Peer)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("could not get peer info for %s: %v", routes[0].Peer, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
network := Network{
|
|
||||||
Name: string(id),
|
|
||||||
Network: netStr,
|
|
||||||
Peer: peer.FQDN,
|
|
||||||
Status: peer.ConnStatus.String(),
|
|
||||||
}
|
|
||||||
networkArray.Add(network)
|
|
||||||
}
|
|
||||||
return networkArray
|
|
||||||
}
|
|
||||||
|
|
||||||
// OnUpdatedHostDNS update the DNS servers addresses for root zones
|
// OnUpdatedHostDNS update the DNS servers addresses for root zones
|
||||||
func (c *Client) OnUpdatedHostDNS(list *DNSList) error {
|
func (c *Client) OnUpdatedHostDNS(list *DNSList) error {
|
||||||
dnsServer, err := dns.GetServerDns()
|
dnsServer, err := dns.GetServerDns()
|
||||||
|
|||||||
@@ -1,26 +0,0 @@
|
|||||||
//go:build android
|
|
||||||
|
|
||||||
package android
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
_ "unsafe"
|
|
||||||
)
|
|
||||||
|
|
||||||
// https://github.com/golang/go/pull/69543/commits/aad6b3b32c81795f86bc4a9e81aad94899daf520
|
|
||||||
// In Android version 11 and earlier, pidfd-related system calls
|
|
||||||
// are not allowed by the seccomp policy, which causes crashes due
|
|
||||||
// to SIGSYS signals.
|
|
||||||
|
|
||||||
//go:linkname checkPidfdOnce os.checkPidfdOnce
|
|
||||||
var checkPidfdOnce func() error
|
|
||||||
|
|
||||||
func execWorkaround(androidSDKVersion int) {
|
|
||||||
if androidSDKVersion > 30 { // above Android 11
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
checkPidfdOnce = func() error {
|
|
||||||
return fmt.Errorf("unsupported Android version")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,27 +0,0 @@
|
|||||||
//go:build android
|
|
||||||
|
|
||||||
package android
|
|
||||||
|
|
||||||
type Network struct {
|
|
||||||
Name string
|
|
||||||
Network string
|
|
||||||
Peer string
|
|
||||||
Status string
|
|
||||||
}
|
|
||||||
|
|
||||||
type NetworkArray struct {
|
|
||||||
items []Network
|
|
||||||
}
|
|
||||||
|
|
||||||
func (array *NetworkArray) Add(s Network) *NetworkArray {
|
|
||||||
array.items = append(array.items, s)
|
|
||||||
return array
|
|
||||||
}
|
|
||||||
|
|
||||||
func (array *NetworkArray) Get(i int) *Network {
|
|
||||||
return &array.items[i]
|
|
||||||
}
|
|
||||||
|
|
||||||
func (array *NetworkArray) Size() int {
|
|
||||||
return len(array.items)
|
|
||||||
}
|
|
||||||
@@ -7,23 +7,30 @@ type PeerInfo struct {
|
|||||||
ConnStatus string // Todo replace to enum
|
ConnStatus string // Todo replace to enum
|
||||||
}
|
}
|
||||||
|
|
||||||
// PeerInfoArray is a wrapper of []PeerInfo
|
// PeerInfoCollection made for Java layer to get non default types as collection
|
||||||
|
type PeerInfoCollection interface {
|
||||||
|
Add(s string) PeerInfoCollection
|
||||||
|
Get(i int) string
|
||||||
|
Size() int
|
||||||
|
}
|
||||||
|
|
||||||
|
// PeerInfoArray is the implementation of the PeerInfoCollection
|
||||||
type PeerInfoArray struct {
|
type PeerInfoArray struct {
|
||||||
items []PeerInfo
|
items []PeerInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add new PeerInfo to the collection
|
// Add new PeerInfo to the collection
|
||||||
func (array *PeerInfoArray) Add(s PeerInfo) *PeerInfoArray {
|
func (array PeerInfoArray) Add(s PeerInfo) PeerInfoArray {
|
||||||
array.items = append(array.items, s)
|
array.items = append(array.items, s)
|
||||||
return array
|
return array
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get return an element of the collection
|
// Get return an element of the collection
|
||||||
func (array *PeerInfoArray) Get(i int) *PeerInfo {
|
func (array PeerInfoArray) Get(i int) *PeerInfo {
|
||||||
return &array.items[i]
|
return &array.items[i]
|
||||||
}
|
}
|
||||||
|
|
||||||
// Size return with the size of the collection
|
// Size return with the size of the collection
|
||||||
func (array *PeerInfoArray) Size() int {
|
func (array PeerInfoArray) Size() int {
|
||||||
return len(array.items)
|
return len(array.items)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,12 +4,12 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Preferences exports a subset of the internal config for gomobile
|
// Preferences export a subset of the internal config for gomobile
|
||||||
type Preferences struct {
|
type Preferences struct {
|
||||||
configInput internal.ConfigInput
|
configInput internal.ConfigInput
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewPreferences creates a new Preferences instance
|
// NewPreferences create new Preferences instance
|
||||||
func NewPreferences(configPath string) *Preferences {
|
func NewPreferences(configPath string) *Preferences {
|
||||||
ci := internal.ConfigInput{
|
ci := internal.ConfigInput{
|
||||||
ConfigPath: configPath,
|
ConfigPath: configPath,
|
||||||
@@ -17,7 +17,7 @@ func NewPreferences(configPath string) *Preferences {
|
|||||||
return &Preferences{ci}
|
return &Preferences{ci}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetManagementURL reads URL from config file
|
// GetManagementURL read url from config file
|
||||||
func (p *Preferences) GetManagementURL() (string, error) {
|
func (p *Preferences) GetManagementURL() (string, error) {
|
||||||
if p.configInput.ManagementURL != "" {
|
if p.configInput.ManagementURL != "" {
|
||||||
return p.configInput.ManagementURL, nil
|
return p.configInput.ManagementURL, nil
|
||||||
@@ -30,12 +30,12 @@ func (p *Preferences) GetManagementURL() (string, error) {
|
|||||||
return cfg.ManagementURL.String(), err
|
return cfg.ManagementURL.String(), err
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetManagementURL stores the given URL and waits for commit
|
// SetManagementURL store the given url and wait for commit
|
||||||
func (p *Preferences) SetManagementURL(url string) {
|
func (p *Preferences) SetManagementURL(url string) {
|
||||||
p.configInput.ManagementURL = url
|
p.configInput.ManagementURL = url
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAdminURL reads URL from config file
|
// GetAdminURL read url from config file
|
||||||
func (p *Preferences) GetAdminURL() (string, error) {
|
func (p *Preferences) GetAdminURL() (string, error) {
|
||||||
if p.configInput.AdminURL != "" {
|
if p.configInput.AdminURL != "" {
|
||||||
return p.configInput.AdminURL, nil
|
return p.configInput.AdminURL, nil
|
||||||
@@ -48,12 +48,12 @@ func (p *Preferences) GetAdminURL() (string, error) {
|
|||||||
return cfg.AdminURL.String(), err
|
return cfg.AdminURL.String(), err
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetAdminURL stores the given URL and waits for commit
|
// SetAdminURL store the given url and wait for commit
|
||||||
func (p *Preferences) SetAdminURL(url string) {
|
func (p *Preferences) SetAdminURL(url string) {
|
||||||
p.configInput.AdminURL = url
|
p.configInput.AdminURL = url
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPreSharedKey reads pre-shared key from config file
|
// GetPreSharedKey read preshared key from config file
|
||||||
func (p *Preferences) GetPreSharedKey() (string, error) {
|
func (p *Preferences) GetPreSharedKey() (string, error) {
|
||||||
if p.configInput.PreSharedKey != nil {
|
if p.configInput.PreSharedKey != nil {
|
||||||
return *p.configInput.PreSharedKey, nil
|
return *p.configInput.PreSharedKey, nil
|
||||||
@@ -66,160 +66,12 @@ func (p *Preferences) GetPreSharedKey() (string, error) {
|
|||||||
return cfg.PreSharedKey, err
|
return cfg.PreSharedKey, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetPreSharedKey stores the given key and waits for commit
|
// SetPreSharedKey store the given key and wait for commit
|
||||||
func (p *Preferences) SetPreSharedKey(key string) {
|
func (p *Preferences) SetPreSharedKey(key string) {
|
||||||
p.configInput.PreSharedKey = &key
|
p.configInput.PreSharedKey = &key
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetRosenpassEnabled stores whether Rosenpass is enabled
|
// Commit write out the changes into config file
|
||||||
func (p *Preferences) SetRosenpassEnabled(enabled bool) {
|
|
||||||
p.configInput.RosenpassEnabled = &enabled
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetRosenpassEnabled reads Rosenpass enabled status from config file
|
|
||||||
func (p *Preferences) GetRosenpassEnabled() (bool, error) {
|
|
||||||
if p.configInput.RosenpassEnabled != nil {
|
|
||||||
return *p.configInput.RosenpassEnabled, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
return cfg.RosenpassEnabled, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetRosenpassPermissive stores the given permissive setting and waits for commit
|
|
||||||
func (p *Preferences) SetRosenpassPermissive(permissive bool) {
|
|
||||||
p.configInput.RosenpassPermissive = &permissive
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetRosenpassPermissive reads Rosenpass permissive setting from config file
|
|
||||||
func (p *Preferences) GetRosenpassPermissive() (bool, error) {
|
|
||||||
if p.configInput.RosenpassPermissive != nil {
|
|
||||||
return *p.configInput.RosenpassPermissive, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
return cfg.RosenpassPermissive, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetDisableClientRoutes reads disable client routes setting from config file
|
|
||||||
func (p *Preferences) GetDisableClientRoutes() (bool, error) {
|
|
||||||
if p.configInput.DisableClientRoutes != nil {
|
|
||||||
return *p.configInput.DisableClientRoutes, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
return cfg.DisableClientRoutes, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetDisableClientRoutes stores the given value and waits for commit
|
|
||||||
func (p *Preferences) SetDisableClientRoutes(disable bool) {
|
|
||||||
p.configInput.DisableClientRoutes = &disable
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetDisableServerRoutes reads disable server routes setting from config file
|
|
||||||
func (p *Preferences) GetDisableServerRoutes() (bool, error) {
|
|
||||||
if p.configInput.DisableServerRoutes != nil {
|
|
||||||
return *p.configInput.DisableServerRoutes, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
return cfg.DisableServerRoutes, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetDisableServerRoutes stores the given value and waits for commit
|
|
||||||
func (p *Preferences) SetDisableServerRoutes(disable bool) {
|
|
||||||
p.configInput.DisableServerRoutes = &disable
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetDisableDNS reads disable DNS setting from config file
|
|
||||||
func (p *Preferences) GetDisableDNS() (bool, error) {
|
|
||||||
if p.configInput.DisableDNS != nil {
|
|
||||||
return *p.configInput.DisableDNS, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
return cfg.DisableDNS, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetDisableDNS stores the given value and waits for commit
|
|
||||||
func (p *Preferences) SetDisableDNS(disable bool) {
|
|
||||||
p.configInput.DisableDNS = &disable
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetDisableFirewall reads disable firewall setting from config file
|
|
||||||
func (p *Preferences) GetDisableFirewall() (bool, error) {
|
|
||||||
if p.configInput.DisableFirewall != nil {
|
|
||||||
return *p.configInput.DisableFirewall, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
return cfg.DisableFirewall, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetDisableFirewall stores the given value and waits for commit
|
|
||||||
func (p *Preferences) SetDisableFirewall(disable bool) {
|
|
||||||
p.configInput.DisableFirewall = &disable
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetServerSSHAllowed reads server SSH allowed setting from config file
|
|
||||||
func (p *Preferences) GetServerSSHAllowed() (bool, error) {
|
|
||||||
if p.configInput.ServerSSHAllowed != nil {
|
|
||||||
return *p.configInput.ServerSSHAllowed, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
if cfg.ServerSSHAllowed == nil {
|
|
||||||
// Default to false for security on Android
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
return *cfg.ServerSSHAllowed, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetServerSSHAllowed stores the given value and waits for commit
|
|
||||||
func (p *Preferences) SetServerSSHAllowed(allowed bool) {
|
|
||||||
p.configInput.ServerSSHAllowed = &allowed
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetBlockInbound reads block inbound setting from config file
|
|
||||||
func (p *Preferences) GetBlockInbound() (bool, error) {
|
|
||||||
if p.configInput.BlockInbound != nil {
|
|
||||||
return *p.configInput.BlockInbound, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
return cfg.BlockInbound, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetBlockInbound stores the given value and waits for commit
|
|
||||||
func (p *Preferences) SetBlockInbound(block bool) {
|
|
||||||
p.configInput.BlockInbound = &block
|
|
||||||
}
|
|
||||||
|
|
||||||
// Commit writes out the changes to the config file
|
|
||||||
func (p *Preferences) Commit() error {
|
func (p *Preferences) Commit() error {
|
||||||
_, err := internal.UpdateOrCreateConfig(p.configInput)
|
_, err := internal.UpdateOrCreateConfig(p.configInput)
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -69,22 +69,6 @@ func (a *Anonymizer) AnonymizeIP(ip netip.Addr) netip.Addr {
|
|||||||
return a.ipAnonymizer[ip]
|
return a.ipAnonymizer[ip]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Anonymizer) AnonymizeUDPAddr(addr net.UDPAddr) net.UDPAddr {
|
|
||||||
// Convert IP to netip.Addr
|
|
||||||
ip, ok := netip.AddrFromSlice(addr.IP)
|
|
||||||
if !ok {
|
|
||||||
return addr
|
|
||||||
}
|
|
||||||
|
|
||||||
anonIP := a.AnonymizeIP(ip)
|
|
||||||
|
|
||||||
return net.UDPAddr{
|
|
||||||
IP: anonIP.AsSlice(),
|
|
||||||
Port: addr.Port,
|
|
||||||
Zone: addr.Zone,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// isInAnonymizedRange checks if an IP is within the range of already assigned anonymized IPs
|
// isInAnonymizedRange checks if an IP is within the range of already assigned anonymized IPs
|
||||||
func (a *Anonymizer) isInAnonymizedRange(ip netip.Addr) bool {
|
func (a *Anonymizer) isInAnonymizedRange(ip netip.Addr) bool {
|
||||||
if ip.Is4() && ip.Compare(a.startAnonIPv4) >= 0 && ip.Compare(a.currentAnonIPv4) <= 0 {
|
if ip.Is4() && ip.Compare(a.startAnonIPv4) >= 0 && ip.Compare(a.currentAnonIPv4) <= 0 {
|
||||||
|
|||||||
@@ -17,18 +17,10 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/server"
|
"github.com/netbirdio/netbird/client/server"
|
||||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||||
"github.com/netbirdio/netbird/upload-server/types"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const errCloseConnection = "Failed to close connection: %v"
|
const errCloseConnection = "Failed to close connection: %v"
|
||||||
|
|
||||||
var (
|
|
||||||
logFileCount uint32
|
|
||||||
systemInfoFlag bool
|
|
||||||
uploadBundleFlag bool
|
|
||||||
uploadBundleURLFlag string
|
|
||||||
)
|
|
||||||
|
|
||||||
var debugCmd = &cobra.Command{
|
var debugCmd = &cobra.Command{
|
||||||
Use: "debug",
|
Use: "debug",
|
||||||
Short: "Debugging commands",
|
Short: "Debugging commands",
|
||||||
@@ -96,13 +88,12 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
|
|||||||
|
|
||||||
client := proto.NewDaemonServiceClient(conn)
|
client := proto.NewDaemonServiceClient(conn)
|
||||||
request := &proto.DebugBundleRequest{
|
request := &proto.DebugBundleRequest{
|
||||||
Anonymize: anonymizeFlag,
|
Anonymize: anonymizeFlag,
|
||||||
Status: getStatusOutput(cmd, anonymizeFlag),
|
Status: getStatusOutput(cmd, anonymizeFlag),
|
||||||
SystemInfo: systemInfoFlag,
|
SystemInfo: debugSystemInfoFlag,
|
||||||
LogFileCount: logFileCount,
|
|
||||||
}
|
}
|
||||||
if uploadBundleFlag {
|
if debugUploadBundle {
|
||||||
request.UploadURL = uploadBundleURLFlag
|
request.UploadURL = debugUploadBundleURL
|
||||||
}
|
}
|
||||||
resp, err := client.DebugBundle(cmd.Context(), request)
|
resp, err := client.DebugBundle(cmd.Context(), request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -114,7 +105,7 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
|
|||||||
return fmt.Errorf("upload failed: %s", resp.GetUploadFailureReason())
|
return fmt.Errorf("upload failed: %s", resp.GetUploadFailureReason())
|
||||||
}
|
}
|
||||||
|
|
||||||
if uploadBundleFlag {
|
if debugUploadBundle {
|
||||||
cmd.Printf("Upload file key:\n%s\n", resp.GetUploadedKey())
|
cmd.Printf("Upload file key:\n%s\n", resp.GetUploadedKey())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -232,13 +223,12 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
|||||||
headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
|
headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
|
||||||
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag))
|
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag))
|
||||||
request := &proto.DebugBundleRequest{
|
request := &proto.DebugBundleRequest{
|
||||||
Anonymize: anonymizeFlag,
|
Anonymize: anonymizeFlag,
|
||||||
Status: statusOutput,
|
Status: statusOutput,
|
||||||
SystemInfo: systemInfoFlag,
|
SystemInfo: debugSystemInfoFlag,
|
||||||
LogFileCount: logFileCount,
|
|
||||||
}
|
}
|
||||||
if uploadBundleFlag {
|
if debugUploadBundle {
|
||||||
request.UploadURL = uploadBundleURLFlag
|
request.UploadURL = debugUploadBundleURL
|
||||||
}
|
}
|
||||||
resp, err := client.DebugBundle(cmd.Context(), request)
|
resp, err := client.DebugBundle(cmd.Context(), request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -265,7 +255,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
|||||||
return fmt.Errorf("upload failed: %s", resp.GetUploadFailureReason())
|
return fmt.Errorf("upload failed: %s", resp.GetUploadFailureReason())
|
||||||
}
|
}
|
||||||
|
|
||||||
if uploadBundleFlag {
|
if debugUploadBundle {
|
||||||
cmd.Printf("Upload file key:\n%s\n", resp.GetUploadedKey())
|
cmd.Printf("Upload file key:\n%s\n", resp.GetUploadedKey())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -307,7 +297,7 @@ func getStatusOutput(cmd *cobra.Command, anon bool) string {
|
|||||||
cmd.PrintErrf("Failed to get status: %v\n", err)
|
cmd.PrintErrf("Failed to get status: %v\n", err)
|
||||||
} else {
|
} else {
|
||||||
statusOutputString = nbstatus.ParseToFullDetailSummary(
|
statusOutputString = nbstatus.ParseToFullDetailSummary(
|
||||||
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, ""),
|
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
return statusOutputString
|
return statusOutputString
|
||||||
@@ -385,15 +375,3 @@ func generateDebugBundle(config *internal.Config, recorder *peer.Status, connect
|
|||||||
}
|
}
|
||||||
log.Infof("Generated debug bundle from SIGUSR1 at: %s", path)
|
log.Infof("Generated debug bundle from SIGUSR1 at: %s", path)
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
|
||||||
debugBundleCmd.Flags().Uint32VarP(&logFileCount, "log-file-count", "C", 1, "Number of rotated log files to include in debug bundle")
|
|
||||||
debugBundleCmd.Flags().BoolVarP(&systemInfoFlag, "system-info", "S", true, "Adds system information to the debug bundle")
|
|
||||||
debugBundleCmd.Flags().BoolVarP(&uploadBundleFlag, "upload-bundle", "U", false, "Uploads the debug bundle to a server")
|
|
||||||
debugBundleCmd.Flags().StringVar(&uploadBundleURLFlag, "upload-bundle-url", types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle")
|
|
||||||
|
|
||||||
forCmd.Flags().Uint32VarP(&logFileCount, "log-file-count", "C", 1, "Number of rotated log files to include in debug bundle")
|
|
||||||
forCmd.Flags().BoolVarP(&systemInfoFlag, "system-info", "S", true, "Adds system information to the debug bundle")
|
|
||||||
forCmd.Flags().BoolVarP(&uploadBundleFlag, "upload-bundle", "U", false, "Uploads the debug bundle to a server")
|
|
||||||
forCmd.Flags().StringVar(&uploadBundleURLFlag, "upload-bundle-url", types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle")
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"runtime"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -99,11 +98,11 @@ var loginCmd = &cobra.Command{
|
|||||||
}
|
}
|
||||||
|
|
||||||
loginRequest := proto.LoginRequest{
|
loginRequest := proto.LoginRequest{
|
||||||
SetupKey: providedSetupKey,
|
SetupKey: providedSetupKey,
|
||||||
ManagementUrl: managementURL,
|
ManagementUrl: managementURL,
|
||||||
IsUnixDesktopClient: isUnixRunningDesktop(),
|
IsLinuxDesktopClient: isLinuxRunningDesktop(),
|
||||||
Hostname: hostName,
|
Hostname: hostName,
|
||||||
DnsLabels: dnsLabelsReq,
|
DnsLabels: dnsLabelsReq,
|
||||||
}
|
}
|
||||||
|
|
||||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||||
@@ -196,7 +195,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.C
|
|||||||
}
|
}
|
||||||
|
|
||||||
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *internal.Config) (*auth.TokenInfo, error) {
|
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *internal.Config) (*auth.TokenInfo, error) {
|
||||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop())
|
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isLinuxRunningDesktop())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -244,10 +243,7 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// isUnixRunningDesktop checks if a Linux OS is running desktop environment
|
// isLinuxRunningDesktop checks if a Linux OS is running desktop environment
|
||||||
func isUnixRunningDesktop() bool {
|
func isLinuxRunningDesktop() bool {
|
||||||
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != ""
|
return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != ""
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,22 +22,26 @@ import (
|
|||||||
"google.golang.org/grpc/credentials/insecure"
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
"github.com/netbirdio/netbird/upload-server/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
externalIPMapFlag = "external-ip-map"
|
externalIPMapFlag = "external-ip-map"
|
||||||
dnsResolverAddress = "dns-resolver-address"
|
dnsResolverAddress = "dns-resolver-address"
|
||||||
enableRosenpassFlag = "enable-rosenpass"
|
enableRosenpassFlag = "enable-rosenpass"
|
||||||
rosenpassPermissiveFlag = "rosenpass-permissive"
|
rosenpassPermissiveFlag = "rosenpass-permissive"
|
||||||
preSharedKeyFlag = "preshared-key"
|
preSharedKeyFlag = "preshared-key"
|
||||||
interfaceNameFlag = "interface-name"
|
interfaceNameFlag = "interface-name"
|
||||||
wireguardPortFlag = "wireguard-port"
|
wireguardPortFlag = "wireguard-port"
|
||||||
networkMonitorFlag = "network-monitor"
|
networkMonitorFlag = "network-monitor"
|
||||||
disableAutoConnectFlag = "disable-auto-connect"
|
disableAutoConnectFlag = "disable-auto-connect"
|
||||||
serverSSHAllowedFlag = "allow-server-ssh"
|
serverSSHAllowedFlag = "allow-server-ssh"
|
||||||
extraIFaceBlackListFlag = "extra-iface-blacklist"
|
extraIFaceBlackListFlag = "extra-iface-blacklist"
|
||||||
dnsRouteIntervalFlag = "dns-router-interval"
|
dnsRouteIntervalFlag = "dns-router-interval"
|
||||||
enableLazyConnectionFlag = "enable-lazy-connection"
|
systemInfoFlag = "system-info"
|
||||||
|
blockLANAccessFlag = "block-lan-access"
|
||||||
|
uploadBundle = "upload-bundle"
|
||||||
|
uploadBundleURL = "upload-bundle-url"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -71,8 +75,11 @@ var (
|
|||||||
autoConnectDisabled bool
|
autoConnectDisabled bool
|
||||||
extraIFaceBlackList []string
|
extraIFaceBlackList []string
|
||||||
anonymizeFlag bool
|
anonymizeFlag bool
|
||||||
|
debugSystemInfoFlag bool
|
||||||
dnsRouteInterval time.Duration
|
dnsRouteInterval time.Duration
|
||||||
lazyConnEnabled bool
|
blockLANAccess bool
|
||||||
|
debugUploadBundle bool
|
||||||
|
debugUploadBundleURL string
|
||||||
|
|
||||||
rootCmd = &cobra.Command{
|
rootCmd = &cobra.Command{
|
||||||
Use: "netbird",
|
Use: "netbird",
|
||||||
@@ -177,8 +184,10 @@ func init() {
|
|||||||
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
|
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
|
||||||
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
|
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
|
||||||
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
|
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
|
||||||
upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand. Note: this setting may be overridden by management configuration.")
|
|
||||||
|
|
||||||
|
debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", true, "Adds system information to the debug bundle")
|
||||||
|
debugCmd.PersistentFlags().BoolVarP(&debugUploadBundle, uploadBundle, "U", false, fmt.Sprintf("Uploads the debug bundle to a server from URL defined by %s", uploadBundleURL))
|
||||||
|
debugCmd.PersistentFlags().StringVar(&debugUploadBundleURL, uploadBundleURL, types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle")
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetupCloseHandler handles SIGTERM signal and exits with success
|
// SetupCloseHandler handles SIGTERM signal and exits with success
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package cmd
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"runtime"
|
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/kardianos/service"
|
"github.com/kardianos/service"
|
||||||
@@ -28,19 +27,12 @@ func newProgram(ctx context.Context, cancel context.CancelFunc) *program {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func newSVCConfig() *service.Config {
|
func newSVCConfig() *service.Config {
|
||||||
config := &service.Config{
|
return &service.Config{
|
||||||
Name: serviceName,
|
Name: serviceName,
|
||||||
DisplayName: "Netbird",
|
DisplayName: "Netbird",
|
||||||
Description: "Netbird mesh network client",
|
Description: "A WireGuard-based mesh network that connects your devices into a single private network.",
|
||||||
Option: make(service.KeyValue),
|
Option: make(service.KeyValue),
|
||||||
EnvVars: make(map[string]string),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if runtime.GOOS == "linux" {
|
|
||||||
config.EnvVars["SYSTEMD_UNIT"] = serviceName
|
|
||||||
}
|
|
||||||
|
|
||||||
return config
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSVC(prg *program, conf *service.Config) (service.Service, error) {
|
func newSVC(prg *program, conf *service.Config) (service.Service, error) {
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ var installCmd = &cobra.Command{
|
|||||||
svcConfig.Arguments = append(svcConfig.Arguments, "--management-url", managementURL)
|
svcConfig.Arguments = append(svcConfig.Arguments, "--management-url", managementURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
if logFile != "" {
|
if logFile != "console" {
|
||||||
svcConfig.Arguments = append(svcConfig.Arguments, "--log-file", logFile)
|
svcConfig.Arguments = append(svcConfig.Arguments, "--log-file", logFile)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -26,7 +26,6 @@ var (
|
|||||||
statusFilter string
|
statusFilter string
|
||||||
ipsFilterMap map[string]struct{}
|
ipsFilterMap map[string]struct{}
|
||||||
prefixNamesFilterMap map[string]struct{}
|
prefixNamesFilterMap map[string]struct{}
|
||||||
connectionTypeFilter string
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var statusCmd = &cobra.Command{
|
var statusCmd = &cobra.Command{
|
||||||
@@ -45,8 +44,7 @@ func init() {
|
|||||||
statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4")
|
statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4")
|
||||||
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200")
|
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200")
|
||||||
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
|
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
|
||||||
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
|
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(connected|disconnected), e.g., --filter-by-status connected")
|
||||||
statusCmd.PersistentFlags().StringVar(&connectionTypeFilter, "filter-by-connection-type", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func statusFunc(cmd *cobra.Command, args []string) error {
|
func statusFunc(cmd *cobra.Command, args []string) error {
|
||||||
@@ -71,10 +69,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
status := resp.GetStatus()
|
if resp.GetStatus() == string(internal.StatusNeedsLogin) || resp.GetStatus() == string(internal.StatusLoginFailed) {
|
||||||
|
|
||||||
if status == string(internal.StatusNeedsLogin) || status == string(internal.StatusLoginFailed) ||
|
|
||||||
status == string(internal.StatusSessionExpired) {
|
|
||||||
cmd.Printf("Daemon status: %s\n\n"+
|
cmd.Printf("Daemon status: %s\n\n"+
|
||||||
"Run UP command to log in with SSO (interactive login):\n\n"+
|
"Run UP command to log in with SSO (interactive login):\n\n"+
|
||||||
" netbird up \n\n"+
|
" netbird up \n\n"+
|
||||||
@@ -91,7 +86,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter)
|
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap)
|
||||||
var statusOutputString string
|
var statusOutputString string
|
||||||
switch {
|
switch {
|
||||||
case detailFlag:
|
case detailFlag:
|
||||||
@@ -122,7 +117,7 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
|
|||||||
}
|
}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: true})
|
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message())
|
return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message())
|
||||||
}
|
}
|
||||||
@@ -132,12 +127,12 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
|
|||||||
|
|
||||||
func parseFilters() error {
|
func parseFilters() error {
|
||||||
switch strings.ToLower(statusFilter) {
|
switch strings.ToLower(statusFilter) {
|
||||||
case "", "idle", "connecting", "connected":
|
case "", "disconnected", "connected":
|
||||||
if strings.ToLower(statusFilter) != "" {
|
if strings.ToLower(statusFilter) != "" {
|
||||||
enableDetailFlagWhenFilterFlag()
|
enableDetailFlagWhenFilterFlag()
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("wrong status filter, should be one of connected|connecting|idle, got: %s", statusFilter)
|
return fmt.Errorf("wrong status filter, should be one of connected|disconnected, got: %s", statusFilter)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(ipsFilter) > 0 {
|
if len(ipsFilter) > 0 {
|
||||||
@@ -158,15 +153,6 @@ func parseFilters() error {
|
|||||||
enableDetailFlagWhenFilterFlag()
|
enableDetailFlagWhenFilterFlag()
|
||||||
}
|
}
|
||||||
|
|
||||||
switch strings.ToLower(connectionTypeFilter) {
|
|
||||||
case "", "p2p", "relayed":
|
|
||||||
if strings.ToLower(connectionTypeFilter) != "" {
|
|
||||||
enableDetailFlagWhenFilterFlag()
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("wrong connection-type filter, should be one of P2P|Relayed, got: %s", connectionTypeFilter)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,8 +6,6 @@ const (
|
|||||||
disableServerRoutesFlag = "disable-server-routes"
|
disableServerRoutesFlag = "disable-server-routes"
|
||||||
disableDNSFlag = "disable-dns"
|
disableDNSFlag = "disable-dns"
|
||||||
disableFirewallFlag = "disable-firewall"
|
disableFirewallFlag = "disable-firewall"
|
||||||
blockLANAccessFlag = "block-lan-access"
|
|
||||||
blockInboundFlag = "block-inbound"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -15,8 +13,6 @@ var (
|
|||||||
disableServerRoutes bool
|
disableServerRoutes bool
|
||||||
disableDNS bool
|
disableDNS bool
|
||||||
disableFirewall bool
|
disableFirewall bool
|
||||||
blockLANAccess bool
|
|
||||||
blockInbound bool
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
@@ -32,11 +28,4 @@ func init() {
|
|||||||
|
|
||||||
upCmd.PersistentFlags().BoolVar(&disableFirewall, disableFirewallFlag, false,
|
upCmd.PersistentFlags().BoolVar(&disableFirewall, disableFirewallFlag, false,
|
||||||
"Disable firewall configuration. If enabled, the client won't modify firewall rules.")
|
"Disable firewall configuration. If enabled, the client won't modify firewall rules.")
|
||||||
|
|
||||||
upCmd.PersistentFlags().BoolVar(&blockLANAccess, blockLANAccessFlag, false,
|
|
||||||
"Block access to local networks (LAN) when using this peer as a router or exit node")
|
|
||||||
|
|
||||||
upCmd.PersistentFlags().BoolVar(&blockInbound, blockInboundFlag, false,
|
|
||||||
"Block inbound connections. If enabled, the client will not allow any inbound connections to the local machine nor routed networks.\n"+
|
|
||||||
"This overrides any policies received from the management service.")
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -103,13 +103,13 @@ func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc
|
|||||||
Return(&types.Settings{}, nil).
|
Return(&types.Settings{}, nil).
|
||||||
AnyTimes()
|
AnyTimes()
|
||||||
|
|
||||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
|
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
|
||||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &mgmt.MockIntegratedValidator{})
|
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ var traceCmd = &cobra.Command{
|
|||||||
Example: `
|
Example: `
|
||||||
netbird debug trace in 192.168.1.10 10.10.0.2 -p tcp --sport 12345 --dport 443 --syn --ack
|
netbird debug trace in 192.168.1.10 10.10.0.2 -p tcp --sport 12345 --dport 443 --syn --ack
|
||||||
netbird debug trace out 10.10.0.1 8.8.8.8 -p udp --dport 53
|
netbird debug trace out 10.10.0.1 8.8.8.8 -p udp --dport 53
|
||||||
netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --icmp-type 8 --icmp-code 0
|
netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --type 8 --code 0
|
||||||
netbird debug trace in 100.64.1.1 self -p tcp --dport 80`,
|
netbird debug trace in 100.64.1.1 self -p tcp --dport 80`,
|
||||||
Args: cobra.ExactArgs(3),
|
Args: cobra.ExactArgs(3),
|
||||||
RunE: tracePacket,
|
RunE: tracePacket,
|
||||||
@@ -118,7 +118,7 @@ func tracePacket(cmd *cobra.Command, args []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func printTrace(cmd *cobra.Command, src, dst, proto string, sport, dport uint16, resp *proto.TracePacketResponse) {
|
func printTrace(cmd *cobra.Command, src, dst, proto string, sport, dport uint16, resp *proto.TracePacketResponse) {
|
||||||
cmd.Printf("Packet trace %s:%d → %s:%d (%s)\n\n", src, sport, dst, dport, strings.ToUpper(proto))
|
cmd.Printf("Packet trace %s:%d -> %s:%d (%s)\n\n", src, sport, dst, dport, strings.ToUpper(proto))
|
||||||
|
|
||||||
for _, stage := range resp.Stages {
|
for _, stage := range resp.Stages {
|
||||||
if stage.ForwardingDetails != nil {
|
if stage.ForwardingDetails != nil {
|
||||||
|
|||||||
280
client/cmd/up.go
280
client/cmd/up.go
@@ -55,11 +55,12 @@ func init() {
|
|||||||
upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name")
|
upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name")
|
||||||
upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port")
|
upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port")
|
||||||
upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", networkMonitor,
|
upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", networkMonitor,
|
||||||
`Manage network monitoring. Defaults to true on Windows and macOS, false on Linux and FreeBSD. `+
|
`Manage network monitoring. Defaults to true on Windows and macOS, false on Linux. `+
|
||||||
`E.g. --network-monitor=false to disable or --network-monitor=true to enable.`,
|
`E.g. --network-monitor=false to disable or --network-monitor=true to enable.`,
|
||||||
)
|
)
|
||||||
upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening")
|
upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening")
|
||||||
upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval")
|
upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval")
|
||||||
|
upCmd.PersistentFlags().BoolVar(&blockLANAccess, blockLANAccessFlag, false, "Block access to local networks (LAN) when using this peer as a router or exit node")
|
||||||
|
|
||||||
upCmd.PersistentFlags().StringSliceVar(&dnsLabels, dnsLabelsFlag, nil,
|
upCmd.PersistentFlags().StringSliceVar(&dnsLabels, dnsLabelsFlag, nil,
|
||||||
`Sets DNS labels`+
|
`Sets DNS labels`+
|
||||||
@@ -118,9 +119,79 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
ic, err := setupConfig(customDNSAddressConverted, cmd)
|
ic := internal.ConfigInput{
|
||||||
if err != nil {
|
ManagementURL: managementURL,
|
||||||
return fmt.Errorf("setup config: %v", err)
|
AdminURL: adminURL,
|
||||||
|
ConfigPath: configPath,
|
||||||
|
NATExternalIPs: natExternalIPs,
|
||||||
|
CustomDNSAddress: customDNSAddressConverted,
|
||||||
|
ExtraIFaceBlackList: extraIFaceBlackList,
|
||||||
|
DNSLabels: dnsLabelsValidated,
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(enableRosenpassFlag).Changed {
|
||||||
|
ic.RosenpassEnabled = &rosenpassEnabled
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(rosenpassPermissiveFlag).Changed {
|
||||||
|
ic.RosenpassPermissive = &rosenpassPermissive
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||||
|
ic.ServerSSHAllowed = &serverSSHAllowed
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(interfaceNameFlag).Changed {
|
||||||
|
if err := parseInterfaceName(interfaceName); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
ic.InterfaceName = &interfaceName
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(wireguardPortFlag).Changed {
|
||||||
|
p := int(wireguardPort)
|
||||||
|
ic.WireguardPort = &p
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(networkMonitorFlag).Changed {
|
||||||
|
ic.NetworkMonitor = &networkMonitor
|
||||||
|
}
|
||||||
|
|
||||||
|
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||||
|
ic.PreSharedKey = &preSharedKey
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(disableAutoConnectFlag).Changed {
|
||||||
|
ic.DisableAutoConnect = &autoConnectDisabled
|
||||||
|
|
||||||
|
if autoConnectDisabled {
|
||||||
|
cmd.Println("Autoconnect has been disabled. The client won't connect automatically when the service starts.")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !autoConnectDisabled {
|
||||||
|
cmd.Println("Autoconnect has been enabled. The client will connect automatically when the service starts.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(dnsRouteIntervalFlag).Changed {
|
||||||
|
ic.DNSRouteInterval = &dnsRouteInterval
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(disableClientRoutesFlag).Changed {
|
||||||
|
ic.DisableClientRoutes = &disableClientRoutes
|
||||||
|
}
|
||||||
|
if cmd.Flag(disableServerRoutesFlag).Changed {
|
||||||
|
ic.DisableServerRoutes = &disableServerRoutes
|
||||||
|
}
|
||||||
|
if cmd.Flag(disableDNSFlag).Changed {
|
||||||
|
ic.DisableDNS = &disableDNS
|
||||||
|
}
|
||||||
|
if cmd.Flag(disableFirewallFlag).Changed {
|
||||||
|
ic.DisableFirewall = &disableFirewall
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(blockLANAccessFlag).Changed {
|
||||||
|
ic.BlockLANAccess = &blockLANAccess
|
||||||
}
|
}
|
||||||
|
|
||||||
providedSetupKey, err := getSetupKey()
|
providedSetupKey, err := getSetupKey()
|
||||||
@@ -128,7 +199,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
config, err := internal.UpdateOrCreateConfig(*ic)
|
config, err := internal.UpdateOrCreateConfig(ic)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("get config file: %v", err)
|
return fmt.Errorf("get config file: %v", err)
|
||||||
}
|
}
|
||||||
@@ -187,153 +258,21 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
|
|
||||||
providedSetupKey, err := getSetupKey()
|
providedSetupKey, err := getSetupKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("get setup key: %v", err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
loginRequest, err := setupLoginRequest(providedSetupKey, customDNSAddressConverted, cmd)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("setup login request: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var loginErr error
|
|
||||||
var loginResp *proto.LoginResponse
|
|
||||||
|
|
||||||
err = WithBackOff(func() error {
|
|
||||||
var backOffErr error
|
|
||||||
loginResp, backOffErr = client.Login(ctx, loginRequest)
|
|
||||||
if s, ok := gstatus.FromError(backOffErr); ok && (s.Code() == codes.InvalidArgument ||
|
|
||||||
s.Code() == codes.PermissionDenied ||
|
|
||||||
s.Code() == codes.NotFound ||
|
|
||||||
s.Code() == codes.Unimplemented) {
|
|
||||||
loginErr = backOffErr
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return backOffErr
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("login backoff cycle failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if loginErr != nil {
|
|
||||||
return fmt.Errorf("login failed: %v", loginErr)
|
|
||||||
}
|
|
||||||
|
|
||||||
if loginResp.NeedsSSOLogin {
|
|
||||||
|
|
||||||
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
|
|
||||||
|
|
||||||
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("waiting sso login failed with: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := client.Up(ctx, &proto.UpRequest{}); err != nil {
|
|
||||||
return fmt.Errorf("call service up method: %v", err)
|
|
||||||
}
|
|
||||||
cmd.Println("Connected")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command) (*internal.ConfigInput, error) {
|
|
||||||
ic := internal.ConfigInput{
|
|
||||||
ManagementURL: managementURL,
|
|
||||||
AdminURL: adminURL,
|
|
||||||
ConfigPath: configPath,
|
|
||||||
NATExternalIPs: natExternalIPs,
|
|
||||||
CustomDNSAddress: customDNSAddressConverted,
|
|
||||||
ExtraIFaceBlackList: extraIFaceBlackList,
|
|
||||||
DNSLabels: dnsLabelsValidated,
|
|
||||||
}
|
|
||||||
|
|
||||||
if cmd.Flag(enableRosenpassFlag).Changed {
|
|
||||||
ic.RosenpassEnabled = &rosenpassEnabled
|
|
||||||
}
|
|
||||||
|
|
||||||
if cmd.Flag(rosenpassPermissiveFlag).Changed {
|
|
||||||
ic.RosenpassPermissive = &rosenpassPermissive
|
|
||||||
}
|
|
||||||
|
|
||||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
|
||||||
ic.ServerSSHAllowed = &serverSSHAllowed
|
|
||||||
}
|
|
||||||
|
|
||||||
if cmd.Flag(interfaceNameFlag).Changed {
|
|
||||||
if err := parseInterfaceName(interfaceName); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
ic.InterfaceName = &interfaceName
|
|
||||||
}
|
|
||||||
|
|
||||||
if cmd.Flag(wireguardPortFlag).Changed {
|
|
||||||
p := int(wireguardPort)
|
|
||||||
ic.WireguardPort = &p
|
|
||||||
}
|
|
||||||
|
|
||||||
if cmd.Flag(networkMonitorFlag).Changed {
|
|
||||||
ic.NetworkMonitor = &networkMonitor
|
|
||||||
}
|
|
||||||
|
|
||||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
|
||||||
ic.PreSharedKey = &preSharedKey
|
|
||||||
}
|
|
||||||
|
|
||||||
if cmd.Flag(disableAutoConnectFlag).Changed {
|
|
||||||
ic.DisableAutoConnect = &autoConnectDisabled
|
|
||||||
|
|
||||||
if autoConnectDisabled {
|
|
||||||
cmd.Println("Autoconnect has been disabled. The client won't connect automatically when the service starts.")
|
|
||||||
}
|
|
||||||
|
|
||||||
if !autoConnectDisabled {
|
|
||||||
cmd.Println("Autoconnect has been enabled. The client will connect automatically when the service starts.")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if cmd.Flag(dnsRouteIntervalFlag).Changed {
|
|
||||||
ic.DNSRouteInterval = &dnsRouteInterval
|
|
||||||
}
|
|
||||||
|
|
||||||
if cmd.Flag(disableClientRoutesFlag).Changed {
|
|
||||||
ic.DisableClientRoutes = &disableClientRoutes
|
|
||||||
}
|
|
||||||
if cmd.Flag(disableServerRoutesFlag).Changed {
|
|
||||||
ic.DisableServerRoutes = &disableServerRoutes
|
|
||||||
}
|
|
||||||
if cmd.Flag(disableDNSFlag).Changed {
|
|
||||||
ic.DisableDNS = &disableDNS
|
|
||||||
}
|
|
||||||
if cmd.Flag(disableFirewallFlag).Changed {
|
|
||||||
ic.DisableFirewall = &disableFirewall
|
|
||||||
}
|
|
||||||
|
|
||||||
if cmd.Flag(blockLANAccessFlag).Changed {
|
|
||||||
ic.BlockLANAccess = &blockLANAccess
|
|
||||||
}
|
|
||||||
|
|
||||||
if cmd.Flag(blockInboundFlag).Changed {
|
|
||||||
ic.BlockInbound = &blockInbound
|
|
||||||
}
|
|
||||||
|
|
||||||
if cmd.Flag(enableLazyConnectionFlag).Changed {
|
|
||||||
ic.LazyConnectionEnabled = &lazyConnEnabled
|
|
||||||
}
|
|
||||||
return &ic, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte, cmd *cobra.Command) (*proto.LoginRequest, error) {
|
|
||||||
loginRequest := proto.LoginRequest{
|
loginRequest := proto.LoginRequest{
|
||||||
SetupKey: providedSetupKey,
|
SetupKey: providedSetupKey,
|
||||||
ManagementUrl: managementURL,
|
ManagementUrl: managementURL,
|
||||||
AdminURL: adminURL,
|
AdminURL: adminURL,
|
||||||
NatExternalIPs: natExternalIPs,
|
NatExternalIPs: natExternalIPs,
|
||||||
CleanNATExternalIPs: natExternalIPs != nil && len(natExternalIPs) == 0,
|
CleanNATExternalIPs: natExternalIPs != nil && len(natExternalIPs) == 0,
|
||||||
CustomDNSAddress: customDNSAddressConverted,
|
CustomDNSAddress: customDNSAddressConverted,
|
||||||
IsUnixDesktopClient: isUnixRunningDesktop(),
|
IsLinuxDesktopClient: isLinuxRunningDesktop(),
|
||||||
Hostname: hostName,
|
Hostname: hostName,
|
||||||
ExtraIFaceBlacklist: extraIFaceBlackList,
|
ExtraIFaceBlacklist: extraIFaceBlackList,
|
||||||
DnsLabels: dnsLabels,
|
DnsLabels: dnsLabels,
|
||||||
CleanDNSLabels: dnsLabels != nil && len(dnsLabels) == 0,
|
CleanDNSLabels: dnsLabels != nil && len(dnsLabels) == 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||||
@@ -358,7 +297,7 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
|
|||||||
|
|
||||||
if cmd.Flag(interfaceNameFlag).Changed {
|
if cmd.Flag(interfaceNameFlag).Changed {
|
||||||
if err := parseInterfaceName(interfaceName); err != nil {
|
if err := parseInterfaceName(interfaceName); err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
loginRequest.InterfaceName = &interfaceName
|
loginRequest.InterfaceName = &interfaceName
|
||||||
}
|
}
|
||||||
@@ -393,14 +332,45 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
|
|||||||
loginRequest.BlockLanAccess = &blockLANAccess
|
loginRequest.BlockLanAccess = &blockLANAccess
|
||||||
}
|
}
|
||||||
|
|
||||||
if cmd.Flag(blockInboundFlag).Changed {
|
var loginErr error
|
||||||
loginRequest.BlockInbound = &blockInbound
|
|
||||||
|
var loginResp *proto.LoginResponse
|
||||||
|
|
||||||
|
err = WithBackOff(func() error {
|
||||||
|
var backOffErr error
|
||||||
|
loginResp, backOffErr = client.Login(ctx, &loginRequest)
|
||||||
|
if s, ok := gstatus.FromError(backOffErr); ok && (s.Code() == codes.InvalidArgument ||
|
||||||
|
s.Code() == codes.PermissionDenied ||
|
||||||
|
s.Code() == codes.NotFound ||
|
||||||
|
s.Code() == codes.Unimplemented) {
|
||||||
|
loginErr = backOffErr
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return backOffErr
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("login backoff cycle failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if cmd.Flag(enableLazyConnectionFlag).Changed {
|
if loginErr != nil {
|
||||||
loginRequest.LazyConnectionEnabled = &lazyConnEnabled
|
return fmt.Errorf("login failed: %v", loginErr)
|
||||||
}
|
}
|
||||||
return &loginRequest, nil
|
|
||||||
|
if loginResp.NeedsSSOLogin {
|
||||||
|
|
||||||
|
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
|
||||||
|
|
||||||
|
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("waiting sso login failed with: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := client.Up(ctx, &proto.UpRequest{}); err != nil {
|
||||||
|
return fmt.Errorf("call service up method: %v", err)
|
||||||
|
}
|
||||||
|
cmd.Println("Connected")
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateNATExternalIPs(list []string) error {
|
func validateNATExternalIPs(list []string) error {
|
||||||
|
|||||||
@@ -147,10 +147,6 @@ func (m *Manager) IsServerRouteSupported() bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) IsStateful() bool {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
@@ -202,7 +198,7 @@ func (m *Manager) AllowNetbird() error {
|
|||||||
_, err := m.AddPeerFiltering(
|
_, err := m.AddPeerFiltering(
|
||||||
nil,
|
nil,
|
||||||
net.IP{0, 0, 0, 0},
|
net.IP{0, 0, 0, 0},
|
||||||
firewall.ProtocolALL,
|
"all",
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
firewall.ActionAccept,
|
firewall.ActionAccept,
|
||||||
@@ -223,16 +219,10 @@ func (m *Manager) SetLogLevel(log.Level) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) EnableRouting() error {
|
func (m *Manager) EnableRouting() error {
|
||||||
if err := m.router.ipFwdState.RequestForwarding(); err != nil {
|
|
||||||
return fmt.Errorf("enable IP forwarding: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) DisableRouting() error {
|
func (m *Manager) DisableRouting() error {
|
||||||
if err := m.router.ipFwdState.ReleaseForwarding(); err != nil {
|
|
||||||
return fmt.Errorf("disable IP forwarding: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ package iptables
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -19,8 +19,11 @@ var ifaceMock = &iFaceMock{
|
|||||||
},
|
},
|
||||||
AddressFunc: func() wgaddr.Address {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return wgaddr.Address{
|
return wgaddr.Address{
|
||||||
IP: netip.MustParseAddr("10.20.0.1"),
|
IP: net.ParseIP("10.20.0.1"),
|
||||||
Network: netip.MustParsePrefix("10.20.0.0/24"),
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("10.20.0.0"),
|
||||||
|
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -67,12 +70,12 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
|
|
||||||
var rule2 []fw.Rule
|
var rule2 []fw.Rule
|
||||||
t.Run("add second rule", func(t *testing.T) {
|
t.Run("add second rule", func(t *testing.T) {
|
||||||
ip := netip.MustParseAddr("10.20.0.3")
|
ip := net.ParseIP("10.20.0.3")
|
||||||
port := &fw.Port{
|
port := &fw.Port{
|
||||||
IsRange: true,
|
IsRange: true,
|
||||||
Values: []uint16{8043, 8046},
|
Values: []uint16{8043, 8046},
|
||||||
}
|
}
|
||||||
rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "")
|
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
for _, r := range rule2 {
|
for _, r := range rule2 {
|
||||||
@@ -92,9 +95,9 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("reset check", func(t *testing.T) {
|
t.Run("reset check", func(t *testing.T) {
|
||||||
// add second rule
|
// add second rule
|
||||||
ip := netip.MustParseAddr("10.20.0.3")
|
ip := net.ParseIP("10.20.0.3")
|
||||||
port := &fw.Port{Values: []uint16{5353}}
|
port := &fw.Port{Values: []uint16{5353}}
|
||||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "udp", nil, port, fw.ActionAccept, "")
|
_, err = manager.AddPeerFiltering(nil, ip, "udp", nil, port, fw.ActionAccept, "")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
err = manager.Close(nil)
|
err = manager.Close(nil)
|
||||||
@@ -116,8 +119,11 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
|||||||
},
|
},
|
||||||
AddressFunc: func() wgaddr.Address {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return wgaddr.Address{
|
return wgaddr.Address{
|
||||||
IP: netip.MustParseAddr("10.20.0.1"),
|
IP: net.ParseIP("10.20.0.1"),
|
||||||
Network: netip.MustParsePrefix("10.20.0.0/24"),
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("10.20.0.0"),
|
||||||
|
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -138,11 +144,11 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
|||||||
|
|
||||||
var rule2 []fw.Rule
|
var rule2 []fw.Rule
|
||||||
t.Run("add second rule", func(t *testing.T) {
|
t.Run("add second rule", func(t *testing.T) {
|
||||||
ip := netip.MustParseAddr("10.20.0.3")
|
ip := net.ParseIP("10.20.0.3")
|
||||||
port := &fw.Port{
|
port := &fw.Port{
|
||||||
Values: []uint16{443},
|
Values: []uint16{443},
|
||||||
}
|
}
|
||||||
rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "default")
|
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "default")
|
||||||
for _, r := range rule2 {
|
for _, r := range rule2 {
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
|
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
|
||||||
@@ -180,8 +186,11 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
|||||||
},
|
},
|
||||||
AddressFunc: func() wgaddr.Address {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return wgaddr.Address{
|
return wgaddr.Address{
|
||||||
IP: netip.MustParseAddr("10.20.0.1"),
|
IP: net.ParseIP("10.20.0.1"),
|
||||||
Network: netip.MustParsePrefix("10.20.0.0/24"),
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("10.20.0.0"),
|
||||||
|
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -203,11 +212,11 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
|||||||
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
ip := netip.MustParseAddr("10.20.0.100")
|
ip := net.ParseIP("10.20.0.100")
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
for i := 0; i < testMax; i++ {
|
for i := 0; i < testMax; i++ {
|
||||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "")
|
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
|
||||||
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -248,6 +248,10 @@ func (r *router) deleteIpSet(setName string) error {
|
|||||||
|
|
||||||
// AddNatRule inserts an iptables rule pair into the nat chain
|
// AddNatRule inserts an iptables rule pair into the nat chain
|
||||||
func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
||||||
|
if err := r.ipFwdState.RequestForwarding(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
if r.legacyManagement {
|
if r.legacyManagement {
|
||||||
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
|
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
|
||||||
if err := r.addLegacyRouteRule(pair); err != nil {
|
if err := r.addLegacyRouteRule(pair); err != nil {
|
||||||
@@ -274,6 +278,10 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
|||||||
|
|
||||||
// RemoveNatRule removes an iptables rule pair from forwarding and nat chains
|
// RemoveNatRule removes an iptables rule pair from forwarding and nat chains
|
||||||
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
||||||
|
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
|
||||||
|
log.Errorf("%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
if pair.Masquerade {
|
if pair.Masquerade {
|
||||||
if err := r.removeNatRule(pair); err != nil {
|
if err := r.removeNatRule(pair); err != nil {
|
||||||
return fmt.Errorf("remove nat rule: %w", err)
|
return fmt.Errorf("remove nat rule: %w", err)
|
||||||
|
|||||||
@@ -116,8 +116,6 @@ type Manager interface {
|
|||||||
// IsServerRouteSupported returns true if the firewall supports server side routing operations
|
// IsServerRouteSupported returns true if the firewall supports server side routing operations
|
||||||
IsServerRouteSupported() bool
|
IsServerRouteSupported() bool
|
||||||
|
|
||||||
IsStateful() bool
|
|
||||||
|
|
||||||
AddRouteFiltering(
|
AddRouteFiltering(
|
||||||
id []byte,
|
id []byte,
|
||||||
sources []netip.Prefix,
|
sources []netip.Prefix,
|
||||||
|
|||||||
@@ -170,10 +170,6 @@ func (m *Manager) IsServerRouteSupported() bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) IsStateful() bool {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
@@ -328,16 +324,10 @@ func (m *Manager) SetLogLevel(log.Level) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) EnableRouting() error {
|
func (m *Manager) EnableRouting() error {
|
||||||
if err := m.router.ipFwdState.RequestForwarding(); err != nil {
|
|
||||||
return fmt.Errorf("enable IP forwarding: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) DisableRouting() error {
|
func (m *Manager) DisableRouting() error {
|
||||||
if err := m.router.ipFwdState.ReleaseForwarding(); err != nil {
|
|
||||||
return fmt.Errorf("disable IP forwarding: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package nftables
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -24,8 +25,11 @@ var ifaceMock = &iFaceMock{
|
|||||||
},
|
},
|
||||||
AddressFunc: func() wgaddr.Address {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return wgaddr.Address{
|
return wgaddr.Address{
|
||||||
IP: netip.MustParseAddr("100.96.0.1"),
|
IP: net.ParseIP("100.96.0.1"),
|
||||||
Network: netip.MustParsePrefix("100.96.0.0/16"),
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("100.96.0.0"),
|
||||||
|
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -66,11 +70,11 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
ip := netip.MustParseAddr("100.96.0.1").Unmap()
|
ip := net.ParseIP("100.96.0.1")
|
||||||
|
|
||||||
testClient := &nftables.Conn{}
|
testClient := &nftables.Conn{}
|
||||||
|
|
||||||
rule, err := manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "")
|
rule, err := manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
err = manager.Flush()
|
err = manager.Flush()
|
||||||
@@ -105,6 +109,8 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
}
|
}
|
||||||
compareExprsIgnoringCounters(t, rules[0].Exprs, expectedExprs1)
|
compareExprsIgnoringCounters(t, rules[0].Exprs, expectedExprs1)
|
||||||
|
|
||||||
|
ipToAdd, _ := netip.AddrFromSlice(ip)
|
||||||
|
add := ipToAdd.Unmap()
|
||||||
expectedExprs2 := []expr.Any{
|
expectedExprs2 := []expr.Any{
|
||||||
&expr.Payload{
|
&expr.Payload{
|
||||||
DestRegister: 1,
|
DestRegister: 1,
|
||||||
@@ -126,7 +132,7 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
&expr.Cmp{
|
&expr.Cmp{
|
||||||
Op: expr.CmpOpEq,
|
Op: expr.CmpOpEq,
|
||||||
Register: 1,
|
Register: 1,
|
||||||
Data: ip.AsSlice(),
|
Data: add.AsSlice(),
|
||||||
},
|
},
|
||||||
&expr.Payload{
|
&expr.Payload{
|
||||||
DestRegister: 1,
|
DestRegister: 1,
|
||||||
@@ -167,8 +173,11 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
|||||||
},
|
},
|
||||||
AddressFunc: func() wgaddr.Address {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return wgaddr.Address{
|
return wgaddr.Address{
|
||||||
IP: netip.MustParseAddr("100.96.0.1"),
|
IP: net.ParseIP("100.96.0.1"),
|
||||||
Network: netip.MustParsePrefix("100.96.0.0/16"),
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("100.96.0.0"),
|
||||||
|
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -188,11 +197,11 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
|||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
ip := netip.MustParseAddr("10.20.0.100")
|
ip := net.ParseIP("10.20.0.100")
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
for i := 0; i < testMax; i++ {
|
for i := 0; i < testMax; i++ {
|
||||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "")
|
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
if i%100 == 0 {
|
if i%100 == 0 {
|
||||||
@@ -273,8 +282,8 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
|||||||
verifyIptablesOutput(t, stdout, stderr)
|
verifyIptablesOutput(t, stdout, stderr)
|
||||||
})
|
})
|
||||||
|
|
||||||
ip := netip.MustParseAddr("100.96.0.1")
|
ip := net.ParseIP("100.96.0.1")
|
||||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
_, err = manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||||
require.NoError(t, err, "failed to add peer filtering rule")
|
require.NoError(t, err, "failed to add peer filtering rule")
|
||||||
|
|
||||||
_, err = manager.AddRouteFiltering(
|
_, err = manager.AddRouteFiltering(
|
||||||
|
|||||||
@@ -573,6 +573,10 @@ func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
|
|||||||
|
|
||||||
// AddNatRule appends a nftables rule pair to the nat chain
|
// AddNatRule appends a nftables rule pair to the nat chain
|
||||||
func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
||||||
|
if err := r.ipFwdState.RequestForwarding(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
if err := r.refreshRulesMap(); err != nil {
|
if err := r.refreshRulesMap(); err != nil {
|
||||||
return fmt.Errorf(refreshRulesMapError, err)
|
return fmt.Errorf(refreshRulesMapError, err)
|
||||||
}
|
}
|
||||||
@@ -1002,6 +1006,10 @@ func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error
|
|||||||
|
|
||||||
// RemoveNatRule removes the prerouting mark rule
|
// RemoveNatRule removes the prerouting mark rule
|
||||||
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
||||||
|
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
|
||||||
|
log.Errorf("%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
if err := r.refreshRulesMap(); err != nil {
|
if err := r.refreshRulesMap(); err != nil {
|
||||||
return fmt.Errorf(refreshRulesMapError, err)
|
return fmt.Errorf(refreshRulesMapError, err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -62,5 +62,5 @@ type ConnKey struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c ConnKey) String() string {
|
func (c ConnKey) String() string {
|
||||||
return fmt.Sprintf("%s:%d → %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
|
return fmt.Sprintf("%s:%d -> %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package conntrack
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@@ -20,10 +19,6 @@ const (
|
|||||||
DefaultICMPTimeout = 30 * time.Second
|
DefaultICMPTimeout = 30 * time.Second
|
||||||
// ICMPCleanupInterval is how often we check for stale ICMP connections
|
// ICMPCleanupInterval is how often we check for stale ICMP connections
|
||||||
ICMPCleanupInterval = 15 * time.Second
|
ICMPCleanupInterval = 15 * time.Second
|
||||||
|
|
||||||
// MaxICMPPayloadLength is the maximum length of ICMP payload we consider for original packet info,
|
|
||||||
// which includes the IP header (20 bytes) and transport header (8 bytes)
|
|
||||||
MaxICMPPayloadLength = 28
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ICMPConnKey uniquely identifies an ICMP connection
|
// ICMPConnKey uniquely identifies an ICMP connection
|
||||||
@@ -34,7 +29,7 @@ type ICMPConnKey struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (i ICMPConnKey) String() string {
|
func (i ICMPConnKey) String() string {
|
||||||
return fmt.Sprintf("%s → %s (id %d)", i.SrcIP, i.DstIP, i.ID)
|
return fmt.Sprintf("%s -> %s (id %d)", i.SrcIP, i.DstIP, i.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ICMPConnTrack represents an ICMP connection state
|
// ICMPConnTrack represents an ICMP connection state
|
||||||
@@ -55,72 +50,6 @@ type ICMPTracker struct {
|
|||||||
flowLogger nftypes.FlowLogger
|
flowLogger nftypes.FlowLogger
|
||||||
}
|
}
|
||||||
|
|
||||||
// ICMPInfo holds ICMP type, code, and payload for lazy string formatting in logs
|
|
||||||
type ICMPInfo struct {
|
|
||||||
TypeCode layers.ICMPv4TypeCode
|
|
||||||
PayloadData [MaxICMPPayloadLength]byte
|
|
||||||
// actual length of valid data
|
|
||||||
PayloadLen int
|
|
||||||
}
|
|
||||||
|
|
||||||
// String implements fmt.Stringer for lazy evaluation in log messages
|
|
||||||
func (info ICMPInfo) String() string {
|
|
||||||
if info.isErrorMessage() && info.PayloadLen >= MaxICMPPayloadLength {
|
|
||||||
if origInfo := info.parseOriginalPacket(); origInfo != "" {
|
|
||||||
return fmt.Sprintf("%s (original: %s)", info.TypeCode, origInfo)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return info.TypeCode.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
// isErrorMessage returns true if this ICMP type carries original packet info
|
|
||||||
func (info ICMPInfo) isErrorMessage() bool {
|
|
||||||
typ := info.TypeCode.Type()
|
|
||||||
return typ == 3 || // Destination Unreachable
|
|
||||||
typ == 5 || // Redirect
|
|
||||||
typ == 11 || // Time Exceeded
|
|
||||||
typ == 12 // Parameter Problem
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseOriginalPacket extracts info about the original packet from ICMP payload
|
|
||||||
func (info ICMPInfo) parseOriginalPacket() string {
|
|
||||||
if info.PayloadLen < MaxICMPPayloadLength {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: handle IPv6
|
|
||||||
if version := (info.PayloadData[0] >> 4) & 0xF; version != 4 {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
protocol := info.PayloadData[9]
|
|
||||||
srcIP := net.IP(info.PayloadData[12:16])
|
|
||||||
dstIP := net.IP(info.PayloadData[16:20])
|
|
||||||
|
|
||||||
transportData := info.PayloadData[20:]
|
|
||||||
|
|
||||||
switch nftypes.Protocol(protocol) {
|
|
||||||
case nftypes.TCP:
|
|
||||||
srcPort := uint16(transportData[0])<<8 | uint16(transportData[1])
|
|
||||||
dstPort := uint16(transportData[2])<<8 | uint16(transportData[3])
|
|
||||||
return fmt.Sprintf("TCP %s:%d → %s:%d", srcIP, srcPort, dstIP, dstPort)
|
|
||||||
|
|
||||||
case nftypes.UDP:
|
|
||||||
srcPort := uint16(transportData[0])<<8 | uint16(transportData[1])
|
|
||||||
dstPort := uint16(transportData[2])<<8 | uint16(transportData[3])
|
|
||||||
return fmt.Sprintf("UDP %s:%d → %s:%d", srcIP, srcPort, dstIP, dstPort)
|
|
||||||
|
|
||||||
case nftypes.ICMP:
|
|
||||||
icmpType := transportData[0]
|
|
||||||
icmpCode := transportData[1]
|
|
||||||
return fmt.Sprintf("ICMP %s → %s (type %d code %d)", srcIP, dstIP, icmpType, icmpCode)
|
|
||||||
|
|
||||||
default:
|
|
||||||
return fmt.Sprintf("Proto %d %s → %s", protocol, srcIP, dstIP)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewICMPTracker creates a new ICMP connection tracker
|
// NewICMPTracker creates a new ICMP connection tracker
|
||||||
func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *ICMPTracker {
|
func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *ICMPTracker {
|
||||||
if timeout == 0 {
|
if timeout == 0 {
|
||||||
@@ -164,64 +93,30 @@ func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TrackOutbound records an outbound ICMP connection
|
// TrackOutbound records an outbound ICMP connection
|
||||||
func (t *ICMPTracker) TrackOutbound(
|
func (t *ICMPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, size int) {
|
||||||
srcIP netip.Addr,
|
|
||||||
dstIP netip.Addr,
|
|
||||||
id uint16,
|
|
||||||
typecode layers.ICMPv4TypeCode,
|
|
||||||
payload []byte,
|
|
||||||
size int,
|
|
||||||
) {
|
|
||||||
if _, exists := t.updateIfExists(dstIP, srcIP, id, nftypes.Egress, size); !exists {
|
if _, exists := t.updateIfExists(dstIP, srcIP, id, nftypes.Egress, size); !exists {
|
||||||
// if (inverted direction) conn is not tracked, track this direction
|
// if (inverted direction) conn is not tracked, track this direction
|
||||||
t.track(srcIP, dstIP, id, typecode, nftypes.Egress, nil, payload, size)
|
t.track(srcIP, dstIP, id, typecode, nftypes.Egress, nil, size)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TrackInbound records an inbound ICMP Echo Request
|
// TrackInbound records an inbound ICMP Echo Request
|
||||||
func (t *ICMPTracker) TrackInbound(
|
func (t *ICMPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, ruleId []byte, size int) {
|
||||||
srcIP netip.Addr,
|
t.track(srcIP, dstIP, id, typecode, nftypes.Ingress, ruleId, size)
|
||||||
dstIP netip.Addr,
|
|
||||||
id uint16,
|
|
||||||
typecode layers.ICMPv4TypeCode,
|
|
||||||
ruleId []byte,
|
|
||||||
payload []byte,
|
|
||||||
size int,
|
|
||||||
) {
|
|
||||||
t.track(srcIP, dstIP, id, typecode, nftypes.Ingress, ruleId, payload, size)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// track is the common implementation for tracking both inbound and outbound ICMP connections
|
// track is the common implementation for tracking both inbound and outbound ICMP connections
|
||||||
func (t *ICMPTracker) track(
|
func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction, ruleId []byte, size int) {
|
||||||
srcIP netip.Addr,
|
|
||||||
dstIP netip.Addr,
|
|
||||||
id uint16,
|
|
||||||
typecode layers.ICMPv4TypeCode,
|
|
||||||
direction nftypes.Direction,
|
|
||||||
ruleId []byte,
|
|
||||||
payload []byte,
|
|
||||||
size int,
|
|
||||||
) {
|
|
||||||
key, exists := t.updateIfExists(srcIP, dstIP, id, direction, size)
|
key, exists := t.updateIfExists(srcIP, dstIP, id, direction, size)
|
||||||
if exists {
|
if exists {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
typ, code := typecode.Type(), typecode.Code()
|
typ, code := typecode.Type(), typecode.Code()
|
||||||
icmpInfo := ICMPInfo{
|
|
||||||
TypeCode: typecode,
|
|
||||||
}
|
|
||||||
if len(payload) > 0 {
|
|
||||||
icmpInfo.PayloadLen = len(payload)
|
|
||||||
if icmpInfo.PayloadLen > MaxICMPPayloadLength {
|
|
||||||
icmpInfo.PayloadLen = MaxICMPPayloadLength
|
|
||||||
}
|
|
||||||
copy(icmpInfo.PayloadData[:], payload[:icmpInfo.PayloadLen])
|
|
||||||
}
|
|
||||||
|
|
||||||
// non echo requests don't need tracking
|
// non echo requests don't need tracking
|
||||||
if typ != uint8(layers.ICMPv4TypeEchoRequest) {
|
if typ != uint8(layers.ICMPv4TypeEchoRequest) {
|
||||||
t.logger.Trace("New %s ICMP connection %s - %s", direction, key, icmpInfo)
|
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
|
||||||
t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size)
|
t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -243,7 +138,7 @@ func (t *ICMPTracker) track(
|
|||||||
t.connections[key] = conn
|
t.connections[key] = conn
|
||||||
t.mutex.Unlock()
|
t.mutex.Unlock()
|
||||||
|
|
||||||
t.logger.Trace("New %s ICMP connection %s - %s", direction, key, icmpInfo)
|
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
|
||||||
t.sendEvent(nftypes.TypeStart, conn, ruleId)
|
t.sendEvent(nftypes.TypeStart, conn, ruleId)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ func BenchmarkICMPTracker(b *testing.B) {
|
|||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0, []byte{}, 0)
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -28,7 +28,7 @@ func BenchmarkICMPTracker(b *testing.B) {
|
|||||||
|
|
||||||
// Pre-populate some connections
|
// Pre-populate some connections
|
||||||
for i := 0; i < 1000; i++ {
|
for i := 0; i < 1000; i++ {
|
||||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0, []byte{}, 0)
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
|
|||||||
@@ -86,5 +86,5 @@ type epID stack.TransportEndpointID
|
|||||||
|
|
||||||
func (i epID) String() string {
|
func (i epID) String() string {
|
||||||
// src and remote is swapped
|
// src and remote is swapped
|
||||||
return fmt.Sprintf("%s:%d → %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort)
|
return fmt.Sprintf("%s:%d -> %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ type Forwarder struct {
|
|||||||
udpForwarder *udpForwarder
|
udpForwarder *udpForwarder
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
ip tcpip.Address
|
ip net.IP
|
||||||
netstack bool
|
netstack bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -71,11 +71,12 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
|||||||
return nil, fmt.Errorf("failed to create NIC: %v", err)
|
return nil, fmt.Errorf("failed to create NIC: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ones, _ := iface.Address().Network.Mask.Size()
|
||||||
protoAddr := tcpip.ProtocolAddress{
|
protoAddr := tcpip.ProtocolAddress{
|
||||||
Protocol: ipv4.ProtocolNumber,
|
Protocol: ipv4.ProtocolNumber,
|
||||||
AddressWithPrefix: tcpip.AddressWithPrefix{
|
AddressWithPrefix: tcpip.AddressWithPrefix{
|
||||||
Address: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
|
Address: tcpip.AddrFromSlice(iface.Address().IP.To4()),
|
||||||
PrefixLen: iface.Address().Network.Bits(),
|
PrefixLen: ones,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -115,7 +116,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
|||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
netstack: netstack,
|
netstack: netstack,
|
||||||
ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
|
ip: iface.Address().IP,
|
||||||
}
|
}
|
||||||
|
|
||||||
receiveWindow := defaultReceiveWindow
|
receiveWindow := defaultReceiveWindow
|
||||||
@@ -166,7 +167,7 @@ func (f *Forwarder) Stop() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
|
func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
|
||||||
if f.netstack && f.ip.Equal(addr) {
|
if f.netstack && f.ip.Equal(addr.AsSlice()) {
|
||||||
return net.IPv4(127, 0, 0, 1)
|
return net.IPv4(127, 0, 0, 1)
|
||||||
}
|
}
|
||||||
return addr.AsSlice()
|
return addr.AsSlice()
|
||||||
@@ -178,6 +179,7 @@ func (f *Forwarder) RegisterRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uin
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (f *Forwarder) getRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) ([]byte, bool) {
|
func (f *Forwarder) getRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) ([]byte, bool) {
|
||||||
|
|
||||||
if value, ok := f.ruleIdMap.Load(buildKey(srcIP, dstIP, srcPort, dstPort)); ok {
|
if value, ok := f.ruleIdMap.Load(buildKey(srcIP, dstIP, srcPort, dstPort)); ok {
|
||||||
return value.([]byte), true
|
return value.([]byte), true
|
||||||
} else if value, ok := f.ruleIdMap.Load(buildKey(dstIP, srcIP, dstPort, srcPort)); ok {
|
} else if value, ok := f.ruleIdMap.Load(buildKey(dstIP, srcIP, dstPort, srcPort)); ok {
|
||||||
|
|||||||
@@ -111,12 +111,12 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
|
|||||||
|
|
||||||
if errInToOut != nil {
|
if errInToOut != nil {
|
||||||
if !isClosedError(errInToOut) {
|
if !isClosedError(errInToOut) {
|
||||||
f.logger.Error("proxyTCP: copy error (in → out) for %s: %v", epID(id), errInToOut)
|
f.logger.Error("proxyTCP: copy error (in -> out): %v", errInToOut)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if errOutToIn != nil {
|
if errOutToIn != nil {
|
||||||
if !isClosedError(errOutToIn) {
|
if !isClosedError(errOutToIn) {
|
||||||
f.logger.Error("proxyTCP: copy error (out → in) for %s: %v", epID(id), errOutToIn)
|
f.logger.Error("proxyTCP: copy error (out -> in): %v", errOutToIn)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -250,10 +250,10 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
|
|||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
|
||||||
if outboundErr != nil && !isClosedError(outboundErr) {
|
if outboundErr != nil && !isClosedError(outboundErr) {
|
||||||
f.logger.Error("proxyUDP: copy error (outbound→inbound) for %s: %v", epID(id), outboundErr)
|
f.logger.Error("proxyUDP: copy error (outbound->inbound): %v", outboundErr)
|
||||||
}
|
}
|
||||||
if inboundErr != nil && !isClosedError(inboundErr) {
|
if inboundErr != nil && !isClosedError(inboundErr) {
|
||||||
f.logger.Error("proxyUDP: copy error (inbound→outbound) for %s: %v", epID(id), inboundErr)
|
f.logger.Error("proxyUDP: copy error (inbound->outbound): %v", inboundErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
var rxPackets, txPackets uint64
|
var rxPackets, txPackets uint64
|
||||||
|
|||||||
@@ -45,26 +45,24 @@ func (m *localIPManager) setBitmapBit(ip net.IP) {
|
|||||||
m.ipv4Bitmap[high].bitmap[index] |= 1 << bit
|
m.ipv4Bitmap[high].bitmap[index] |= 1 << bit
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *localIPManager) setBitInBitmap(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
|
func (m *localIPManager) setBitInBitmap(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
|
||||||
if !ip.Is4() {
|
if ipv4 := ip.To4(); ipv4 != nil {
|
||||||
return
|
high := uint16(ipv4[0])
|
||||||
}
|
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
|
||||||
ipv4 := ip.AsSlice()
|
|
||||||
|
|
||||||
high := uint16(ipv4[0])
|
if bitmap[high] == nil {
|
||||||
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
|
bitmap[high] = &ipv4LowBitmap{}
|
||||||
|
}
|
||||||
|
|
||||||
if bitmap[high] == nil {
|
index := low / 32
|
||||||
bitmap[high] = &ipv4LowBitmap{}
|
bit := low % 32
|
||||||
}
|
bitmap[high].bitmap[index] |= 1 << bit
|
||||||
|
|
||||||
index := low / 32
|
ipStr := ipv4.String()
|
||||||
bit := low % 32
|
if _, exists := ipv4Set[ipStr]; !exists {
|
||||||
bitmap[high].bitmap[index] |= 1 << bit
|
ipv4Set[ipStr] = struct{}{}
|
||||||
|
*ipv4Addresses = append(*ipv4Addresses, ipStr)
|
||||||
if _, exists := ipv4Set[ip]; !exists {
|
}
|
||||||
ipv4Set[ip] = struct{}{}
|
|
||||||
*ipv4Addresses = append(*ipv4Addresses, ip)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -81,12 +79,12 @@ func (m *localIPManager) checkBitmapBit(ip []byte) bool {
|
|||||||
return (m.ipv4Bitmap[high].bitmap[index] & (1 << bit)) != 0
|
return (m.ipv4Bitmap[high].bitmap[index] & (1 << bit)) != 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *localIPManager) processIP(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) error {
|
func (m *localIPManager) processIP(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error {
|
||||||
m.setBitInBitmap(ip, bitmap, ipv4Set, ipv4Addresses)
|
m.setBitInBitmap(ip, bitmap, ipv4Set, ipv4Addresses)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
|
func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
|
||||||
addrs, err := iface.Addrs()
|
addrs, err := iface.Addrs()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
|
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
|
||||||
@@ -104,13 +102,7 @@ func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
addr, ok := netip.AddrFromSlice(ip)
|
if err := m.processIP(ip, bitmap, ipv4Set, ipv4Addresses); err != nil {
|
||||||
if !ok {
|
|
||||||
log.Warnf("invalid IP address %s in interface %s", ip.String(), iface.Name)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.processIP(addr.Unmap(), bitmap, ipv4Set, ipv4Addresses); err != nil {
|
|
||||||
log.Debugf("process IP failed: %v", err)
|
log.Debugf("process IP failed: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -124,8 +116,8 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
var newIPv4Bitmap [256]*ipv4LowBitmap
|
var newIPv4Bitmap [256]*ipv4LowBitmap
|
||||||
ipv4Set := make(map[netip.Addr]struct{})
|
ipv4Set := make(map[string]struct{})
|
||||||
var ipv4Addresses []netip.Addr
|
var ipv4Addresses []string
|
||||||
|
|
||||||
// 127.0.0.0/8
|
// 127.0.0.0/8
|
||||||
newIPv4Bitmap[127] = &ipv4LowBitmap{}
|
newIPv4Bitmap[127] = &ipv4LowBitmap{}
|
||||||
|
|||||||
@@ -20,8 +20,11 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Localhost range",
|
name: "Localhost range",
|
||||||
setupAddr: wgaddr.Address{
|
setupAddr: wgaddr.Address{
|
||||||
IP: netip.MustParseAddr("192.168.1.1"),
|
IP: net.ParseIP("192.168.1.1"),
|
||||||
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("192.168.1.0"),
|
||||||
|
Mask: net.CIDRMask(24, 32),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
testIP: netip.MustParseAddr("127.0.0.2"),
|
testIP: netip.MustParseAddr("127.0.0.2"),
|
||||||
expected: true,
|
expected: true,
|
||||||
@@ -29,8 +32,11 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Localhost standard address",
|
name: "Localhost standard address",
|
||||||
setupAddr: wgaddr.Address{
|
setupAddr: wgaddr.Address{
|
||||||
IP: netip.MustParseAddr("192.168.1.1"),
|
IP: net.ParseIP("192.168.1.1"),
|
||||||
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("192.168.1.0"),
|
||||||
|
Mask: net.CIDRMask(24, 32),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
testIP: netip.MustParseAddr("127.0.0.1"),
|
testIP: netip.MustParseAddr("127.0.0.1"),
|
||||||
expected: true,
|
expected: true,
|
||||||
@@ -38,8 +44,11 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Localhost range edge",
|
name: "Localhost range edge",
|
||||||
setupAddr: wgaddr.Address{
|
setupAddr: wgaddr.Address{
|
||||||
IP: netip.MustParseAddr("192.168.1.1"),
|
IP: net.ParseIP("192.168.1.1"),
|
||||||
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("192.168.1.0"),
|
||||||
|
Mask: net.CIDRMask(24, 32),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
testIP: netip.MustParseAddr("127.255.255.255"),
|
testIP: netip.MustParseAddr("127.255.255.255"),
|
||||||
expected: true,
|
expected: true,
|
||||||
@@ -47,8 +56,11 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Local IP matches",
|
name: "Local IP matches",
|
||||||
setupAddr: wgaddr.Address{
|
setupAddr: wgaddr.Address{
|
||||||
IP: netip.MustParseAddr("192.168.1.1"),
|
IP: net.ParseIP("192.168.1.1"),
|
||||||
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("192.168.1.0"),
|
||||||
|
Mask: net.CIDRMask(24, 32),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
testIP: netip.MustParseAddr("192.168.1.1"),
|
testIP: netip.MustParseAddr("192.168.1.1"),
|
||||||
expected: true,
|
expected: true,
|
||||||
@@ -56,8 +68,11 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Local IP doesn't match",
|
name: "Local IP doesn't match",
|
||||||
setupAddr: wgaddr.Address{
|
setupAddr: wgaddr.Address{
|
||||||
IP: netip.MustParseAddr("192.168.1.1"),
|
IP: net.ParseIP("192.168.1.1"),
|
||||||
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("192.168.1.0"),
|
||||||
|
Mask: net.CIDRMask(24, 32),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
testIP: netip.MustParseAddr("192.168.1.2"),
|
testIP: netip.MustParseAddr("192.168.1.2"),
|
||||||
expected: false,
|
expected: false,
|
||||||
@@ -65,8 +80,11 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Local IP doesn't match - addresses 32 apart",
|
name: "Local IP doesn't match - addresses 32 apart",
|
||||||
setupAddr: wgaddr.Address{
|
setupAddr: wgaddr.Address{
|
||||||
IP: netip.MustParseAddr("192.168.1.1"),
|
IP: net.ParseIP("192.168.1.1"),
|
||||||
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("192.168.1.0"),
|
||||||
|
Mask: net.CIDRMask(24, 32),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
testIP: netip.MustParseAddr("192.168.1.33"),
|
testIP: netip.MustParseAddr("192.168.1.33"),
|
||||||
expected: false,
|
expected: false,
|
||||||
@@ -74,8 +92,11 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "IPv6 address",
|
name: "IPv6 address",
|
||||||
setupAddr: wgaddr.Address{
|
setupAddr: wgaddr.Address{
|
||||||
IP: netip.MustParseAddr("fe80::1"),
|
IP: net.ParseIP("fe80::1"),
|
||||||
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("fe80::"),
|
||||||
|
Mask: net.CIDRMask(64, 128),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
testIP: netip.MustParseAddr("fe80::1"),
|
testIP: netip.MustParseAddr("fe80::1"),
|
||||||
expected: false,
|
expected: false,
|
||||||
|
|||||||
@@ -1,408 +0,0 @@
|
|||||||
package uspfilter
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/binary"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"net/netip"
|
|
||||||
|
|
||||||
"github.com/google/gopacket/layers"
|
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
|
||||||
)
|
|
||||||
|
|
||||||
var ErrIPv4Only = errors.New("only IPv4 is supported for DNAT")
|
|
||||||
|
|
||||||
func ipv4Checksum(header []byte) uint16 {
|
|
||||||
if len(header) < 20 {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
var sum1, sum2 uint32
|
|
||||||
|
|
||||||
// Parallel processing - unroll and compute two sums simultaneously
|
|
||||||
sum1 += uint32(binary.BigEndian.Uint16(header[0:2]))
|
|
||||||
sum2 += uint32(binary.BigEndian.Uint16(header[2:4]))
|
|
||||||
sum1 += uint32(binary.BigEndian.Uint16(header[4:6]))
|
|
||||||
sum2 += uint32(binary.BigEndian.Uint16(header[6:8]))
|
|
||||||
sum1 += uint32(binary.BigEndian.Uint16(header[8:10]))
|
|
||||||
// Skip checksum field at [10:12]
|
|
||||||
sum2 += uint32(binary.BigEndian.Uint16(header[12:14]))
|
|
||||||
sum1 += uint32(binary.BigEndian.Uint16(header[14:16]))
|
|
||||||
sum2 += uint32(binary.BigEndian.Uint16(header[16:18]))
|
|
||||||
sum1 += uint32(binary.BigEndian.Uint16(header[18:20]))
|
|
||||||
|
|
||||||
sum := sum1 + sum2
|
|
||||||
|
|
||||||
// Handle remaining bytes for headers > 20 bytes
|
|
||||||
for i := 20; i < len(header)-1; i += 2 {
|
|
||||||
sum += uint32(binary.BigEndian.Uint16(header[i : i+2]))
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(header)%2 == 1 {
|
|
||||||
sum += uint32(header[len(header)-1]) << 8
|
|
||||||
}
|
|
||||||
|
|
||||||
// Optimized carry fold - single iteration handles most cases
|
|
||||||
sum = (sum & 0xFFFF) + (sum >> 16)
|
|
||||||
if sum > 0xFFFF {
|
|
||||||
sum++
|
|
||||||
}
|
|
||||||
|
|
||||||
return ^uint16(sum)
|
|
||||||
}
|
|
||||||
|
|
||||||
func icmpChecksum(data []byte) uint16 {
|
|
||||||
var sum1, sum2, sum3, sum4 uint32
|
|
||||||
i := 0
|
|
||||||
|
|
||||||
// Process 16 bytes at once with 4 parallel accumulators
|
|
||||||
for i <= len(data)-16 {
|
|
||||||
sum1 += uint32(binary.BigEndian.Uint16(data[i : i+2]))
|
|
||||||
sum2 += uint32(binary.BigEndian.Uint16(data[i+2 : i+4]))
|
|
||||||
sum3 += uint32(binary.BigEndian.Uint16(data[i+4 : i+6]))
|
|
||||||
sum4 += uint32(binary.BigEndian.Uint16(data[i+6 : i+8]))
|
|
||||||
sum1 += uint32(binary.BigEndian.Uint16(data[i+8 : i+10]))
|
|
||||||
sum2 += uint32(binary.BigEndian.Uint16(data[i+10 : i+12]))
|
|
||||||
sum3 += uint32(binary.BigEndian.Uint16(data[i+12 : i+14]))
|
|
||||||
sum4 += uint32(binary.BigEndian.Uint16(data[i+14 : i+16]))
|
|
||||||
i += 16
|
|
||||||
}
|
|
||||||
|
|
||||||
sum := sum1 + sum2 + sum3 + sum4
|
|
||||||
|
|
||||||
// Handle remaining bytes
|
|
||||||
for i < len(data)-1 {
|
|
||||||
sum += uint32(binary.BigEndian.Uint16(data[i : i+2]))
|
|
||||||
i += 2
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(data)%2 == 1 {
|
|
||||||
sum += uint32(data[len(data)-1]) << 8
|
|
||||||
}
|
|
||||||
|
|
||||||
sum = (sum & 0xFFFF) + (sum >> 16)
|
|
||||||
if sum > 0xFFFF {
|
|
||||||
sum++
|
|
||||||
}
|
|
||||||
|
|
||||||
return ^uint16(sum)
|
|
||||||
}
|
|
||||||
|
|
||||||
type biDNATMap struct {
|
|
||||||
forward map[netip.Addr]netip.Addr
|
|
||||||
reverse map[netip.Addr]netip.Addr
|
|
||||||
}
|
|
||||||
|
|
||||||
func newBiDNATMap() *biDNATMap {
|
|
||||||
return &biDNATMap{
|
|
||||||
forward: make(map[netip.Addr]netip.Addr),
|
|
||||||
reverse: make(map[netip.Addr]netip.Addr),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *biDNATMap) set(original, translated netip.Addr) {
|
|
||||||
b.forward[original] = translated
|
|
||||||
b.reverse[translated] = original
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *biDNATMap) delete(original netip.Addr) {
|
|
||||||
if translated, exists := b.forward[original]; exists {
|
|
||||||
delete(b.forward, original)
|
|
||||||
delete(b.reverse, translated)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *biDNATMap) getTranslated(original netip.Addr) (netip.Addr, bool) {
|
|
||||||
translated, exists := b.forward[original]
|
|
||||||
return translated, exists
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *biDNATMap) getOriginal(translated netip.Addr) (netip.Addr, bool) {
|
|
||||||
original, exists := b.reverse[translated]
|
|
||||||
return original, exists
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr) error {
|
|
||||||
if !originalAddr.IsValid() || !translatedAddr.IsValid() {
|
|
||||||
return fmt.Errorf("invalid IP addresses")
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.localipmanager.IsLocalIP(translatedAddr) {
|
|
||||||
return fmt.Errorf("cannot map to local IP: %s", translatedAddr)
|
|
||||||
}
|
|
||||||
|
|
||||||
m.dnatMutex.Lock()
|
|
||||||
defer m.dnatMutex.Unlock()
|
|
||||||
|
|
||||||
// Initialize both maps together if either is nil
|
|
||||||
if m.dnatMappings == nil || m.dnatBiMap == nil {
|
|
||||||
m.dnatMappings = make(map[netip.Addr]netip.Addr)
|
|
||||||
m.dnatBiMap = newBiDNATMap()
|
|
||||||
}
|
|
||||||
|
|
||||||
m.dnatMappings[originalAddr] = translatedAddr
|
|
||||||
m.dnatBiMap.set(originalAddr, translatedAddr)
|
|
||||||
|
|
||||||
if len(m.dnatMappings) == 1 {
|
|
||||||
m.dnatEnabled.Store(true)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveInternalDNATMapping removes a 1:1 IP address mapping
|
|
||||||
func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error {
|
|
||||||
m.dnatMutex.Lock()
|
|
||||||
defer m.dnatMutex.Unlock()
|
|
||||||
|
|
||||||
if _, exists := m.dnatMappings[originalAddr]; !exists {
|
|
||||||
return fmt.Errorf("mapping not found for: %s", originalAddr)
|
|
||||||
}
|
|
||||||
|
|
||||||
delete(m.dnatMappings, originalAddr)
|
|
||||||
m.dnatBiMap.delete(originalAddr)
|
|
||||||
if len(m.dnatMappings) == 0 {
|
|
||||||
m.dnatEnabled.Store(false)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// getDNATTranslation returns the translated address if a mapping exists
|
|
||||||
func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) {
|
|
||||||
if !m.dnatEnabled.Load() {
|
|
||||||
return addr, false
|
|
||||||
}
|
|
||||||
|
|
||||||
m.dnatMutex.RLock()
|
|
||||||
translated, exists := m.dnatBiMap.getTranslated(addr)
|
|
||||||
m.dnatMutex.RUnlock()
|
|
||||||
return translated, exists
|
|
||||||
}
|
|
||||||
|
|
||||||
// findReverseDNATMapping finds original address for return traffic
|
|
||||||
func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr, bool) {
|
|
||||||
if !m.dnatEnabled.Load() {
|
|
||||||
return translatedAddr, false
|
|
||||||
}
|
|
||||||
|
|
||||||
m.dnatMutex.RLock()
|
|
||||||
original, exists := m.dnatBiMap.getOriginal(translatedAddr)
|
|
||||||
m.dnatMutex.RUnlock()
|
|
||||||
return original, exists
|
|
||||||
}
|
|
||||||
|
|
||||||
// translateOutboundDNAT applies DNAT translation to outbound packets
|
|
||||||
func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
|
|
||||||
if !m.dnatEnabled.Load() {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
|
|
||||||
|
|
||||||
translatedIP, exists := m.getDNATTranslation(dstIP)
|
|
||||||
if !exists {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.rewritePacketDestination(packetData, d, translatedIP); err != nil {
|
|
||||||
m.logger.Error("Failed to rewrite packet destination: %v", err)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
m.logger.Trace("DNAT: %s -> %s", dstIP, translatedIP)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// translateInboundReverse applies reverse DNAT to inbound return traffic
|
|
||||||
func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
|
|
||||||
if !m.dnatEnabled.Load() {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
|
|
||||||
|
|
||||||
originalIP, exists := m.findReverseDNATMapping(srcIP)
|
|
||||||
if !exists {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.rewritePacketSource(packetData, d, originalIP); err != nil {
|
|
||||||
m.logger.Error("Failed to rewrite packet source: %v", err)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
m.logger.Trace("Reverse DNAT: %s -> %s", srcIP, originalIP)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// rewritePacketDestination replaces destination IP in the packet
|
|
||||||
func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP netip.Addr) error {
|
|
||||||
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
|
|
||||||
return ErrIPv4Only
|
|
||||||
}
|
|
||||||
|
|
||||||
var oldDst [4]byte
|
|
||||||
copy(oldDst[:], packetData[16:20])
|
|
||||||
newDst := newIP.As4()
|
|
||||||
|
|
||||||
copy(packetData[16:20], newDst[:])
|
|
||||||
|
|
||||||
ipHeaderLen := int(d.ip4.IHL) * 4
|
|
||||||
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
|
|
||||||
return fmt.Errorf("invalid IP header length")
|
|
||||||
}
|
|
||||||
|
|
||||||
binary.BigEndian.PutUint16(packetData[10:12], 0)
|
|
||||||
ipChecksum := ipv4Checksum(packetData[:ipHeaderLen])
|
|
||||||
binary.BigEndian.PutUint16(packetData[10:12], ipChecksum)
|
|
||||||
|
|
||||||
if len(d.decoded) > 1 {
|
|
||||||
switch d.decoded[1] {
|
|
||||||
case layers.LayerTypeTCP:
|
|
||||||
m.updateTCPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:])
|
|
||||||
case layers.LayerTypeUDP:
|
|
||||||
m.updateUDPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:])
|
|
||||||
case layers.LayerTypeICMPv4:
|
|
||||||
m.updateICMPChecksum(packetData, ipHeaderLen)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// rewritePacketSource replaces the source IP address in the packet
|
|
||||||
func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip.Addr) error {
|
|
||||||
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
|
|
||||||
return ErrIPv4Only
|
|
||||||
}
|
|
||||||
|
|
||||||
var oldSrc [4]byte
|
|
||||||
copy(oldSrc[:], packetData[12:16])
|
|
||||||
newSrc := newIP.As4()
|
|
||||||
|
|
||||||
copy(packetData[12:16], newSrc[:])
|
|
||||||
|
|
||||||
ipHeaderLen := int(d.ip4.IHL) * 4
|
|
||||||
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
|
|
||||||
return fmt.Errorf("invalid IP header length")
|
|
||||||
}
|
|
||||||
|
|
||||||
binary.BigEndian.PutUint16(packetData[10:12], 0)
|
|
||||||
ipChecksum := ipv4Checksum(packetData[:ipHeaderLen])
|
|
||||||
binary.BigEndian.PutUint16(packetData[10:12], ipChecksum)
|
|
||||||
|
|
||||||
if len(d.decoded) > 1 {
|
|
||||||
switch d.decoded[1] {
|
|
||||||
case layers.LayerTypeTCP:
|
|
||||||
m.updateTCPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:])
|
|
||||||
case layers.LayerTypeUDP:
|
|
||||||
m.updateUDPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:])
|
|
||||||
case layers.LayerTypeICMPv4:
|
|
||||||
m.updateICMPChecksum(packetData, ipHeaderLen)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
|
|
||||||
tcpStart := ipHeaderLen
|
|
||||||
if len(packetData) < tcpStart+18 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
checksumOffset := tcpStart + 16
|
|
||||||
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
|
|
||||||
newChecksum := incrementalUpdate(oldChecksum, oldIP, newIP)
|
|
||||||
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
|
|
||||||
udpStart := ipHeaderLen
|
|
||||||
if len(packetData) < udpStart+8 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
checksumOffset := udpStart + 6
|
|
||||||
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
|
|
||||||
|
|
||||||
if oldChecksum == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
newChecksum := incrementalUpdate(oldChecksum, oldIP, newIP)
|
|
||||||
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) {
|
|
||||||
icmpStart := ipHeaderLen
|
|
||||||
if len(packetData) < icmpStart+8 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
icmpData := packetData[icmpStart:]
|
|
||||||
binary.BigEndian.PutUint16(icmpData[2:4], 0)
|
|
||||||
checksum := icmpChecksum(icmpData)
|
|
||||||
binary.BigEndian.PutUint16(icmpData[2:4], checksum)
|
|
||||||
}
|
|
||||||
|
|
||||||
// incrementalUpdate performs incremental checksum update per RFC 1624
|
|
||||||
func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
|
|
||||||
sum := uint32(^oldChecksum)
|
|
||||||
|
|
||||||
// Fast path for IPv4 addresses (4 bytes) - most common case
|
|
||||||
if len(oldBytes) == 4 && len(newBytes) == 4 {
|
|
||||||
sum += uint32(^binary.BigEndian.Uint16(oldBytes[0:2]))
|
|
||||||
sum += uint32(^binary.BigEndian.Uint16(oldBytes[2:4]))
|
|
||||||
sum += uint32(binary.BigEndian.Uint16(newBytes[0:2]))
|
|
||||||
sum += uint32(binary.BigEndian.Uint16(newBytes[2:4]))
|
|
||||||
} else {
|
|
||||||
// Fallback for other lengths
|
|
||||||
for i := 0; i < len(oldBytes)-1; i += 2 {
|
|
||||||
sum += uint32(^binary.BigEndian.Uint16(oldBytes[i : i+2]))
|
|
||||||
}
|
|
||||||
if len(oldBytes)%2 == 1 {
|
|
||||||
sum += uint32(^oldBytes[len(oldBytes)-1]) << 8
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := 0; i < len(newBytes)-1; i += 2 {
|
|
||||||
sum += uint32(binary.BigEndian.Uint16(newBytes[i : i+2]))
|
|
||||||
}
|
|
||||||
if len(newBytes)%2 == 1 {
|
|
||||||
sum += uint32(newBytes[len(newBytes)-1]) << 8
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
sum = (sum & 0xFFFF) + (sum >> 16)
|
|
||||||
if sum > 0xFFFF {
|
|
||||||
sum++
|
|
||||||
}
|
|
||||||
|
|
||||||
return ^uint16(sum)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddDNATRule adds a DNAT rule (delegates to native firewall for port forwarding)
|
|
||||||
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
|
||||||
if m.nativeFirewall == nil {
|
|
||||||
return nil, errNatNotSupported
|
|
||||||
}
|
|
||||||
return m.nativeFirewall.AddDNATRule(rule)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteDNATRule deletes a DNAT rule (delegates to native firewall)
|
|
||||||
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
|
||||||
if m.nativeFirewall == nil {
|
|
||||||
return errNatNotSupported
|
|
||||||
}
|
|
||||||
return m.nativeFirewall.DeleteDNATRule(rule)
|
|
||||||
}
|
|
||||||
@@ -1,416 +0,0 @@
|
|||||||
package uspfilter
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net/netip"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
|
||||||
"github.com/google/gopacket/layers"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
|
||||||
)
|
|
||||||
|
|
||||||
// BenchmarkDNATTranslation measures the performance of DNAT operations
|
|
||||||
func BenchmarkDNATTranslation(b *testing.B) {
|
|
||||||
scenarios := []struct {
|
|
||||||
name string
|
|
||||||
proto layers.IPProtocol
|
|
||||||
setupDNAT bool
|
|
||||||
description string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "tcp_with_dnat",
|
|
||||||
proto: layers.IPProtocolTCP,
|
|
||||||
setupDNAT: true,
|
|
||||||
description: "TCP packet with DNAT translation enabled",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "tcp_without_dnat",
|
|
||||||
proto: layers.IPProtocolTCP,
|
|
||||||
setupDNAT: false,
|
|
||||||
description: "TCP packet without DNAT (baseline)",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "udp_with_dnat",
|
|
||||||
proto: layers.IPProtocolUDP,
|
|
||||||
setupDNAT: true,
|
|
||||||
description: "UDP packet with DNAT translation enabled",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "udp_without_dnat",
|
|
||||||
proto: layers.IPProtocolUDP,
|
|
||||||
setupDNAT: false,
|
|
||||||
description: "UDP packet without DNAT (baseline)",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "icmp_with_dnat",
|
|
||||||
proto: layers.IPProtocolICMPv4,
|
|
||||||
setupDNAT: true,
|
|
||||||
description: "ICMP packet with DNAT translation enabled",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "icmp_without_dnat",
|
|
||||||
proto: layers.IPProtocolICMPv4,
|
|
||||||
setupDNAT: false,
|
|
||||||
description: "ICMP packet without DNAT (baseline)",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, sc := range scenarios {
|
|
||||||
b.Run(sc.name, func(b *testing.B) {
|
|
||||||
manager, err := Create(&IFaceMock{
|
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
|
||||||
}, false, flowLogger)
|
|
||||||
require.NoError(b, err)
|
|
||||||
defer func() {
|
|
||||||
require.NoError(b, manager.Close(nil))
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Set logger to error level to reduce noise during benchmarking
|
|
||||||
manager.SetLogLevel(log.ErrorLevel)
|
|
||||||
defer func() {
|
|
||||||
// Restore to info level after benchmark
|
|
||||||
manager.SetLogLevel(log.InfoLevel)
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Setup DNAT mapping if needed
|
|
||||||
originalIP := netip.MustParseAddr("192.168.1.100")
|
|
||||||
translatedIP := netip.MustParseAddr("10.0.0.100")
|
|
||||||
|
|
||||||
if sc.setupDNAT {
|
|
||||||
err := manager.AddInternalDNATMapping(originalIP, translatedIP)
|
|
||||||
require.NoError(b, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create test packets
|
|
||||||
srcIP := netip.MustParseAddr("172.16.0.1")
|
|
||||||
outboundPacket := generateDNATTestPacket(b, srcIP, originalIP, sc.proto, 12345, 80)
|
|
||||||
|
|
||||||
// Pre-establish connection for reverse DNAT test
|
|
||||||
if sc.setupDNAT {
|
|
||||||
manager.filterOutbound(outboundPacket, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
|
|
||||||
// Benchmark outbound DNAT translation
|
|
||||||
b.Run("outbound", func(b *testing.B) {
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
// Create fresh packet each time since translation modifies it
|
|
||||||
packet := generateDNATTestPacket(b, srcIP, originalIP, sc.proto, 12345, 80)
|
|
||||||
manager.filterOutbound(packet, 0)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
// Benchmark inbound reverse DNAT translation
|
|
||||||
if sc.setupDNAT {
|
|
||||||
b.Run("inbound_reverse", func(b *testing.B) {
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
// Create fresh packet each time since translation modifies it
|
|
||||||
packet := generateDNATTestPacket(b, translatedIP, srcIP, sc.proto, 80, 12345)
|
|
||||||
manager.filterInbound(packet, 0)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// BenchmarkDNATConcurrency tests DNAT performance under concurrent load
|
|
||||||
func BenchmarkDNATConcurrency(b *testing.B) {
|
|
||||||
manager, err := Create(&IFaceMock{
|
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
|
||||||
}, false, flowLogger)
|
|
||||||
require.NoError(b, err)
|
|
||||||
defer func() {
|
|
||||||
require.NoError(b, manager.Close(nil))
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Set logger to error level to reduce noise during benchmarking
|
|
||||||
manager.SetLogLevel(log.ErrorLevel)
|
|
||||||
defer func() {
|
|
||||||
// Restore to info level after benchmark
|
|
||||||
manager.SetLogLevel(log.InfoLevel)
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Setup multiple DNAT mappings
|
|
||||||
numMappings := 100
|
|
||||||
originalIPs := make([]netip.Addr, numMappings)
|
|
||||||
translatedIPs := make([]netip.Addr, numMappings)
|
|
||||||
|
|
||||||
for i := 0; i < numMappings; i++ {
|
|
||||||
originalIPs[i] = netip.MustParseAddr(fmt.Sprintf("192.168.%d.%d", (i/254)+1, (i%254)+1))
|
|
||||||
translatedIPs[i] = netip.MustParseAddr(fmt.Sprintf("10.0.%d.%d", (i/254)+1, (i%254)+1))
|
|
||||||
err := manager.AddInternalDNATMapping(originalIPs[i], translatedIPs[i])
|
|
||||||
require.NoError(b, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
srcIP := netip.MustParseAddr("172.16.0.1")
|
|
||||||
|
|
||||||
// Pre-generate packets
|
|
||||||
outboundPackets := make([][]byte, numMappings)
|
|
||||||
inboundPackets := make([][]byte, numMappings)
|
|
||||||
for i := 0; i < numMappings; i++ {
|
|
||||||
outboundPackets[i] = generateDNATTestPacket(b, srcIP, originalIPs[i], layers.IPProtocolTCP, 12345, 80)
|
|
||||||
inboundPackets[i] = generateDNATTestPacket(b, translatedIPs[i], srcIP, layers.IPProtocolTCP, 80, 12345)
|
|
||||||
// Establish connections
|
|
||||||
manager.filterOutbound(outboundPackets[i], 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
|
|
||||||
b.Run("concurrent_outbound", func(b *testing.B) {
|
|
||||||
b.RunParallel(func(pb *testing.PB) {
|
|
||||||
i := 0
|
|
||||||
for pb.Next() {
|
|
||||||
idx := i % numMappings
|
|
||||||
packet := generateDNATTestPacket(b, srcIP, originalIPs[idx], layers.IPProtocolTCP, 12345, 80)
|
|
||||||
manager.filterOutbound(packet, 0)
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
b.Run("concurrent_inbound", func(b *testing.B) {
|
|
||||||
b.RunParallel(func(pb *testing.PB) {
|
|
||||||
i := 0
|
|
||||||
for pb.Next() {
|
|
||||||
idx := i % numMappings
|
|
||||||
packet := generateDNATTestPacket(b, translatedIPs[idx], srcIP, layers.IPProtocolTCP, 80, 12345)
|
|
||||||
manager.filterInbound(packet, 0)
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// BenchmarkDNATScaling tests how DNAT performance scales with number of mappings
|
|
||||||
func BenchmarkDNATScaling(b *testing.B) {
|
|
||||||
mappingCounts := []int{1, 10, 100, 1000}
|
|
||||||
|
|
||||||
for _, count := range mappingCounts {
|
|
||||||
b.Run(fmt.Sprintf("mappings_%d", count), func(b *testing.B) {
|
|
||||||
manager, err := Create(&IFaceMock{
|
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
|
||||||
}, false, flowLogger)
|
|
||||||
require.NoError(b, err)
|
|
||||||
defer func() {
|
|
||||||
require.NoError(b, manager.Close(nil))
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Set logger to error level to reduce noise during benchmarking
|
|
||||||
manager.SetLogLevel(log.ErrorLevel)
|
|
||||||
defer func() {
|
|
||||||
// Restore to info level after benchmark
|
|
||||||
manager.SetLogLevel(log.InfoLevel)
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Setup DNAT mappings
|
|
||||||
for i := 0; i < count; i++ {
|
|
||||||
originalIP := netip.MustParseAddr(fmt.Sprintf("192.168.%d.%d", (i/254)+1, (i%254)+1))
|
|
||||||
translatedIP := netip.MustParseAddr(fmt.Sprintf("10.0.%d.%d", (i/254)+1, (i%254)+1))
|
|
||||||
err := manager.AddInternalDNATMapping(originalIP, translatedIP)
|
|
||||||
require.NoError(b, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test with the last mapping added (worst case for lookup)
|
|
||||||
srcIP := netip.MustParseAddr("172.16.0.1")
|
|
||||||
lastOriginal := netip.MustParseAddr(fmt.Sprintf("192.168.%d.%d", ((count-1)/254)+1, ((count-1)%254)+1))
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
packet := generateDNATTestPacket(b, srcIP, lastOriginal, layers.IPProtocolTCP, 12345, 80)
|
|
||||||
manager.filterOutbound(packet, 0)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// generateDNATTestPacket creates a test packet for DNAT benchmarking
|
|
||||||
func generateDNATTestPacket(tb testing.TB, srcIP, dstIP netip.Addr, proto layers.IPProtocol, srcPort, dstPort uint16) []byte {
|
|
||||||
tb.Helper()
|
|
||||||
|
|
||||||
ipv4 := &layers.IPv4{
|
|
||||||
TTL: 64,
|
|
||||||
Version: 4,
|
|
||||||
SrcIP: srcIP.AsSlice(),
|
|
||||||
DstIP: dstIP.AsSlice(),
|
|
||||||
Protocol: proto,
|
|
||||||
}
|
|
||||||
|
|
||||||
var transportLayer gopacket.SerializableLayer
|
|
||||||
switch proto {
|
|
||||||
case layers.IPProtocolTCP:
|
|
||||||
tcp := &layers.TCP{
|
|
||||||
SrcPort: layers.TCPPort(srcPort),
|
|
||||||
DstPort: layers.TCPPort(dstPort),
|
|
||||||
SYN: true,
|
|
||||||
}
|
|
||||||
require.NoError(tb, tcp.SetNetworkLayerForChecksum(ipv4))
|
|
||||||
transportLayer = tcp
|
|
||||||
case layers.IPProtocolUDP:
|
|
||||||
udp := &layers.UDP{
|
|
||||||
SrcPort: layers.UDPPort(srcPort),
|
|
||||||
DstPort: layers.UDPPort(dstPort),
|
|
||||||
}
|
|
||||||
require.NoError(tb, udp.SetNetworkLayerForChecksum(ipv4))
|
|
||||||
transportLayer = udp
|
|
||||||
case layers.IPProtocolICMPv4:
|
|
||||||
icmp := &layers.ICMPv4{
|
|
||||||
TypeCode: layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0),
|
|
||||||
}
|
|
||||||
transportLayer = icmp
|
|
||||||
}
|
|
||||||
|
|
||||||
buf := gopacket.NewSerializeBuffer()
|
|
||||||
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
|
|
||||||
err := gopacket.SerializeLayers(buf, opts, ipv4, transportLayer, gopacket.Payload("test"))
|
|
||||||
require.NoError(tb, err)
|
|
||||||
return buf.Bytes()
|
|
||||||
}
|
|
||||||
|
|
||||||
// BenchmarkChecksumUpdate specifically benchmarks checksum calculation performance
|
|
||||||
func BenchmarkChecksumUpdate(b *testing.B) {
|
|
||||||
// Create test data for checksum calculations
|
|
||||||
testData := make([]byte, 64) // Typical packet size for checksum testing
|
|
||||||
for i := range testData {
|
|
||||||
testData[i] = byte(i)
|
|
||||||
}
|
|
||||||
|
|
||||||
b.Run("ipv4_checksum", func(b *testing.B) {
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
_ = ipv4Checksum(testData[:20]) // IPv4 header is typically 20 bytes
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
b.Run("icmp_checksum", func(b *testing.B) {
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
_ = icmpChecksum(testData)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
b.Run("incremental_update", func(b *testing.B) {
|
|
||||||
oldBytes := []byte{192, 168, 1, 100}
|
|
||||||
newBytes := []byte{10, 0, 0, 100}
|
|
||||||
oldChecksum := uint16(0x1234)
|
|
||||||
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
_ = incrementalUpdate(oldChecksum, oldBytes, newBytes)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// BenchmarkDNATMemoryAllocations checks for memory allocations in DNAT operations
|
|
||||||
func BenchmarkDNATMemoryAllocations(b *testing.B) {
|
|
||||||
manager, err := Create(&IFaceMock{
|
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
|
||||||
}, false, flowLogger)
|
|
||||||
require.NoError(b, err)
|
|
||||||
defer func() {
|
|
||||||
require.NoError(b, manager.Close(nil))
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Set logger to error level to reduce noise during benchmarking
|
|
||||||
manager.SetLogLevel(log.ErrorLevel)
|
|
||||||
defer func() {
|
|
||||||
// Restore to info level after benchmark
|
|
||||||
manager.SetLogLevel(log.InfoLevel)
|
|
||||||
}()
|
|
||||||
|
|
||||||
originalIP := netip.MustParseAddr("192.168.1.100")
|
|
||||||
translatedIP := netip.MustParseAddr("10.0.0.100")
|
|
||||||
srcIP := netip.MustParseAddr("172.16.0.1")
|
|
||||||
|
|
||||||
err = manager.AddInternalDNATMapping(originalIP, translatedIP)
|
|
||||||
require.NoError(b, err)
|
|
||||||
|
|
||||||
packet := generateDNATTestPacket(b, srcIP, originalIP, layers.IPProtocolTCP, 12345, 80)
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
b.ReportAllocs()
|
|
||||||
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
// Create fresh packet each time to isolate allocation testing
|
|
||||||
testPacket := make([]byte, len(packet))
|
|
||||||
copy(testPacket, packet)
|
|
||||||
|
|
||||||
// Parse the packet fresh each time to get a clean decoder
|
|
||||||
d := &decoder{decoded: []gopacket.LayerType{}}
|
|
||||||
d.parser = gopacket.NewDecodingLayerParser(
|
|
||||||
layers.LayerTypeIPv4,
|
|
||||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
|
||||||
)
|
|
||||||
d.parser.IgnoreUnsupported = true
|
|
||||||
err = d.parser.DecodeLayers(testPacket, &d.decoded)
|
|
||||||
assert.NoError(b, err)
|
|
||||||
|
|
||||||
manager.translateOutboundDNAT(testPacket, d)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// BenchmarkDirectIPExtraction tests the performance improvement of direct IP extraction
|
|
||||||
func BenchmarkDirectIPExtraction(b *testing.B) {
|
|
||||||
// Create a test packet
|
|
||||||
srcIP := netip.MustParseAddr("172.16.0.1")
|
|
||||||
dstIP := netip.MustParseAddr("192.168.1.100")
|
|
||||||
packet := generateDNATTestPacket(b, srcIP, dstIP, layers.IPProtocolTCP, 12345, 80)
|
|
||||||
|
|
||||||
b.Run("direct_byte_access", func(b *testing.B) {
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
// Direct extraction from packet bytes
|
|
||||||
_ = netip.AddrFrom4([4]byte{packet[16], packet[17], packet[18], packet[19]})
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
b.Run("decoder_extraction", func(b *testing.B) {
|
|
||||||
// Create decoder once for comparison
|
|
||||||
d := &decoder{decoded: []gopacket.LayerType{}}
|
|
||||||
d.parser = gopacket.NewDecodingLayerParser(
|
|
||||||
layers.LayerTypeIPv4,
|
|
||||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
|
||||||
)
|
|
||||||
d.parser.IgnoreUnsupported = true
|
|
||||||
err := d.parser.DecodeLayers(packet, &d.decoded)
|
|
||||||
assert.NoError(b, err)
|
|
||||||
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
// Extract using decoder (traditional method)
|
|
||||||
dst, _ := netip.AddrFromSlice(d.ip4.DstIP)
|
|
||||||
_ = dst
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// BenchmarkChecksumOptimizations compares optimized vs standard checksum implementations
|
|
||||||
func BenchmarkChecksumOptimizations(b *testing.B) {
|
|
||||||
// Create test IPv4 header (20 bytes)
|
|
||||||
header := make([]byte, 20)
|
|
||||||
for i := range header {
|
|
||||||
header[i] = byte(i)
|
|
||||||
}
|
|
||||||
// Clear checksum field
|
|
||||||
header[10] = 0
|
|
||||||
header[11] = 0
|
|
||||||
|
|
||||||
b.Run("optimized_ipv4_checksum", func(b *testing.B) {
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
_ = ipv4Checksum(header)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
// Test incremental checksum updates
|
|
||||||
oldIP := []byte{192, 168, 1, 100}
|
|
||||||
newIP := []byte{10, 0, 0, 100}
|
|
||||||
oldChecksum := uint16(0x1234)
|
|
||||||
|
|
||||||
b.Run("optimized_incremental_update", func(b *testing.B) {
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
_ = incrementalUpdate(oldChecksum, oldIP, newIP)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -1,145 +0,0 @@
|
|||||||
package uspfilter
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/netip"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
|
||||||
"github.com/google/gopacket/layers"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TestDNATTranslationCorrectness verifies DNAT translation works correctly
|
|
||||||
func TestDNATTranslationCorrectness(t *testing.T) {
|
|
||||||
manager, err := Create(&IFaceMock{
|
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
|
||||||
}, false, flowLogger)
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer func() {
|
|
||||||
require.NoError(t, manager.Close(nil))
|
|
||||||
}()
|
|
||||||
|
|
||||||
originalIP := netip.MustParseAddr("192.168.1.100")
|
|
||||||
translatedIP := netip.MustParseAddr("10.0.0.100")
|
|
||||||
srcIP := netip.MustParseAddr("172.16.0.1")
|
|
||||||
|
|
||||||
// Add DNAT mapping
|
|
||||||
err = manager.AddInternalDNATMapping(originalIP, translatedIP)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
testCases := []struct {
|
|
||||||
name string
|
|
||||||
protocol layers.IPProtocol
|
|
||||||
srcPort uint16
|
|
||||||
dstPort uint16
|
|
||||||
}{
|
|
||||||
{"TCP", layers.IPProtocolTCP, 12345, 80},
|
|
||||||
{"UDP", layers.IPProtocolUDP, 12345, 53},
|
|
||||||
{"ICMP", layers.IPProtocolICMPv4, 0, 0},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
// Test outbound DNAT translation
|
|
||||||
outboundPacket := generateDNATTestPacket(t, srcIP, originalIP, tc.protocol, tc.srcPort, tc.dstPort)
|
|
||||||
originalOutbound := make([]byte, len(outboundPacket))
|
|
||||||
copy(originalOutbound, outboundPacket)
|
|
||||||
|
|
||||||
// Process outbound packet (should translate destination)
|
|
||||||
translated := manager.translateOutboundDNAT(outboundPacket, parsePacket(t, outboundPacket))
|
|
||||||
require.True(t, translated, "Outbound packet should be translated")
|
|
||||||
|
|
||||||
// Verify destination IP was changed
|
|
||||||
dstIPAfter := netip.AddrFrom4([4]byte{outboundPacket[16], outboundPacket[17], outboundPacket[18], outboundPacket[19]})
|
|
||||||
require.Equal(t, translatedIP, dstIPAfter, "Destination IP should be translated")
|
|
||||||
|
|
||||||
// Test inbound reverse DNAT translation
|
|
||||||
inboundPacket := generateDNATTestPacket(t, translatedIP, srcIP, tc.protocol, tc.dstPort, tc.srcPort)
|
|
||||||
originalInbound := make([]byte, len(inboundPacket))
|
|
||||||
copy(originalInbound, inboundPacket)
|
|
||||||
|
|
||||||
// Process inbound packet (should reverse translate source)
|
|
||||||
reversed := manager.translateInboundReverse(inboundPacket, parsePacket(t, inboundPacket))
|
|
||||||
require.True(t, reversed, "Inbound packet should be reverse translated")
|
|
||||||
|
|
||||||
// Verify source IP was changed back to original
|
|
||||||
srcIPAfter := netip.AddrFrom4([4]byte{inboundPacket[12], inboundPacket[13], inboundPacket[14], inboundPacket[15]})
|
|
||||||
require.Equal(t, originalIP, srcIPAfter, "Source IP should be reverse translated")
|
|
||||||
|
|
||||||
// Test that checksums are recalculated correctly
|
|
||||||
if tc.protocol != layers.IPProtocolICMPv4 {
|
|
||||||
// For TCP/UDP, verify the transport checksum was updated
|
|
||||||
require.NotEqual(t, originalOutbound, outboundPacket, "Outbound packet should be modified")
|
|
||||||
require.NotEqual(t, originalInbound, inboundPacket, "Inbound packet should be modified")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// parsePacket helper to create a decoder for testing
|
|
||||||
func parsePacket(t testing.TB, packetData []byte) *decoder {
|
|
||||||
t.Helper()
|
|
||||||
d := &decoder{
|
|
||||||
decoded: []gopacket.LayerType{},
|
|
||||||
}
|
|
||||||
d.parser = gopacket.NewDecodingLayerParser(
|
|
||||||
layers.LayerTypeIPv4,
|
|
||||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
|
||||||
)
|
|
||||||
d.parser.IgnoreUnsupported = true
|
|
||||||
|
|
||||||
err := d.parser.DecodeLayers(packetData, &d.decoded)
|
|
||||||
require.NoError(t, err)
|
|
||||||
return d
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestDNATMappingManagement tests adding/removing DNAT mappings
|
|
||||||
func TestDNATMappingManagement(t *testing.T) {
|
|
||||||
manager, err := Create(&IFaceMock{
|
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
|
||||||
}, false, flowLogger)
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer func() {
|
|
||||||
require.NoError(t, manager.Close(nil))
|
|
||||||
}()
|
|
||||||
|
|
||||||
originalIP := netip.MustParseAddr("192.168.1.100")
|
|
||||||
translatedIP := netip.MustParseAddr("10.0.0.100")
|
|
||||||
|
|
||||||
// Test adding mapping
|
|
||||||
err = manager.AddInternalDNATMapping(originalIP, translatedIP)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Verify mapping exists
|
|
||||||
result, exists := manager.getDNATTranslation(originalIP)
|
|
||||||
require.True(t, exists)
|
|
||||||
require.Equal(t, translatedIP, result)
|
|
||||||
|
|
||||||
// Test reverse lookup
|
|
||||||
reverseResult, exists := manager.findReverseDNATMapping(translatedIP)
|
|
||||||
require.True(t, exists)
|
|
||||||
require.Equal(t, originalIP, reverseResult)
|
|
||||||
|
|
||||||
// Test removing mapping
|
|
||||||
err = manager.RemoveInternalDNATMapping(originalIP)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Verify mapping no longer exists
|
|
||||||
_, exists = manager.getDNATTranslation(originalIP)
|
|
||||||
require.False(t, exists)
|
|
||||||
|
|
||||||
_, exists = manager.findReverseDNATMapping(translatedIP)
|
|
||||||
require.False(t, exists)
|
|
||||||
|
|
||||||
// Test error cases
|
|
||||||
err = manager.AddInternalDNATMapping(netip.Addr{}, translatedIP)
|
|
||||||
require.Error(t, err, "Should reject invalid original IP")
|
|
||||||
|
|
||||||
err = manager.AddInternalDNATMapping(originalIP, netip.Addr{})
|
|
||||||
require.Error(t, err, "Should reject invalid translated IP")
|
|
||||||
|
|
||||||
err = manager.RemoveInternalDNATMapping(originalIP)
|
|
||||||
require.Error(t, err, "Should error when removing non-existent mapping")
|
|
||||||
}
|
|
||||||
@@ -401,7 +401,7 @@ func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr str
|
|||||||
|
|
||||||
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
|
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
|
||||||
// will create or update the connection state
|
// will create or update the connection state
|
||||||
dropped := m.filterOutbound(packetData, 0)
|
dropped := m.processOutgoingHooks(packetData, 0)
|
||||||
if dropped {
|
if dropped {
|
||||||
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
|
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -38,8 +38,11 @@ func TestTracePacket(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
AddressFunc: func() wgaddr.Address {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return wgaddr.Address{
|
return wgaddr.Address{
|
||||||
IP: netip.MustParseAddr("100.10.0.100"),
|
IP: net.ParseIP("100.10.0.100"),
|
||||||
Network: netip.MustParsePrefix("100.10.0.0/16"),
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("100.10.0.0"),
|
||||||
|
Mask: net.CIDRMask(16, 32),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -39,12 +39,8 @@ const (
|
|||||||
// EnvForceUserspaceRouter forces userspace routing even if native routing is available.
|
// EnvForceUserspaceRouter forces userspace routing even if native routing is available.
|
||||||
EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER"
|
EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER"
|
||||||
|
|
||||||
// EnvEnableLocalForwarding enables forwarding of local traffic to the native stack for internal (non-NetBird) interfaces.
|
// EnvEnableNetstackLocalForwarding enables forwarding of local traffic to the native stack when running netstack
|
||||||
// Default off as it might be security risk because sockets listening on localhost only will become accessible.
|
// Leaving this on by default introduces a security risk as sockets on listening on localhost only will be accessible
|
||||||
EnvEnableLocalForwarding = "NB_ENABLE_LOCAL_FORWARDING"
|
|
||||||
|
|
||||||
// EnvEnableNetstackLocalForwarding is an alias for EnvEnableLocalForwarding.
|
|
||||||
// In netstack mode, it enables forwarding of local traffic to the native stack for all interfaces.
|
|
||||||
EnvEnableNetstackLocalForwarding = "NB_ENABLE_NETSTACK_LOCAL_FORWARDING"
|
EnvEnableNetstackLocalForwarding = "NB_ENABLE_NETSTACK_LOCAL_FORWARDING"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -75,6 +71,7 @@ type Manager struct {
|
|||||||
// incomingRules is used for filtering and hooks
|
// incomingRules is used for filtering and hooks
|
||||||
incomingRules map[netip.Addr]RuleSet
|
incomingRules map[netip.Addr]RuleSet
|
||||||
routeRules RouteRules
|
routeRules RouteRules
|
||||||
|
wgNetwork *net.IPNet
|
||||||
decoders sync.Pool
|
decoders sync.Pool
|
||||||
wgIface common.IFaceMapper
|
wgIface common.IFaceMapper
|
||||||
nativeFirewall firewall.Manager
|
nativeFirewall firewall.Manager
|
||||||
@@ -104,12 +101,6 @@ type Manager struct {
|
|||||||
flowLogger nftypes.FlowLogger
|
flowLogger nftypes.FlowLogger
|
||||||
|
|
||||||
blockRule firewall.Rule
|
blockRule firewall.Rule
|
||||||
|
|
||||||
// Internal 1:1 DNAT
|
|
||||||
dnatEnabled atomic.Bool
|
|
||||||
dnatMappings map[netip.Addr]netip.Addr
|
|
||||||
dnatMutex sync.RWMutex
|
|
||||||
dnatBiMap *biDNATMap
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// decoder for packages
|
// decoder for packages
|
||||||
@@ -157,11 +148,6 @@ func parseCreateEnv() (bool, bool) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to parse %s: %v", EnvEnableNetstackLocalForwarding, err)
|
log.Warnf("failed to parse %s: %v", EnvEnableNetstackLocalForwarding, err)
|
||||||
}
|
}
|
||||||
} else if val := os.Getenv(EnvEnableLocalForwarding); val != "" {
|
|
||||||
enableLocalForwarding, err = strconv.ParseBool(val)
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed to parse %s: %v", EnvEnableLocalForwarding, err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return disableConntrack, enableLocalForwarding
|
return disableConntrack, enableLocalForwarding
|
||||||
@@ -195,7 +181,6 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
|||||||
flowLogger: flowLogger,
|
flowLogger: flowLogger,
|
||||||
netstack: netstack.IsEnabled(),
|
netstack: netstack.IsEnabled(),
|
||||||
localForwarding: enableLocalForwarding,
|
localForwarding: enableLocalForwarding,
|
||||||
dnatMappings: make(map[netip.Addr]netip.Addr),
|
|
||||||
}
|
}
|
||||||
m.routingEnabled.Store(false)
|
m.routingEnabled.Store(false)
|
||||||
|
|
||||||
@@ -284,7 +269,7 @@ func (m *Manager) determineRouting() error {
|
|||||||
|
|
||||||
log.Info("userspace routing is forced")
|
log.Info("userspace routing is forced")
|
||||||
|
|
||||||
case !m.netstack && m.nativeFirewall != nil:
|
case !m.netstack && m.nativeFirewall != nil && m.nativeFirewall.IsServerRouteSupported():
|
||||||
// if the OS supports routing natively, then we don't need to filter/route ourselves
|
// if the OS supports routing natively, then we don't need to filter/route ourselves
|
||||||
// netstack mode won't support native routing as there is no interface
|
// netstack mode won't support native routing as there is no interface
|
||||||
|
|
||||||
@@ -341,10 +326,6 @@ func (m *Manager) IsServerRouteSupported() bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) IsStateful() bool {
|
|
||||||
return m.stateful
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||||
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||||
return m.nativeFirewall.AddNatRule(pair)
|
return m.nativeFirewall.AddNatRule(pair)
|
||||||
@@ -526,6 +507,22 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
|||||||
// Flush doesn't need to be implemented for this manager
|
// Flush doesn't need to be implemented for this manager
|
||||||
func (m *Manager) Flush() error { return nil }
|
func (m *Manager) Flush() error { return nil }
|
||||||
|
|
||||||
|
// AddDNATRule adds a DNAT rule
|
||||||
|
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||||
|
if m.nativeFirewall == nil {
|
||||||
|
return nil, errNatNotSupported
|
||||||
|
}
|
||||||
|
return m.nativeFirewall.AddDNATRule(rule)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteDNATRule deletes a DNAT rule
|
||||||
|
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
||||||
|
if m.nativeFirewall == nil {
|
||||||
|
return errNatNotSupported
|
||||||
|
}
|
||||||
|
return m.nativeFirewall.DeleteDNATRule(rule)
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateSet updates the rule destinations associated with the given set
|
// UpdateSet updates the rule destinations associated with the given set
|
||||||
// by merging the existing prefixes with the new ones, then deduplicating.
|
// by merging the existing prefixes with the new ones, then deduplicating.
|
||||||
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||||
@@ -572,14 +569,14 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// FilterOutBound filters outgoing packets
|
// DropOutgoing filter outgoing packets
|
||||||
func (m *Manager) FilterOutbound(packetData []byte, size int) bool {
|
func (m *Manager) DropOutgoing(packetData []byte, size int) bool {
|
||||||
return m.filterOutbound(packetData, size)
|
return m.processOutgoingHooks(packetData, size)
|
||||||
}
|
}
|
||||||
|
|
||||||
// FilterInbound filters incoming packets
|
// DropIncoming filter incoming packets
|
||||||
func (m *Manager) FilterInbound(packetData []byte, size int) bool {
|
func (m *Manager) DropIncoming(packetData []byte, size int) bool {
|
||||||
return m.filterInbound(packetData, size)
|
return m.dropFilter(packetData, size)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateLocalIPs updates the list of local IPs
|
// UpdateLocalIPs updates the list of local IPs
|
||||||
@@ -587,7 +584,7 @@ func (m *Manager) UpdateLocalIPs() error {
|
|||||||
return m.localipmanager.UpdateLocalIPs(m.wgIface)
|
return m.localipmanager.UpdateLocalIPs(m.wgIface)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) filterOutbound(packetData []byte, size int) bool {
|
func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool {
|
||||||
d := m.decoders.Get().(*decoder)
|
d := m.decoders.Get().(*decoder)
|
||||||
defer m.decoders.Put(d)
|
defer m.decoders.Put(d)
|
||||||
|
|
||||||
@@ -609,8 +606,9 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
m.trackOutbound(d, srcIP, dstIP, size)
|
if m.stateful {
|
||||||
m.translateOutboundDNAT(packetData, d)
|
m.trackOutbound(d, srcIP, dstIP, size)
|
||||||
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -662,7 +660,7 @@ func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, size int) {
|
|||||||
flags := getTCPFlags(&d.tcp)
|
flags := getTCPFlags(&d.tcp)
|
||||||
m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size)
|
m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size)
|
||||||
case layers.LayerTypeICMPv4:
|
case layers.LayerTypeICMPv4:
|
||||||
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size)
|
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, size)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -675,7 +673,7 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
|
|||||||
flags := getTCPFlags(&d.tcp)
|
flags := getTCPFlags(&d.tcp)
|
||||||
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size)
|
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size)
|
||||||
case layers.LayerTypeICMPv4:
|
case layers.LayerTypeICMPv4:
|
||||||
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size)
|
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, size)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -714,9 +712,9 @@ func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// filterInbound implements filtering logic for incoming packets.
|
// dropFilter implements filtering logic for incoming packets.
|
||||||
// If it returns true, the packet should be dropped.
|
// If it returns true, the packet should be dropped.
|
||||||
func (m *Manager) filterInbound(packetData []byte, size int) bool {
|
func (m *Manager) dropFilter(packetData []byte, size int) bool {
|
||||||
d := m.decoders.Get().(*decoder)
|
d := m.decoders.Get().(*decoder)
|
||||||
defer m.decoders.Put(d)
|
defer m.decoders.Put(d)
|
||||||
|
|
||||||
@@ -738,15 +736,8 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if translated := m.translateInboundReverse(packetData, d); translated {
|
// For all inbound traffic, first check if it matches a tracked connection.
|
||||||
// Re-decode after translation to get original addresses
|
// This must happen before any other filtering because the packets are statefully tracked.
|
||||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
|
||||||
m.logger.Error("Failed to re-decode packet after reverse DNAT: %v", err)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
srcIP, dstIP = m.extractIPs(d)
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP, size) {
|
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP, size) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -786,10 +777,9 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// If requested we pass local traffic to internal interfaces to the forwarder.
|
// if running in netstack mode we need to pass this to the forwarder
|
||||||
// netstack doesn't have an interface to forward packets to the native stack so we always need to use the forwarder.
|
if m.netstack && m.localForwarding {
|
||||||
if m.localForwarding && (m.netstack || dstIP != m.wgIface.Address().IP) {
|
return m.handleNetstackLocalTraffic(packetData)
|
||||||
return m.handleForwardedLocalTraffic(packetData)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// track inbound packets to get the correct direction and session id for flows
|
// track inbound packets to get the correct direction and session id for flows
|
||||||
@@ -799,7 +789,8 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) handleForwardedLocalTraffic(packetData []byte) bool {
|
func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool {
|
||||||
|
|
||||||
fwd := m.forwarder.Load()
|
fwd := m.forwarder.Load()
|
||||||
if fwd == nil {
|
if fwd == nil {
|
||||||
m.logger.Trace("Dropping local packet (forwarder not initialized)")
|
m.logger.Trace("Dropping local packet (forwarder not initialized)")
|
||||||
@@ -1097,6 +1088,11 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetNetwork of the wireguard interface to which filtering applied
|
||||||
|
func (m *Manager) SetNetwork(network *net.IPNet) {
|
||||||
|
m.wgNetwork = network
|
||||||
|
}
|
||||||
|
|
||||||
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
||||||
//
|
//
|
||||||
// Hook function returns flag which indicates should be the matched package dropped or not
|
// Hook function returns flag which indicates should be the matched package dropped or not
|
||||||
@@ -174,6 +174,11 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
|||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
manager.wgNetwork = &net.IPNet{
|
||||||
|
IP: net.ParseIP("100.64.0.0"),
|
||||||
|
Mask: net.CIDRMask(10, 32),
|
||||||
|
}
|
||||||
|
|
||||||
// Apply scenario-specific setup
|
// Apply scenario-specific setup
|
||||||
sc.setupFunc(manager)
|
sc.setupFunc(manager)
|
||||||
|
|
||||||
@@ -188,13 +193,13 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
|||||||
|
|
||||||
// For stateful scenarios, establish the connection
|
// For stateful scenarios, establish the connection
|
||||||
if sc.stateful {
|
if sc.stateful {
|
||||||
manager.filterOutbound(outbound, 0)
|
manager.processOutgoingHooks(outbound, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Measure inbound packet processing
|
// Measure inbound packet processing
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
manager.filterInbound(inbound, 0)
|
manager.dropFilter(inbound, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -214,13 +219,18 @@ func BenchmarkStateScaling(b *testing.B) {
|
|||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
manager.wgNetwork = &net.IPNet{
|
||||||
|
IP: net.ParseIP("100.64.0.0"),
|
||||||
|
Mask: net.CIDRMask(10, 32),
|
||||||
|
}
|
||||||
|
|
||||||
// Pre-populate connection table
|
// Pre-populate connection table
|
||||||
srcIPs := generateRandomIPs(count)
|
srcIPs := generateRandomIPs(count)
|
||||||
dstIPs := generateRandomIPs(count)
|
dstIPs := generateRandomIPs(count)
|
||||||
for i := 0; i < count; i++ {
|
for i := 0; i < count; i++ {
|
||||||
outbound := generatePacket(b, srcIPs[i], dstIPs[i],
|
outbound := generatePacket(b, srcIPs[i], dstIPs[i],
|
||||||
uint16(1024+i), 80, layers.IPProtocolTCP)
|
uint16(1024+i), 80, layers.IPProtocolTCP)
|
||||||
manager.filterOutbound(outbound, 0)
|
manager.processOutgoingHooks(outbound, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test packet
|
// Test packet
|
||||||
@@ -228,11 +238,11 @@ func BenchmarkStateScaling(b *testing.B) {
|
|||||||
testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP)
|
testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP)
|
||||||
|
|
||||||
// First establish our test connection
|
// First establish our test connection
|
||||||
manager.filterOutbound(testOut, 0)
|
manager.processOutgoingHooks(testOut, 0)
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
manager.filterInbound(testIn, 0)
|
manager.dropFilter(testIn, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -257,18 +267,23 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
|
|||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
manager.wgNetwork = &net.IPNet{
|
||||||
|
IP: net.ParseIP("100.64.0.0"),
|
||||||
|
Mask: net.CIDRMask(10, 32),
|
||||||
|
}
|
||||||
|
|
||||||
srcIP := generateRandomIPs(1)[0]
|
srcIP := generateRandomIPs(1)[0]
|
||||||
dstIP := generateRandomIPs(1)[0]
|
dstIP := generateRandomIPs(1)[0]
|
||||||
outbound := generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP)
|
outbound := generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP)
|
||||||
inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP)
|
inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP)
|
||||||
|
|
||||||
if sc.established {
|
if sc.established {
|
||||||
manager.filterOutbound(outbound, 0)
|
manager.processOutgoingHooks(outbound, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
manager.filterInbound(inbound, 0)
|
manager.dropFilter(inbound, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -289,6 +304,10 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
proto: layers.IPProtocolTCP,
|
proto: layers.IPProtocolTCP,
|
||||||
state: "new",
|
state: "new",
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
|
m.wgNetwork = &net.IPNet{
|
||||||
|
IP: net.ParseIP("100.64.0.0"),
|
||||||
|
Mask: net.CIDRMask(10, 32),
|
||||||
|
}
|
||||||
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||||
},
|
},
|
||||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
@@ -302,6 +321,10 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
proto: layers.IPProtocolTCP,
|
proto: layers.IPProtocolTCP,
|
||||||
state: "established",
|
state: "established",
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
|
m.wgNetwork = &net.IPNet{
|
||||||
|
IP: net.ParseIP("100.64.0.0"),
|
||||||
|
Mask: net.CIDRMask(10, 32),
|
||||||
|
}
|
||||||
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||||
},
|
},
|
||||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
@@ -316,6 +339,10 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
proto: layers.IPProtocolUDP,
|
proto: layers.IPProtocolUDP,
|
||||||
state: "new",
|
state: "new",
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
|
m.wgNetwork = &net.IPNet{
|
||||||
|
IP: net.ParseIP("100.64.0.0"),
|
||||||
|
Mask: net.CIDRMask(10, 32),
|
||||||
|
}
|
||||||
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||||
},
|
},
|
||||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
@@ -329,6 +356,10 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
proto: layers.IPProtocolUDP,
|
proto: layers.IPProtocolUDP,
|
||||||
state: "established",
|
state: "established",
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
|
m.wgNetwork = &net.IPNet{
|
||||||
|
IP: net.ParseIP("100.64.0.0"),
|
||||||
|
Mask: net.CIDRMask(10, 32),
|
||||||
|
}
|
||||||
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||||
},
|
},
|
||||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
@@ -342,6 +373,10 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
proto: layers.IPProtocolTCP,
|
proto: layers.IPProtocolTCP,
|
||||||
state: "new",
|
state: "new",
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
|
m.wgNetwork = &net.IPNet{
|
||||||
|
IP: net.ParseIP("0.0.0.0"),
|
||||||
|
Mask: net.CIDRMask(0, 32),
|
||||||
|
}
|
||||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||||
},
|
},
|
||||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
@@ -355,6 +390,10 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
proto: layers.IPProtocolTCP,
|
proto: layers.IPProtocolTCP,
|
||||||
state: "established",
|
state: "established",
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
|
m.wgNetwork = &net.IPNet{
|
||||||
|
IP: net.ParseIP("0.0.0.0"),
|
||||||
|
Mask: net.CIDRMask(0, 32),
|
||||||
|
}
|
||||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||||
},
|
},
|
||||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
@@ -369,6 +408,10 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
proto: layers.IPProtocolTCP,
|
proto: layers.IPProtocolTCP,
|
||||||
state: "post_handshake",
|
state: "post_handshake",
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
|
m.wgNetwork = &net.IPNet{
|
||||||
|
IP: net.ParseIP("0.0.0.0"),
|
||||||
|
Mask: net.CIDRMask(0, 32),
|
||||||
|
}
|
||||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||||
},
|
},
|
||||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
@@ -383,6 +426,10 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
proto: layers.IPProtocolUDP,
|
proto: layers.IPProtocolUDP,
|
||||||
state: "new",
|
state: "new",
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
|
m.wgNetwork = &net.IPNet{
|
||||||
|
IP: net.ParseIP("0.0.0.0"),
|
||||||
|
Mask: net.CIDRMask(0, 32),
|
||||||
|
}
|
||||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||||
},
|
},
|
||||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
@@ -396,6 +443,10 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
proto: layers.IPProtocolUDP,
|
proto: layers.IPProtocolUDP,
|
||||||
state: "established",
|
state: "established",
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
|
m.wgNetwork = &net.IPNet{
|
||||||
|
IP: net.ParseIP("0.0.0.0"),
|
||||||
|
Mask: net.CIDRMask(0, 32),
|
||||||
|
}
|
||||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||||
},
|
},
|
||||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
@@ -426,25 +477,25 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
// For stateful cases and established connections
|
// For stateful cases and established connections
|
||||||
if !strings.Contains(sc.name, "allow_non_wg") ||
|
if !strings.Contains(sc.name, "allow_non_wg") ||
|
||||||
(strings.Contains(sc.state, "established") || sc.state == "post_handshake") {
|
(strings.Contains(sc.state, "established") || sc.state == "post_handshake") {
|
||||||
manager.filterOutbound(outbound, 0)
|
manager.processOutgoingHooks(outbound, 0)
|
||||||
|
|
||||||
// For TCP post-handshake, simulate full handshake
|
// For TCP post-handshake, simulate full handshake
|
||||||
if sc.state == "post_handshake" {
|
if sc.state == "post_handshake" {
|
||||||
// SYN
|
// SYN
|
||||||
syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn))
|
syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn))
|
||||||
manager.filterOutbound(syn, 0)
|
manager.processOutgoingHooks(syn, 0)
|
||||||
// SYN-ACK
|
// SYN-ACK
|
||||||
synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||||
manager.filterInbound(synack, 0)
|
manager.dropFilter(synack, 0)
|
||||||
// ACK
|
// ACK
|
||||||
ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck))
|
ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck))
|
||||||
manager.filterOutbound(ack, 0)
|
manager.processOutgoingHooks(ack, 0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
manager.filterInbound(inbound, 0)
|
manager.dropFilter(inbound, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -542,6 +593,11 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
|||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
manager.SetNetwork(&net.IPNet{
|
||||||
|
IP: net.ParseIP("100.64.0.0"),
|
||||||
|
Mask: net.CIDRMask(10, 32),
|
||||||
|
})
|
||||||
|
|
||||||
// Setup initial state based on scenario
|
// Setup initial state based on scenario
|
||||||
if sc.rules {
|
if sc.rules {
|
||||||
// Single rule to allow all return traffic from port 80
|
// Single rule to allow all return traffic from port 80
|
||||||
@@ -568,17 +624,17 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
|||||||
// Initial SYN
|
// Initial SYN
|
||||||
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
|
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
|
||||||
manager.filterOutbound(syn, 0)
|
manager.processOutgoingHooks(syn, 0)
|
||||||
|
|
||||||
// SYN-ACK
|
// SYN-ACK
|
||||||
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||||
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||||
manager.filterInbound(synack, 0)
|
manager.dropFilter(synack, 0)
|
||||||
|
|
||||||
// ACK
|
// ACK
|
||||||
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
||||||
manager.filterOutbound(ack, 0)
|
manager.processOutgoingHooks(ack, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prepare test packets simulating bidirectional traffic
|
// Prepare test packets simulating bidirectional traffic
|
||||||
@@ -599,9 +655,9 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
// Simulate bidirectional traffic
|
// Simulate bidirectional traffic
|
||||||
// First outbound data
|
// First outbound data
|
||||||
manager.filterOutbound(outPackets[connIdx], 0)
|
manager.processOutgoingHooks(outPackets[connIdx], 0)
|
||||||
// Then inbound response - this is what we're actually measuring
|
// Then inbound response - this is what we're actually measuring
|
||||||
manager.filterInbound(inPackets[connIdx], 0)
|
manager.dropFilter(inPackets[connIdx], 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -625,6 +681,11 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
|||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
manager.SetNetwork(&net.IPNet{
|
||||||
|
IP: net.ParseIP("100.64.0.0"),
|
||||||
|
Mask: net.CIDRMask(10, 32),
|
||||||
|
})
|
||||||
|
|
||||||
// Setup initial state based on scenario
|
// Setup initial state based on scenario
|
||||||
if sc.rules {
|
if sc.rules {
|
||||||
// Single rule to allow all return traffic from port 80
|
// Single rule to allow all return traffic from port 80
|
||||||
@@ -700,19 +761,19 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
|||||||
p := patterns[connIdx]
|
p := patterns[connIdx]
|
||||||
|
|
||||||
// Connection establishment
|
// Connection establishment
|
||||||
manager.filterOutbound(p.syn, 0)
|
manager.processOutgoingHooks(p.syn, 0)
|
||||||
manager.filterInbound(p.synAck, 0)
|
manager.dropFilter(p.synAck, 0)
|
||||||
manager.filterOutbound(p.ack, 0)
|
manager.processOutgoingHooks(p.ack, 0)
|
||||||
|
|
||||||
// Data transfer
|
// Data transfer
|
||||||
manager.filterOutbound(p.request, 0)
|
manager.processOutgoingHooks(p.request, 0)
|
||||||
manager.filterInbound(p.response, 0)
|
manager.dropFilter(p.response, 0)
|
||||||
|
|
||||||
// Connection teardown
|
// Connection teardown
|
||||||
manager.filterOutbound(p.finClient, 0)
|
manager.processOutgoingHooks(p.finClient, 0)
|
||||||
manager.filterInbound(p.ackServer, 0)
|
manager.dropFilter(p.ackServer, 0)
|
||||||
manager.filterInbound(p.finServer, 0)
|
manager.dropFilter(p.finServer, 0)
|
||||||
manager.filterOutbound(p.ackClient, 0)
|
manager.processOutgoingHooks(p.ackClient, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -736,6 +797,11 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
|||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
manager.SetNetwork(&net.IPNet{
|
||||||
|
IP: net.ParseIP("100.64.0.0"),
|
||||||
|
Mask: net.CIDRMask(10, 32),
|
||||||
|
})
|
||||||
|
|
||||||
// Setup initial state based on scenario
|
// Setup initial state based on scenario
|
||||||
if sc.rules {
|
if sc.rules {
|
||||||
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||||
@@ -760,15 +826,15 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
|||||||
for i := 0; i < sc.connCount; i++ {
|
for i := 0; i < sc.connCount; i++ {
|
||||||
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
|
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
|
||||||
manager.filterOutbound(syn, 0)
|
manager.processOutgoingHooks(syn, 0)
|
||||||
|
|
||||||
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||||
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||||
manager.filterInbound(synack, 0)
|
manager.dropFilter(synack, 0)
|
||||||
|
|
||||||
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
||||||
manager.filterOutbound(ack, 0)
|
manager.processOutgoingHooks(ack, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Pre-generate test packets
|
// Pre-generate test packets
|
||||||
@@ -790,8 +856,8 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
|||||||
counter++
|
counter++
|
||||||
|
|
||||||
// Simulate bidirectional traffic
|
// Simulate bidirectional traffic
|
||||||
manager.filterOutbound(outPackets[connIdx], 0)
|
manager.processOutgoingHooks(outPackets[connIdx], 0)
|
||||||
manager.filterInbound(inPackets[connIdx], 0)
|
manager.dropFilter(inPackets[connIdx], 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
@@ -816,6 +882,11 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
|||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
manager.SetNetwork(&net.IPNet{
|
||||||
|
IP: net.ParseIP("100.64.0.0"),
|
||||||
|
Mask: net.CIDRMask(10, 32),
|
||||||
|
})
|
||||||
|
|
||||||
if sc.rules {
|
if sc.rules {
|
||||||
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
@@ -879,17 +950,17 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
|||||||
p := patterns[connIdx]
|
p := patterns[connIdx]
|
||||||
|
|
||||||
// Full connection lifecycle
|
// Full connection lifecycle
|
||||||
manager.filterOutbound(p.syn, 0)
|
manager.processOutgoingHooks(p.syn, 0)
|
||||||
manager.filterInbound(p.synAck, 0)
|
manager.dropFilter(p.synAck, 0)
|
||||||
manager.filterOutbound(p.ack, 0)
|
manager.processOutgoingHooks(p.ack, 0)
|
||||||
|
|
||||||
manager.filterOutbound(p.request, 0)
|
manager.processOutgoingHooks(p.request, 0)
|
||||||
manager.filterInbound(p.response, 0)
|
manager.dropFilter(p.response, 0)
|
||||||
|
|
||||||
manager.filterOutbound(p.finClient, 0)
|
manager.processOutgoingHooks(p.finClient, 0)
|
||||||
manager.filterInbound(p.ackServer, 0)
|
manager.dropFilter(p.ackServer, 0)
|
||||||
manager.filterInbound(p.finServer, 0)
|
manager.dropFilter(p.finServer, 0)
|
||||||
manager.filterOutbound(p.ackClient, 0)
|
manager.processOutgoingHooks(p.ackClient, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
@@ -961,8 +1032,7 @@ func BenchmarkRouteACLs(b *testing.B) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, r := range rules {
|
for _, r := range rules {
|
||||||
dst := fw.Network{Prefix: r.dest}
|
_, err := manager.AddRouteFiltering(nil, r.sources, r.dest, r.proto, nil, r.port, fw.ActionAccept)
|
||||||
_, err := manager.AddRouteFiltering(nil, r.sources, dst, r.proto, nil, r.port, fw.ActionAccept)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Fatal(err)
|
b.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -19,8 +19,12 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestPeerACLFiltering(t *testing.T) {
|
func TestPeerACLFiltering(t *testing.T) {
|
||||||
localIP := netip.MustParseAddr("100.10.0.100")
|
localIP := net.ParseIP("100.10.0.100")
|
||||||
wgNet := netip.MustParsePrefix("100.10.0.0/16")
|
wgNet := &net.IPNet{
|
||||||
|
IP: net.ParseIP("100.10.0.0"),
|
||||||
|
Mask: net.CIDRMask(16, 32),
|
||||||
|
}
|
||||||
|
|
||||||
ifaceMock := &IFaceMock{
|
ifaceMock := &IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
AddressFunc: func() wgaddr.Address {
|
AddressFunc: func() wgaddr.Address {
|
||||||
@@ -39,6 +43,8 @@ func TestPeerACLFiltering(t *testing.T) {
|
|||||||
require.NoError(t, manager.Close(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
manager.wgNetwork = wgNet
|
||||||
|
|
||||||
err = manager.UpdateLocalIPs()
|
err = manager.UpdateLocalIPs()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@@ -462,7 +468,7 @@ func TestPeerACLFiltering(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("Implicit DROP (no rules)", func(t *testing.T) {
|
t.Run("Implicit DROP (no rules)", func(t *testing.T) {
|
||||||
packet := createTestPacket(t, "100.10.0.1", "100.10.0.100", fw.ProtocolTCP, 12345, 443)
|
packet := createTestPacket(t, "100.10.0.1", "100.10.0.100", fw.ProtocolTCP, 12345, 443)
|
||||||
isDropped := manager.FilterInbound(packet, 0)
|
isDropped := manager.DropIncoming(packet, 0)
|
||||||
require.True(t, isDropped, "Packet should be dropped when no rules exist")
|
require.True(t, isDropped, "Packet should be dropped when no rules exist")
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -509,7 +515,7 @@ func TestPeerACLFiltering(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
||||||
isDropped := manager.FilterInbound(packet, 0)
|
isDropped := manager.DropIncoming(packet, 0)
|
||||||
require.Equal(t, tc.shouldBeBlocked, isDropped)
|
require.Equal(t, tc.shouldBeBlocked, isDropped)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -575,13 +581,14 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
|
|||||||
dev := mocks.NewMockDevice(ctrl)
|
dev := mocks.NewMockDevice(ctrl)
|
||||||
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
|
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
|
||||||
|
|
||||||
wgNet := netip.MustParsePrefix(network)
|
localIP, wgNet, err := net.ParseCIDR(network)
|
||||||
|
require.NoError(tb, err)
|
||||||
|
|
||||||
ifaceMock := &IFaceMock{
|
ifaceMock := &IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
AddressFunc: func() wgaddr.Address {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return wgaddr.Address{
|
return wgaddr.Address{
|
||||||
IP: wgNet.Addr(),
|
IP: localIP,
|
||||||
Network: wgNet,
|
Network: wgNet,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -1233,7 +1240,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
srcIP := netip.MustParseAddr(tc.srcIP)
|
srcIP := netip.MustParseAddr(tc.srcIP)
|
||||||
dstIP := netip.MustParseAddr(tc.dstIP)
|
dstIP := netip.MustParseAddr(tc.dstIP)
|
||||||
|
|
||||||
// testing routeACLsPass only and not FilterInbound, as routed packets are dropped after being passed
|
// testing routeACLsPass only and not DropIncoming, as routed packets are dropped after being passed
|
||||||
// to the forwarder
|
// to the forwarder
|
||||||
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
||||||
require.Equal(t, tc.shouldPass, isAllowed)
|
require.Equal(t, tc.shouldPass, isAllowed)
|
||||||
@@ -1433,8 +1440,11 @@ func TestRouteACLSet(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
AddressFunc: func() wgaddr.Address {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return wgaddr.Address{
|
return wgaddr.Address{
|
||||||
IP: netip.MustParseAddr("100.10.0.100"),
|
IP: net.ParseIP("100.10.0.100"),
|
||||||
Network: netip.MustParsePrefix("100.10.0.0/16"),
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("100.10.0.0"),
|
||||||
|
Mask: net.CIDRMask(16, 32),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -271,8 +271,11 @@ func TestNotMatchByIP(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
AddressFunc: func() wgaddr.Address {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return wgaddr.Address{
|
return wgaddr.Address{
|
||||||
IP: netip.MustParseAddr("100.10.0.100"),
|
IP: net.ParseIP("100.10.0.100"),
|
||||||
Network: netip.MustParsePrefix("100.10.0.0/16"),
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("100.10.0.0"),
|
||||||
|
Mask: net.CIDRMask(16, 32),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -282,6 +285,10 @@ func TestNotMatchByIP(t *testing.T) {
|
|||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
m.wgNetwork = &net.IPNet{
|
||||||
|
IP: net.ParseIP("100.10.0.0"),
|
||||||
|
Mask: net.CIDRMask(16, 32),
|
||||||
|
}
|
||||||
|
|
||||||
ip := net.ParseIP("0.0.0.0")
|
ip := net.ParseIP("0.0.0.0")
|
||||||
proto := fw.ProtocolUDP
|
proto := fw.ProtocolUDP
|
||||||
@@ -321,7 +328,7 @@ func TestNotMatchByIP(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.filterInbound(buf.Bytes(), 0) {
|
if m.dropFilter(buf.Bytes(), 0) {
|
||||||
t.Errorf("expected packet to be accepted")
|
t.Errorf("expected packet to be accepted")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -389,6 +396,10 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
|||||||
}, false, flowLogger)
|
}, false, flowLogger)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
manager.wgNetwork = &net.IPNet{
|
||||||
|
IP: net.ParseIP("100.10.0.0"),
|
||||||
|
Mask: net.CIDRMask(16, 32),
|
||||||
|
}
|
||||||
manager.udpTracker.Close()
|
manager.udpTracker.Close()
|
||||||
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger, flowLogger)
|
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger, flowLogger)
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -447,7 +458,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Test hook gets called
|
// Test hook gets called
|
||||||
result := manager.filterOutbound(buf.Bytes(), 0)
|
result := manager.processOutgoingHooks(buf.Bytes(), 0)
|
||||||
require.True(t, result)
|
require.True(t, result)
|
||||||
require.True(t, hookCalled)
|
require.True(t, hookCalled)
|
||||||
|
|
||||||
@@ -457,7 +468,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
|||||||
err = gopacket.SerializeLayers(buf, opts, ipv4)
|
err = gopacket.SerializeLayers(buf, opts, ipv4)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
result = manager.filterOutbound(buf.Bytes(), 0)
|
result = manager.processOutgoingHooks(buf.Bytes(), 0)
|
||||||
require.False(t, result)
|
require.False(t, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -498,6 +509,11 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
}, false, flowLogger)
|
}, false, flowLogger)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
manager.wgNetwork = &net.IPNet{
|
||||||
|
IP: net.ParseIP("100.10.0.0"),
|
||||||
|
Mask: net.CIDRMask(16, 32),
|
||||||
|
}
|
||||||
|
|
||||||
manager.udpTracker.Close() // Close the existing tracker
|
manager.udpTracker.Close() // Close the existing tracker
|
||||||
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger, flowLogger)
|
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger, flowLogger)
|
||||||
manager.decoders = sync.Pool{
|
manager.decoders = sync.Pool{
|
||||||
@@ -553,7 +569,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Process outbound packet and verify connection tracking
|
// Process outbound packet and verify connection tracking
|
||||||
drop := manager.FilterOutbound(outboundBuf.Bytes(), 0)
|
drop := manager.DropOutgoing(outboundBuf.Bytes(), 0)
|
||||||
require.False(t, drop, "Initial outbound packet should not be dropped")
|
require.False(t, drop, "Initial outbound packet should not be dropped")
|
||||||
|
|
||||||
// Verify connection was tracked
|
// Verify connection was tracked
|
||||||
@@ -620,7 +636,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
for _, cp := range checkPoints {
|
for _, cp := range checkPoints {
|
||||||
time.Sleep(cp.sleep)
|
time.Sleep(cp.sleep)
|
||||||
|
|
||||||
drop = manager.filterInbound(inboundBuf.Bytes(), 0)
|
drop = manager.dropFilter(inboundBuf.Bytes(), 0)
|
||||||
require.Equal(t, cp.shouldAllow, !drop, cp.description)
|
require.Equal(t, cp.shouldAllow, !drop, cp.description)
|
||||||
|
|
||||||
// If the connection should still be valid, verify it exists
|
// If the connection should still be valid, verify it exists
|
||||||
@@ -669,7 +685,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create a new outbound connection for invalid tests
|
// Create a new outbound connection for invalid tests
|
||||||
drop = manager.filterOutbound(outboundBuf.Bytes(), 0)
|
drop = manager.processOutgoingHooks(outboundBuf.Bytes(), 0)
|
||||||
require.False(t, drop, "Second outbound packet should not be dropped")
|
require.False(t, drop, "Second outbound packet should not be dropped")
|
||||||
|
|
||||||
for _, tc := range invalidCases {
|
for _, tc := range invalidCases {
|
||||||
@@ -691,7 +707,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Verify the invalid packet is dropped
|
// Verify the invalid packet is dropped
|
||||||
drop = manager.filterInbound(testBuf.Bytes(), 0)
|
drop = manager.dropFilter(testBuf.Bytes(), 0)
|
||||||
require.True(t, drop, tc.description)
|
require.True(t, drop, tc.description)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -1,96 +0,0 @@
|
|||||||
package bind
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/netip"
|
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/monotime"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
saveFrequency = int64(5 * time.Second)
|
|
||||||
)
|
|
||||||
|
|
||||||
type PeerRecord struct {
|
|
||||||
Address netip.AddrPort
|
|
||||||
LastActivity atomic.Int64 // UnixNano timestamp
|
|
||||||
}
|
|
||||||
|
|
||||||
type ActivityRecorder struct {
|
|
||||||
mu sync.RWMutex
|
|
||||||
peers map[string]*PeerRecord // publicKey to PeerRecord map
|
|
||||||
addrToPeer map[netip.AddrPort]*PeerRecord // address to PeerRecord map
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewActivityRecorder() *ActivityRecorder {
|
|
||||||
return &ActivityRecorder{
|
|
||||||
peers: make(map[string]*PeerRecord),
|
|
||||||
addrToPeer: make(map[netip.AddrPort]*PeerRecord),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetLastActivities returns a snapshot of peer last activity
|
|
||||||
func (r *ActivityRecorder) GetLastActivities() map[string]monotime.Time {
|
|
||||||
r.mu.RLock()
|
|
||||||
defer r.mu.RUnlock()
|
|
||||||
|
|
||||||
activities := make(map[string]monotime.Time, len(r.peers))
|
|
||||||
for key, record := range r.peers {
|
|
||||||
monoTime := record.LastActivity.Load()
|
|
||||||
activities[key] = monotime.Time(monoTime)
|
|
||||||
}
|
|
||||||
return activities
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpsertAddress adds or updates the address for a publicKey
|
|
||||||
func (r *ActivityRecorder) UpsertAddress(publicKey string, address netip.AddrPort) {
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
|
|
||||||
var record *PeerRecord
|
|
||||||
record, exists := r.peers[publicKey]
|
|
||||||
if exists {
|
|
||||||
delete(r.addrToPeer, record.Address)
|
|
||||||
record.Address = address
|
|
||||||
} else {
|
|
||||||
record = &PeerRecord{
|
|
||||||
Address: address,
|
|
||||||
}
|
|
||||||
record.LastActivity.Store(int64(monotime.Now()))
|
|
||||||
r.peers[publicKey] = record
|
|
||||||
}
|
|
||||||
|
|
||||||
r.addrToPeer[address] = record
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *ActivityRecorder) Remove(publicKey string) {
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
if record, exists := r.peers[publicKey]; exists {
|
|
||||||
delete(r.addrToPeer, record.Address)
|
|
||||||
delete(r.peers, publicKey)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// record updates LastActivity for the given address using atomic store
|
|
||||||
func (r *ActivityRecorder) record(address netip.AddrPort) {
|
|
||||||
r.mu.RLock()
|
|
||||||
record, ok := r.addrToPeer[address]
|
|
||||||
r.mu.RUnlock()
|
|
||||||
if !ok {
|
|
||||||
log.Warnf("could not find record for address %s", address)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
now := int64(monotime.Now())
|
|
||||||
last := record.LastActivity.Load()
|
|
||||||
if now-last < saveFrequency {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
_ = record.LastActivity.CompareAndSwap(last, now)
|
|
||||||
}
|
|
||||||
@@ -1,25 +0,0 @@
|
|||||||
package bind
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/netip"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/monotime"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestActivityRecorder_GetLastActivities(t *testing.T) {
|
|
||||||
peer := "peer1"
|
|
||||||
ar := NewActivityRecorder()
|
|
||||||
ar.UpsertAddress("peer1", netip.MustParseAddrPort("192.168.0.5:51820"))
|
|
||||||
activities := ar.GetLastActivities()
|
|
||||||
|
|
||||||
p, ok := activities[peer]
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("Expected activity for peer %s, but got none", peer)
|
|
||||||
}
|
|
||||||
|
|
||||||
if monotime.Since(p) > 5*time.Second {
|
|
||||||
t.Fatalf("Expected activity for peer %s to be recent, but got %v", peer, p)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,15 +0,0 @@
|
|||||||
package bind
|
|
||||||
|
|
||||||
import (
|
|
||||||
wireguard "golang.zx2c4.com/wireguard/conn"
|
|
||||||
|
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TODO: This is most likely obsolete since the control fns should be called by the wrapped udpconn (ice_bind.go)
|
|
||||||
func init() {
|
|
||||||
listener := nbnet.NewListener()
|
|
||||||
if listener.ListenConfig.Control != nil {
|
|
||||||
*wireguard.ControlFns = append(*wireguard.ControlFns, listener.ListenConfig.Control)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
12
client/iface/bind/control_android.go
Normal file
12
client/iface/bind/control_android.go
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
package bind
|
||||||
|
|
||||||
|
import (
|
||||||
|
wireguard "golang.zx2c4.com/wireguard/conn"
|
||||||
|
|
||||||
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
// ControlFns is not thread safe and should only be modified during init.
|
||||||
|
*wireguard.ControlFns = append(*wireguard.ControlFns, nbnet.ControlProtectSocket)
|
||||||
|
}
|
||||||
@@ -1,7 +1,6 @@
|
|||||||
package bind
|
package bind
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
@@ -16,7 +15,6 @@ import (
|
|||||||
wgConn "golang.zx2c4.com/wireguard/conn"
|
wgConn "golang.zx2c4.com/wireguard/conn"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type RecvMessage struct {
|
type RecvMessage struct {
|
||||||
@@ -53,24 +51,22 @@ type ICEBind struct {
|
|||||||
closedChanMu sync.RWMutex // protect the closeChan recreation from reading from it.
|
closedChanMu sync.RWMutex // protect the closeChan recreation from reading from it.
|
||||||
closed bool
|
closed bool
|
||||||
|
|
||||||
muUDPMux sync.Mutex
|
muUDPMux sync.Mutex
|
||||||
udpMux *UniversalUDPMuxDefault
|
udpMux *UniversalUDPMuxDefault
|
||||||
address wgaddr.Address
|
address wgaddr.Address
|
||||||
activityRecorder *ActivityRecorder
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address) *ICEBind {
|
func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address) *ICEBind {
|
||||||
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
|
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
|
||||||
ib := &ICEBind{
|
ib := &ICEBind{
|
||||||
StdNetBind: b,
|
StdNetBind: b,
|
||||||
RecvChan: make(chan RecvMessage, 1),
|
RecvChan: make(chan RecvMessage, 1),
|
||||||
transportNet: transportNet,
|
transportNet: transportNet,
|
||||||
filterFn: filterFn,
|
filterFn: filterFn,
|
||||||
endpoints: make(map[netip.Addr]net.Conn),
|
endpoints: make(map[netip.Addr]net.Conn),
|
||||||
closedChan: make(chan struct{}),
|
closedChan: make(chan struct{}),
|
||||||
closed: true,
|
closed: true,
|
||||||
address: address,
|
address: address,
|
||||||
activityRecorder: NewActivityRecorder(),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
rc := receiverCreator{
|
rc := receiverCreator{
|
||||||
@@ -104,10 +100,6 @@ func (s *ICEBind) Close() error {
|
|||||||
return s.StdNetBind.Close()
|
return s.StdNetBind.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ICEBind) ActivityRecorder() *ActivityRecorder {
|
|
||||||
return s.activityRecorder
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
|
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
|
||||||
func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
|
func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
|
||||||
s.muUDPMux.Lock()
|
s.muUDPMux.Lock()
|
||||||
@@ -154,7 +146,7 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
|
|||||||
|
|
||||||
s.udpMux = NewUniversalUDPMuxDefault(
|
s.udpMux = NewUniversalUDPMuxDefault(
|
||||||
UniversalUDPMuxParams{
|
UniversalUDPMuxParams{
|
||||||
UDPConn: nbnet.WrapUDPConn(conn),
|
UDPConn: conn,
|
||||||
Net: s.transportNet,
|
Net: s.transportNet,
|
||||||
FilterFn: s.filterFn,
|
FilterFn: s.filterFn,
|
||||||
WGAddress: s.address,
|
WGAddress: s.address,
|
||||||
@@ -207,11 +199,6 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
|
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
|
||||||
|
|
||||||
if isTransportPkg(msg.Buffers, msg.N) {
|
|
||||||
s.activityRecorder.record(addrPort)
|
|
||||||
}
|
|
||||||
|
|
||||||
ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
|
ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
|
||||||
wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep)
|
wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep)
|
||||||
eps[i] = ep
|
eps[i] = ep
|
||||||
@@ -270,13 +257,6 @@ func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpo
|
|||||||
copy(buffs[0], msg.Buffer)
|
copy(buffs[0], msg.Buffer)
|
||||||
sizes[0] = len(msg.Buffer)
|
sizes[0] = len(msg.Buffer)
|
||||||
eps[0] = wgConn.Endpoint(msg.Endpoint)
|
eps[0] = wgConn.Endpoint(msg.Endpoint)
|
||||||
|
|
||||||
if isTransportPkg(buffs, sizes[0]) {
|
|
||||||
if ep, ok := eps[0].(*Endpoint); ok {
|
|
||||||
c.activityRecorder.record(ep.AddrPort)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return 1, nil
|
return 1, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -292,19 +272,3 @@ func putMessages(msgs *[]ipv6.Message, msgsPool *sync.Pool) {
|
|||||||
}
|
}
|
||||||
msgsPool.Put(msgs)
|
msgsPool.Put(msgs)
|
||||||
}
|
}
|
||||||
|
|
||||||
func isTransportPkg(buffers [][]byte, n int) bool {
|
|
||||||
// The first buffer should contain at least 4 bytes for type
|
|
||||||
if len(buffers[0]) < 4 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// WireGuard packet type is a little-endian uint32 at start
|
|
||||||
packetType := binary.LittleEndian.Uint32(buffers[0][:4])
|
|
||||||
|
|
||||||
// Check if packetType matches known WireGuard message types
|
|
||||||
if packetType == 4 && n > 32 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -296,20 +296,14 @@ func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var allAddresses []string
|
m.addressMapMu.Lock()
|
||||||
|
defer m.addressMapMu.Unlock()
|
||||||
|
|
||||||
for _, c := range removedConns {
|
for _, c := range removedConns {
|
||||||
addresses := c.getAddresses()
|
addresses := c.getAddresses()
|
||||||
allAddresses = append(allAddresses, addresses...)
|
for _, addr := range addresses {
|
||||||
}
|
delete(m.addressMap, addr)
|
||||||
|
}
|
||||||
m.addressMapMu.Lock()
|
|
||||||
for _, addr := range allAddresses {
|
|
||||||
delete(m.addressMap, addr)
|
|
||||||
}
|
|
||||||
m.addressMapMu.Unlock()
|
|
||||||
|
|
||||||
for _, addr := range allAddresses {
|
|
||||||
m.notifyAddressRemoval(addr)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -357,13 +351,14 @@ func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string)
|
|||||||
}
|
}
|
||||||
|
|
||||||
m.addressMapMu.Lock()
|
m.addressMapMu.Lock()
|
||||||
|
defer m.addressMapMu.Unlock()
|
||||||
|
|
||||||
existing, ok := m.addressMap[addr]
|
existing, ok := m.addressMap[addr]
|
||||||
if !ok {
|
if !ok {
|
||||||
existing = []*udpMuxedConn{}
|
existing = []*udpMuxedConn{}
|
||||||
}
|
}
|
||||||
existing = append(existing, conn)
|
existing = append(existing, conn)
|
||||||
m.addressMap[addr] = existing
|
m.addressMap[addr] = existing
|
||||||
m.addressMapMu.Unlock()
|
|
||||||
|
|
||||||
log.Debugf("ICE: registered %s for %s", addr, conn.params.Key)
|
log.Debugf("ICE: registered %s for %s", addr, conn.params.Key)
|
||||||
}
|
}
|
||||||
@@ -391,12 +386,12 @@ func (m *UDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.Addr) erro
|
|||||||
// If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one
|
// If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one
|
||||||
// muxed connection - one for the SRFLX candidate and the other one for the HOST one.
|
// muxed connection - one for the SRFLX candidate and the other one for the HOST one.
|
||||||
// We will then forward STUN packets to each of these connections.
|
// We will then forward STUN packets to each of these connections.
|
||||||
m.addressMapMu.RLock()
|
m.addressMapMu.Lock()
|
||||||
var destinationConnList []*udpMuxedConn
|
var destinationConnList []*udpMuxedConn
|
||||||
if storedConns, ok := m.addressMap[addr.String()]; ok {
|
if storedConns, ok := m.addressMap[addr.String()]; ok {
|
||||||
destinationConnList = append(destinationConnList, storedConns...)
|
destinationConnList = append(destinationConnList, storedConns...)
|
||||||
}
|
}
|
||||||
m.addressMapMu.RUnlock()
|
m.addressMapMu.Unlock()
|
||||||
|
|
||||||
var isIPv6 bool
|
var isIPv6 bool
|
||||||
if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil {
|
if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil {
|
||||||
|
|||||||
@@ -1,21 +0,0 @@
|
|||||||
//go:build !ios
|
|
||||||
|
|
||||||
package bind
|
|
||||||
|
|
||||||
import (
|
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (m *UDPMuxDefault) notifyAddressRemoval(addr string) {
|
|
||||||
wrapped, ok := m.params.UDPConn.(*UDPConn)
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
nbnetConn, ok := wrapped.GetPacketConn().(*nbnet.UDPConn)
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
nbnetConn.RemoveAddress(addr)
|
|
||||||
}
|
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
//go:build ios
|
|
||||||
|
|
||||||
package bind
|
|
||||||
|
|
||||||
func (m *UDPMuxDefault) notifyAddressRemoval(addr string) {
|
|
||||||
// iOS doesn't support nbnet hooks, so this is a no-op
|
|
||||||
}
|
|
||||||
@@ -62,7 +62,7 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
|
|||||||
|
|
||||||
// wrap UDP connection, process server reflexive messages
|
// wrap UDP connection, process server reflexive messages
|
||||||
// before they are passed to the UDPMux connection handler (connWorker)
|
// before they are passed to the UDPMux connection handler (connWorker)
|
||||||
m.params.UDPConn = &UDPConn{
|
m.params.UDPConn = &udpConn{
|
||||||
PacketConn: params.UDPConn,
|
PacketConn: params.UDPConn,
|
||||||
mux: m,
|
mux: m,
|
||||||
logger: params.Logger,
|
logger: params.Logger,
|
||||||
@@ -70,6 +70,7 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
|
|||||||
address: params.WGAddress,
|
address: params.WGAddress,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// embed UDPMux
|
||||||
udpMuxParams := UDPMuxParams{
|
udpMuxParams := UDPMuxParams{
|
||||||
Logger: params.Logger,
|
Logger: params.Logger,
|
||||||
UDPConn: m.params.UDPConn,
|
UDPConn: m.params.UDPConn,
|
||||||
@@ -113,8 +114,8 @@ func (m *UniversalUDPMuxDefault) ReadFromConn(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// UDPConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets
|
// udpConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets
|
||||||
type UDPConn struct {
|
type udpConn struct {
|
||||||
net.PacketConn
|
net.PacketConn
|
||||||
mux *UniversalUDPMuxDefault
|
mux *UniversalUDPMuxDefault
|
||||||
logger logging.LeveledLogger
|
logger logging.LeveledLogger
|
||||||
@@ -124,12 +125,7 @@ type UDPConn struct {
|
|||||||
address wgaddr.Address
|
address wgaddr.Address
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPacketConn returns the underlying PacketConn
|
func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) {
|
||||||
func (u *UDPConn) GetPacketConn() net.PacketConn {
|
|
||||||
return u.PacketConn
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
|
|
||||||
if u.filterFn == nil {
|
if u.filterFn == nil {
|
||||||
return u.PacketConn.WriteTo(b, addr)
|
return u.PacketConn.WriteTo(b, addr)
|
||||||
}
|
}
|
||||||
@@ -141,21 +137,21 @@ func (u *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
|
|||||||
return u.handleUncachedAddress(b, addr)
|
return u.handleUncachedAddress(b, addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *UDPConn) handleCachedAddress(isRouted bool, b []byte, addr net.Addr) (int, error) {
|
func (u *udpConn) handleCachedAddress(isRouted bool, b []byte, addr net.Addr) (int, error) {
|
||||||
if isRouted {
|
if isRouted {
|
||||||
return 0, fmt.Errorf("address %s is part of a routed network, refusing to write", addr)
|
return 0, fmt.Errorf("address %s is part of a routed network, refusing to write", addr)
|
||||||
}
|
}
|
||||||
return u.PacketConn.WriteTo(b, addr)
|
return u.PacketConn.WriteTo(b, addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *UDPConn) handleUncachedAddress(b []byte, addr net.Addr) (int, error) {
|
func (u *udpConn) handleUncachedAddress(b []byte, addr net.Addr) (int, error) {
|
||||||
if err := u.performFilterCheck(addr); err != nil {
|
if err := u.performFilterCheck(addr); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
return u.PacketConn.WriteTo(b, addr)
|
return u.PacketConn.WriteTo(b, addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *UDPConn) performFilterCheck(addr net.Addr) error {
|
func (u *udpConn) performFilterCheck(addr net.Addr) error {
|
||||||
host, err := getHostFromAddr(addr)
|
host, err := getHostFromAddr(addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed to get host from address %s: %v", addr, err)
|
log.Errorf("Failed to get host from address %s: %v", addr, err)
|
||||||
@@ -168,7 +164,7 @@ func (u *UDPConn) performFilterCheck(addr net.Addr) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if u.address.Network.Contains(a) {
|
if u.address.Network.Contains(a.AsSlice()) {
|
||||||
log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||||
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,17 +0,0 @@
|
|||||||
package configurer
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
)
|
|
||||||
|
|
||||||
func prefixesToIPNets(prefixes []netip.Prefix) []net.IPNet {
|
|
||||||
ipNets := make([]net.IPNet, len(prefixes))
|
|
||||||
for i, prefix := range prefixes {
|
|
||||||
ipNets[i] = net.IPNet{
|
|
||||||
IP: prefix.Addr().AsSlice(), // Convert netip.Addr to net.IP
|
|
||||||
Mask: net.CIDRMask(prefix.Bits(), prefix.Addr().BitLen()), // Create subnet mask
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ipNets
|
|
||||||
}
|
|
||||||
@@ -5,18 +5,13 @@ package configurer
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl"
|
"golang.zx2c4.com/wireguard/wgctrl"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/monotime"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var zeroKey wgtypes.Key
|
|
||||||
|
|
||||||
type KernelConfigurer struct {
|
type KernelConfigurer struct {
|
||||||
deviceName string
|
deviceName string
|
||||||
}
|
}
|
||||||
@@ -48,7 +43,7 @@ func (c *KernelConfigurer) ConfigureInterface(privateKey string, port int) error
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
||||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -57,7 +52,7 @@ func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
|
|||||||
PublicKey: peerKeyParsed,
|
PublicKey: peerKeyParsed,
|
||||||
ReplaceAllowedIPs: false,
|
ReplaceAllowedIPs: false,
|
||||||
// don't replace allowed ips, wg will handle duplicated peer IP
|
// don't replace allowed ips, wg will handle duplicated peer IP
|
||||||
AllowedIPs: prefixesToIPNets(allowedIps),
|
AllowedIPs: allowedIps,
|
||||||
PersistentKeepaliveInterval: &keepAlive,
|
PersistentKeepaliveInterval: &keepAlive,
|
||||||
Endpoint: endpoint,
|
Endpoint: endpoint,
|
||||||
PresharedKey: preSharedKey,
|
PresharedKey: preSharedKey,
|
||||||
@@ -94,10 +89,10 @@ func (c *KernelConfigurer) RemovePeer(peerKey string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *KernelConfigurer) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
|
func (c *KernelConfigurer) AddAllowedIP(peerKey string, allowedIP string) error {
|
||||||
ipNet := net.IPNet{
|
_, ipNet, err := net.ParseCIDR(allowedIP)
|
||||||
IP: allowedIP.Addr().AsSlice(),
|
if err != nil {
|
||||||
Mask: net.CIDRMask(allowedIP.Bits(), allowedIP.Addr().BitLen()),
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||||
@@ -108,7 +103,7 @@ func (c *KernelConfigurer) AddAllowedIP(peerKey string, allowedIP netip.Prefix)
|
|||||||
PublicKey: peerKeyParsed,
|
PublicKey: peerKeyParsed,
|
||||||
UpdateOnly: true,
|
UpdateOnly: true,
|
||||||
ReplaceAllowedIPs: false,
|
ReplaceAllowedIPs: false,
|
||||||
AllowedIPs: []net.IPNet{ipNet},
|
AllowedIPs: []net.IPNet{*ipNet},
|
||||||
}
|
}
|
||||||
|
|
||||||
config := wgtypes.Config{
|
config := wgtypes.Config{
|
||||||
@@ -121,10 +116,10 @@ func (c *KernelConfigurer) AddAllowedIP(peerKey string, allowedIP netip.Prefix)
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *KernelConfigurer) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error {
|
func (c *KernelConfigurer) RemoveAllowedIP(peerKey string, allowedIP string) error {
|
||||||
ipNet := net.IPNet{
|
_, ipNet, err := net.ParseCIDR(allowedIP)
|
||||||
IP: allowedIP.Addr().AsSlice(),
|
if err != nil {
|
||||||
Mask: net.CIDRMask(allowedIP.Bits(), allowedIP.Addr().BitLen()),
|
return fmt.Errorf("parse allowed IP: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||||
@@ -192,11 +187,7 @@ func (c *KernelConfigurer) configure(config wgtypes.Config) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer func() {
|
defer wg.Close()
|
||||||
if err := wg.Close(); err != nil {
|
|
||||||
log.Errorf("Failed to close wgctrl client: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// validate if device with name exists
|
// validate if device with name exists
|
||||||
_, err = wg.Device(c.deviceName)
|
_, err = wg.Device(c.deviceName)
|
||||||
@@ -210,75 +201,14 @@ func (c *KernelConfigurer) configure(config wgtypes.Config) error {
|
|||||||
func (c *KernelConfigurer) Close() {
|
func (c *KernelConfigurer) Close() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *KernelConfigurer) FullStats() (*Stats, error) {
|
func (c *KernelConfigurer) GetStats(peerKey string) (WGStats, error) {
|
||||||
wg, err := wgctrl.New()
|
peer, err := c.getPeer(c.deviceName, peerKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("wgctl: %w", err)
|
return WGStats{}, fmt.Errorf("get wireguard stats: %w", err)
|
||||||
}
|
}
|
||||||
defer func() {
|
return WGStats{
|
||||||
err = wg.Close()
|
LastHandshake: peer.LastHandshakeTime,
|
||||||
if err != nil {
|
TxBytes: peer.TransmitBytes,
|
||||||
log.Errorf("Got error while closing wgctl: %v", err)
|
RxBytes: peer.ReceiveBytes,
|
||||||
}
|
}, nil
|
||||||
}()
|
|
||||||
|
|
||||||
wgDevice, err := wg.Device(c.deviceName)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("get device %s: %w", c.deviceName, err)
|
|
||||||
}
|
|
||||||
fullStats := &Stats{
|
|
||||||
DeviceName: wgDevice.Name,
|
|
||||||
PublicKey: wgDevice.PublicKey.String(),
|
|
||||||
ListenPort: wgDevice.ListenPort,
|
|
||||||
FWMark: wgDevice.FirewallMark,
|
|
||||||
Peers: []Peer{},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, p := range wgDevice.Peers {
|
|
||||||
peer := Peer{
|
|
||||||
PublicKey: p.PublicKey.String(),
|
|
||||||
AllowedIPs: p.AllowedIPs,
|
|
||||||
TxBytes: p.TransmitBytes,
|
|
||||||
RxBytes: p.ReceiveBytes,
|
|
||||||
LastHandshake: p.LastHandshakeTime,
|
|
||||||
PresharedKey: p.PresharedKey != zeroKey,
|
|
||||||
}
|
|
||||||
if p.Endpoint != nil {
|
|
||||||
peer.Endpoint = *p.Endpoint
|
|
||||||
}
|
|
||||||
fullStats.Peers = append(fullStats.Peers, peer)
|
|
||||||
}
|
|
||||||
return fullStats, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *KernelConfigurer) GetStats() (map[string]WGStats, error) {
|
|
||||||
stats := make(map[string]WGStats)
|
|
||||||
wg, err := wgctrl.New()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("wgctl: %w", err)
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
err = wg.Close()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("Got error while closing wgctl: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
wgDevice, err := wg.Device(c.deviceName)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("get device %s: %w", c.deviceName, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, peer := range wgDevice.Peers {
|
|
||||||
stats[peer.PublicKey.String()] = WGStats{
|
|
||||||
LastHandshake: peer.LastHandshakeTime,
|
|
||||||
TxBytes: peer.TransmitBytes,
|
|
||||||
RxBytes: peer.ReceiveBytes,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return stats, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *KernelConfigurer) LastActivities() map[string]monotime.Time {
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,4 +3,4 @@
|
|||||||
package configurer
|
package configurer
|
||||||
|
|
||||||
// WgInterfaceDefault is a default interface name of Netbird
|
// WgInterfaceDefault is a default interface name of Netbird
|
||||||
const WgInterfaceDefault = "nb0"
|
const WgInterfaceDefault = "wt0"
|
||||||
|
|||||||
@@ -1,11 +1,9 @@
|
|||||||
package configurer
|
package configurer
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/base64"
|
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
|
||||||
"os"
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -16,40 +14,22 @@ import (
|
|||||||
"golang.zx2c4.com/wireguard/device"
|
"golang.zx2c4.com/wireguard/device"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
|
||||||
"github.com/netbirdio/netbird/monotime"
|
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
privateKey = "private_key"
|
|
||||||
ipcKeyLastHandshakeTimeSec = "last_handshake_time_sec"
|
|
||||||
ipcKeyLastHandshakeTimeNsec = "last_handshake_time_nsec"
|
|
||||||
ipcKeyTxBytes = "tx_bytes"
|
|
||||||
ipcKeyRxBytes = "rx_bytes"
|
|
||||||
allowedIP = "allowed_ip"
|
|
||||||
endpoint = "endpoint"
|
|
||||||
fwmark = "fwmark"
|
|
||||||
listenPort = "listen_port"
|
|
||||||
publicKey = "public_key"
|
|
||||||
presharedKey = "preshared_key"
|
|
||||||
)
|
|
||||||
|
|
||||||
var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found")
|
var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found")
|
||||||
|
|
||||||
type WGUSPConfigurer struct {
|
type WGUSPConfigurer struct {
|
||||||
device *device.Device
|
device *device.Device
|
||||||
deviceName string
|
deviceName string
|
||||||
activityRecorder *bind.ActivityRecorder
|
|
||||||
|
|
||||||
uapiListener net.Listener
|
uapiListener net.Listener
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUSPConfigurer(device *device.Device, deviceName string, activityRecorder *bind.ActivityRecorder) *WGUSPConfigurer {
|
func NewUSPConfigurer(device *device.Device, deviceName string) *WGUSPConfigurer {
|
||||||
wgCfg := &WGUSPConfigurer{
|
wgCfg := &WGUSPConfigurer{
|
||||||
device: device,
|
device: device,
|
||||||
deviceName: deviceName,
|
deviceName: deviceName,
|
||||||
activityRecorder: activityRecorder,
|
|
||||||
}
|
}
|
||||||
wgCfg.startUAPI()
|
wgCfg.startUAPI()
|
||||||
return wgCfg
|
return wgCfg
|
||||||
@@ -72,7 +52,7 @@ func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error
|
|||||||
return c.device.IpcSet(toWgUserspaceString(config))
|
return c.device.IpcSet(toWgUserspaceString(config))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
||||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -81,7 +61,7 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
|
|||||||
PublicKey: peerKeyParsed,
|
PublicKey: peerKeyParsed,
|
||||||
ReplaceAllowedIPs: false,
|
ReplaceAllowedIPs: false,
|
||||||
// don't replace allowed ips, wg will handle duplicated peer IP
|
// don't replace allowed ips, wg will handle duplicated peer IP
|
||||||
AllowedIPs: prefixesToIPNets(allowedIps),
|
AllowedIPs: allowedIps,
|
||||||
PersistentKeepaliveInterval: &keepAlive,
|
PersistentKeepaliveInterval: &keepAlive,
|
||||||
PresharedKey: preSharedKey,
|
PresharedKey: preSharedKey,
|
||||||
Endpoint: endpoint,
|
Endpoint: endpoint,
|
||||||
@@ -91,19 +71,7 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
|
|||||||
Peers: []wgtypes.PeerConfig{peer},
|
Peers: []wgtypes.PeerConfig{peer},
|
||||||
}
|
}
|
||||||
|
|
||||||
if ipcErr := c.device.IpcSet(toWgUserspaceString(config)); ipcErr != nil {
|
return c.device.IpcSet(toWgUserspaceString(config))
|
||||||
return ipcErr
|
|
||||||
}
|
|
||||||
|
|
||||||
if endpoint != nil {
|
|
||||||
addr, err := netip.ParseAddr(endpoint.IP.String())
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to parse endpoint address: %w", err)
|
|
||||||
}
|
|
||||||
addrPort := netip.AddrPortFrom(addr, uint16(endpoint.Port))
|
|
||||||
c.activityRecorder.UpsertAddress(peerKey, addrPort)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
|
func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
|
||||||
@@ -120,16 +88,13 @@ func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
|
|||||||
config := wgtypes.Config{
|
config := wgtypes.Config{
|
||||||
Peers: []wgtypes.PeerConfig{peer},
|
Peers: []wgtypes.PeerConfig{peer},
|
||||||
}
|
}
|
||||||
ipcErr := c.device.IpcSet(toWgUserspaceString(config))
|
return c.device.IpcSet(toWgUserspaceString(config))
|
||||||
|
|
||||||
c.activityRecorder.Remove(peerKey)
|
|
||||||
return ipcErr
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
|
func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP string) error {
|
||||||
ipNet := net.IPNet{
|
_, ipNet, err := net.ParseCIDR(allowedIP)
|
||||||
IP: allowedIP.Addr().AsSlice(),
|
if err != nil {
|
||||||
Mask: net.CIDRMask(allowedIP.Bits(), allowedIP.Addr().BitLen()),
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||||
@@ -140,7 +105,7 @@ func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP netip.Prefix) e
|
|||||||
PublicKey: peerKeyParsed,
|
PublicKey: peerKeyParsed,
|
||||||
UpdateOnly: true,
|
UpdateOnly: true,
|
||||||
ReplaceAllowedIPs: false,
|
ReplaceAllowedIPs: false,
|
||||||
AllowedIPs: []net.IPNet{ipNet},
|
AllowedIPs: []net.IPNet{*ipNet},
|
||||||
}
|
}
|
||||||
|
|
||||||
config := wgtypes.Config{
|
config := wgtypes.Config{
|
||||||
@@ -150,7 +115,7 @@ func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP netip.Prefix) e
|
|||||||
return c.device.IpcSet(toWgUserspaceString(config))
|
return c.device.IpcSet(toWgUserspaceString(config))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error {
|
func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, ip string) error {
|
||||||
ipc, err := c.device.IpcGet()
|
ipc, err := c.device.IpcGet()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -173,8 +138,6 @@ func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix
|
|||||||
|
|
||||||
foundPeer := false
|
foundPeer := false
|
||||||
removedAllowedIP := false
|
removedAllowedIP := false
|
||||||
ip := allowedIP.String()
|
|
||||||
|
|
||||||
for _, line := range lines {
|
for _, line := range lines {
|
||||||
line = strings.TrimSpace(line)
|
line = strings.TrimSpace(line)
|
||||||
|
|
||||||
@@ -197,8 +160,8 @@ func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix
|
|||||||
|
|
||||||
// Append the line to the output string
|
// Append the line to the output string
|
||||||
if foundPeer && strings.HasPrefix(line, "allowed_ip=") {
|
if foundPeer && strings.HasPrefix(line, "allowed_ip=") {
|
||||||
allowedIPStr := strings.TrimPrefix(line, "allowed_ip=")
|
allowedIP := strings.TrimPrefix(line, "allowed_ip=")
|
||||||
_, ipNet, err := net.ParseCIDR(allowedIPStr)
|
_, ipNet, err := net.ParseCIDR(allowedIP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -215,19 +178,6 @@ func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix
|
|||||||
return c.device.IpcSet(toWgUserspaceString(config))
|
return c.device.IpcSet(toWgUserspaceString(config))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *WGUSPConfigurer) FullStats() (*Stats, error) {
|
|
||||||
ipcStr, err := c.device.IpcGet()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("IpcGet failed: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return parseStatus(c.deviceName, ipcStr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *WGUSPConfigurer) LastActivities() map[string]monotime.Time {
|
|
||||||
return c.activityRecorder.GetLastActivities()
|
|
||||||
}
|
|
||||||
|
|
||||||
// startUAPI starts the UAPI listener for managing the WireGuard interface via external tool
|
// startUAPI starts the UAPI listener for managing the WireGuard interface via external tool
|
||||||
func (t *WGUSPConfigurer) startUAPI() {
|
func (t *WGUSPConfigurer) startUAPI() {
|
||||||
var err error
|
var err error
|
||||||
@@ -267,75 +217,91 @@ func (t *WGUSPConfigurer) Close() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *WGUSPConfigurer) GetStats() (map[string]WGStats, error) {
|
func (t *WGUSPConfigurer) GetStats(peerKey string) (WGStats, error) {
|
||||||
ipc, err := t.device.IpcGet()
|
ipc, err := t.device.IpcGet()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("ipc get: %w", err)
|
return WGStats{}, fmt.Errorf("ipc get: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return parseTransfers(ipc)
|
stats, err := findPeerInfo(ipc, peerKey, []string{
|
||||||
|
"last_handshake_time_sec",
|
||||||
|
"last_handshake_time_nsec",
|
||||||
|
"tx_bytes",
|
||||||
|
"rx_bytes",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return WGStats{}, fmt.Errorf("find peer info: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sec, err := strconv.ParseInt(stats["last_handshake_time_sec"], 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return WGStats{}, fmt.Errorf("parse handshake sec: %w", err)
|
||||||
|
}
|
||||||
|
nsec, err := strconv.ParseInt(stats["last_handshake_time_nsec"], 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return WGStats{}, fmt.Errorf("parse handshake nsec: %w", err)
|
||||||
|
}
|
||||||
|
txBytes, err := strconv.ParseInt(stats["tx_bytes"], 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return WGStats{}, fmt.Errorf("parse tx_bytes: %w", err)
|
||||||
|
}
|
||||||
|
rxBytes, err := strconv.ParseInt(stats["rx_bytes"], 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return WGStats{}, fmt.Errorf("parse rx_bytes: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return WGStats{
|
||||||
|
LastHandshake: time.Unix(sec, nsec),
|
||||||
|
TxBytes: txBytes,
|
||||||
|
RxBytes: rxBytes,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseTransfers(ipc string) (map[string]WGStats, error) {
|
func findPeerInfo(ipcInput string, peerKey string, searchConfigKeys []string) (map[string]string, error) {
|
||||||
stats := make(map[string]WGStats)
|
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||||
var (
|
if err != nil {
|
||||||
currentKey string
|
return nil, fmt.Errorf("parse key: %w", err)
|
||||||
currentStats WGStats
|
}
|
||||||
hasPeer bool
|
|
||||||
)
|
hexKey := hex.EncodeToString(peerKeyParsed[:])
|
||||||
lines := strings.Split(ipc, "\n")
|
|
||||||
|
lines := strings.Split(ipcInput, "\n")
|
||||||
|
|
||||||
|
configFound := map[string]string{}
|
||||||
|
foundPeer := false
|
||||||
for _, line := range lines {
|
for _, line := range lines {
|
||||||
line = strings.TrimSpace(line)
|
line = strings.TrimSpace(line)
|
||||||
|
|
||||||
// If we're within the details of the found peer and encounter another public key,
|
// If we're within the details of the found peer and encounter another public key,
|
||||||
// this means we're starting another peer's details. So, stop.
|
// this means we're starting another peer's details. So, stop.
|
||||||
if strings.HasPrefix(line, "public_key=") {
|
if strings.HasPrefix(line, "public_key=") && foundPeer {
|
||||||
peerID := strings.TrimPrefix(line, "public_key=")
|
break
|
||||||
h, err := hex.DecodeString(peerID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("decode peerID: %w", err)
|
|
||||||
}
|
|
||||||
currentKey = base64.StdEncoding.EncodeToString(h)
|
|
||||||
currentStats = WGStats{} // Reset stats for the new peer
|
|
||||||
hasPeer = true
|
|
||||||
stats[currentKey] = currentStats
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if !hasPeer {
|
// Identify the peer with the specific public key
|
||||||
continue
|
if line == fmt.Sprintf("public_key=%s", hexKey) {
|
||||||
|
foundPeer = true
|
||||||
}
|
}
|
||||||
|
|
||||||
key := strings.SplitN(line, "=", 2)
|
for _, key := range searchConfigKeys {
|
||||||
if len(key) != 2 {
|
if foundPeer && strings.HasPrefix(line, key+"=") {
|
||||||
continue
|
v := strings.SplitN(line, "=", 2)
|
||||||
}
|
configFound[v[0]] = v[1]
|
||||||
switch key[0] {
|
|
||||||
case ipcKeyLastHandshakeTimeSec:
|
|
||||||
hs, err := toLastHandshake(key[1])
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
currentStats.LastHandshake = hs
|
|
||||||
stats[currentKey] = currentStats
|
|
||||||
case ipcKeyRxBytes:
|
|
||||||
rxBytes, err := toBytes(key[1])
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("parse rx_bytes: %w", err)
|
|
||||||
}
|
|
||||||
currentStats.RxBytes = rxBytes
|
|
||||||
stats[currentKey] = currentStats
|
|
||||||
case ipcKeyTxBytes:
|
|
||||||
TxBytes, err := toBytes(key[1])
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("parse tx_bytes: %w", err)
|
|
||||||
}
|
|
||||||
currentStats.TxBytes = TxBytes
|
|
||||||
stats[currentKey] = currentStats
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return stats, nil
|
// todo: use multierr
|
||||||
|
for _, key := range searchConfigKeys {
|
||||||
|
if _, ok := configFound[key]; !ok {
|
||||||
|
return configFound, fmt.Errorf("config key not found: %s", key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !foundPeer {
|
||||||
|
return nil, fmt.Errorf("%w: %s", ErrPeerNotFound, peerKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
return configFound, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func toWgUserspaceString(wgCfg wgtypes.Config) string {
|
func toWgUserspaceString(wgCfg wgtypes.Config) string {
|
||||||
@@ -389,154 +355,9 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
|
|||||||
return sb.String()
|
return sb.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func toLastHandshake(stringVar string) (time.Time, error) {
|
|
||||||
sec, err := strconv.ParseInt(stringVar, 10, 64)
|
|
||||||
if err != nil {
|
|
||||||
return time.Time{}, fmt.Errorf("parse handshake sec: %w", err)
|
|
||||||
}
|
|
||||||
return time.Unix(sec, 0), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func toBytes(s string) (int64, error) {
|
|
||||||
return strconv.ParseInt(s, 10, 64)
|
|
||||||
}
|
|
||||||
|
|
||||||
func getFwmark() int {
|
func getFwmark() int {
|
||||||
if nbnet.AdvancedRouting() {
|
if nbnet.AdvancedRouting() {
|
||||||
return nbnet.ControlPlaneMark
|
return nbnet.ControlPlaneMark
|
||||||
}
|
}
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func hexToWireguardKey(hexKey string) (wgtypes.Key, error) {
|
|
||||||
// Decode hex string to bytes
|
|
||||||
keyBytes, err := hex.DecodeString(hexKey)
|
|
||||||
if err != nil {
|
|
||||||
return wgtypes.Key{}, fmt.Errorf("failed to decode hex key: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if we have the right number of bytes (WireGuard keys are 32 bytes)
|
|
||||||
if len(keyBytes) != 32 {
|
|
||||||
return wgtypes.Key{}, fmt.Errorf("invalid key length: expected 32 bytes, got %d", len(keyBytes))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert to wgtypes.Key
|
|
||||||
var key wgtypes.Key
|
|
||||||
copy(key[:], keyBytes)
|
|
||||||
|
|
||||||
return key, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseStatus(deviceName, ipcStr string) (*Stats, error) {
|
|
||||||
stats := &Stats{DeviceName: deviceName}
|
|
||||||
var currentPeer *Peer
|
|
||||||
for _, line := range strings.Split(strings.TrimSpace(ipcStr), "\n") {
|
|
||||||
if line == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
parts := strings.SplitN(line, "=", 2)
|
|
||||||
if len(parts) != 2 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
key := parts[0]
|
|
||||||
val := parts[1]
|
|
||||||
|
|
||||||
switch key {
|
|
||||||
case privateKey:
|
|
||||||
key, err := hexToWireguardKey(val)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to parse private key: %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
stats.PublicKey = key.PublicKey().String()
|
|
||||||
case publicKey:
|
|
||||||
// Save previous peer
|
|
||||||
if currentPeer != nil {
|
|
||||||
stats.Peers = append(stats.Peers, *currentPeer)
|
|
||||||
}
|
|
||||||
key, err := hexToWireguardKey(val)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to parse public key: %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
currentPeer = &Peer{
|
|
||||||
PublicKey: key.String(),
|
|
||||||
}
|
|
||||||
case listenPort:
|
|
||||||
if port, err := strconv.Atoi(val); err == nil {
|
|
||||||
stats.ListenPort = port
|
|
||||||
}
|
|
||||||
case fwmark:
|
|
||||||
if fwmark, err := strconv.Atoi(val); err == nil {
|
|
||||||
stats.FWMark = fwmark
|
|
||||||
}
|
|
||||||
case endpoint:
|
|
||||||
if currentPeer == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
host, portStr, err := net.SplitHostPort(strings.Trim(val, "[]"))
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to parse endpoint: %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
port, err := strconv.Atoi(portStr)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to parse endpoint port: %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
currentPeer.Endpoint = net.UDPAddr{
|
|
||||||
IP: net.ParseIP(host),
|
|
||||||
Port: port,
|
|
||||||
}
|
|
||||||
case allowedIP:
|
|
||||||
if currentPeer == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
_, ipnet, err := net.ParseCIDR(val)
|
|
||||||
if err == nil {
|
|
||||||
currentPeer.AllowedIPs = append(currentPeer.AllowedIPs, *ipnet)
|
|
||||||
}
|
|
||||||
case ipcKeyTxBytes:
|
|
||||||
if currentPeer == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
rxBytes, err := toBytes(val)
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
currentPeer.TxBytes = rxBytes
|
|
||||||
case ipcKeyRxBytes:
|
|
||||||
if currentPeer == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
rxBytes, err := toBytes(val)
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
currentPeer.RxBytes = rxBytes
|
|
||||||
|
|
||||||
case ipcKeyLastHandshakeTimeSec:
|
|
||||||
if currentPeer == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
ts, err := toLastHandshake(val)
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
currentPeer.LastHandshake = ts
|
|
||||||
case presharedKey:
|
|
||||||
if currentPeer == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if val != "" {
|
|
||||||
currentPeer.PresharedKey = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if currentPeer != nil {
|
|
||||||
stats.Peers = append(stats.Peers, *currentPeer)
|
|
||||||
}
|
|
||||||
return stats, nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -2,8 +2,10 @@ package configurer
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
)
|
)
|
||||||
@@ -32,35 +34,58 @@ errno=0
|
|||||||
|
|
||||||
`
|
`
|
||||||
|
|
||||||
func Test_parseTransfers(t *testing.T) {
|
func Test_findPeerInfo(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
peerKey string
|
peerKey string
|
||||||
want WGStats
|
searchKeys []string
|
||||||
|
want map[string]string
|
||||||
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "single",
|
name: "single",
|
||||||
peerKey: "b85996fecc9c7f1fc6d2572a76eda11d59bcd20be8e543b15ce4bd85a8e75a33",
|
peerKey: "58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376",
|
||||||
want: WGStats{
|
searchKeys: []string{"tx_bytes"},
|
||||||
TxBytes: 0,
|
want: map[string]string{
|
||||||
RxBytes: 0,
|
"tx_bytes": "38333",
|
||||||
},
|
},
|
||||||
|
wantErr: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "multiple",
|
name: "multiple",
|
||||||
peerKey: "58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376",
|
peerKey: "58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376",
|
||||||
want: WGStats{
|
searchKeys: []string{"tx_bytes", "rx_bytes"},
|
||||||
TxBytes: 38333,
|
want: map[string]string{
|
||||||
RxBytes: 2224,
|
"tx_bytes": "38333",
|
||||||
|
"rx_bytes": "2224",
|
||||||
},
|
},
|
||||||
|
wantErr: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "lastpeer",
|
name: "lastpeer",
|
||||||
peerKey: "662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58",
|
peerKey: "662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58",
|
||||||
want: WGStats{
|
searchKeys: []string{"tx_bytes", "rx_bytes"},
|
||||||
TxBytes: 1212111,
|
want: map[string]string{
|
||||||
RxBytes: 1929999999,
|
"tx_bytes": "1212111",
|
||||||
|
"rx_bytes": "1929999999",
|
||||||
},
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "peer not found",
|
||||||
|
peerKey: "1111111111111111111111111111111111111111111111111111111111111111",
|
||||||
|
searchKeys: nil,
|
||||||
|
want: nil,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "key not found",
|
||||||
|
peerKey: "662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58",
|
||||||
|
searchKeys: []string{"tx_bytes", "unknown_key"},
|
||||||
|
want: map[string]string{
|
||||||
|
"tx_bytes": "1212111",
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
@@ -71,19 +96,9 @@ func Test_parseTransfers(t *testing.T) {
|
|||||||
key, err := wgtypes.NewKey(res)
|
key, err := wgtypes.NewKey(res)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
stats, err := parseTransfers(ipcFixture)
|
got, err := findPeerInfo(ipcFixture, key.String(), tt.searchKeys)
|
||||||
if err != nil {
|
assert.Equalf(t, tt.wantErr, err != nil, fmt.Sprintf("findPeerInfo(%v, %v, %v)", ipcFixture, key.String(), tt.searchKeys))
|
||||||
require.NoError(t, err)
|
assert.Equalf(t, tt.want, got, "findPeerInfo(%v, %v, %v)", ipcFixture, key.String(), tt.searchKeys)
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
stat, ok := stats[key.String()]
|
|
||||||
if !ok {
|
|
||||||
require.True(t, ok)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
require.Equal(t, tt.want, stat)
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,24 +0,0 @@
|
|||||||
package configurer
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Peer struct {
|
|
||||||
PublicKey string
|
|
||||||
Endpoint net.UDPAddr
|
|
||||||
AllowedIPs []net.IPNet
|
|
||||||
TxBytes int64
|
|
||||||
RxBytes int64
|
|
||||||
LastHandshake time.Time
|
|
||||||
PresharedKey bool
|
|
||||||
}
|
|
||||||
|
|
||||||
type Stats struct {
|
|
||||||
DeviceName string
|
|
||||||
PublicKey string
|
|
||||||
ListenPort int
|
|
||||||
FWMark int
|
|
||||||
Peers []Peer
|
|
||||||
}
|
|
||||||
@@ -24,7 +24,6 @@ type WGTunDevice struct {
|
|||||||
mtu int
|
mtu int
|
||||||
iceBind *bind.ICEBind
|
iceBind *bind.ICEBind
|
||||||
tunAdapter TunAdapter
|
tunAdapter TunAdapter
|
||||||
disableDNS bool
|
|
||||||
|
|
||||||
name string
|
name string
|
||||||
device *device.Device
|
device *device.Device
|
||||||
@@ -33,7 +32,7 @@ type WGTunDevice struct {
|
|||||||
configurer WGConfigurer
|
configurer WGConfigurer
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter, disableDNS bool) *WGTunDevice {
|
func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice {
|
||||||
return &WGTunDevice{
|
return &WGTunDevice{
|
||||||
address: address,
|
address: address,
|
||||||
port: port,
|
port: port,
|
||||||
@@ -41,7 +40,6 @@ func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind
|
|||||||
mtu: mtu,
|
mtu: mtu,
|
||||||
iceBind: iceBind,
|
iceBind: iceBind,
|
||||||
tunAdapter: tunAdapter,
|
tunAdapter: tunAdapter,
|
||||||
disableDNS: disableDNS,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -51,13 +49,6 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
|
|||||||
routesString := routesToString(routes)
|
routesString := routesToString(routes)
|
||||||
searchDomainsToString := searchDomainsToString(searchDomains)
|
searchDomainsToString := searchDomainsToString(searchDomains)
|
||||||
|
|
||||||
// Skip DNS configuration when DisableDNS is enabled
|
|
||||||
if t.disableDNS {
|
|
||||||
log.Info("DNS is disabled, skipping DNS and search domain configuration")
|
|
||||||
dns = ""
|
|
||||||
searchDomainsToString = ""
|
|
||||||
}
|
|
||||||
|
|
||||||
fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, dns, searchDomainsToString, routesString)
|
fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, dns, searchDomainsToString, routesString)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to create Android interface: %s", err)
|
log.Errorf("failed to create Android interface: %s", err)
|
||||||
@@ -79,7 +70,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
|
|||||||
// this helps with support for the older NetBird clients that had a hardcoded direct mode
|
// this helps with support for the older NetBird clients that had a hardcoded direct mode
|
||||||
// t.device.DisableSomeRoamingForBrokenMobileSemantics()
|
// t.device.DisableSomeRoamingForBrokenMobileSemantics()
|
||||||
|
|
||||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
|
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
|
||||||
err = t.configurer.ConfigureInterface(t.key, t.port)
|
err = t.configurer.ConfigureInterface(t.key, t.port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.device.Close()
|
t.device.Close()
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
|
|||||||
return nil, fmt.Errorf("error assigning ip: %s", err)
|
return nil, fmt.Errorf("error assigning ip: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
|
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
|
||||||
err = t.configurer.ConfigureInterface(t.key, t.port)
|
err = t.configurer.ConfigureInterface(t.key, t.port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.device.Close()
|
t.device.Close()
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package device
|
package device
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
@@ -9,11 +10,11 @@ import (
|
|||||||
|
|
||||||
// PacketFilter interface for firewall abilities
|
// PacketFilter interface for firewall abilities
|
||||||
type PacketFilter interface {
|
type PacketFilter interface {
|
||||||
// FilterOutbound filter outgoing packets from host to external destinations
|
// DropOutgoing filter outgoing packets from host to external destinations
|
||||||
FilterOutbound(packetData []byte, size int) bool
|
DropOutgoing(packetData []byte, size int) bool
|
||||||
|
|
||||||
// FilterInbound filter incoming packets from external sources to host
|
// DropIncoming filter incoming packets from external sources to host
|
||||||
FilterInbound(packetData []byte, size int) bool
|
DropIncoming(packetData []byte, size int) bool
|
||||||
|
|
||||||
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
||||||
//
|
//
|
||||||
@@ -23,6 +24,9 @@ type PacketFilter interface {
|
|||||||
|
|
||||||
// RemovePacketHook removes hook by ID
|
// RemovePacketHook removes hook by ID
|
||||||
RemovePacketHook(hookID string) error
|
RemovePacketHook(hookID string) error
|
||||||
|
|
||||||
|
// SetNetwork of the wireguard interface to which filtering applied
|
||||||
|
SetNetwork(*net.IPNet)
|
||||||
}
|
}
|
||||||
|
|
||||||
// FilteredDevice to override Read or Write of packets
|
// FilteredDevice to override Read or Write of packets
|
||||||
@@ -54,7 +58,7 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er
|
|||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < n; i++ {
|
for i := 0; i < n; i++ {
|
||||||
if filter.FilterOutbound(bufs[i][offset:offset+sizes[i]], sizes[i]) {
|
if filter.DropOutgoing(bufs[i][offset:offset+sizes[i]], sizes[i]) {
|
||||||
bufs = append(bufs[:i], bufs[i+1:]...)
|
bufs = append(bufs[:i], bufs[i+1:]...)
|
||||||
sizes = append(sizes[:i], sizes[i+1:]...)
|
sizes = append(sizes[:i], sizes[i+1:]...)
|
||||||
n--
|
n--
|
||||||
@@ -78,7 +82,7 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
|
|||||||
filteredBufs := make([][]byte, 0, len(bufs))
|
filteredBufs := make([][]byte, 0, len(bufs))
|
||||||
dropped := 0
|
dropped := 0
|
||||||
for _, buf := range bufs {
|
for _, buf := range bufs {
|
||||||
if !filter.FilterInbound(buf[offset:], len(buf)) {
|
if !filter.DropIncoming(buf[offset:], len(buf)) {
|
||||||
filteredBufs = append(filteredBufs, buf)
|
filteredBufs = append(filteredBufs, buf)
|
||||||
dropped++
|
dropped++
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -146,7 +146,7 @@ func TestDeviceWrapperRead(t *testing.T) {
|
|||||||
tun.EXPECT().Write(mockBufs, 0).Return(0, nil)
|
tun.EXPECT().Write(mockBufs, 0).Return(0, nil)
|
||||||
|
|
||||||
filter := mocks.NewMockPacketFilter(ctrl)
|
filter := mocks.NewMockPacketFilter(ctrl)
|
||||||
filter.EXPECT().FilterInbound(gomock.Any(), gomock.Any()).Return(true)
|
filter.EXPECT().DropIncoming(gomock.Any(), gomock.Any()).Return(true)
|
||||||
|
|
||||||
wrapped := newDeviceFilter(tun)
|
wrapped := newDeviceFilter(tun)
|
||||||
wrapped.filter = filter
|
wrapped.filter = filter
|
||||||
@@ -201,7 +201,7 @@ func TestDeviceWrapperRead(t *testing.T) {
|
|||||||
return 1, nil
|
return 1, nil
|
||||||
})
|
})
|
||||||
filter := mocks.NewMockPacketFilter(ctrl)
|
filter := mocks.NewMockPacketFilter(ctrl)
|
||||||
filter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).Return(true)
|
filter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).Return(true)
|
||||||
|
|
||||||
wrapped := newDeviceFilter(tun)
|
wrapped := newDeviceFilter(tun)
|
||||||
wrapped.filter = filter
|
wrapped.filter = filter
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
|
|||||||
// this helps with support for the older NetBird clients that had a hardcoded direct mode
|
// this helps with support for the older NetBird clients that had a hardcoded direct mode
|
||||||
// t.device.DisableSomeRoamingForBrokenMobileSemantics()
|
// t.device.DisableSomeRoamingForBrokenMobileSemantics()
|
||||||
|
|
||||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
|
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
|
||||||
err = t.configurer.ConfigureInterface(t.key, t.port)
|
err = t.configurer.ConfigureInterface(t.key, t.port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.device.Close()
|
t.device.Close()
|
||||||
|
|||||||
@@ -51,11 +51,7 @@ func (t *TunNetstackDevice) Create() (WGConfigurer, error) {
|
|||||||
log.Info("create nbnetstack tun interface")
|
log.Info("create nbnetstack tun interface")
|
||||||
|
|
||||||
// TODO: get from service listener runtime IP
|
// TODO: get from service listener runtime IP
|
||||||
dnsAddr, err := nbnet.GetLastIPFromNetwork(t.address.Network, 1)
|
dnsAddr := nbnet.GetLastIPFromNetwork(t.address.Network, 1)
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("last ip: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("netstack using address: %s", t.address.IP)
|
log.Debugf("netstack using address: %s", t.address.IP)
|
||||||
t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, t.mtu)
|
t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, t.mtu)
|
||||||
log.Debugf("netstack using dns address: %s", dnsAddr)
|
log.Debugf("netstack using dns address: %s", dnsAddr)
|
||||||
@@ -72,7 +68,7 @@ func (t *TunNetstackDevice) Create() (WGConfigurer, error) {
|
|||||||
device.NewLogger(wgLogLevel(), "[netbird] "),
|
device.NewLogger(wgLogLevel(), "[netbird] "),
|
||||||
)
|
)
|
||||||
|
|
||||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
|
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
|
||||||
err = t.configurer.ConfigureInterface(t.key, t.port)
|
err = t.configurer.ConfigureInterface(t.key, t.port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = tunIface.Close()
|
_ = tunIface.Close()
|
||||||
|
|||||||
@@ -64,7 +64,7 @@ func (t *USPDevice) Create() (WGConfigurer, error) {
|
|||||||
return nil, fmt.Errorf("error assigning ip: %s", err)
|
return nil, fmt.Errorf("error assigning ip: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
|
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
|
||||||
err = t.configurer.ConfigureInterface(t.key, t.port)
|
err = t.configurer.ConfigureInterface(t.key, t.port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.device.Close()
|
t.device.Close()
|
||||||
|
|||||||
@@ -94,7 +94,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
|
|||||||
return nil, fmt.Errorf("error assigning ip: %s", err)
|
return nil, fmt.Errorf("error assigning ip: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
|
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
|
||||||
err = t.configurer.ConfigureInterface(t.key, t.port)
|
err = t.configurer.ConfigureInterface(t.key, t.port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.device.Close()
|
t.device.Close()
|
||||||
|
|||||||
@@ -2,23 +2,19 @@ package device
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
"github.com/netbirdio/netbird/monotime"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type WGConfigurer interface {
|
type WGConfigurer interface {
|
||||||
ConfigureInterface(privateKey string, port int) error
|
ConfigureInterface(privateKey string, port int) error
|
||||||
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||||
RemovePeer(peerKey string) error
|
RemovePeer(peerKey string) error
|
||||||
AddAllowedIP(peerKey string, allowedIP netip.Prefix) error
|
AddAllowedIP(peerKey string, allowedIP string) error
|
||||||
RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error
|
RemoveAllowedIP(peerKey string, allowedIP string) error
|
||||||
Close()
|
Close()
|
||||||
GetStats() (map[string]configurer.WGStats, error)
|
GetStats(peerKey string) (configurer.WGStats, error)
|
||||||
FullStats() (*configurer.Stats, error)
|
|
||||||
LastActivities() map[string]monotime.Time
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -64,15 +64,7 @@ func (l *wgLink) assignAddr(address wgaddr.Address) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ip := address.IP.String()
|
ip := address.IP.String()
|
||||||
|
mask := "0x" + address.Network.Mask.String()
|
||||||
// Convert prefix length to hex netmask
|
|
||||||
prefixLen := address.Network.Bits()
|
|
||||||
if !address.IP.Is4() {
|
|
||||||
return fmt.Errorf("IPv6 not supported for interface assignment")
|
|
||||||
}
|
|
||||||
|
|
||||||
maskBits := uint32(0xffffffff) << (32 - prefixLen)
|
|
||||||
mask := fmt.Sprintf("0x%08x", maskBits)
|
|
||||||
|
|
||||||
log.Infof("assign addr %s mask %s to %s interface", ip, mask, l.name)
|
log.Infof("assign addr %s mask %s to %s interface", ip, mask, l.name)
|
||||||
|
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||||
"github.com/netbirdio/netbird/monotime"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -30,11 +29,6 @@ const (
|
|||||||
WgInterfaceDefault = configurer.WgInterfaceDefault
|
WgInterfaceDefault = configurer.WgInterfaceDefault
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
// ErrIfaceNotFound is returned when the WireGuard interface is not found
|
|
||||||
ErrIfaceNotFound = fmt.Errorf("wireguard interface not found")
|
|
||||||
)
|
|
||||||
|
|
||||||
type wgProxyFactory interface {
|
type wgProxyFactory interface {
|
||||||
GetProxy() wgproxy.Proxy
|
GetProxy() wgproxy.Proxy
|
||||||
Free() error
|
Free() error
|
||||||
@@ -49,7 +43,6 @@ type WGIFaceOpts struct {
|
|||||||
MobileArgs *device.MobileIFaceArguments
|
MobileArgs *device.MobileIFaceArguments
|
||||||
TransportNet transport.Net
|
TransportNet transport.Net
|
||||||
FilterFn bind.FilterFn
|
FilterFn bind.FilterFn
|
||||||
DisableDNS bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// WGIface represents an interface instance
|
// WGIface represents an interface instance
|
||||||
@@ -118,50 +111,38 @@ func (w *WGIface) UpdateAddr(newAddr string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// UpdatePeer updates existing Wireguard Peer or creates a new one if doesn't exist
|
// UpdatePeer updates existing Wireguard Peer or creates a new one if doesn't exist
|
||||||
// Endpoint is optional.
|
// Endpoint is optional
|
||||||
// If allowedIps is given it will be added to the existing ones.
|
|
||||||
func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
||||||
w.mu.Lock()
|
w.mu.Lock()
|
||||||
defer w.mu.Unlock()
|
defer w.mu.Unlock()
|
||||||
if w.configurer == nil {
|
|
||||||
return ErrIfaceNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("updating interface %s peer %s, endpoint %s, allowedIPs %v", w.tun.DeviceName(), peerKey, endpoint, allowedIps)
|
netIPNets := prefixesToIPNets(allowedIps)
|
||||||
return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
|
log.Debugf("updating interface %s peer %s, endpoint %s", w.tun.DeviceName(), peerKey, endpoint)
|
||||||
|
return w.configurer.UpdatePeer(peerKey, netIPNets, keepAlive, endpoint, preSharedKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemovePeer removes a Wireguard Peer from the interface iface
|
// RemovePeer removes a Wireguard Peer from the interface iface
|
||||||
func (w *WGIface) RemovePeer(peerKey string) error {
|
func (w *WGIface) RemovePeer(peerKey string) error {
|
||||||
w.mu.Lock()
|
w.mu.Lock()
|
||||||
defer w.mu.Unlock()
|
defer w.mu.Unlock()
|
||||||
if w.configurer == nil {
|
|
||||||
return ErrIfaceNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("Removing peer %s from interface %s ", peerKey, w.tun.DeviceName())
|
log.Debugf("Removing peer %s from interface %s ", peerKey, w.tun.DeviceName())
|
||||||
return w.configurer.RemovePeer(peerKey)
|
return w.configurer.RemovePeer(peerKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddAllowedIP adds a prefix to the allowed IPs list of peer
|
// AddAllowedIP adds a prefix to the allowed IPs list of peer
|
||||||
func (w *WGIface) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
|
func (w *WGIface) AddAllowedIP(peerKey string, allowedIP string) error {
|
||||||
w.mu.Lock()
|
w.mu.Lock()
|
||||||
defer w.mu.Unlock()
|
defer w.mu.Unlock()
|
||||||
if w.configurer == nil {
|
|
||||||
return ErrIfaceNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("Adding allowed IP to interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP)
|
log.Debugf("Adding allowed IP to interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP)
|
||||||
return w.configurer.AddAllowedIP(peerKey, allowedIP)
|
return w.configurer.AddAllowedIP(peerKey, allowedIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveAllowedIP removes a prefix from the allowed IPs list of peer
|
// RemoveAllowedIP removes a prefix from the allowed IPs list of peer
|
||||||
func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error {
|
func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP string) error {
|
||||||
w.mu.Lock()
|
w.mu.Lock()
|
||||||
defer w.mu.Unlock()
|
defer w.mu.Unlock()
|
||||||
if w.configurer == nil {
|
|
||||||
return ErrIfaceNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("Removing allowed IP from interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP)
|
log.Debugf("Removing allowed IP from interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP)
|
||||||
return w.configurer.RemoveAllowedIP(peerKey, allowedIP)
|
return w.configurer.RemoveAllowedIP(peerKey, allowedIP)
|
||||||
@@ -204,6 +185,7 @@ func (w *WGIface) SetFilter(filter device.PacketFilter) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
w.filter = filter
|
w.filter = filter
|
||||||
|
w.filter.SetNetwork(w.tun.WgAddress().Network)
|
||||||
|
|
||||||
w.tun.FilteredDevice().SetFilter(filter)
|
w.tun.FilteredDevice().SetFilter(filter)
|
||||||
return nil
|
return nil
|
||||||
@@ -230,32 +212,9 @@ func (w *WGIface) GetWGDevice() *wgdevice.Device {
|
|||||||
return w.tun.Device()
|
return w.tun.Device()
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetStats returns the last handshake time, rx and tx bytes
|
// GetStats returns the last handshake time, rx and tx bytes for the given peer
|
||||||
func (w *WGIface) GetStats() (map[string]configurer.WGStats, error) {
|
func (w *WGIface) GetStats(peerKey string) (configurer.WGStats, error) {
|
||||||
if w.configurer == nil {
|
return w.configurer.GetStats(peerKey)
|
||||||
return nil, ErrIfaceNotFound
|
|
||||||
}
|
|
||||||
return w.configurer.GetStats()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *WGIface) LastActivities() map[string]monotime.Time {
|
|
||||||
w.mu.Lock()
|
|
||||||
defer w.mu.Unlock()
|
|
||||||
|
|
||||||
if w.configurer == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return w.configurer.LastActivities()
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *WGIface) FullStats() (*configurer.Stats, error) {
|
|
||||||
if w.configurer == nil {
|
|
||||||
return nil, ErrIfaceNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
return w.configurer.FullStats()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *WGIface) waitUntilRemoved() error {
|
func (w *WGIface) waitUntilRemoved() error {
|
||||||
@@ -292,3 +251,14 @@ func (w *WGIface) GetNet() *netstack.Net {
|
|||||||
|
|
||||||
return w.tun.GetNet()
|
return w.tun.GetNet()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func prefixesToIPNets(prefixes []netip.Prefix) []net.IPNet {
|
||||||
|
ipNets := make([]net.IPNet, len(prefixes))
|
||||||
|
for i, prefix := range prefixes {
|
||||||
|
ipNets[i] = net.IPNet{
|
||||||
|
IP: net.IP(prefix.Addr().AsSlice()), // Convert netip.Addr to net.IP
|
||||||
|
Mask: net.CIDRMask(prefix.Bits(), prefix.Addr().BitLen()), // Create subnet mask
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ipNets
|
||||||
|
}
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
|||||||
|
|
||||||
wgIFace := &WGIface{
|
wgIFace := &WGIface{
|
||||||
userspaceBind: true,
|
userspaceBind: true,
|
||||||
tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter, opts.DisableDNS),
|
tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter),
|
||||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
|
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
|
||||||
}
|
}
|
||||||
return wgIFace, nil
|
return wgIFace, nil
|
||||||
|
|||||||
@@ -5,6 +5,7 @@
|
|||||||
package mocks
|
package mocks
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
net "net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
reflect "reflect"
|
reflect "reflect"
|
||||||
|
|
||||||
@@ -48,32 +49,32 @@ func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
|
||||||
}
|
}
|
||||||
|
|
||||||
// FilterInbound mocks base method.
|
// DropIncoming mocks base method.
|
||||||
func (m *MockPacketFilter) FilterInbound(arg0 []byte, arg1 int) bool {
|
func (m *MockPacketFilter) DropIncoming(arg0 []byte, arg1 int) bool {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "FilterInbound", arg0, arg1)
|
ret := m.ctrl.Call(m, "DropIncoming", arg0, arg1)
|
||||||
ret0, _ := ret[0].(bool)
|
ret0, _ := ret[0].(bool)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// FilterInbound indicates an expected call of FilterInbound.
|
// DropIncoming indicates an expected call of DropIncoming.
|
||||||
func (mr *MockPacketFilterMockRecorder) FilterInbound(arg0 interface{}, arg1 any) *gomock.Call {
|
func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}, arg1 any) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterInbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterInbound), arg0, arg1)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0, arg1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// FilterOutbound mocks base method.
|
// DropOutgoing mocks base method.
|
||||||
func (m *MockPacketFilter) FilterOutbound(arg0 []byte, arg1 int) bool {
|
func (m *MockPacketFilter) DropOutgoing(arg0 []byte, arg1 int) bool {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "FilterOutbound", arg0, arg1)
|
ret := m.ctrl.Call(m, "DropOutgoing", arg0, arg1)
|
||||||
ret0, _ := ret[0].(bool)
|
ret0, _ := ret[0].(bool)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// FilterOutbound indicates an expected call of FilterOutbound.
|
// DropOutgoing indicates an expected call of DropOutgoing.
|
||||||
func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}, arg1 any) *gomock.Call {
|
func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}, arg1 any) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0, arg1)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0, arg1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemovePacketHook mocks base method.
|
// RemovePacketHook mocks base method.
|
||||||
@@ -89,3 +90,15 @@ func (mr *MockPacketFilterMockRecorder) RemovePacketHook(arg0 interface{}) *gomo
|
|||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemovePacketHook", reflect.TypeOf((*MockPacketFilter)(nil).RemovePacketHook), arg0)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemovePacketHook", reflect.TypeOf((*MockPacketFilter)(nil).RemovePacketHook), arg0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetNetwork mocks base method.
|
||||||
|
func (m *MockPacketFilter) SetNetwork(arg0 *net.IPNet) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
m.ctrl.Call(m, "SetNetwork", arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNetwork indicates an expected call of SetNetwork.
|
||||||
|
func (mr *MockPacketFilterMockRecorder) SetNetwork(arg0 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetwork", reflect.TypeOf((*MockPacketFilter)(nil).SetNetwork), arg0)
|
||||||
|
}
|
||||||
|
|||||||
@@ -46,32 +46,32 @@ func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
|
||||||
}
|
}
|
||||||
|
|
||||||
// FilterInbound mocks base method.
|
// DropIncoming mocks base method.
|
||||||
func (m *MockPacketFilter) FilterInbound(arg0 []byte) bool {
|
func (m *MockPacketFilter) DropIncoming(arg0 []byte) bool {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "FilterInbound", arg0)
|
ret := m.ctrl.Call(m, "DropIncoming", arg0)
|
||||||
ret0, _ := ret[0].(bool)
|
ret0, _ := ret[0].(bool)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// FilterInbound indicates an expected call of FilterInbound.
|
// DropIncoming indicates an expected call of DropIncoming.
|
||||||
func (mr *MockPacketFilterMockRecorder) FilterInbound(arg0 interface{}) *gomock.Call {
|
func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterInbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterInbound), arg0)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// FilterOutbound mocks base method.
|
// DropOutgoing mocks base method.
|
||||||
func (m *MockPacketFilter) FilterOutbound(arg0 []byte) bool {
|
func (m *MockPacketFilter) DropOutgoing(arg0 []byte) bool {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "FilterOutbound", arg0)
|
ret := m.ctrl.Call(m, "DropOutgoing", arg0)
|
||||||
ret0, _ := ret[0].(bool)
|
ret0, _ := ret[0].(bool)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// FilterOutbound indicates an expected call of FilterOutbound.
|
// DropOutgoing indicates an expected call of DropOutgoing.
|
||||||
func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}) *gomock.Call {
|
func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetNetwork mocks base method.
|
// SetNetwork mocks base method.
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
package netstack
|
package netstack
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -13,8 +15,8 @@ import (
|
|||||||
const EnvSkipProxy = "NB_NETSTACK_SKIP_PROXY"
|
const EnvSkipProxy = "NB_NETSTACK_SKIP_PROXY"
|
||||||
|
|
||||||
type NetStackTun struct { //nolint:revive
|
type NetStackTun struct { //nolint:revive
|
||||||
address netip.Addr
|
address net.IP
|
||||||
dnsAddress netip.Addr
|
dnsAddress net.IP
|
||||||
mtu int
|
mtu int
|
||||||
listenAddress string
|
listenAddress string
|
||||||
|
|
||||||
@@ -22,7 +24,7 @@ type NetStackTun struct { //nolint:revive
|
|||||||
tundev tun.Device
|
tundev tun.Device
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewNetStackTun(listenAddress string, address netip.Addr, dnsAddress netip.Addr, mtu int) *NetStackTun {
|
func NewNetStackTun(listenAddress string, address net.IP, dnsAddress net.IP, mtu int) *NetStackTun {
|
||||||
return &NetStackTun{
|
return &NetStackTun{
|
||||||
address: address,
|
address: address,
|
||||||
dnsAddress: dnsAddress,
|
dnsAddress: dnsAddress,
|
||||||
@@ -32,21 +34,28 @@ func NewNetStackTun(listenAddress string, address netip.Addr, dnsAddress netip.A
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
|
func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
|
||||||
|
addr, ok := netip.AddrFromSlice(t.address)
|
||||||
|
if !ok {
|
||||||
|
return nil, nil, fmt.Errorf("convert address to netip.Addr: %v", t.address)
|
||||||
|
}
|
||||||
|
|
||||||
|
dnsAddr, ok := netip.AddrFromSlice(t.dnsAddress)
|
||||||
|
if !ok {
|
||||||
|
return nil, nil, fmt.Errorf("convert dns address to netip.Addr: %v", t.dnsAddress)
|
||||||
|
}
|
||||||
|
|
||||||
nsTunDev, tunNet, err := netstack.CreateNetTUN(
|
nsTunDev, tunNet, err := netstack.CreateNetTUN(
|
||||||
[]netip.Addr{t.address},
|
[]netip.Addr{addr.Unmap()},
|
||||||
[]netip.Addr{t.dnsAddress},
|
[]netip.Addr{dnsAddr.Unmap()},
|
||||||
t.mtu)
|
t.mtu)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
t.tundev = nsTunDev
|
t.tundev = nsTunDev
|
||||||
|
|
||||||
var skipProxy bool
|
skipProxy, err := strconv.ParseBool(os.Getenv(EnvSkipProxy))
|
||||||
if val := os.Getenv(EnvSkipProxy); val != "" {
|
if err != nil {
|
||||||
skipProxy, err = strconv.ParseBool(val)
|
log.Errorf("failed to parse %s: %s", EnvSkipProxy, err)
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to parse %s: %s", EnvSkipProxy, err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if skipProxy {
|
if skipProxy {
|
||||||
return nsTunDev, tunNet, nil
|
return nsTunDev, tunNet, nil
|
||||||
|
|||||||
@@ -2,27 +2,28 @@ package wgaddr
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Address WireGuard parsed address
|
// Address WireGuard parsed address
|
||||||
type Address struct {
|
type Address struct {
|
||||||
IP netip.Addr
|
IP net.IP
|
||||||
Network netip.Prefix
|
Network *net.IPNet
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address
|
// ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address
|
||||||
func ParseWGAddress(address string) (Address, error) {
|
func ParseWGAddress(address string) (Address, error) {
|
||||||
prefix, err := netip.ParsePrefix(address)
|
ip, network, err := net.ParseCIDR(address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Address{}, err
|
return Address{}, err
|
||||||
}
|
}
|
||||||
return Address{
|
return Address{
|
||||||
IP: prefix.Addr().Unmap(),
|
IP: ip,
|
||||||
Network: prefix.Masked(),
|
Network: network,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (addr Address) String() string {
|
func (addr Address) String() string {
|
||||||
return fmt.Sprintf("%s/%d", addr.IP.String(), addr.Network.Bits())
|
maskSize, _ := addr.Network.Mask.Size()
|
||||||
|
return fmt.Sprintf("%s/%d", addr.IP.String(), maskSize)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type ProxyBind struct {
|
type ProxyBind struct {
|
||||||
@@ -29,17 +28,6 @@ type ProxyBind struct {
|
|||||||
pausedMu sync.Mutex
|
pausedMu sync.Mutex
|
||||||
paused bool
|
paused bool
|
||||||
isStarted bool
|
isStarted bool
|
||||||
|
|
||||||
closeListener *listener.CloseListener
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewProxyBind(bind *bind.ICEBind) *ProxyBind {
|
|
||||||
p := &ProxyBind{
|
|
||||||
Bind: bind,
|
|
||||||
closeListener: listener.NewCloseListener(),
|
|
||||||
}
|
|
||||||
|
|
||||||
return p
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddTurnConn adds a new connection to the bind.
|
// AddTurnConn adds a new connection to the bind.
|
||||||
@@ -66,10 +54,6 @@ func (p *ProxyBind) EndpointAddr() *net.UDPAddr {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProxyBind) SetDisconnectListener(disconnected func()) {
|
|
||||||
p.closeListener.SetCloseListener(disconnected)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *ProxyBind) Work() {
|
func (p *ProxyBind) Work() {
|
||||||
if p.remoteConn == nil {
|
if p.remoteConn == nil {
|
||||||
return
|
return
|
||||||
@@ -112,9 +96,6 @@ func (p *ProxyBind) close() error {
|
|||||||
if p.closed {
|
if p.closed {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
p.closeListener.SetCloseListener(nil)
|
|
||||||
|
|
||||||
p.closed = true
|
p.closed = true
|
||||||
|
|
||||||
p.cancel()
|
p.cancel()
|
||||||
@@ -141,7 +122,6 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
|
|||||||
if ctx.Err() != nil {
|
if ctx.Err() != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
p.closeListener.Notify()
|
|
||||||
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
|
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,8 +11,6 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
|
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
|
||||||
@@ -28,15 +26,6 @@ type ProxyWrapper struct {
|
|||||||
pausedMu sync.Mutex
|
pausedMu sync.Mutex
|
||||||
paused bool
|
paused bool
|
||||||
isStarted bool
|
isStarted bool
|
||||||
|
|
||||||
closeListener *listener.CloseListener
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewProxyWrapper(WgeBPFProxy *WGEBPFProxy) *ProxyWrapper {
|
|
||||||
return &ProxyWrapper{
|
|
||||||
WgeBPFProxy: WgeBPFProxy,
|
|
||||||
closeListener: listener.NewCloseListener(),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
|
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
|
||||||
@@ -54,10 +43,6 @@ func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
|
|||||||
return p.wgEndpointAddr
|
return p.wgEndpointAddr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProxyWrapper) SetDisconnectListener(disconnected func()) {
|
|
||||||
p.closeListener.SetCloseListener(disconnected)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *ProxyWrapper) Work() {
|
func (p *ProxyWrapper) Work() {
|
||||||
if p.remoteConn == nil {
|
if p.remoteConn == nil {
|
||||||
return
|
return
|
||||||
@@ -92,8 +77,6 @@ func (e *ProxyWrapper) CloseConn() error {
|
|||||||
|
|
||||||
e.cancel()
|
e.cancel()
|
||||||
|
|
||||||
e.closeListener.SetCloseListener(nil)
|
|
||||||
|
|
||||||
if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||||
return fmt.Errorf("failed to close remote conn: %w", err)
|
return fmt.Errorf("failed to close remote conn: %w", err)
|
||||||
}
|
}
|
||||||
@@ -134,7 +117,6 @@ func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, err
|
|||||||
if ctx.Err() != nil {
|
if ctx.Err() != nil {
|
||||||
return 0, ctx.Err()
|
return 0, ctx.Err()
|
||||||
}
|
}
|
||||||
p.closeListener.Notify()
|
|
||||||
if !errors.Is(err, io.EOF) {
|
if !errors.Is(err, io.EOF) {
|
||||||
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err)
|
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -36,8 +36,9 @@ func (w *KernelFactory) GetProxy() Proxy {
|
|||||||
return udpProxy.NewWGUDPProxy(w.wgPort)
|
return udpProxy.NewWGUDPProxy(w.wgPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
return ebpf.NewProxyWrapper(w.ebpfProxy)
|
return &ebpf.ProxyWrapper{
|
||||||
|
WgeBPFProxy: w.ebpfProxy,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *KernelFactory) Free() error {
|
func (w *KernelFactory) Free() error {
|
||||||
|
|||||||
@@ -20,7 +20,9 @@ func NewUSPFactory(iceBind *bind.ICEBind) *USPFactory {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (w *USPFactory) GetProxy() Proxy {
|
func (w *USPFactory) GetProxy() Proxy {
|
||||||
return proxyBind.NewProxyBind(w.bind)
|
return &proxyBind.ProxyBind{
|
||||||
|
Bind: w.bind,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *USPFactory) Free() error {
|
func (w *USPFactory) Free() error {
|
||||||
|
|||||||
@@ -1,19 +0,0 @@
|
|||||||
package listener
|
|
||||||
|
|
||||||
type CloseListener struct {
|
|
||||||
listener func()
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewCloseListener() *CloseListener {
|
|
||||||
return &CloseListener{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *CloseListener) SetCloseListener(listener func()) {
|
|
||||||
c.listener = listener
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *CloseListener) Notify() {
|
|
||||||
if c.listener != nil {
|
|
||||||
c.listener()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -12,5 +12,4 @@ type Proxy interface {
|
|||||||
Work() // Work start or resume the proxy
|
Work() // Work start or resume the proxy
|
||||||
Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works.
|
Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works.
|
||||||
CloseConn() error
|
CloseConn() error
|
||||||
SetDisconnectListener(disconnected func())
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -98,7 +98,9 @@ func TestProxyCloseByRemoteConn(t *testing.T) {
|
|||||||
t.Errorf("failed to free ebpf proxy: %s", err)
|
t.Errorf("failed to free ebpf proxy: %s", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
proxyWrapper := ebpf.NewProxyWrapper(ebpfProxy)
|
proxyWrapper := &ebpf.ProxyWrapper{
|
||||||
|
WgeBPFProxy: ebpfProxy,
|
||||||
|
}
|
||||||
|
|
||||||
tests = append(tests, struct {
|
tests = append(tests, struct {
|
||||||
name string
|
name string
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
cerrors "github.com/netbirdio/netbird/client/errors"
|
cerrors "github.com/netbirdio/netbird/client/errors"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// WGUDPProxy proxies
|
// WGUDPProxy proxies
|
||||||
@@ -29,8 +28,6 @@ type WGUDPProxy struct {
|
|||||||
pausedMu sync.Mutex
|
pausedMu sync.Mutex
|
||||||
paused bool
|
paused bool
|
||||||
isStarted bool
|
isStarted bool
|
||||||
|
|
||||||
closeListener *listener.CloseListener
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewWGUDPProxy instantiate a UDP based WireGuard proxy. This is not a thread safe implementation
|
// NewWGUDPProxy instantiate a UDP based WireGuard proxy. This is not a thread safe implementation
|
||||||
@@ -38,7 +35,6 @@ func NewWGUDPProxy(wgPort int) *WGUDPProxy {
|
|||||||
log.Debugf("Initializing new user space proxy with port %d", wgPort)
|
log.Debugf("Initializing new user space proxy with port %d", wgPort)
|
||||||
p := &WGUDPProxy{
|
p := &WGUDPProxy{
|
||||||
localWGListenPort: wgPort,
|
localWGListenPort: wgPort,
|
||||||
closeListener: listener.NewCloseListener(),
|
|
||||||
}
|
}
|
||||||
return p
|
return p
|
||||||
}
|
}
|
||||||
@@ -71,10 +67,6 @@ func (p *WGUDPProxy) EndpointAddr() *net.UDPAddr {
|
|||||||
return endpointUdpAddr
|
return endpointUdpAddr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *WGUDPProxy) SetDisconnectListener(disconnected func()) {
|
|
||||||
p.closeListener.SetCloseListener(disconnected)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Work starts the proxy or resumes it if it was paused
|
// Work starts the proxy or resumes it if it was paused
|
||||||
func (p *WGUDPProxy) Work() {
|
func (p *WGUDPProxy) Work() {
|
||||||
if p.remoteConn == nil {
|
if p.remoteConn == nil {
|
||||||
@@ -119,8 +111,6 @@ func (p *WGUDPProxy) close() error {
|
|||||||
if p.closed {
|
if p.closed {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
p.closeListener.SetCloseListener(nil)
|
|
||||||
p.closed = true
|
p.closed = true
|
||||||
|
|
||||||
p.cancel()
|
p.cancel()
|
||||||
@@ -151,7 +141,6 @@ func (p *WGUDPProxy) proxyToRemote(ctx context.Context) {
|
|||||||
if ctx.Err() != nil {
|
if ctx.Err() != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
p.closeListener.Notify()
|
|
||||||
log.Debugf("failed to read from wg interface conn: %s", err)
|
log.Debugf("failed to read from wg interface conn: %s", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,8 +24,6 @@
|
|||||||
|
|
||||||
!define AUTOSTART_REG_KEY "Software\Microsoft\Windows\CurrentVersion\Run"
|
!define AUTOSTART_REG_KEY "Software\Microsoft\Windows\CurrentVersion\Run"
|
||||||
|
|
||||||
!define NETBIRD_DATA_DIR "$COMMONPROGRAMDATA\Netbird"
|
|
||||||
|
|
||||||
Unicode True
|
Unicode True
|
||||||
|
|
||||||
######################################################################
|
######################################################################
|
||||||
@@ -51,10 +49,6 @@ ShowInstDetails Show
|
|||||||
|
|
||||||
######################################################################
|
######################################################################
|
||||||
|
|
||||||
!include "MUI2.nsh"
|
|
||||||
!include LogicLib.nsh
|
|
||||||
!include "nsDialogs.nsh"
|
|
||||||
|
|
||||||
!define MUI_ICON "${ICON}"
|
!define MUI_ICON "${ICON}"
|
||||||
!define MUI_UNICON "${ICON}"
|
!define MUI_UNICON "${ICON}"
|
||||||
!define MUI_WELCOMEFINISHPAGE_BITMAP "${BANNER}"
|
!define MUI_WELCOMEFINISHPAGE_BITMAP "${BANNER}"
|
||||||
@@ -64,6 +58,9 @@ ShowInstDetails Show
|
|||||||
!define MUI_FINISHPAGE_RUN_FUNCTION "LaunchLink"
|
!define MUI_FINISHPAGE_RUN_FUNCTION "LaunchLink"
|
||||||
######################################################################
|
######################################################################
|
||||||
|
|
||||||
|
!include "MUI2.nsh"
|
||||||
|
!include LogicLib.nsh
|
||||||
|
|
||||||
!define MUI_ABORTWARNING
|
!define MUI_ABORTWARNING
|
||||||
!define MUI_UNABORTWARNING
|
!define MUI_UNABORTWARNING
|
||||||
|
|
||||||
@@ -73,16 +70,13 @@ ShowInstDetails Show
|
|||||||
|
|
||||||
!insertmacro MUI_PAGE_DIRECTORY
|
!insertmacro MUI_PAGE_DIRECTORY
|
||||||
|
|
||||||
|
; Custom page for autostart checkbox
|
||||||
Page custom AutostartPage AutostartPageLeave
|
Page custom AutostartPage AutostartPageLeave
|
||||||
|
|
||||||
!insertmacro MUI_PAGE_INSTFILES
|
!insertmacro MUI_PAGE_INSTFILES
|
||||||
|
|
||||||
!insertmacro MUI_PAGE_FINISH
|
!insertmacro MUI_PAGE_FINISH
|
||||||
|
|
||||||
!insertmacro MUI_UNPAGE_WELCOME
|
|
||||||
|
|
||||||
UninstPage custom un.DeleteDataPage un.DeleteDataPageLeave
|
|
||||||
|
|
||||||
!insertmacro MUI_UNPAGE_CONFIRM
|
!insertmacro MUI_UNPAGE_CONFIRM
|
||||||
|
|
||||||
!insertmacro MUI_UNPAGE_INSTFILES
|
!insertmacro MUI_UNPAGE_INSTFILES
|
||||||
@@ -95,10 +89,6 @@ UninstPage custom un.DeleteDataPage un.DeleteDataPageLeave
|
|||||||
Var AutostartCheckbox
|
Var AutostartCheckbox
|
||||||
Var AutostartEnabled
|
Var AutostartEnabled
|
||||||
|
|
||||||
; Variables for uninstall data deletion option
|
|
||||||
Var DeleteDataCheckbox
|
|
||||||
Var DeleteDataEnabled
|
|
||||||
|
|
||||||
######################################################################
|
######################################################################
|
||||||
|
|
||||||
; Function to create the autostart options page
|
; Function to create the autostart options page
|
||||||
@@ -114,8 +104,8 @@ Function AutostartPage
|
|||||||
|
|
||||||
${NSD_CreateCheckbox} 0 20u 100% 10u "Start ${APP_NAME} UI automatically when Windows starts"
|
${NSD_CreateCheckbox} 0 20u 100% 10u "Start ${APP_NAME} UI automatically when Windows starts"
|
||||||
Pop $AutostartCheckbox
|
Pop $AutostartCheckbox
|
||||||
${NSD_Check} $AutostartCheckbox
|
${NSD_Check} $AutostartCheckbox ; Default to checked
|
||||||
StrCpy $AutostartEnabled "1"
|
StrCpy $AutostartEnabled "1" ; Default to enabled
|
||||||
|
|
||||||
nsDialogs::Show
|
nsDialogs::Show
|
||||||
FunctionEnd
|
FunctionEnd
|
||||||
@@ -125,30 +115,6 @@ Function AutostartPageLeave
|
|||||||
${NSD_GetState} $AutostartCheckbox $AutostartEnabled
|
${NSD_GetState} $AutostartCheckbox $AutostartEnabled
|
||||||
FunctionEnd
|
FunctionEnd
|
||||||
|
|
||||||
; Function to create the uninstall data deletion page
|
|
||||||
Function un.DeleteDataPage
|
|
||||||
!insertmacro MUI_HEADER_TEXT "Uninstall Options" "Choose whether to delete ${APP_NAME} data."
|
|
||||||
|
|
||||||
nsDialogs::Create 1018
|
|
||||||
Pop $0
|
|
||||||
|
|
||||||
${If} $0 == error
|
|
||||||
Abort
|
|
||||||
${EndIf}
|
|
||||||
|
|
||||||
${NSD_CreateCheckbox} 0 20u 100% 10u "Delete all ${APP_NAME} configuration and state data (${NETBIRD_DATA_DIR})"
|
|
||||||
Pop $DeleteDataCheckbox
|
|
||||||
${NSD_Uncheck} $DeleteDataCheckbox
|
|
||||||
StrCpy $DeleteDataEnabled "0"
|
|
||||||
|
|
||||||
nsDialogs::Show
|
|
||||||
FunctionEnd
|
|
||||||
|
|
||||||
; Function to handle leaving the data deletion page
|
|
||||||
Function un.DeleteDataPageLeave
|
|
||||||
${NSD_GetState} $DeleteDataCheckbox $DeleteDataEnabled
|
|
||||||
FunctionEnd
|
|
||||||
|
|
||||||
Function GetAppFromCommand
|
Function GetAppFromCommand
|
||||||
Exch $1
|
Exch $1
|
||||||
Push $2
|
Push $2
|
||||||
@@ -210,10 +176,10 @@ ${EndIf}
|
|||||||
FunctionEnd
|
FunctionEnd
|
||||||
######################################################################
|
######################################################################
|
||||||
Section -MainProgram
|
Section -MainProgram
|
||||||
${INSTALL_TYPE}
|
${INSTALL_TYPE}
|
||||||
# SetOverwrite ifnewer
|
# SetOverwrite ifnewer
|
||||||
SetOutPath "$INSTDIR"
|
SetOutPath "$INSTDIR"
|
||||||
File /r "..\\dist\\netbird_windows_amd64\\"
|
File /r "..\\dist\\netbird_windows_amd64\\"
|
||||||
SectionEnd
|
SectionEnd
|
||||||
######################################################################
|
######################################################################
|
||||||
|
|
||||||
@@ -259,58 +225,31 @@ SectionEnd
|
|||||||
Section Uninstall
|
Section Uninstall
|
||||||
${INSTALL_TYPE}
|
${INSTALL_TYPE}
|
||||||
|
|
||||||
DetailPrint "Stopping Netbird service..."
|
|
||||||
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service stop'
|
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service stop'
|
||||||
DetailPrint "Uninstalling Netbird service..."
|
|
||||||
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall'
|
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall'
|
||||||
|
|
||||||
DetailPrint "Terminating Netbird UI process..."
|
# kill ui client
|
||||||
ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
|
ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
|
||||||
|
|
||||||
; Remove autostart registry entry
|
; Remove autostart registry entry
|
||||||
DetailPrint "Removing autostart registry entry if exists..."
|
|
||||||
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||||
|
|
||||||
; Handle data deletion based on checkbox
|
|
||||||
DetailPrint "Checking if user requested data deletion..."
|
|
||||||
${If} $DeleteDataEnabled == "1"
|
|
||||||
DetailPrint "User opted to delete Netbird data. Removing ${NETBIRD_DATA_DIR}..."
|
|
||||||
ClearErrors
|
|
||||||
RMDir /r "${NETBIRD_DATA_DIR}"
|
|
||||||
IfErrors 0 +2 ; If no errors, jump over the message
|
|
||||||
DetailPrint "Error deleting Netbird data directory. It might be in use or already removed."
|
|
||||||
DetailPrint "Netbird data directory removal complete."
|
|
||||||
${Else}
|
|
||||||
DetailPrint "User did not opt to delete Netbird data."
|
|
||||||
${EndIf}
|
|
||||||
|
|
||||||
# wait the service uninstall take unblock the executable
|
# wait the service uninstall take unblock the executable
|
||||||
DetailPrint "Waiting for service handle to be released..."
|
|
||||||
Sleep 3000
|
Sleep 3000
|
||||||
|
|
||||||
DetailPrint "Deleting application files..."
|
|
||||||
Delete "$INSTDIR\${UI_APP_EXE}"
|
Delete "$INSTDIR\${UI_APP_EXE}"
|
||||||
Delete "$INSTDIR\${MAIN_APP_EXE}"
|
Delete "$INSTDIR\${MAIN_APP_EXE}"
|
||||||
Delete "$INSTDIR\wintun.dll"
|
Delete "$INSTDIR\wintun.dll"
|
||||||
Delete "$INSTDIR\opengl32.dll"
|
Delete "$INSTDIR\opengl32.dll"
|
||||||
DetailPrint "Removing application directory..."
|
|
||||||
RmDir /r "$INSTDIR"
|
RmDir /r "$INSTDIR"
|
||||||
|
|
||||||
DetailPrint "Removing shortcuts..."
|
|
||||||
SetShellVarContext all
|
SetShellVarContext all
|
||||||
Delete "$DESKTOP\${APP_NAME}.lnk"
|
Delete "$DESKTOP\${APP_NAME}.lnk"
|
||||||
Delete "$SMPROGRAMS\${APP_NAME}.lnk"
|
Delete "$SMPROGRAMS\${APP_NAME}.lnk"
|
||||||
|
|
||||||
DetailPrint "Removing registry keys..."
|
|
||||||
DeleteRegKey ${REG_ROOT} "${REG_APP_PATH}"
|
DeleteRegKey ${REG_ROOT} "${REG_APP_PATH}"
|
||||||
DeleteRegKey ${REG_ROOT} "${UNINSTALL_PATH}"
|
DeleteRegKey ${REG_ROOT} "${UNINSTALL_PATH}"
|
||||||
DeleteRegKey ${REG_ROOT} "${UI_REG_APP_PATH}"
|
|
||||||
|
|
||||||
DetailPrint "Removing application directory from PATH..."
|
|
||||||
EnVar::SetHKLM
|
EnVar::SetHKLM
|
||||||
EnVar::DeleteValue "path" "$INSTDIR"
|
EnVar::DeleteValue "path" "$INSTDIR"
|
||||||
|
|
||||||
DetailPrint "Uninstallation finished."
|
|
||||||
SectionEnd
|
SectionEnd
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -58,11 +58,6 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
|
|||||||
d.mutex.Lock()
|
d.mutex.Lock()
|
||||||
defer d.mutex.Unlock()
|
defer d.mutex.Unlock()
|
||||||
|
|
||||||
if d.firewall == nil {
|
|
||||||
log.Debug("firewall manager is not supported, skipping firewall rules")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
defer func() {
|
defer func() {
|
||||||
total := 0
|
total := 0
|
||||||
@@ -74,8 +69,20 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
|
|||||||
time.Since(start), total)
|
time.Since(start), total)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
if d.firewall == nil {
|
||||||
|
log.Debug("firewall manager is not supported, skipping firewall rules")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
d.applyPeerACLs(networkMap)
|
d.applyPeerACLs(networkMap)
|
||||||
|
|
||||||
|
// If we got empty rules list but management did not set the networkMap.FirewallRulesIsEmpty flag,
|
||||||
|
// then the mgmt server is older than the client, and we need to allow all traffic for routes
|
||||||
|
isLegacy := len(networkMap.RoutesFirewallRules) == 0 && !networkMap.RoutesFirewallRulesIsEmpty
|
||||||
|
if err := d.firewall.SetLegacyManagement(isLegacy); err != nil {
|
||||||
|
log.Errorf("failed to set legacy management flag: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag); err != nil {
|
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag); err != nil {
|
||||||
log.Errorf("Failed to apply route ACLs: %v", err)
|
log.Errorf("Failed to apply route ACLs: %v", err)
|
||||||
}
|
}
|
||||||
@@ -284,10 +291,8 @@ func (d *DefaultManager) protoRuleToFirewallRule(
|
|||||||
case mgmProto.RuleDirection_IN:
|
case mgmProto.RuleDirection_IN:
|
||||||
rules, err = d.addInRules(r.PolicyID, ip, protocol, port, action, ipsetName)
|
rules, err = d.addInRules(r.PolicyID, ip, protocol, port, action, ipsetName)
|
||||||
case mgmProto.RuleDirection_OUT:
|
case mgmProto.RuleDirection_OUT:
|
||||||
if d.firewall.IsStateful() {
|
// TODO: Remove this soon. Outbound rules are obsolete.
|
||||||
return "", nil, nil
|
// We only maintain this for return traffic (inbound dir) which is now handled by the stateful firewall already
|
||||||
}
|
|
||||||
// return traffic for outbound connections if firewall is stateless
|
|
||||||
rules, err = d.addOutRules(r.PolicyID, ip, protocol, port, action, ipsetName)
|
rules, err = d.addOutRules(r.PolicyID, ip, protocol, port, action, ipsetName)
|
||||||
default:
|
default:
|
||||||
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
||||||
@@ -398,15 +403,11 @@ func (d *DefaultManager) squashAcceptRules(
|
|||||||
//
|
//
|
||||||
// We zeroed this to notify squash function that this protocol can't be squashed.
|
// We zeroed this to notify squash function that this protocol can't be squashed.
|
||||||
addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols map[mgmProto.RuleProtocol]*protoMatch) {
|
addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols map[mgmProto.RuleProtocol]*protoMatch) {
|
||||||
hasPortRestrictions := r.Action == mgmProto.RuleAction_DROP ||
|
drop := r.Action == mgmProto.RuleAction_DROP || r.Port != ""
|
||||||
r.Port != "" || !portInfoEmpty(r.PortInfo)
|
if drop {
|
||||||
|
|
||||||
if hasPortRestrictions {
|
|
||||||
// Don't squash rules with port restrictions
|
|
||||||
protocols[r.Protocol] = &protoMatch{ips: map[string]int{}}
|
protocols[r.Protocol] = &protoMatch{ips: map[string]int{}}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := protocols[r.Protocol]; !ok {
|
if _, ok := protocols[r.Protocol]; !ok {
|
||||||
protocols[r.Protocol] = &protoMatch{
|
protocols[r.Protocol] = &protoMatch{
|
||||||
ips: map[string]int{},
|
ips: map[string]int{},
|
||||||
|
|||||||
@@ -1,14 +1,13 @@
|
|||||||
package acl
|
package acl
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/netip"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
"github.com/golang/mock/gomock"
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall"
|
"github.com/netbirdio/netbird/client/firewall"
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
||||||
"github.com/netbirdio/netbird/client/internal/netflow"
|
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||||
@@ -43,31 +42,35 @@ func TestDefaultManager(t *testing.T) {
|
|||||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||||
network := netip.MustParsePrefix("172.0.0.1/32")
|
ip, network, err := net.ParseCIDR("172.0.0.1/32")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to parse IP address: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||||
IP: network.Addr(),
|
IP: ip,
|
||||||
Network: network,
|
Network: network,
|
||||||
}).AnyTimes()
|
}).AnyTimes()
|
||||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||||
|
|
||||||
|
// we receive one rule from the management so for testing purposes ignore it
|
||||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
|
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
|
||||||
require.NoError(t, err)
|
if err != nil {
|
||||||
defer func() {
|
t.Errorf("create firewall: %v", err)
|
||||||
err = fw.Close(nil)
|
return
|
||||||
require.NoError(t, err)
|
}
|
||||||
}()
|
defer func(fw manager.Manager) {
|
||||||
|
_ = fw.Close(nil)
|
||||||
|
}(fw)
|
||||||
acl := NewDefaultManager(fw)
|
acl := NewDefaultManager(fw)
|
||||||
|
|
||||||
t.Run("apply firewall rules", func(t *testing.T) {
|
t.Run("apply firewall rules", func(t *testing.T) {
|
||||||
acl.ApplyFiltering(networkMap, false)
|
acl.ApplyFiltering(networkMap, false)
|
||||||
|
|
||||||
if fw.IsStateful() {
|
if len(acl.peerRulesPairs) != 2 {
|
||||||
assert.Equal(t, 0, len(acl.peerRulesPairs))
|
t.Errorf("firewall rules not applied: %v", acl.peerRulesPairs)
|
||||||
} else {
|
return
|
||||||
assert.Equal(t, 2, len(acl.peerRulesPairs))
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -91,13 +94,12 @@ func TestDefaultManager(t *testing.T) {
|
|||||||
|
|
||||||
acl.ApplyFiltering(networkMap, false)
|
acl.ApplyFiltering(networkMap, false)
|
||||||
|
|
||||||
expectedRules := 2
|
// we should have one old and one new rule in the existed rules
|
||||||
if fw.IsStateful() {
|
if len(acl.peerRulesPairs) != 2 {
|
||||||
expectedRules = 1 // only the inbound rule
|
t.Errorf("firewall rules not applied")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.Equal(t, expectedRules, len(acl.peerRulesPairs))
|
|
||||||
|
|
||||||
// check that old rule was removed
|
// check that old rule was removed
|
||||||
previousCount := 0
|
previousCount := 0
|
||||||
for id := range acl.peerRulesPairs {
|
for id := range acl.peerRulesPairs {
|
||||||
@@ -105,86 +107,26 @@ func TestDefaultManager(t *testing.T) {
|
|||||||
previousCount++
|
previousCount++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if previousCount != 1 {
|
||||||
expectedPreviousCount := 0
|
t.Errorf("old rule was not removed")
|
||||||
if !fw.IsStateful() {
|
|
||||||
expectedPreviousCount = 1
|
|
||||||
}
|
}
|
||||||
assert.Equal(t, expectedPreviousCount, previousCount)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("handle default rules", func(t *testing.T) {
|
t.Run("handle default rules", func(t *testing.T) {
|
||||||
networkMap.FirewallRules = networkMap.FirewallRules[:0]
|
networkMap.FirewallRules = networkMap.FirewallRules[:0]
|
||||||
|
|
||||||
networkMap.FirewallRulesIsEmpty = true
|
networkMap.FirewallRulesIsEmpty = true
|
||||||
acl.ApplyFiltering(networkMap, false)
|
if acl.ApplyFiltering(networkMap, false); len(acl.peerRulesPairs) != 0 {
|
||||||
assert.Equal(t, 0, len(acl.peerRulesPairs))
|
t.Errorf("rules should be empty if FirewallRulesIsEmpty is set, got: %v", len(acl.peerRulesPairs))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
networkMap.FirewallRulesIsEmpty = false
|
networkMap.FirewallRulesIsEmpty = false
|
||||||
acl.ApplyFiltering(networkMap, false)
|
acl.ApplyFiltering(networkMap, false)
|
||||||
|
if len(acl.peerRulesPairs) != 1 {
|
||||||
expectedRules := 1
|
t.Errorf("rules should contain 1 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs))
|
||||||
if fw.IsStateful() {
|
return
|
||||||
expectedRules = 1 // only inbound allow-all rule
|
|
||||||
}
|
}
|
||||||
assert.Equal(t, expectedRules, len(acl.peerRulesPairs))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDefaultManagerStateless(t *testing.T) {
|
|
||||||
// stateless currently only in userspace, so we have to disable kernel
|
|
||||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
|
||||||
t.Setenv("NB_DISABLE_CONNTRACK", "true")
|
|
||||||
|
|
||||||
networkMap := &mgmProto.NetworkMap{
|
|
||||||
FirewallRules: []*mgmProto.FirewallRule{
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.1",
|
|
||||||
Direction: mgmProto.RuleDirection_OUT,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
Port: "80",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.2",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_UDP,
|
|
||||||
Port: "53",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
ctrl := gomock.NewController(t)
|
|
||||||
defer ctrl.Finish()
|
|
||||||
|
|
||||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
|
||||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
|
||||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
|
||||||
network := netip.MustParsePrefix("172.0.0.1/32")
|
|
||||||
|
|
||||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
|
||||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
|
||||||
IP: network.Addr(),
|
|
||||||
Network: network,
|
|
||||||
}).AnyTimes()
|
|
||||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
|
||||||
|
|
||||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer func() {
|
|
||||||
err = fw.Close(nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
}()
|
|
||||||
|
|
||||||
acl := NewDefaultManager(fw)
|
|
||||||
|
|
||||||
t.Run("stateless firewall creates outbound rules", func(t *testing.T) {
|
|
||||||
acl.ApplyFiltering(networkMap, false)
|
|
||||||
|
|
||||||
// In stateless mode, we should have both inbound and outbound rules
|
|
||||||
assert.False(t, fw.IsStateful())
|
|
||||||
assert.Equal(t, 2, len(acl.peerRulesPairs))
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -250,19 +192,42 @@ func TestDefaultManagerSquashRules(t *testing.T) {
|
|||||||
|
|
||||||
manager := &DefaultManager{}
|
manager := &DefaultManager{}
|
||||||
rules, _ := manager.squashAcceptRules(networkMap)
|
rules, _ := manager.squashAcceptRules(networkMap)
|
||||||
assert.Equal(t, 2, len(rules))
|
if len(rules) != 2 {
|
||||||
|
t.Errorf("rules should contain 2, got: %v", rules)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
r := rules[0]
|
r := rules[0]
|
||||||
assert.Equal(t, "0.0.0.0", r.PeerIP)
|
switch {
|
||||||
assert.Equal(t, mgmProto.RuleDirection_IN, r.Direction)
|
case r.PeerIP != "0.0.0.0":
|
||||||
assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol)
|
t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP)
|
||||||
assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action)
|
return
|
||||||
|
case r.Direction != mgmProto.RuleDirection_IN:
|
||||||
|
t.Errorf("direction should be IN, got: %v", r.Direction)
|
||||||
|
return
|
||||||
|
case r.Protocol != mgmProto.RuleProtocol_ALL:
|
||||||
|
t.Errorf("protocol should be ALL, got: %v", r.Protocol)
|
||||||
|
return
|
||||||
|
case r.Action != mgmProto.RuleAction_ACCEPT:
|
||||||
|
t.Errorf("action should be ACCEPT, got: %v", r.Action)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
r = rules[1]
|
r = rules[1]
|
||||||
assert.Equal(t, "0.0.0.0", r.PeerIP)
|
switch {
|
||||||
assert.Equal(t, mgmProto.RuleDirection_OUT, r.Direction)
|
case r.PeerIP != "0.0.0.0":
|
||||||
assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol)
|
t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP)
|
||||||
assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action)
|
return
|
||||||
|
case r.Direction != mgmProto.RuleDirection_OUT:
|
||||||
|
t.Errorf("direction should be OUT, got: %v", r.Direction)
|
||||||
|
return
|
||||||
|
case r.Protocol != mgmProto.RuleProtocol_ALL:
|
||||||
|
t.Errorf("protocol should be ALL, got: %v", r.Protocol)
|
||||||
|
return
|
||||||
|
case r.Action != mgmProto.RuleAction_ACCEPT:
|
||||||
|
t.Errorf("action should be ACCEPT, got: %v", r.Action)
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
|
func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
|
||||||
@@ -326,435 +291,8 @@ func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
manager := &DefaultManager{}
|
manager := &DefaultManager{}
|
||||||
rules, _ := manager.squashAcceptRules(networkMap)
|
if rules, _ := manager.squashAcceptRules(networkMap); len(rules) != len(networkMap.FirewallRules) {
|
||||||
assert.Equal(t, len(networkMap.FirewallRules), len(rules))
|
t.Errorf("we should get the same amount of rules as output, got %v", len(rules))
|
||||||
}
|
|
||||||
|
|
||||||
func TestDefaultManagerSquashRulesWithPortRestrictions(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
rules []*mgmProto.FirewallRule
|
|
||||||
expectedCount int
|
|
||||||
description string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "should not squash rules with port ranges",
|
|
||||||
rules: []*mgmProto.FirewallRule{
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.1",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
PortInfo: &mgmProto.PortInfo{
|
|
||||||
PortSelection: &mgmProto.PortInfo_Range_{
|
|
||||||
Range: &mgmProto.PortInfo_Range{
|
|
||||||
Start: 8080,
|
|
||||||
End: 8090,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.2",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
PortInfo: &mgmProto.PortInfo{
|
|
||||||
PortSelection: &mgmProto.PortInfo_Range_{
|
|
||||||
Range: &mgmProto.PortInfo_Range{
|
|
||||||
Start: 8080,
|
|
||||||
End: 8090,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.3",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
PortInfo: &mgmProto.PortInfo{
|
|
||||||
PortSelection: &mgmProto.PortInfo_Range_{
|
|
||||||
Range: &mgmProto.PortInfo_Range{
|
|
||||||
Start: 8080,
|
|
||||||
End: 8090,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.4",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
PortInfo: &mgmProto.PortInfo{
|
|
||||||
PortSelection: &mgmProto.PortInfo_Range_{
|
|
||||||
Range: &mgmProto.PortInfo_Range{
|
|
||||||
Start: 8080,
|
|
||||||
End: 8090,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expectedCount: 4,
|
|
||||||
description: "Rules with port ranges should not be squashed even if they cover all peers",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "should not squash rules with specific ports",
|
|
||||||
rules: []*mgmProto.FirewallRule{
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.1",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
PortInfo: &mgmProto.PortInfo{
|
|
||||||
PortSelection: &mgmProto.PortInfo_Port{
|
|
||||||
Port: 80,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.2",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
PortInfo: &mgmProto.PortInfo{
|
|
||||||
PortSelection: &mgmProto.PortInfo_Port{
|
|
||||||
Port: 80,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.3",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
PortInfo: &mgmProto.PortInfo{
|
|
||||||
PortSelection: &mgmProto.PortInfo_Port{
|
|
||||||
Port: 80,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.4",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
PortInfo: &mgmProto.PortInfo{
|
|
||||||
PortSelection: &mgmProto.PortInfo_Port{
|
|
||||||
Port: 80,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expectedCount: 4,
|
|
||||||
description: "Rules with specific ports should not be squashed even if they cover all peers",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "should not squash rules with legacy port field",
|
|
||||||
rules: []*mgmProto.FirewallRule{
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.1",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
Port: "443",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.2",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
Port: "443",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.3",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
Port: "443",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.4",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
Port: "443",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expectedCount: 4,
|
|
||||||
description: "Rules with legacy port field should not be squashed",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "should not squash rules with DROP action",
|
|
||||||
rules: []*mgmProto.FirewallRule{
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.1",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_DROP,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.2",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_DROP,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.3",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_DROP,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.4",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_DROP,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expectedCount: 4,
|
|
||||||
description: "Rules with DROP action should not be squashed",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "should squash rules without port restrictions",
|
|
||||||
rules: []*mgmProto.FirewallRule{
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.1",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.2",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.3",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.4",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expectedCount: 1,
|
|
||||||
description: "Rules without port restrictions should be squashed into a single 0.0.0.0 rule",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "mixed rules should not squash protocol with port restrictions",
|
|
||||||
rules: []*mgmProto.FirewallRule{
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.1",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.2",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
PortInfo: &mgmProto.PortInfo{
|
|
||||||
PortSelection: &mgmProto.PortInfo_Port{
|
|
||||||
Port: 80,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.3",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.4",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expectedCount: 4,
|
|
||||||
description: "TCP should not be squashed because one rule has port restrictions",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "should squash UDP but not TCP when TCP has port restrictions",
|
|
||||||
rules: []*mgmProto.FirewallRule{
|
|
||||||
// TCP rules with port restrictions - should NOT be squashed
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.1",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
Port: "443",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.2",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
Port: "443",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.3",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
Port: "443",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.4",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
Port: "443",
|
|
||||||
},
|
|
||||||
// UDP rules without port restrictions - SHOULD be squashed
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.1",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_UDP,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.2",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_UDP,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.3",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_UDP,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.4",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_UDP,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expectedCount: 5, // 4 TCP rules + 1 squashed UDP rule (0.0.0.0)
|
|
||||||
description: "UDP should be squashed to 0.0.0.0 rule, but TCP should remain as individual rules due to port restrictions",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
networkMap := &mgmProto.NetworkMap{
|
|
||||||
RemotePeers: []*mgmProto.RemotePeerConfig{
|
|
||||||
{AllowedIps: []string{"10.93.0.1"}},
|
|
||||||
{AllowedIps: []string{"10.93.0.2"}},
|
|
||||||
{AllowedIps: []string{"10.93.0.3"}},
|
|
||||||
{AllowedIps: []string{"10.93.0.4"}},
|
|
||||||
},
|
|
||||||
FirewallRules: tt.rules,
|
|
||||||
}
|
|
||||||
|
|
||||||
manager := &DefaultManager{}
|
|
||||||
rules, _ := manager.squashAcceptRules(networkMap)
|
|
||||||
|
|
||||||
assert.Equal(t, tt.expectedCount, len(rules), tt.description)
|
|
||||||
|
|
||||||
// For squashed rules, verify we get the expected 0.0.0.0 rule
|
|
||||||
if tt.expectedCount == 1 {
|
|
||||||
assert.Equal(t, "0.0.0.0", rules[0].PeerIP)
|
|
||||||
assert.Equal(t, mgmProto.RuleDirection_IN, rules[0].Direction)
|
|
||||||
assert.Equal(t, mgmProto.RuleAction_ACCEPT, rules[0].Action)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPortInfoEmpty(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
portInfo *mgmProto.PortInfo
|
|
||||||
expected bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "nil PortInfo should be empty",
|
|
||||||
portInfo: nil,
|
|
||||||
expected: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "PortInfo with zero port should be empty",
|
|
||||||
portInfo: &mgmProto.PortInfo{
|
|
||||||
PortSelection: &mgmProto.PortInfo_Port{
|
|
||||||
Port: 0,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expected: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "PortInfo with valid port should not be empty",
|
|
||||||
portInfo: &mgmProto.PortInfo{
|
|
||||||
PortSelection: &mgmProto.PortInfo_Port{
|
|
||||||
Port: 80,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expected: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "PortInfo with nil range should be empty",
|
|
||||||
portInfo: &mgmProto.PortInfo{
|
|
||||||
PortSelection: &mgmProto.PortInfo_Range_{
|
|
||||||
Range: nil,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expected: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "PortInfo with zero start range should be empty",
|
|
||||||
portInfo: &mgmProto.PortInfo{
|
|
||||||
PortSelection: &mgmProto.PortInfo_Range_{
|
|
||||||
Range: &mgmProto.PortInfo_Range{
|
|
||||||
Start: 0,
|
|
||||||
End: 100,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expected: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "PortInfo with zero end range should be empty",
|
|
||||||
portInfo: &mgmProto.PortInfo{
|
|
||||||
PortSelection: &mgmProto.PortInfo_Range_{
|
|
||||||
Range: &mgmProto.PortInfo_Range{
|
|
||||||
Start: 80,
|
|
||||||
End: 0,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expected: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "PortInfo with valid range should not be empty",
|
|
||||||
portInfo: &mgmProto.PortInfo{
|
|
||||||
PortSelection: &mgmProto.PortInfo_Range_{
|
|
||||||
Range: &mgmProto.PortInfo_Range{
|
|
||||||
Start: 8080,
|
|
||||||
End: 8090,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expected: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
result := portInfoEmpty(tt.portInfo)
|
|
||||||
assert.Equal(t, tt.expected, result)
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -798,29 +336,33 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
|||||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||||
network := netip.MustParsePrefix("172.0.0.1/32")
|
ip, network, err := net.ParseCIDR("172.0.0.1/32")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to parse IP address: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||||
IP: network.Addr(),
|
IP: ip,
|
||||||
Network: network,
|
Network: network,
|
||||||
}).AnyTimes()
|
}).AnyTimes()
|
||||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||||
|
|
||||||
|
// we receive one rule from the management so for testing purposes ignore it
|
||||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
|
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
|
||||||
require.NoError(t, err)
|
if err != nil {
|
||||||
defer func() {
|
t.Errorf("create firewall: %v", err)
|
||||||
err = fw.Close(nil)
|
return
|
||||||
require.NoError(t, err)
|
}
|
||||||
}()
|
defer func(fw manager.Manager) {
|
||||||
|
_ = fw.Close(nil)
|
||||||
|
}(fw)
|
||||||
acl := NewDefaultManager(fw)
|
acl := NewDefaultManager(fw)
|
||||||
|
|
||||||
acl.ApplyFiltering(networkMap, false)
|
acl.ApplyFiltering(networkMap, false)
|
||||||
|
|
||||||
expectedRules := 3
|
if len(acl.peerRulesPairs) != 3 {
|
||||||
if fw.IsStateful() {
|
t.Errorf("expect 3 rules (last must be SSH), got: %d", len(acl.peerRulesPairs))
|
||||||
expectedRules = 3 // 2 inbound rules + SSH rule
|
return
|
||||||
}
|
}
|
||||||
assert.Equal(t, expectedRules, len(acl.peerRulesPairs))
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -64,8 +64,13 @@ func (t TokenInfo) GetTokenToUse() string {
|
|||||||
// and if that also fails, the authentication process is deemed unsuccessful
|
// and if that also fails, the authentication process is deemed unsuccessful
|
||||||
//
|
//
|
||||||
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
|
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
|
||||||
func NewOAuthFlow(ctx context.Context, config *internal.Config, isUnixDesktopClient bool) (OAuthFlow, error) {
|
func NewOAuthFlow(ctx context.Context, config *internal.Config, isLinuxDesktopClient bool) (OAuthFlow, error) {
|
||||||
if (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient {
|
if runtime.GOOS == "linux" && !isLinuxDesktopClient {
|
||||||
|
return authenticateWithDeviceCodeFlow(ctx, config)
|
||||||
|
}
|
||||||
|
|
||||||
|
// On FreeBSD we currently do not support desktop environments and offer only Device Code Flow (#2384)
|
||||||
|
if runtime.GOOS == "freebsd" {
|
||||||
return authenticateWithDeviceCodeFlow(ctx, config)
|
return authenticateWithDeviceCodeFlow(ctx, config)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -101,12 +101,7 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
|
|||||||
oauth2.SetAuthURLParam("audience", p.providerConfig.Audience),
|
oauth2.SetAuthURLParam("audience", p.providerConfig.Audience),
|
||||||
}
|
}
|
||||||
if !p.providerConfig.DisablePromptLogin {
|
if !p.providerConfig.DisablePromptLogin {
|
||||||
if p.providerConfig.LoginFlag.IsPromptLogin() {
|
params = append(params, oauth2.SetAuthURLParam("prompt", "login"))
|
||||||
params = append(params, oauth2.SetAuthURLParam("prompt", "login"))
|
|
||||||
}
|
|
||||||
if p.providerConfig.LoginFlag.IsMaxAge0Login() {
|
|
||||||
params = append(params, oauth2.SetAuthURLParam("max_age", "0"))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
authURL := p.oAuthConfig.AuthCodeURL(state, params...)
|
authURL := p.oAuthConfig.AuthCodeURL(state, params...)
|
||||||
|
|||||||
@@ -7,36 +7,15 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
mgm "github.com/netbirdio/netbird/management/client/common"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestPromptLogin(t *testing.T) {
|
func TestPromptLogin(t *testing.T) {
|
||||||
const (
|
|
||||||
promptLogin = "prompt=login"
|
|
||||||
maxAge0 = "max_age=0"
|
|
||||||
)
|
|
||||||
|
|
||||||
tt := []struct {
|
tt := []struct {
|
||||||
name string
|
name string
|
||||||
loginFlag mgm.LoginFlag
|
prompt bool
|
||||||
disablePromptLogin bool
|
|
||||||
expect string
|
|
||||||
}{
|
}{
|
||||||
{
|
{"PromptLogin", true},
|
||||||
name: "Prompt login",
|
{"NoPromptLogin", false},
|
||||||
loginFlag: mgm.LoginFlagPrompt,
|
|
||||||
expect: promptLogin,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Max age 0 login",
|
|
||||||
loginFlag: mgm.LoginFlagMaxAge0,
|
|
||||||
expect: maxAge0,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Disable prompt login",
|
|
||||||
loginFlag: mgm.LoginFlagPrompt,
|
|
||||||
disablePromptLogin: true,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range tt {
|
for _, tc := range tt {
|
||||||
@@ -49,7 +28,7 @@ func TestPromptLogin(t *testing.T) {
|
|||||||
AuthorizationEndpoint: "https://test-auth-endpoint.com/authorize",
|
AuthorizationEndpoint: "https://test-auth-endpoint.com/authorize",
|
||||||
RedirectURLs: []string{"http://127.0.0.1:33992/"},
|
RedirectURLs: []string{"http://127.0.0.1:33992/"},
|
||||||
UseIDToken: true,
|
UseIDToken: true,
|
||||||
LoginFlag: tc.loginFlag,
|
DisablePromptLogin: !tc.prompt,
|
||||||
}
|
}
|
||||||
pkce, err := NewPKCEAuthorizationFlow(config)
|
pkce, err := NewPKCEAuthorizationFlow(config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -59,12 +38,11 @@ func TestPromptLogin(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to request auth info: %v", err)
|
t.Fatalf("Failed to request auth info: %v", err)
|
||||||
}
|
}
|
||||||
|
pattern := "prompt=login"
|
||||||
if !tc.disablePromptLogin {
|
if tc.prompt {
|
||||||
require.Contains(t, authInfo.VerificationURIComplete, tc.expect)
|
require.Contains(t, authInfo.VerificationURIComplete, pattern)
|
||||||
} else {
|
} else {
|
||||||
require.Contains(t, authInfo.VerificationURIComplete, promptLogin)
|
require.NotContains(t, authInfo.VerificationURIComplete, pattern)
|
||||||
require.NotContains(t, authInfo.VerificationURIComplete, maxAge0)
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user