Compare commits

..

3 Commits

Author SHA1 Message Date
Pedro Costa
43176fb96c Merge branch 'main' into feature/buf-cli
# Conflicts:
#	client/proto/daemon.pb.go
#	management/proto/management.pb.go
2025-05-12 10:34:14 +01:00
Pedro Costa
d1e4bb0fa3 Merge branch 'main' into feature/buf-cli
# Conflicts:
#	management/proto/management.pb.go
2025-04-17 11:53:47 +01:00
Pedro Costa
12b0c13511 [misc] buf cli proto lint, format and gen 2025-04-09 20:39:32 +01:00
312 changed files with 8249 additions and 18046 deletions

View File

@@ -37,21 +37,16 @@ If yes, which one?
**Debug output** **Debug output**
To help us resolve the problem, please attach the following anonymized status output To help us resolve the problem, please attach the following debug output
netbird status -dA netbird status -dA
Create and upload a debug bundle, and share the returned file key: As well as the file created by
netbird debug for 1m -AS -U
*Uploaded files are automatically deleted after 30 days.*
Alternatively, create the file only and attach it here manually:
netbird debug for 1m -AS netbird debug for 1m -AS
We advise reviewing the anonymized output for any remaining personal information.
**Screenshots** **Screenshots**
@@ -62,10 +57,8 @@ If applicable, add screenshots to help explain your problem.
Add any other context about the problem here. Add any other context about the problem here.
**Have you tried these troubleshooting steps?** **Have you tried these troubleshooting steps?**
- [ ] Reviewed [client troubleshooting](https://docs.netbird.io/how-to/troubleshooting-client) (if applicable)
- [ ] Checked for newer NetBird versions - [ ] Checked for newer NetBird versions
- [ ] Searched for similar issues on GitHub (including closed ones) - [ ] Searched for similar issues on GitHub (including closed ones)
- [ ] Restarted the NetBird client - [ ] Restarted the NetBird client
- [ ] Disabled other VPN software - [ ] Disabled other VPN software
- [ ] Checked firewall settings - [ ] Checked firewall settings

View File

@@ -13,5 +13,3 @@
- [ ] It is a refactor - [ ] It is a refactor
- [ ] Created tests that fail without the change (if possible) - [ ] Created tests that fail without the change (if possible)
- [ ] Extended the README / documentation, if necessary - [ ] Extended the README / documentation, if necessary
> By submitting this pull request, you confirm that you have read and agree to the terms of the [Contributor License Agreement](https://github.com/netbirdio/netbird/blob/main/CONTRIBUTOR_LICENSE_AGREEMENT.md).

View File

@@ -223,10 +223,6 @@ jobs:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Install dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
- name: Get Go environment - name: Get Go environment
run: | run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
@@ -273,10 +269,6 @@ jobs:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Install dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
- name: Get Go environment - name: Get Go environment
run: | run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV

View File

@@ -21,6 +21,7 @@ jobs:
with: with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe
skip: go.mod,go.sum skip: go.mod,go.sum
only_warn: 1
golangci: golangci:
strategy: strategy:
fail-fast: false fail-fast: false

View File

@@ -9,7 +9,7 @@ on:
pull_request: pull_request:
env: env:
SIGN_PIPE_VER: "v0.0.20" SIGN_PIPE_VER: "v0.0.18"
GORELEASER_VER: "v2.3.2" GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird" PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH" COPYRIGHT: "NetBird GmbH"
@@ -65,13 +65,6 @@ jobs:
with: with:
username: ${{ secrets.DOCKER_USER }} username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }} password: ${{ secrets.DOCKER_TOKEN }}
- name: Log in to the GitHub container registry
if: github.event_name != 'pull_request'
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.CI_DOCKER_PUSH_GITHUB_TOKEN }}
- name: Install OS build dependencies - name: Install OS build dependencies
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu

View File

@@ -134,7 +134,6 @@ jobs:
NETBIRD_STORE_ENGINE_MYSQL_DSN: '${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}$' NETBIRD_STORE_ENGINE_MYSQL_DSN: '${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}$'
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4" CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4"
CI_NETBIRD_MGMT_DISABLE_DEFAULT_POLICY: false
run: | run: |
set -x set -x
@@ -173,15 +172,13 @@ jobs:
grep "NETBIRD_STORE_ENGINE_MYSQL_DSN=$NETBIRD_STORE_ENGINE_MYSQL_DSN" docker-compose.yml grep "NETBIRD_STORE_ENGINE_MYSQL_DSN=$NETBIRD_STORE_ENGINE_MYSQL_DSN" docker-compose.yml
grep NETBIRD_STORE_ENGINE_POSTGRES_DSN docker-compose.yml | egrep "$NETBIRD_STORE_ENGINE_POSTGRES_DSN" grep NETBIRD_STORE_ENGINE_POSTGRES_DSN docker-compose.yml | egrep "$NETBIRD_STORE_ENGINE_POSTGRES_DSN"
# check relay values # check relay values
grep "NB_EXPOSED_ADDRESS=rels://$CI_NETBIRD_DOMAIN:33445" docker-compose.yml grep "NB_EXPOSED_ADDRESS=$CI_NETBIRD_DOMAIN:33445" docker-compose.yml
grep "NB_LISTEN_ADDRESS=:33445" docker-compose.yml grep "NB_LISTEN_ADDRESS=:33445" docker-compose.yml
grep '33445:33445' docker-compose.yml grep '33445:33445' docker-compose.yml
grep -A 10 'relay:' docker-compose.yml | egrep 'NB_AUTH_SECRET=.+$' grep -A 10 'relay:' docker-compose.yml | egrep 'NB_AUTH_SECRET=.+$'
grep -A 7 Relay management.json | grep "rels://$CI_NETBIRD_DOMAIN:33445" grep -A 7 Relay management.json | grep "rel://$CI_NETBIRD_DOMAIN:33445"
grep -A 7 Relay management.json | egrep '"Secret": ".+"' grep -A 7 Relay management.json | egrep '"Secret": ".+"'
grep DisablePromptLogin management.json | grep 'true' grep DisablePromptLogin management.json | grep 'true'
grep LoginFlag management.json | grep 0
grep DisableDefaultPolicy management.json | grep "$CI_NETBIRD_MGMT_DISABLE_DEFAULT_POLICY"
- name: Install modules - name: Install modules
run: go mod tidy run: go mod tidy

View File

@@ -149,7 +149,6 @@ nfpms:
dockers: dockers:
- image_templates: - image_templates:
- netbirdio/netbird:{{ .Version }}-amd64 - netbirdio/netbird:{{ .Version }}-amd64
- ghcr.io/netbirdio/netbird:{{ .Version }}-amd64
ids: ids:
- netbird - netbird
goarch: amd64 goarch: amd64
@@ -165,7 +164,6 @@ dockers:
- "--label=maintainer=dev@netbird.io" - "--label=maintainer=dev@netbird.io"
- image_templates: - image_templates:
- netbirdio/netbird:{{ .Version }}-arm64v8 - netbirdio/netbird:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
ids: ids:
- netbird - netbird
goarch: arm64 goarch: arm64
@@ -177,11 +175,10 @@ dockers:
- "--label=org.opencontainers.image.title={{.ProjectName}}" - "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}" - "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}" - "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}" - "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io" - "--label=maintainer=dev@netbird.io"
- image_templates: - image_templates:
- netbirdio/netbird:{{ .Version }}-arm - netbirdio/netbird:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm
ids: ids:
- netbird - netbird
goarch: arm goarch: arm
@@ -194,12 +191,11 @@ dockers:
- "--label=org.opencontainers.image.title={{.ProjectName}}" - "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}" - "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}" - "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}" - "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io" - "--label=maintainer=dev@netbird.io"
- image_templates: - image_templates:
- netbirdio/netbird:{{ .Version }}-rootless-amd64 - netbirdio/netbird:{{ .Version }}-rootless-amd64
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-amd64
ids: ids:
- netbird - netbird
goarch: amd64 goarch: amd64
@@ -211,11 +207,9 @@ dockers:
- "--label=org.opencontainers.image.title={{.ProjectName}}" - "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}" - "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}" - "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io" - "--label=maintainer=dev@netbird.io"
- image_templates: - image_templates:
- netbirdio/netbird:{{ .Version }}-rootless-arm64v8 - netbirdio/netbird:{{ .Version }}-rootless-arm64v8
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm64v8
ids: ids:
- netbird - netbird
goarch: arm64 goarch: arm64
@@ -227,11 +221,9 @@ dockers:
- "--label=org.opencontainers.image.title={{.ProjectName}}" - "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}" - "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}" - "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io" - "--label=maintainer=dev@netbird.io"
- image_templates: - image_templates:
- netbirdio/netbird:{{ .Version }}-rootless-arm - netbirdio/netbird:{{ .Version }}-rootless-arm
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm
ids: ids:
- netbird - netbird
goarch: arm goarch: arm
@@ -244,12 +236,10 @@ dockers:
- "--label=org.opencontainers.image.title={{.ProjectName}}" - "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}" - "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}" - "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io" - "--label=maintainer=dev@netbird.io"
- image_templates: - image_templates:
- netbirdio/relay:{{ .Version }}-amd64 - netbirdio/relay:{{ .Version }}-amd64
- ghcr.io/netbirdio/relay:{{ .Version }}-amd64
ids: ids:
- netbird-relay - netbird-relay
goarch: amd64 goarch: amd64
@@ -261,11 +251,10 @@ dockers:
- "--label=org.opencontainers.image.title={{.ProjectName}}" - "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}" - "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}" - "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}" - "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io" - "--label=maintainer=dev@netbird.io"
- image_templates: - image_templates:
- netbirdio/relay:{{ .Version }}-arm64v8 - netbirdio/relay:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/relay:{{ .Version }}-arm64v8
ids: ids:
- netbird-relay - netbird-relay
goarch: arm64 goarch: arm64
@@ -277,11 +266,10 @@ dockers:
- "--label=org.opencontainers.image.title={{.ProjectName}}" - "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}" - "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}" - "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}" - "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io" - "--label=maintainer=dev@netbird.io"
- image_templates: - image_templates:
- netbirdio/relay:{{ .Version }}-arm - netbirdio/relay:{{ .Version }}-arm
- ghcr.io/netbirdio/relay:{{ .Version }}-arm
ids: ids:
- netbird-relay - netbird-relay
goarch: arm goarch: arm
@@ -294,11 +282,10 @@ dockers:
- "--label=org.opencontainers.image.title={{.ProjectName}}" - "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}" - "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}" - "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}" - "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io" - "--label=maintainer=dev@netbird.io"
- image_templates: - image_templates:
- netbirdio/signal:{{ .Version }}-amd64 - netbirdio/signal:{{ .Version }}-amd64
- ghcr.io/netbirdio/signal:{{ .Version }}-amd64
ids: ids:
- netbird-signal - netbird-signal
goarch: amd64 goarch: amd64
@@ -310,11 +297,10 @@ dockers:
- "--label=org.opencontainers.image.title={{.ProjectName}}" - "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}" - "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}" - "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}" - "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io" - "--label=maintainer=dev@netbird.io"
- image_templates: - image_templates:
- netbirdio/signal:{{ .Version }}-arm64v8 - netbirdio/signal:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/signal:{{ .Version }}-arm64v8
ids: ids:
- netbird-signal - netbird-signal
goarch: arm64 goarch: arm64
@@ -326,11 +312,10 @@ dockers:
- "--label=org.opencontainers.image.title={{.ProjectName}}" - "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}" - "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}" - "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}" - "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io" - "--label=maintainer=dev@netbird.io"
- image_templates: - image_templates:
- netbirdio/signal:{{ .Version }}-arm - netbirdio/signal:{{ .Version }}-arm
- ghcr.io/netbirdio/signal:{{ .Version }}-arm
ids: ids:
- netbird-signal - netbird-signal
goarch: arm goarch: arm
@@ -343,11 +328,10 @@ dockers:
- "--label=org.opencontainers.image.title={{.ProjectName}}" - "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}" - "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}" - "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}" - "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io" - "--label=maintainer=dev@netbird.io"
- image_templates: - image_templates:
- netbirdio/management:{{ .Version }}-amd64 - netbirdio/management:{{ .Version }}-amd64
- ghcr.io/netbirdio/management:{{ .Version }}-amd64
ids: ids:
- netbird-mgmt - netbird-mgmt
goarch: amd64 goarch: amd64
@@ -359,11 +343,10 @@ dockers:
- "--label=org.opencontainers.image.title={{.ProjectName}}" - "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}" - "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}" - "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}" - "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io" - "--label=maintainer=dev@netbird.io"
- image_templates: - image_templates:
- netbirdio/management:{{ .Version }}-arm64v8 - netbirdio/management:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/management:{{ .Version }}-arm64v8
ids: ids:
- netbird-mgmt - netbird-mgmt
goarch: arm64 goarch: arm64
@@ -375,11 +358,10 @@ dockers:
- "--label=org.opencontainers.image.title={{.ProjectName}}" - "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}" - "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}" - "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}" - "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io" - "--label=maintainer=dev@netbird.io"
- image_templates: - image_templates:
- netbirdio/management:{{ .Version }}-arm - netbirdio/management:{{ .Version }}-arm
- ghcr.io/netbirdio/management:{{ .Version }}-arm
ids: ids:
- netbird-mgmt - netbird-mgmt
goarch: arm goarch: arm
@@ -392,11 +374,10 @@ dockers:
- "--label=org.opencontainers.image.title={{.ProjectName}}" - "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}" - "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}" - "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}" - "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io" - "--label=maintainer=dev@netbird.io"
- image_templates: - image_templates:
- netbirdio/management:{{ .Version }}-debug-amd64 - netbirdio/management:{{ .Version }}-debug-amd64
- ghcr.io/netbirdio/management:{{ .Version }}-debug-amd64
ids: ids:
- netbird-mgmt - netbird-mgmt
goarch: amd64 goarch: amd64
@@ -408,11 +389,10 @@ dockers:
- "--label=org.opencontainers.image.title={{.ProjectName}}" - "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}" - "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}" - "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}" - "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io" - "--label=maintainer=dev@netbird.io"
- image_templates: - image_templates:
- netbirdio/management:{{ .Version }}-debug-arm64v8 - netbirdio/management:{{ .Version }}-debug-arm64v8
- ghcr.io/netbirdio/management:{{ .Version }}-debug-arm64v8
ids: ids:
- netbird-mgmt - netbird-mgmt
goarch: arm64 goarch: arm64
@@ -424,12 +404,11 @@ dockers:
- "--label=org.opencontainers.image.title={{.ProjectName}}" - "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}" - "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}" - "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}" - "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io" - "--label=maintainer=dev@netbird.io"
- image_templates: - image_templates:
- netbirdio/management:{{ .Version }}-debug-arm - netbirdio/management:{{ .Version }}-debug-arm
- ghcr.io/netbirdio/management:{{ .Version }}-debug-arm
ids: ids:
- netbird-mgmt - netbird-mgmt
goarch: arm goarch: arm
@@ -442,11 +421,10 @@ dockers:
- "--label=org.opencontainers.image.title={{.ProjectName}}" - "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}" - "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}" - "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}" - "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io" - "--label=maintainer=dev@netbird.io"
- image_templates: - image_templates:
- netbirdio/upload:{{ .Version }}-amd64 - netbirdio/upload:{{ .Version }}-amd64
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
ids: ids:
- netbird-upload - netbird-upload
goarch: amd64 goarch: amd64
@@ -458,11 +436,10 @@ dockers:
- "--label=org.opencontainers.image.title={{.ProjectName}}" - "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}" - "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}" - "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}" - "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io" - "--label=maintainer=dev@netbird.io"
- image_templates: - image_templates:
- netbirdio/upload:{{ .Version }}-arm64v8 - netbirdio/upload:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
ids: ids:
- netbird-upload - netbird-upload
goarch: arm64 goarch: arm64
@@ -474,11 +451,10 @@ dockers:
- "--label=org.opencontainers.image.title={{.ProjectName}}" - "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}" - "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}" - "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}" - "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io" - "--label=maintainer=dev@netbird.io"
- image_templates: - image_templates:
- netbirdio/upload:{{ .Version }}-arm - netbirdio/upload:{{ .Version }}-arm
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
ids: ids:
- netbird-upload - netbird-upload
goarch: arm goarch: arm
@@ -491,7 +467,7 @@ dockers:
- "--label=org.opencontainers.image.title={{.ProjectName}}" - "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}" - "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}" - "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}" - "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io" - "--label=maintainer=dev@netbird.io"
docker_manifests: docker_manifests:
- name_template: netbirdio/netbird:{{ .Version }} - name_template: netbirdio/netbird:{{ .Version }}
@@ -570,84 +546,6 @@ docker_manifests:
- netbirdio/upload:{{ .Version }}-arm64v8 - netbirdio/upload:{{ .Version }}-arm64v8
- netbirdio/upload:{{ .Version }}-arm - netbirdio/upload:{{ .Version }}-arm
- netbirdio/upload:{{ .Version }}-amd64 - netbirdio/upload:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird:latest
image_templates:
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird:{{ .Version }}-rootless
image_templates:
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm64v8
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-amd64
- name_template: ghcr.io/netbirdio/netbird:rootless-latest
image_templates:
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm64v8
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-amd64
- name_template: ghcr.io/netbirdio/relay:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/relay:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/relay:{{ .Version }}-arm
- ghcr.io/netbirdio/relay:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/relay:latest
image_templates:
- ghcr.io/netbirdio/relay:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/relay:{{ .Version }}-arm
- ghcr.io/netbirdio/relay:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/signal:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/signal:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/signal:{{ .Version }}-arm
- ghcr.io/netbirdio/signal:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/signal:latest
image_templates:
- ghcr.io/netbirdio/signal:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/signal:{{ .Version }}-arm
- ghcr.io/netbirdio/signal:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/management:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/management:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/management:{{ .Version }}-arm
- ghcr.io/netbirdio/management:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/management:latest
image_templates:
- ghcr.io/netbirdio/management:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/management:{{ .Version }}-arm
- ghcr.io/netbirdio/management:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/management:debug-latest
image_templates:
- ghcr.io/netbirdio/management:{{ .Version }}-debug-arm64v8
- ghcr.io/netbirdio/management:{{ .Version }}-debug-arm
- ghcr.io/netbirdio/management:{{ .Version }}-debug-amd64
- name_template: ghcr.io/netbirdio/upload:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/upload:latest
image_templates:
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
brews: brews:
- ids: - ids:
- default - default

View File

@@ -12,7 +12,7 @@
<img src="https://img.shields.io/badge/license-BSD--3-blue" /> <img src="https://img.shields.io/badge/license-BSD--3-blue" />
</a> </a>
<br> <br>
<a href="https://docs.netbird.io/slack-url"> <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-31rofwmxc-27akKd0Le0vyRpBcwXkP0g">
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/> <img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
</a> </a>
<br> <br>
@@ -29,7 +29,7 @@
<br/> <br/>
See <a href="https://netbird.io/docs/">Documentation</a> See <a href="https://netbird.io/docs/">Documentation</a>
<br/> <br/>
Join our <a href="https://docs.netbird.io/slack-url">Slack channel</a> Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-31rofwmxc-27akKd0Le0vyRpBcwXkP0g">Slack channel</a>
<br/> <br/>
</strong> </strong>

7
buf.gen.yaml Normal file
View File

@@ -0,0 +1,7 @@
# For details on buf.gen.yaml configuration, visit https://buf.build/docs/configuration/v2/buf-gen-yaml/
version: v2
plugins:
- remote: buf.build/protocolbuffers/go:v1.35.1
out: .
- remote: buf.build/grpc/go:v1.5.1
out: .

10
buf.yaml Normal file
View File

@@ -0,0 +1,10 @@
# For details on buf.yaml configuration, visit https://buf.build/docs/configuration/v2/buf-yaml
version: v2
modules:
- path: proto
lint:
use:
- BASIC
breaking:
use:
- FILE

View File

@@ -1,9 +1,6 @@
FROM alpine:3.21.3 FROM alpine:3.21.3
# iproute2: busybox doesn't display ip rules properly # iproute2: busybox doesn't display ip rules properly
RUN apk add --no-cache ca-certificates ip6tables iproute2 iptables RUN apk add --no-cache ca-certificates ip6tables iproute2 iptables
ARG NETBIRD_BINARY=netbird
COPY ${NETBIRD_BINARY} /usr/local/bin/netbird
ENV NB_FOREGROUND_MODE=true ENV NB_FOREGROUND_MODE=true
ENTRYPOINT [ "/usr/local/bin/netbird","up"] ENTRYPOINT [ "/usr/local/bin/netbird","up"]
COPY netbird /usr/local/bin/netbird

View File

@@ -1,7 +1,6 @@
FROM alpine:3.21.0 FROM alpine:3.21.0
ARG NETBIRD_BINARY=netbird COPY netbird /usr/local/bin/netbird
COPY ${NETBIRD_BINARY} /usr/local/bin/netbird
RUN apk add --no-cache ca-certificates \ RUN apk add --no-cache ca-certificates \
&& adduser -D -h /var/lib/netbird netbird && adduser -D -h /var/lib/netbird netbird

View File

@@ -59,8 +59,6 @@ type Client struct {
deviceName string deviceName string
uiVersion string uiVersion string
networkChangeListener listener.NetworkChangeListener networkChangeListener listener.NetworkChangeListener
connectClient *internal.ConnectClient
} }
// NewClient instantiate a new Client // NewClient instantiate a new Client
@@ -108,8 +106,8 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
// todo do not throw error in case of cancelled context // todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx) ctx = internal.CtxInitState(ctx)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener) return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
} }
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot). // RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
@@ -134,8 +132,8 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener
// todo do not throw error in case of cancelled context // todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx) ctx = internal.CtxInitState(ctx)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener) return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
} }
// Stop the internal client and free the resources // Stop the internal client and free the resources
@@ -176,53 +174,6 @@ func (c *Client) PeersList() *PeerInfoArray {
return &PeerInfoArray{items: peerInfos} return &PeerInfoArray{items: peerInfos}
} }
func (c *Client) Networks() *NetworkArray {
if c.connectClient == nil {
log.Error("not connected")
return nil
}
engine := c.connectClient.Engine()
if engine == nil {
log.Error("could not get engine")
return nil
}
routeManager := engine.GetRouteManager()
if routeManager == nil {
log.Error("could not get route manager")
return nil
}
networkArray := &NetworkArray{
items: make([]Network, 0),
}
for id, routes := range routeManager.GetClientRoutesWithNetID() {
if len(routes) == 0 {
continue
}
if routes[0].IsDynamic() {
continue
}
peer, err := c.recorder.GetPeer(routes[0].Peer)
if err != nil {
log.Errorf("could not get peer info for %s: %v", routes[0].Peer, err)
continue
}
network := Network{
Name: string(id),
Network: routes[0].Network.String(),
Peer: peer.FQDN,
Status: peer.ConnStatus.String(),
}
networkArray.Add(network)
}
return networkArray
}
// OnUpdatedHostDNS update the DNS servers addresses for root zones // OnUpdatedHostDNS update the DNS servers addresses for root zones
func (c *Client) OnUpdatedHostDNS(list *DNSList) error { func (c *Client) OnUpdatedHostDNS(list *DNSList) error {
dnsServer, err := dns.GetServerDns() dnsServer, err := dns.GetServerDns()

View File

@@ -1,27 +0,0 @@
//go:build android
package android
type Network struct {
Name string
Network string
Peer string
Status string
}
type NetworkArray struct {
items []Network
}
func (array *NetworkArray) Add(s Network) *NetworkArray {
array.items = append(array.items, s)
return array
}
func (array *NetworkArray) Get(i int) *Network {
return &array.items[i]
}
func (array *NetworkArray) Size() int {
return len(array.items)
}

View File

@@ -7,23 +7,30 @@ type PeerInfo struct {
ConnStatus string // Todo replace to enum ConnStatus string // Todo replace to enum
} }
// PeerInfoArray is a wrapper of []PeerInfo // PeerInfoCollection made for Java layer to get non default types as collection
type PeerInfoCollection interface {
Add(s string) PeerInfoCollection
Get(i int) string
Size() int
}
// PeerInfoArray is the implementation of the PeerInfoCollection
type PeerInfoArray struct { type PeerInfoArray struct {
items []PeerInfo items []PeerInfo
} }
// Add new PeerInfo to the collection // Add new PeerInfo to the collection
func (array *PeerInfoArray) Add(s PeerInfo) *PeerInfoArray { func (array PeerInfoArray) Add(s PeerInfo) PeerInfoArray {
array.items = append(array.items, s) array.items = append(array.items, s)
return array return array
} }
// Get return an element of the collection // Get return an element of the collection
func (array *PeerInfoArray) Get(i int) *PeerInfo { func (array PeerInfoArray) Get(i int) *PeerInfo {
return &array.items[i] return &array.items[i]
} }
// Size return with the size of the collection // Size return with the size of the collection
func (array *PeerInfoArray) Size() int { func (array PeerInfoArray) Size() int {
return len(array.items) return len(array.items)
} }

View File

@@ -4,12 +4,12 @@ import (
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
) )
// Preferences exports a subset of the internal config for gomobile // Preferences export a subset of the internal config for gomobile
type Preferences struct { type Preferences struct {
configInput internal.ConfigInput configInput internal.ConfigInput
} }
// NewPreferences creates a new Preferences instance // NewPreferences create new Preferences instance
func NewPreferences(configPath string) *Preferences { func NewPreferences(configPath string) *Preferences {
ci := internal.ConfigInput{ ci := internal.ConfigInput{
ConfigPath: configPath, ConfigPath: configPath,
@@ -17,7 +17,7 @@ func NewPreferences(configPath string) *Preferences {
return &Preferences{ci} return &Preferences{ci}
} }
// GetManagementURL reads URL from config file // GetManagementURL read url from config file
func (p *Preferences) GetManagementURL() (string, error) { func (p *Preferences) GetManagementURL() (string, error) {
if p.configInput.ManagementURL != "" { if p.configInput.ManagementURL != "" {
return p.configInput.ManagementURL, nil return p.configInput.ManagementURL, nil
@@ -30,12 +30,12 @@ func (p *Preferences) GetManagementURL() (string, error) {
return cfg.ManagementURL.String(), err return cfg.ManagementURL.String(), err
} }
// SetManagementURL stores the given URL and waits for commit // SetManagementURL store the given url and wait for commit
func (p *Preferences) SetManagementURL(url string) { func (p *Preferences) SetManagementURL(url string) {
p.configInput.ManagementURL = url p.configInput.ManagementURL = url
} }
// GetAdminURL reads URL from config file // GetAdminURL read url from config file
func (p *Preferences) GetAdminURL() (string, error) { func (p *Preferences) GetAdminURL() (string, error) {
if p.configInput.AdminURL != "" { if p.configInput.AdminURL != "" {
return p.configInput.AdminURL, nil return p.configInput.AdminURL, nil
@@ -48,12 +48,12 @@ func (p *Preferences) GetAdminURL() (string, error) {
return cfg.AdminURL.String(), err return cfg.AdminURL.String(), err
} }
// SetAdminURL stores the given URL and waits for commit // SetAdminURL store the given url and wait for commit
func (p *Preferences) SetAdminURL(url string) { func (p *Preferences) SetAdminURL(url string) {
p.configInput.AdminURL = url p.configInput.AdminURL = url
} }
// GetPreSharedKey reads pre-shared key from config file // GetPreSharedKey read preshared key from config file
func (p *Preferences) GetPreSharedKey() (string, error) { func (p *Preferences) GetPreSharedKey() (string, error) {
if p.configInput.PreSharedKey != nil { if p.configInput.PreSharedKey != nil {
return *p.configInput.PreSharedKey, nil return *p.configInput.PreSharedKey, nil
@@ -66,160 +66,12 @@ func (p *Preferences) GetPreSharedKey() (string, error) {
return cfg.PreSharedKey, err return cfg.PreSharedKey, err
} }
// SetPreSharedKey stores the given key and waits for commit // SetPreSharedKey store the given key and wait for commit
func (p *Preferences) SetPreSharedKey(key string) { func (p *Preferences) SetPreSharedKey(key string) {
p.configInput.PreSharedKey = &key p.configInput.PreSharedKey = &key
} }
// SetRosenpassEnabled stores whether Rosenpass is enabled // Commit write out the changes into config file
func (p *Preferences) SetRosenpassEnabled(enabled bool) {
p.configInput.RosenpassEnabled = &enabled
}
// GetRosenpassEnabled reads Rosenpass enabled status from config file
func (p *Preferences) GetRosenpassEnabled() (bool, error) {
if p.configInput.RosenpassEnabled != nil {
return *p.configInput.RosenpassEnabled, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.RosenpassEnabled, err
}
// SetRosenpassPermissive stores the given permissive setting and waits for commit
func (p *Preferences) SetRosenpassPermissive(permissive bool) {
p.configInput.RosenpassPermissive = &permissive
}
// GetRosenpassPermissive reads Rosenpass permissive setting from config file
func (p *Preferences) GetRosenpassPermissive() (bool, error) {
if p.configInput.RosenpassPermissive != nil {
return *p.configInput.RosenpassPermissive, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.RosenpassPermissive, err
}
// GetDisableClientRoutes reads disable client routes setting from config file
func (p *Preferences) GetDisableClientRoutes() (bool, error) {
if p.configInput.DisableClientRoutes != nil {
return *p.configInput.DisableClientRoutes, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.DisableClientRoutes, err
}
// SetDisableClientRoutes stores the given value and waits for commit
func (p *Preferences) SetDisableClientRoutes(disable bool) {
p.configInput.DisableClientRoutes = &disable
}
// GetDisableServerRoutes reads disable server routes setting from config file
func (p *Preferences) GetDisableServerRoutes() (bool, error) {
if p.configInput.DisableServerRoutes != nil {
return *p.configInput.DisableServerRoutes, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.DisableServerRoutes, err
}
// SetDisableServerRoutes stores the given value and waits for commit
func (p *Preferences) SetDisableServerRoutes(disable bool) {
p.configInput.DisableServerRoutes = &disable
}
// GetDisableDNS reads disable DNS setting from config file
func (p *Preferences) GetDisableDNS() (bool, error) {
if p.configInput.DisableDNS != nil {
return *p.configInput.DisableDNS, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.DisableDNS, err
}
// SetDisableDNS stores the given value and waits for commit
func (p *Preferences) SetDisableDNS(disable bool) {
p.configInput.DisableDNS = &disable
}
// GetDisableFirewall reads disable firewall setting from config file
func (p *Preferences) GetDisableFirewall() (bool, error) {
if p.configInput.DisableFirewall != nil {
return *p.configInput.DisableFirewall, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.DisableFirewall, err
}
// SetDisableFirewall stores the given value and waits for commit
func (p *Preferences) SetDisableFirewall(disable bool) {
p.configInput.DisableFirewall = &disable
}
// GetServerSSHAllowed reads server SSH allowed setting from config file
func (p *Preferences) GetServerSSHAllowed() (bool, error) {
if p.configInput.ServerSSHAllowed != nil {
return *p.configInput.ServerSSHAllowed, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
if cfg.ServerSSHAllowed == nil {
// Default to false for security on Android
return false, nil
}
return *cfg.ServerSSHAllowed, err
}
// SetServerSSHAllowed stores the given value and waits for commit
func (p *Preferences) SetServerSSHAllowed(allowed bool) {
p.configInput.ServerSSHAllowed = &allowed
}
// GetBlockInbound reads block inbound setting from config file
func (p *Preferences) GetBlockInbound() (bool, error) {
if p.configInput.BlockInbound != nil {
return *p.configInput.BlockInbound, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.BlockInbound, err
}
// SetBlockInbound stores the given value and waits for commit
func (p *Preferences) SetBlockInbound(block bool) {
p.configInput.BlockInbound = &block
}
// Commit writes out the changes to the config file
func (p *Preferences) Commit() error { func (p *Preferences) Commit() error {
_, err := internal.UpdateOrCreateConfig(p.configInput) _, err := internal.UpdateOrCreateConfig(p.configInput)
return err return err

View File

@@ -69,22 +69,6 @@ func (a *Anonymizer) AnonymizeIP(ip netip.Addr) netip.Addr {
return a.ipAnonymizer[ip] return a.ipAnonymizer[ip]
} }
func (a *Anonymizer) AnonymizeUDPAddr(addr net.UDPAddr) net.UDPAddr {
// Convert IP to netip.Addr
ip, ok := netip.AddrFromSlice(addr.IP)
if !ok {
return addr
}
anonIP := a.AnonymizeIP(ip)
return net.UDPAddr{
IP: anonIP.AsSlice(),
Port: addr.Port,
Zone: addr.Zone,
}
}
// isInAnonymizedRange checks if an IP is within the range of already assigned anonymized IPs // isInAnonymizedRange checks if an IP is within the range of already assigned anonymized IPs
func (a *Anonymizer) isInAnonymizedRange(ip netip.Addr) bool { func (a *Anonymizer) isInAnonymizedRange(ip netip.Addr) bool {
if ip.Is4() && ip.Compare(a.startAnonIPv4) >= 0 && ip.Compare(a.currentAnonIPv4) <= 0 { if ip.Is4() && ip.Compare(a.startAnonIPv4) >= 0 && ip.Compare(a.currentAnonIPv4) <= 0 {

View File

@@ -1,210 +0,0 @@
package cmd
import (
"fmt"
"time"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/management/domain"
"github.com/spf13/cobra"
)
// SharedFlags contains all configuration flags that are common between up and set commands
type SharedFlags struct {
// Network configuration
InterfaceName string
WireguardPort uint16
NATExternalIPs []string
CustomDNSAddress string
ExtraIFaceBlackList []string
DNSLabels []string
DNSRouteInterval time.Duration
// Feature flags
RosenpassEnabled bool
RosenpassPermissive bool
ServerSSHAllowed bool
AutoConnectDisabled bool
NetworkMonitor bool
LazyConnEnabled bool
// System flags
DisableClientRoutes bool
DisableServerRoutes bool
DisableDNS bool
DisableFirewall bool
BlockLANAccess bool
BlockInbound bool
// Login-specific (only for up command)
NoBrowser bool
}
// AddSharedFlags adds all shared configuration flags to the given command
func AddSharedFlags(cmd *cobra.Command, flags *SharedFlags) {
// Network configuration flags
cmd.PersistentFlags().StringVar(&flags.InterfaceName, interfaceNameFlag, iface.WgInterfaceDefault,
"Wireguard interface name")
cmd.PersistentFlags().Uint16Var(&flags.WireguardPort, wireguardPortFlag, iface.DefaultWgPort,
"Wireguard interface listening port")
cmd.PersistentFlags().StringSliceVar(&flags.NATExternalIPs, externalIPMapFlag, nil,
`Sets external IPs maps between local addresses and interfaces. `+
`You can specify a comma-separated list with a single IP and IP/IP or IP/Interface Name. `+
`An empty string "" clears the previous configuration. `+
`E.g. --external-ip-map 12.34.56.78/10.0.0.1 or --external-ip-map 12.34.56.200,12.34.56.78/10.0.0.1,12.34.56.80/eth1 `+
`or --external-ip-map ""`)
cmd.PersistentFlags().StringVar(&flags.CustomDNSAddress, dnsResolverAddress, "",
`Sets a custom address for NetBird's local DNS resolver. `+
`If set, the agent won't attempt to discover the best ip and port to listen on. `+
`An empty string "" clears the previous configuration. `+
`E.g. --dns-resolver-address 127.0.0.1:5053 or --dns-resolver-address ""`)
cmd.PersistentFlags().StringSliceVar(&flags.ExtraIFaceBlackList, extraIFaceBlackListFlag, nil,
"Extra list of default interfaces to ignore for listening")
cmd.PersistentFlags().StringSliceVar(&flags.DNSLabels, dnsLabelsFlag, nil,
`Sets DNS labels. `+
`You can specify a comma-separated list of up to 32 labels. `+
`An empty string "" clears the previous configuration. `+
`E.g. --extra-dns-labels vpc1 or --extra-dns-labels vpc1,mgmt1 `+
`or --extra-dns-labels ""`)
cmd.PersistentFlags().DurationVar(&flags.DNSRouteInterval, dnsRouteIntervalFlag, time.Minute,
"DNS route update interval")
// Feature flags
cmd.PersistentFlags().BoolVar(&flags.RosenpassEnabled, enableRosenpassFlag, false,
"[Experimental] Enable Rosenpass feature. If enabled, the connection will be post-quantum secured via Rosenpass.")
cmd.PersistentFlags().BoolVar(&flags.RosenpassPermissive, rosenpassPermissiveFlag, false,
"[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
cmd.PersistentFlags().BoolVar(&flags.ServerSSHAllowed, serverSSHAllowedFlag, false,
"Allow SSH server on peer. If enabled, the SSH server will be permitted")
cmd.PersistentFlags().BoolVar(&flags.AutoConnectDisabled, disableAutoConnectFlag, false,
"Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
cmd.PersistentFlags().BoolVarP(&flags.NetworkMonitor, networkMonitorFlag, "N", networkMonitor,
`Manage network monitoring. Defaults to true on Windows and macOS, false on Linux and FreeBSD. `+
`E.g. --network-monitor=false to disable or --network-monitor=true to enable.`)
cmd.PersistentFlags().BoolVar(&flags.LazyConnEnabled, enableLazyConnectionFlag, false,
"[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand.")
// System flags (from system.go)
cmd.PersistentFlags().BoolVar(&flags.DisableClientRoutes, disableClientRoutesFlag, false,
"Disable client routes. If enabled, the client won't process client routes received from the management service.")
cmd.PersistentFlags().BoolVar(&flags.DisableServerRoutes, disableServerRoutesFlag, false,
"Disable server routes. If enabled, the client won't act as a router for server routes received from the management service.")
cmd.PersistentFlags().BoolVar(&flags.DisableDNS, disableDNSFlag, false,
"Disable DNS. If enabled, the client won't configure DNS settings.")
cmd.PersistentFlags().BoolVar(&flags.DisableFirewall, disableFirewallFlag, false,
"Disable firewall configuration. If enabled, the client won't modify firewall rules.")
cmd.PersistentFlags().BoolVar(&flags.BlockLANAccess, blockLANAccessFlag, false,
"Block access to local networks (LAN) when using this peer as a router or exit node")
cmd.PersistentFlags().BoolVar(&flags.BlockInbound, blockInboundFlag, false,
"Block inbound connections. If enabled, the client will not allow any inbound connections to the local machine nor routed networks.\n"+
"This overrides any policies received from the management service.")
}
// AddUpOnlyFlags adds flags that are specific to the up command
func AddUpOnlyFlags(cmd *cobra.Command, flags *SharedFlags) {
cmd.PersistentFlags().BoolVar(&flags.NoBrowser, noBrowserFlag, false, noBrowserDesc)
}
// BuildConfigInput creates an internal.ConfigInput from SharedFlags with Changed() checks
func BuildConfigInput(cmd *cobra.Command, flags *SharedFlags, customDNSAddressConverted []byte) (*internal.ConfigInput, error) {
ic := internal.ConfigInput{
ManagementURL: managementURL,
AdminURL: adminURL,
ConfigPath: configPath,
CustomDNSAddress: customDNSAddressConverted,
}
// Handle PreSharedKey from root command
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
ic.PreSharedKey = &preSharedKey
}
if cmd.Flag(enableRosenpassFlag).Changed {
ic.RosenpassEnabled = &flags.RosenpassEnabled
}
if cmd.Flag(rosenpassPermissiveFlag).Changed {
ic.RosenpassPermissive = &flags.RosenpassPermissive
}
if cmd.Flag(serverSSHAllowedFlag).Changed {
ic.ServerSSHAllowed = &flags.ServerSSHAllowed
}
if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(flags.InterfaceName); err != nil {
return nil, err
}
ic.InterfaceName = &flags.InterfaceName
}
if cmd.Flag(wireguardPortFlag).Changed {
p := int(flags.WireguardPort)
ic.WireguardPort = &p
}
if cmd.Flag(networkMonitorFlag).Changed {
ic.NetworkMonitor = &flags.NetworkMonitor
}
if cmd.Flag(disableAutoConnectFlag).Changed {
ic.DisableAutoConnect = &flags.AutoConnectDisabled
if flags.AutoConnectDisabled {
cmd.Println("Autoconnect has been disabled. The client won't connect automatically when the service starts.")
} else {
cmd.Println("Autoconnect has been enabled. The client will connect automatically when the service starts.")
}
}
if cmd.Flag(dnsRouteIntervalFlag).Changed {
ic.DNSRouteInterval = &flags.DNSRouteInterval
}
if cmd.Flag(disableClientRoutesFlag).Changed {
ic.DisableClientRoutes = &flags.DisableClientRoutes
}
if cmd.Flag(disableServerRoutesFlag).Changed {
ic.DisableServerRoutes = &flags.DisableServerRoutes
}
if cmd.Flag(disableDNSFlag).Changed {
ic.DisableDNS = &flags.DisableDNS
}
if cmd.Flag(disableFirewallFlag).Changed {
ic.DisableFirewall = &flags.DisableFirewall
}
if cmd.Flag(blockLANAccessFlag).Changed {
ic.BlockLANAccess = &flags.BlockLANAccess
}
if cmd.Flag(blockInboundFlag).Changed {
ic.BlockInbound = &flags.BlockInbound
}
if cmd.Flag(enableLazyConnectionFlag).Changed {
ic.LazyConnectionEnabled = &flags.LazyConnEnabled
}
if cmd.Flag(externalIPMapFlag).Changed {
ic.NATExternalIPs = flags.NATExternalIPs
}
if cmd.Flag(extraIFaceBlackListFlag).Changed {
ic.ExtraIFaceBlackList = flags.ExtraIFaceBlackList
}
if cmd.Flag(dnsLabelsFlag).Changed {
var err error
ic.DNSLabels, err = domain.FromStringList(flags.DNSLabels)
if err != nil {
return nil, fmt.Errorf("invalid DNS labels: %v", err)
}
}
return &ic, nil
}

View File

@@ -4,7 +4,6 @@ import (
"context" "context"
"fmt" "fmt"
"os" "os"
"runtime"
"strings" "strings"
"time" "time"
@@ -99,11 +98,11 @@ var loginCmd = &cobra.Command{
} }
loginRequest := proto.LoginRequest{ loginRequest := proto.LoginRequest{
SetupKey: providedSetupKey, SetupKey: providedSetupKey,
ManagementUrl: managementURL, ManagementUrl: managementURL,
IsUnixDesktopClient: isUnixRunningDesktop(), IsLinuxDesktopClient: isLinuxRunningDesktop(),
Hostname: hostName, Hostname: hostName,
DnsLabels: dnsLabelsReq, DnsLabels: dnsLabelsReq,
} }
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) { if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
@@ -196,7 +195,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.C
} }
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *internal.Config) (*auth.TokenInfo, error) { func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *internal.Config) (*auth.TokenInfo, error) {
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop()) oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isLinuxRunningDesktop())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -244,10 +243,7 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro
} }
} }
// isUnixRunningDesktop checks if a Linux OS is running desktop environment // isLinuxRunningDesktop checks if a Linux OS is running desktop environment
func isUnixRunningDesktop() bool { func isLinuxRunningDesktop() bool {
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
return false
}
return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != "" return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != ""
} }

View File

@@ -26,22 +26,22 @@ import (
) )
const ( const (
externalIPMapFlag = "external-ip-map" externalIPMapFlag = "external-ip-map"
dnsResolverAddress = "dns-resolver-address" dnsResolverAddress = "dns-resolver-address"
enableRosenpassFlag = "enable-rosenpass" enableRosenpassFlag = "enable-rosenpass"
rosenpassPermissiveFlag = "rosenpass-permissive" rosenpassPermissiveFlag = "rosenpass-permissive"
preSharedKeyFlag = "preshared-key" preSharedKeyFlag = "preshared-key"
interfaceNameFlag = "interface-name" interfaceNameFlag = "interface-name"
wireguardPortFlag = "wireguard-port" wireguardPortFlag = "wireguard-port"
networkMonitorFlag = "network-monitor" networkMonitorFlag = "network-monitor"
disableAutoConnectFlag = "disable-auto-connect" disableAutoConnectFlag = "disable-auto-connect"
serverSSHAllowedFlag = "allow-server-ssh" serverSSHAllowedFlag = "allow-server-ssh"
extraIFaceBlackListFlag = "extra-iface-blacklist" extraIFaceBlackListFlag = "extra-iface-blacklist"
dnsRouteIntervalFlag = "dns-router-interval" dnsRouteIntervalFlag = "dns-router-interval"
systemInfoFlag = "system-info" systemInfoFlag = "system-info"
enableLazyConnectionFlag = "enable-lazy-connection" blockLANAccessFlag = "block-lan-access"
uploadBundle = "upload-bundle" uploadBundle = "upload-bundle"
uploadBundleURL = "upload-bundle-url" uploadBundleURL = "upload-bundle-url"
) )
var ( var (
@@ -77,9 +77,9 @@ var (
anonymizeFlag bool anonymizeFlag bool
debugSystemInfoFlag bool debugSystemInfoFlag bool
dnsRouteInterval time.Duration dnsRouteInterval time.Duration
blockLANAccess bool
debugUploadBundle bool debugUploadBundle bool
debugUploadBundleURL string debugUploadBundleURL string
lazyConnEnabled bool
rootCmd = &cobra.Command{ rootCmd = &cobra.Command{
Use: "netbird", Use: "netbird",
@@ -149,7 +149,6 @@ func init() {
rootCmd.AddCommand(loginCmd) rootCmd.AddCommand(loginCmd)
rootCmd.AddCommand(versionCmd) rootCmd.AddCommand(versionCmd)
rootCmd.AddCommand(sshCmd) rootCmd.AddCommand(sshCmd)
rootCmd.AddCommand(setCmd)
rootCmd.AddCommand(networksCMD) rootCmd.AddCommand(networksCMD)
rootCmd.AddCommand(forwardingRulesCmd) rootCmd.AddCommand(forwardingRulesCmd)
rootCmd.AddCommand(debugCmd) rootCmd.AddCommand(debugCmd)
@@ -168,6 +167,23 @@ func init() {
debugCmd.AddCommand(forCmd) debugCmd.AddCommand(forCmd)
debugCmd.AddCommand(persistenceCmd) debugCmd.AddCommand(persistenceCmd)
upCmd.PersistentFlags().StringSliceVar(&natExternalIPs, externalIPMapFlag, nil,
`Sets external IPs maps between local addresses and interfaces.`+
`You can specify a comma-separated list with a single IP and IP/IP or IP/Interface Name. `+
`An empty string "" clears the previous configuration. `+
`E.g. --external-ip-map 12.34.56.78/10.0.0.1 or --external-ip-map 12.34.56.200,12.34.56.78/10.0.0.1,12.34.56.80/eth1 `+
`or --external-ip-map ""`,
)
upCmd.PersistentFlags().StringVar(&customDNSAddress, dnsResolverAddress, "",
`Sets a custom address for NetBird's local DNS resolver. `+
`If set, the agent won't attempt to discover the best ip and port to listen on. `+
`An empty string "" clears the previous configuration. `+
`E.g. --dns-resolver-address 127.0.0.1:5053 or --dns-resolver-address ""`,
)
upCmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "[Experimental] Enable Rosenpass feature. If enabled, the connection will be post-quantum secured via Rosenpass.")
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", true, "Adds system information to the debug bundle") debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", true, "Adds system information to the debug bundle")
debugCmd.PersistentFlags().BoolVarP(&debugUploadBundle, uploadBundle, "U", false, fmt.Sprintf("Uploads the debug bundle to a server from URL defined by %s", uploadBundleURL)) debugCmd.PersistentFlags().BoolVarP(&debugUploadBundle, uploadBundle, "U", false, fmt.Sprintf("Uploads the debug bundle to a server from URL defined by %s", uploadBundleURL))

View File

@@ -2,7 +2,6 @@ package cmd
import ( import (
"context" "context"
"runtime"
"sync" "sync"
"github.com/kardianos/service" "github.com/kardianos/service"
@@ -28,19 +27,12 @@ func newProgram(ctx context.Context, cancel context.CancelFunc) *program {
} }
func newSVCConfig() *service.Config { func newSVCConfig() *service.Config {
config := &service.Config{ return &service.Config{
Name: serviceName, Name: serviceName,
DisplayName: "Netbird", DisplayName: "Netbird",
Description: "Netbird mesh network client", Description: "A WireGuard-based mesh network that connects your devices into a single private network.",
Option: make(service.KeyValue), Option: make(service.KeyValue),
EnvVars: make(map[string]string),
} }
if runtime.GOOS == "linux" {
config.EnvVars["SYSTEMD_UNIT"] = serviceName
}
return config
} }
func newSVC(prg *program, conf *service.Config) (service.Service, error) { func newSVC(prg *program, conf *service.Config) (service.Service, error) {

View File

@@ -39,7 +39,7 @@ var installCmd = &cobra.Command{
svcConfig.Arguments = append(svcConfig.Arguments, "--management-url", managementURL) svcConfig.Arguments = append(svcConfig.Arguments, "--management-url", managementURL)
} }
if logFile != "" { if logFile != "console" {
svcConfig.Arguments = append(svcConfig.Arguments, "--log-file", logFile) svcConfig.Arguments = append(svcConfig.Arguments, "--log-file", logFile)
} }

View File

@@ -1,161 +0,0 @@
package cmd
import (
"fmt"
"github.com/spf13/cobra"
"google.golang.org/protobuf/types/known/durationpb"
"github.com/netbirdio/netbird/client/proto"
)
var (
setFlags = &SharedFlags{}
setCmd = &cobra.Command{
Use: "set",
Short: "Update NetBird client configuration",
Long: `Update NetBird client configuration without connecting. Uses the same flags as 'netbird up' but only updates the configuration file.`,
RunE: setFunc,
}
)
func init() {
// Add all shared flags to the set command
AddSharedFlags(setCmd, setFlags)
// Note: We don't add up-only flags like --no-browser to set command
}
func setFunc(cmd *cobra.Command, _ []string) error {
SetFlagsFromEnvVars(rootCmd)
SetFlagsFromEnvVars(cmd)
cmd.SetOut(cmd.OutOrStdout())
// Validate inputs (reuse validation logic from up.go)
if err := validateNATExternalIPs(setFlags.NATExternalIPs); err != nil {
return err
}
if cmd.Flag(dnsLabelsFlag).Changed {
if _, err := validateDnsLabels(setFlags.DNSLabels); err != nil {
return err
}
}
var customDNSAddressConverted []byte
if cmd.Flag(dnsResolverAddress).Changed {
var err error
customDNSAddressConverted, err = parseCustomDNSAddress(cmd.Flag(dnsResolverAddress).Changed)
if err != nil {
return err
}
}
// Connect to daemon
ctx := cmd.Context()
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
return fmt.Errorf("connect to daemon: %w", err)
}
defer func() {
if closeErr := conn.Close(); closeErr != nil {
fmt.Printf("Warning: failed to close connection: %v\n", closeErr)
}
}()
client := proto.NewDaemonServiceClient(conn)
req := &proto.SetConfigRequest{}
// Set fields based on changed flags
if cmd.Flag(enableRosenpassFlag).Changed {
req.RosenpassEnabled = &setFlags.RosenpassEnabled
}
if cmd.Flag(rosenpassPermissiveFlag).Changed {
req.RosenpassPermissive = &setFlags.RosenpassPermissive
}
if cmd.Flag(serverSSHAllowedFlag).Changed {
req.ServerSSHAllowed = &setFlags.ServerSSHAllowed
}
if cmd.Flag(disableAutoConnectFlag).Changed {
req.DisableAutoConnect = &setFlags.AutoConnectDisabled
}
if cmd.Flag(networkMonitorFlag).Changed {
req.NetworkMonitor = &setFlags.NetworkMonitor
}
if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(setFlags.InterfaceName); err != nil {
return err
}
req.InterfaceName = &setFlags.InterfaceName
}
if cmd.Flag(wireguardPortFlag).Changed {
port := int64(setFlags.WireguardPort)
req.WireguardPort = &port
}
if cmd.Flag(dnsResolverAddress).Changed {
customAddr := string(customDNSAddressConverted)
req.CustomDNSAddress = &customAddr
}
if cmd.Flag(extraIFaceBlackListFlag).Changed {
req.ExtraIFaceBlacklist = setFlags.ExtraIFaceBlackList
}
if cmd.Flag(dnsLabelsFlag).Changed {
req.DnsLabels = setFlags.DNSLabels
req.CleanDNSLabels = &[]bool{setFlags.DNSLabels != nil && len(setFlags.DNSLabels) == 0}[0]
}
if cmd.Flag(externalIPMapFlag).Changed {
req.NatExternalIPs = setFlags.NATExternalIPs
req.CleanNATExternalIPs = &[]bool{setFlags.NATExternalIPs != nil && len(setFlags.NATExternalIPs) == 0}[0]
}
if cmd.Flag(dnsRouteIntervalFlag).Changed {
req.DnsRouteInterval = durationpb.New(setFlags.DNSRouteInterval)
}
if cmd.Flag(disableClientRoutesFlag).Changed {
req.DisableClientRoutes = &setFlags.DisableClientRoutes
}
if cmd.Flag(disableServerRoutesFlag).Changed {
req.DisableServerRoutes = &setFlags.DisableServerRoutes
}
if cmd.Flag(disableDNSFlag).Changed {
req.DisableDns = &setFlags.DisableDNS
}
if cmd.Flag(disableFirewallFlag).Changed {
req.DisableFirewall = &setFlags.DisableFirewall
}
if cmd.Flag(blockLANAccessFlag).Changed {
req.BlockLanAccess = &setFlags.BlockLANAccess
}
if cmd.Flag(blockInboundFlag).Changed {
req.BlockInbound = &setFlags.BlockInbound
}
if cmd.Flag(enableLazyConnectionFlag).Changed {
req.LazyConnectionEnabled = &setFlags.LazyConnEnabled
}
// Send the request
if _, err := client.SetConfig(ctx, req); err != nil {
return fmt.Errorf("update configuration: %w", err)
}
cmd.Println("Configuration updated successfully")
return nil
}

View File

@@ -1,110 +0,0 @@
package cmd
import (
"testing"
"github.com/spf13/cobra"
)
func TestParseBoolArg(t *testing.T) {
tests := []struct {
input string
expected bool
hasError bool
}{
{"true", true, false},
{"True", true, false},
{"1", true, false},
{"yes", true, false},
{"on", true, false},
{"false", false, false},
{"False", false, false},
{"0", false, false},
{"no", false, false},
{"off", false, false},
{"invalid", false, true},
{"maybe", false, true},
{"", false, true},
}
for _, test := range tests {
t.Run(test.input, func(t *testing.T) {
result, err := parseBoolArg(test.input)
if test.hasError {
if err == nil {
t.Errorf("Expected error for input %q, but got none", test.input)
}
} else {
if err != nil {
t.Errorf("Unexpected error for input %q: %v", test.input, err)
}
if result != test.expected {
t.Errorf("For input %q, expected %v but got %v", test.input, test.expected, result)
}
}
})
}
}
func TestSetCommandStructure(t *testing.T) {
// Test that the set command has the expected subcommands
expectedSubcommands := []string{
"autoconnect",
"ssh-server",
"network-monitor",
"rosenpass",
"dns",
"dns-interval",
}
actualSubcommands := make([]string, 0, len(setCmd.Commands()))
for _, cmd := range setCmd.Commands() {
actualSubcommands = append(actualSubcommands, cmd.Name())
}
if len(actualSubcommands) != len(expectedSubcommands) {
t.Errorf("Expected %d subcommands, got %d", len(expectedSubcommands), len(actualSubcommands))
}
for _, expected := range expectedSubcommands {
found := false
for _, actual := range actualSubcommands {
if actual == expected {
found = true
break
}
}
if !found {
t.Errorf("Expected subcommand %q not found", expected)
}
}
}
func TestSetCommandUsage(t *testing.T) {
if setCmd.Use != "set" {
t.Errorf("Expected command use to be 'set', got %q", setCmd.Use)
}
if setCmd.Short != "Set NetBird client configuration" {
t.Errorf("Expected short description to be 'Set NetBird client configuration', got %q", setCmd.Short)
}
}
func TestSubcommandArgRequirements(t *testing.T) {
// Test that all subcommands except dns-interval require exactly 1 argument
subcommands := []*cobra.Command{
setAutoconnectCmd,
setSSHServerCmd,
setNetworkMonitorCmd,
setRosenpassCmd,
setDNSCmd,
setDNSIntervalCmd,
}
for _, cmd := range subcommands {
if cmd.Args == nil {
t.Errorf("Command %q should have Args validation", cmd.Name())
}
}
}

View File

@@ -44,7 +44,7 @@ func init() {
statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4") statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4")
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200") statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200")
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud") statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected") statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(connected|disconnected), e.g., --filter-by-status connected")
} }
func statusFunc(cmd *cobra.Command, args []string) error { func statusFunc(cmd *cobra.Command, args []string) error {
@@ -69,10 +69,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return err return err
} }
status := resp.GetStatus() if resp.GetStatus() == string(internal.StatusNeedsLogin) || resp.GetStatus() == string(internal.StatusLoginFailed) {
if status == string(internal.StatusNeedsLogin) || status == string(internal.StatusLoginFailed) ||
status == string(internal.StatusSessionExpired) {
cmd.Printf("Daemon status: %s\n\n"+ cmd.Printf("Daemon status: %s\n\n"+
"Run UP command to log in with SSO (interactive login):\n\n"+ "Run UP command to log in with SSO (interactive login):\n\n"+
" netbird up \n\n"+ " netbird up \n\n"+
@@ -120,7 +117,7 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
} }
defer conn.Close() defer conn.Close()
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: true}) resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true})
if err != nil { if err != nil {
return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message()) return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message())
} }
@@ -130,12 +127,12 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
func parseFilters() error { func parseFilters() error {
switch strings.ToLower(statusFilter) { switch strings.ToLower(statusFilter) {
case "", "idle", "connecting", "connected": case "", "disconnected", "connected":
if strings.ToLower(statusFilter) != "" { if strings.ToLower(statusFilter) != "" {
enableDetailFlagWhenFilterFlag() enableDetailFlagWhenFilterFlag()
} }
default: default:
return fmt.Errorf("wrong status filter, should be one of connected|connecting|idle, got: %s", statusFilter) return fmt.Errorf("wrong status filter, should be one of connected|disconnected, got: %s", statusFilter)
} }
if len(ipsFilter) > 0 { if len(ipsFilter) > 0 {

View File

@@ -6,8 +6,6 @@ const (
disableServerRoutesFlag = "disable-server-routes" disableServerRoutesFlag = "disable-server-routes"
disableDNSFlag = "disable-dns" disableDNSFlag = "disable-dns"
disableFirewallFlag = "disable-firewall" disableFirewallFlag = "disable-firewall"
blockLANAccessFlag = "block-lan-access"
blockInboundFlag = "block-inbound"
) )
var ( var (
@@ -15,7 +13,19 @@ var (
disableServerRoutes bool disableServerRoutes bool
disableDNS bool disableDNS bool
disableFirewall bool disableFirewall bool
blockLANAccess bool
blockInbound bool
) )
func init() {
// Add system flags to upCmd
upCmd.PersistentFlags().BoolVar(&disableClientRoutes, disableClientRoutesFlag, false,
"Disable client routes. If enabled, the client won't process client routes received from the management service.")
upCmd.PersistentFlags().BoolVar(&disableServerRoutes, disableServerRoutesFlag, false,
"Disable server routes. If enabled, the client won't act as a router for server routes received from the management service.")
upCmd.PersistentFlags().BoolVar(&disableDNS, disableDNSFlag, false,
"Disable DNS. If enabled, the client won't configure DNS settings.")
upCmd.PersistentFlags().BoolVar(&disableFirewall, disableFirewallFlag, false,
"Disable firewall configuration. If enabled, the client won't modify firewall rules.")
}

View File

@@ -103,7 +103,7 @@ func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc
Return(&types.Settings{}, nil). Return(&types.Settings{}, nil).
AnyTimes() AnyTimes()
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -17,7 +17,7 @@ var traceCmd = &cobra.Command{
Example: ` Example: `
netbird debug trace in 192.168.1.10 10.10.0.2 -p tcp --sport 12345 --dport 443 --syn --ack netbird debug trace in 192.168.1.10 10.10.0.2 -p tcp --sport 12345 --dport 443 --syn --ack
netbird debug trace out 10.10.0.1 8.8.8.8 -p udp --dport 53 netbird debug trace out 10.10.0.1 8.8.8.8 -p udp --dport 53
netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --icmp-type 8 --icmp-code 0 netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --type 8 --code 0
netbird debug trace in 100.64.1.1 self -p tcp --dport 80`, netbird debug trace in 100.64.1.1 self -p tcp --dport 80`,
Args: cobra.ExactArgs(3), Args: cobra.ExactArgs(3),
RunE: tracePacket, RunE: tracePacket,
@@ -118,7 +118,7 @@ func tracePacket(cmd *cobra.Command, args []string) error {
} }
func printTrace(cmd *cobra.Command, src, dst, proto string, sport, dport uint16, resp *proto.TracePacketResponse) { func printTrace(cmd *cobra.Command, src, dst, proto string, sport, dport uint16, resp *proto.TracePacketResponse) {
cmd.Printf("Packet trace %s:%d %s:%d (%s)\n\n", src, sport, dst, dport, strings.ToUpper(proto)) cmd.Printf("Packet trace %s:%d -> %s:%d (%s)\n\n", src, sport, dst, dport, strings.ToUpper(proto))
for _, stage := range resp.Stages { for _, stage := range resp.Stages {
if stage.ForwardingDetails != nil { if stage.ForwardingDetails != nil {

View File

@@ -7,6 +7,7 @@ import (
"net/netip" "net/netip"
"runtime" "runtime"
"strings" "strings"
"time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"
@@ -14,6 +15,7 @@ import (
gstatus "google.golang.org/grpc/status" gstatus "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/durationpb" "google.golang.org/protobuf/types/known/durationpb"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
@@ -40,7 +42,6 @@ var (
dnsLabels []string dnsLabels []string
dnsLabelsValidated domain.List dnsLabelsValidated domain.List
noBrowser bool noBrowser bool
upFlags = &SharedFlags{}
upCmd = &cobra.Command{ upCmd = &cobra.Command{
Use: "up", Use: "up",
@@ -50,12 +51,27 @@ var (
) )
func init() { func init() {
// Add shared flags to up command
AddSharedFlags(upCmd, upFlags)
// Add up-specific flags
upCmd.PersistentFlags().BoolVarP(&foregroundMode, "foreground-mode", "F", false, "start service in foreground") upCmd.PersistentFlags().BoolVarP(&foregroundMode, "foreground-mode", "F", false, "start service in foreground")
AddUpOnlyFlags(upCmd, upFlags) upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name")
upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port")
upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", networkMonitor,
`Manage network monitoring. Defaults to true on Windows and macOS, false on Linux. `+
`E.g. --network-monitor=false to disable or --network-monitor=true to enable.`,
)
upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening")
upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval")
upCmd.PersistentFlags().BoolVar(&blockLANAccess, blockLANAccessFlag, false, "Block access to local networks (LAN) when using this peer as a router or exit node")
upCmd.PersistentFlags().StringSliceVar(&dnsLabels, dnsLabelsFlag, nil,
`Sets DNS labels`+
`You can specify a comma-separated list of up to 32 labels. `+
`An empty string "" clears the previous configuration. `+
`E.g. --extra-dns-labels vpc1 or --extra-dns-labels vpc1,mgmt1 `+
`or --extra-dns-labels ""`,
)
upCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
} }
func upFunc(cmd *cobra.Command, args []string) error { func upFunc(cmd *cobra.Command, args []string) error {
@@ -103,18 +119,79 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
return err return err
} }
// Handle DNS labels validation and assignment to SharedFlags ic := internal.ConfigInput{
if cmd.Flag(dnsLabelsFlag).Changed { ManagementURL: managementURL,
var err error AdminURL: adminURL,
dnsLabelsValidated, err = validateDnsLabels(upFlags.DNSLabels) ConfigPath: configPath,
if err != nil { NATExternalIPs: natExternalIPs,
CustomDNSAddress: customDNSAddressConverted,
ExtraIFaceBlackList: extraIFaceBlackList,
DNSLabels: dnsLabelsValidated,
}
if cmd.Flag(enableRosenpassFlag).Changed {
ic.RosenpassEnabled = &rosenpassEnabled
}
if cmd.Flag(rosenpassPermissiveFlag).Changed {
ic.RosenpassPermissive = &rosenpassPermissive
}
if cmd.Flag(serverSSHAllowedFlag).Changed {
ic.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(interfaceName); err != nil {
return err return err
} }
ic.InterfaceName = &interfaceName
}
if cmd.Flag(wireguardPortFlag).Changed {
p := int(wireguardPort)
ic.WireguardPort = &p
}
if cmd.Flag(networkMonitorFlag).Changed {
ic.NetworkMonitor = &networkMonitor
}
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
ic.PreSharedKey = &preSharedKey
}
if cmd.Flag(disableAutoConnectFlag).Changed {
ic.DisableAutoConnect = &autoConnectDisabled
if autoConnectDisabled {
cmd.Println("Autoconnect has been disabled. The client won't connect automatically when the service starts.")
}
if !autoConnectDisabled {
cmd.Println("Autoconnect has been enabled. The client will connect automatically when the service starts.")
}
} }
ic, err := BuildConfigInput(cmd, upFlags, customDNSAddressConverted) if cmd.Flag(dnsRouteIntervalFlag).Changed {
if err != nil { ic.DNSRouteInterval = &dnsRouteInterval
return fmt.Errorf("setup config: %v", err) }
if cmd.Flag(disableClientRoutesFlag).Changed {
ic.DisableClientRoutes = &disableClientRoutes
}
if cmd.Flag(disableServerRoutesFlag).Changed {
ic.DisableServerRoutes = &disableServerRoutes
}
if cmd.Flag(disableDNSFlag).Changed {
ic.DisableDNS = &disableDNS
}
if cmd.Flag(disableFirewallFlag).Changed {
ic.DisableFirewall = &disableFirewall
}
if cmd.Flag(blockLANAccessFlag).Changed {
ic.BlockLANAccess = &blockLANAccess
} }
providedSetupKey, err := getSetupKey() providedSetupKey, err := getSetupKey()
@@ -122,7 +199,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
return err return err
} }
config, err := internal.UpdateOrCreateConfig(*ic) config, err := internal.UpdateOrCreateConfig(ic)
if err != nil { if err != nil {
return fmt.Errorf("get config file: %v", err) return fmt.Errorf("get config file: %v", err)
} }
@@ -181,67 +258,21 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
providedSetupKey, err := getSetupKey() providedSetupKey, err := getSetupKey()
if err != nil { if err != nil {
return fmt.Errorf("get setup key: %v", err) return err
} }
loginRequest, err := setupLoginRequest(providedSetupKey, customDNSAddressConverted, cmd)
if err != nil {
return fmt.Errorf("setup login request: %v", err)
}
var loginErr error
var loginResp *proto.LoginResponse
err = WithBackOff(func() error {
var backOffErr error
loginResp, backOffErr = client.Login(ctx, loginRequest)
if s, ok := gstatus.FromError(backOffErr); ok && (s.Code() == codes.InvalidArgument ||
s.Code() == codes.PermissionDenied ||
s.Code() == codes.NotFound ||
s.Code() == codes.Unimplemented) {
loginErr = backOffErr
return nil
}
return backOffErr
})
if err != nil {
return fmt.Errorf("login backoff cycle failed: %v", err)
}
if loginErr != nil {
return fmt.Errorf("login failed: %v", loginErr)
}
if loginResp.NeedsSSOLogin {
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
if err != nil {
return fmt.Errorf("waiting sso login failed with: %v", err)
}
}
if _, err := client.Up(ctx, &proto.UpRequest{}); err != nil {
return fmt.Errorf("call service up method: %v", err)
}
cmd.Println("Connected")
return nil
}
func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte, cmd *cobra.Command) (*proto.LoginRequest, error) {
loginRequest := proto.LoginRequest{ loginRequest := proto.LoginRequest{
SetupKey: providedSetupKey, SetupKey: providedSetupKey,
ManagementUrl: managementURL, ManagementUrl: managementURL,
AdminURL: adminURL, AdminURL: adminURL,
NatExternalIPs: natExternalIPs, NatExternalIPs: natExternalIPs,
CleanNATExternalIPs: natExternalIPs != nil && len(natExternalIPs) == 0, CleanNATExternalIPs: natExternalIPs != nil && len(natExternalIPs) == 0,
CustomDNSAddress: customDNSAddressConverted, CustomDNSAddress: customDNSAddressConverted,
IsUnixDesktopClient: isUnixRunningDesktop(), IsLinuxDesktopClient: isLinuxRunningDesktop(),
Hostname: hostName, Hostname: hostName,
ExtraIFaceBlacklist: extraIFaceBlackList, ExtraIFaceBlacklist: extraIFaceBlackList,
DnsLabels: dnsLabels, DnsLabels: dnsLabels,
CleanDNSLabels: dnsLabels != nil && len(dnsLabels) == 0, CleanDNSLabels: dnsLabels != nil && len(dnsLabels) == 0,
} }
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) { if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
@@ -266,7 +297,7 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
if cmd.Flag(interfaceNameFlag).Changed { if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(interfaceName); err != nil { if err := parseInterfaceName(interfaceName); err != nil {
return nil, err return err
} }
loginRequest.InterfaceName = &interfaceName loginRequest.InterfaceName = &interfaceName
} }
@@ -301,14 +332,45 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
loginRequest.BlockLanAccess = &blockLANAccess loginRequest.BlockLanAccess = &blockLANAccess
} }
if cmd.Flag(blockInboundFlag).Changed { var loginErr error
loginRequest.BlockInbound = &blockInbound
var loginResp *proto.LoginResponse
err = WithBackOff(func() error {
var backOffErr error
loginResp, backOffErr = client.Login(ctx, &loginRequest)
if s, ok := gstatus.FromError(backOffErr); ok && (s.Code() == codes.InvalidArgument ||
s.Code() == codes.PermissionDenied ||
s.Code() == codes.NotFound ||
s.Code() == codes.Unimplemented) {
loginErr = backOffErr
return nil
}
return backOffErr
})
if err != nil {
return fmt.Errorf("login backoff cycle failed: %v", err)
} }
if cmd.Flag(enableLazyConnectionFlag).Changed { if loginErr != nil {
loginRequest.LazyConnectionEnabled = &lazyConnEnabled return fmt.Errorf("login failed: %v", loginErr)
} }
return &loginRequest, nil
if loginResp.NeedsSSOLogin {
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
if err != nil {
return fmt.Errorf("waiting sso login failed with: %v", err)
}
}
if _, err := client.Up(ctx, &proto.UpRequest{}); err != nil {
return fmt.Errorf("call service up method: %v", err)
}
cmd.Println("Connected")
return nil
} }
func validateNATExternalIPs(list []string) error { func validateNATExternalIPs(list []string) error {

View File

@@ -147,10 +147,6 @@ func (m *Manager) IsServerRouteSupported() bool {
return true return true
} }
func (m *Manager) IsStateful() bool {
return true
}
func (m *Manager) AddNatRule(pair firewall.RouterPair) error { func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
@@ -202,7 +198,7 @@ func (m *Manager) AllowNetbird() error {
_, err := m.AddPeerFiltering( _, err := m.AddPeerFiltering(
nil, nil,
net.IP{0, 0, 0, 0}, net.IP{0, 0, 0, 0},
firewall.ProtocolALL, "all",
nil, nil,
nil, nil,
firewall.ActionAccept, firewall.ActionAccept,
@@ -223,16 +219,10 @@ func (m *Manager) SetLogLevel(log.Level) {
} }
func (m *Manager) EnableRouting() error { func (m *Manager) EnableRouting() error {
if err := m.router.ipFwdState.RequestForwarding(); err != nil {
return fmt.Errorf("enable IP forwarding: %w", err)
}
return nil return nil
} }
func (m *Manager) DisableRouting() error { func (m *Manager) DisableRouting() error {
if err := m.router.ipFwdState.ReleaseForwarding(); err != nil {
return fmt.Errorf("disable IP forwarding: %w", err)
}
return nil return nil
} }

View File

@@ -2,7 +2,7 @@ package iptables
import ( import (
"fmt" "fmt"
"net/netip" "net"
"testing" "testing"
"time" "time"
@@ -19,8 +19,11 @@ var ifaceMock = &iFaceMock{
}, },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: netip.MustParseAddr("10.20.0.1"), IP: net.ParseIP("10.20.0.1"),
Network: netip.MustParsePrefix("10.20.0.0/24"), Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
} }
}, },
} }
@@ -67,12 +70,12 @@ func TestIptablesManager(t *testing.T) {
var rule2 []fw.Rule var rule2 []fw.Rule
t.Run("add second rule", func(t *testing.T) { t.Run("add second rule", func(t *testing.T) {
ip := netip.MustParseAddr("10.20.0.3") ip := net.ParseIP("10.20.0.3")
port := &fw.Port{ port := &fw.Port{
IsRange: true, IsRange: true,
Values: []uint16{8043, 8046}, Values: []uint16{8043, 8046},
} }
rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "") rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
for _, r := range rule2 { for _, r := range rule2 {
@@ -92,9 +95,9 @@ func TestIptablesManager(t *testing.T) {
t.Run("reset check", func(t *testing.T) { t.Run("reset check", func(t *testing.T) {
// add second rule // add second rule
ip := netip.MustParseAddr("10.20.0.3") ip := net.ParseIP("10.20.0.3")
port := &fw.Port{Values: []uint16{5353}} port := &fw.Port{Values: []uint16{5353}}
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "udp", nil, port, fw.ActionAccept, "") _, err = manager.AddPeerFiltering(nil, ip, "udp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
err = manager.Close(nil) err = manager.Close(nil)
@@ -116,8 +119,11 @@ func TestIptablesManagerIPSet(t *testing.T) {
}, },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: netip.MustParseAddr("10.20.0.1"), IP: net.ParseIP("10.20.0.1"),
Network: netip.MustParsePrefix("10.20.0.0/24"), Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
} }
}, },
} }
@@ -138,11 +144,11 @@ func TestIptablesManagerIPSet(t *testing.T) {
var rule2 []fw.Rule var rule2 []fw.Rule
t.Run("add second rule", func(t *testing.T) { t.Run("add second rule", func(t *testing.T) {
ip := netip.MustParseAddr("10.20.0.3") ip := net.ParseIP("10.20.0.3")
port := &fw.Port{ port := &fw.Port{
Values: []uint16{443}, Values: []uint16{443},
} }
rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "default") rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "default")
for _, r := range rule2 { for _, r := range rule2 {
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set") require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
@@ -180,8 +186,11 @@ func TestIptablesCreatePerformance(t *testing.T) {
}, },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: netip.MustParseAddr("10.20.0.1"), IP: net.ParseIP("10.20.0.1"),
Network: netip.MustParsePrefix("10.20.0.0/24"), Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
} }
}, },
} }
@@ -203,11 +212,11 @@ func TestIptablesCreatePerformance(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
ip := netip.MustParseAddr("10.20.0.100") ip := net.ParseIP("10.20.0.100")
start := time.Now() start := time.Now()
for i := 0; i < testMax; i++ { for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []uint16{uint16(1000 + i)}} port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "") _, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
} }

View File

@@ -248,6 +248,10 @@ func (r *router) deleteIpSet(setName string) error {
// AddNatRule inserts an iptables rule pair into the nat chain // AddNatRule inserts an iptables rule pair into the nat chain
func (r *router) AddNatRule(pair firewall.RouterPair) error { func (r *router) AddNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.RequestForwarding(); err != nil {
return err
}
if r.legacyManagement { if r.legacyManagement {
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination) log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
if err := r.addLegacyRouteRule(pair); err != nil { if err := r.addLegacyRouteRule(pair); err != nil {
@@ -274,6 +278,10 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
// RemoveNatRule removes an iptables rule pair from forwarding and nat chains // RemoveNatRule removes an iptables rule pair from forwarding and nat chains
func (r *router) RemoveNatRule(pair firewall.RouterPair) error { func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
log.Errorf("%v", err)
}
if pair.Masquerade { if pair.Masquerade {
if err := r.removeNatRule(pair); err != nil { if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove nat rule: %w", err) return fmt.Errorf("remove nat rule: %w", err)

View File

@@ -116,8 +116,6 @@ type Manager interface {
// IsServerRouteSupported returns true if the firewall supports server side routing operations // IsServerRouteSupported returns true if the firewall supports server side routing operations
IsServerRouteSupported() bool IsServerRouteSupported() bool
IsStateful() bool
AddRouteFiltering( AddRouteFiltering(
id []byte, id []byte,
sources []netip.Prefix, sources []netip.Prefix,

View File

@@ -170,10 +170,6 @@ func (m *Manager) IsServerRouteSupported() bool {
return true return true
} }
func (m *Manager) IsStateful() bool {
return true
}
func (m *Manager) AddNatRule(pair firewall.RouterPair) error { func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
@@ -328,16 +324,10 @@ func (m *Manager) SetLogLevel(log.Level) {
} }
func (m *Manager) EnableRouting() error { func (m *Manager) EnableRouting() error {
if err := m.router.ipFwdState.RequestForwarding(); err != nil {
return fmt.Errorf("enable IP forwarding: %w", err)
}
return nil return nil
} }
func (m *Manager) DisableRouting() error { func (m *Manager) DisableRouting() error {
if err := m.router.ipFwdState.ReleaseForwarding(); err != nil {
return fmt.Errorf("disable IP forwarding: %w", err)
}
return nil return nil
} }

View File

@@ -3,6 +3,7 @@ package nftables
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"net"
"net/netip" "net/netip"
"os/exec" "os/exec"
"testing" "testing"
@@ -24,8 +25,11 @@ var ifaceMock = &iFaceMock{
}, },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: netip.MustParseAddr("100.96.0.1"), IP: net.ParseIP("100.96.0.1"),
Network: netip.MustParsePrefix("100.96.0.0/16"), Network: &net.IPNet{
IP: net.ParseIP("100.96.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
} }
}, },
} }
@@ -66,11 +70,11 @@ func TestNftablesManager(t *testing.T) {
time.Sleep(time.Second) time.Sleep(time.Second)
}() }()
ip := netip.MustParseAddr("100.96.0.1").Unmap() ip := net.ParseIP("100.96.0.1")
testClient := &nftables.Conn{} testClient := &nftables.Conn{}
rule, err := manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "") rule, err := manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
err = manager.Flush() err = manager.Flush()
@@ -105,6 +109,8 @@ func TestNftablesManager(t *testing.T) {
} }
compareExprsIgnoringCounters(t, rules[0].Exprs, expectedExprs1) compareExprsIgnoringCounters(t, rules[0].Exprs, expectedExprs1)
ipToAdd, _ := netip.AddrFromSlice(ip)
add := ipToAdd.Unmap()
expectedExprs2 := []expr.Any{ expectedExprs2 := []expr.Any{
&expr.Payload{ &expr.Payload{
DestRegister: 1, DestRegister: 1,
@@ -126,7 +132,7 @@ func TestNftablesManager(t *testing.T) {
&expr.Cmp{ &expr.Cmp{
Op: expr.CmpOpEq, Op: expr.CmpOpEq,
Register: 1, Register: 1,
Data: ip.AsSlice(), Data: add.AsSlice(),
}, },
&expr.Payload{ &expr.Payload{
DestRegister: 1, DestRegister: 1,
@@ -167,8 +173,11 @@ func TestNFtablesCreatePerformance(t *testing.T) {
}, },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: netip.MustParseAddr("100.96.0.1"), IP: net.ParseIP("100.96.0.1"),
Network: netip.MustParsePrefix("100.96.0.0/16"), Network: &net.IPNet{
IP: net.ParseIP("100.96.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
} }
}, },
} }
@@ -188,11 +197,11 @@ func TestNFtablesCreatePerformance(t *testing.T) {
time.Sleep(time.Second) time.Sleep(time.Second)
}() }()
ip := netip.MustParseAddr("10.20.0.100") ip := net.ParseIP("10.20.0.100")
start := time.Now() start := time.Now()
for i := 0; i < testMax; i++ { for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []uint16{uint16(1000 + i)}} port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "") _, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
if i%100 == 0 { if i%100 == 0 {
@@ -273,8 +282,8 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
verifyIptablesOutput(t, stdout, stderr) verifyIptablesOutput(t, stdout, stderr)
}) })
ip := netip.MustParseAddr("100.96.0.1") ip := net.ParseIP("100.96.0.1")
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "") _, err = manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
require.NoError(t, err, "failed to add peer filtering rule") require.NoError(t, err, "failed to add peer filtering rule")
_, err = manager.AddRouteFiltering( _, err = manager.AddRouteFiltering(

View File

@@ -573,6 +573,10 @@ func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
// AddNatRule appends a nftables rule pair to the nat chain // AddNatRule appends a nftables rule pair to the nat chain
func (r *router) AddNatRule(pair firewall.RouterPair) error { func (r *router) AddNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.RequestForwarding(); err != nil {
return err
}
if err := r.refreshRulesMap(); err != nil { if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err) return fmt.Errorf(refreshRulesMapError, err)
} }
@@ -1002,6 +1006,10 @@ func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error
// RemoveNatRule removes the prerouting mark rule // RemoveNatRule removes the prerouting mark rule
func (r *router) RemoveNatRule(pair firewall.RouterPair) error { func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
log.Errorf("%v", err)
}
if err := r.refreshRulesMap(); err != nil { if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err) return fmt.Errorf(refreshRulesMapError, err)
} }

View File

@@ -62,5 +62,5 @@ type ConnKey struct {
} }
func (c ConnKey) String() string { func (c ConnKey) String() string {
return fmt.Sprintf("%s:%d %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort) return fmt.Sprintf("%s:%d -> %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
} }

View File

@@ -3,7 +3,6 @@ package conntrack
import ( import (
"context" "context"
"fmt" "fmt"
"net"
"net/netip" "net/netip"
"sync" "sync"
"time" "time"
@@ -20,10 +19,6 @@ const (
DefaultICMPTimeout = 30 * time.Second DefaultICMPTimeout = 30 * time.Second
// ICMPCleanupInterval is how often we check for stale ICMP connections // ICMPCleanupInterval is how often we check for stale ICMP connections
ICMPCleanupInterval = 15 * time.Second ICMPCleanupInterval = 15 * time.Second
// MaxICMPPayloadLength is the maximum length of ICMP payload we consider for original packet info,
// which includes the IP header (20 bytes) and transport header (8 bytes)
MaxICMPPayloadLength = 28
) )
// ICMPConnKey uniquely identifies an ICMP connection // ICMPConnKey uniquely identifies an ICMP connection
@@ -34,7 +29,7 @@ type ICMPConnKey struct {
} }
func (i ICMPConnKey) String() string { func (i ICMPConnKey) String() string {
return fmt.Sprintf("%s %s (id %d)", i.SrcIP, i.DstIP, i.ID) return fmt.Sprintf("%s -> %s (id %d)", i.SrcIP, i.DstIP, i.ID)
} }
// ICMPConnTrack represents an ICMP connection state // ICMPConnTrack represents an ICMP connection state
@@ -55,72 +50,6 @@ type ICMPTracker struct {
flowLogger nftypes.FlowLogger flowLogger nftypes.FlowLogger
} }
// ICMPInfo holds ICMP type, code, and payload for lazy string formatting in logs
type ICMPInfo struct {
TypeCode layers.ICMPv4TypeCode
PayloadData [MaxICMPPayloadLength]byte
// actual length of valid data
PayloadLen int
}
// String implements fmt.Stringer for lazy evaluation in log messages
func (info ICMPInfo) String() string {
if info.isErrorMessage() && info.PayloadLen >= MaxICMPPayloadLength {
if origInfo := info.parseOriginalPacket(); origInfo != "" {
return fmt.Sprintf("%s (original: %s)", info.TypeCode, origInfo)
}
}
return info.TypeCode.String()
}
// isErrorMessage returns true if this ICMP type carries original packet info
func (info ICMPInfo) isErrorMessage() bool {
typ := info.TypeCode.Type()
return typ == 3 || // Destination Unreachable
typ == 5 || // Redirect
typ == 11 || // Time Exceeded
typ == 12 // Parameter Problem
}
// parseOriginalPacket extracts info about the original packet from ICMP payload
func (info ICMPInfo) parseOriginalPacket() string {
if info.PayloadLen < MaxICMPPayloadLength {
return ""
}
// TODO: handle IPv6
if version := (info.PayloadData[0] >> 4) & 0xF; version != 4 {
return ""
}
protocol := info.PayloadData[9]
srcIP := net.IP(info.PayloadData[12:16])
dstIP := net.IP(info.PayloadData[16:20])
transportData := info.PayloadData[20:]
switch nftypes.Protocol(protocol) {
case nftypes.TCP:
srcPort := uint16(transportData[0])<<8 | uint16(transportData[1])
dstPort := uint16(transportData[2])<<8 | uint16(transportData[3])
return fmt.Sprintf("TCP %s:%d → %s:%d", srcIP, srcPort, dstIP, dstPort)
case nftypes.UDP:
srcPort := uint16(transportData[0])<<8 | uint16(transportData[1])
dstPort := uint16(transportData[2])<<8 | uint16(transportData[3])
return fmt.Sprintf("UDP %s:%d → %s:%d", srcIP, srcPort, dstIP, dstPort)
case nftypes.ICMP:
icmpType := transportData[0]
icmpCode := transportData[1]
return fmt.Sprintf("ICMP %s → %s (type %d code %d)", srcIP, dstIP, icmpType, icmpCode)
default:
return fmt.Sprintf("Proto %d %s → %s", protocol, srcIP, dstIP)
}
}
// NewICMPTracker creates a new ICMP connection tracker // NewICMPTracker creates a new ICMP connection tracker
func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *ICMPTracker { func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *ICMPTracker {
if timeout == 0 { if timeout == 0 {
@@ -164,64 +93,30 @@ func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint
} }
// TrackOutbound records an outbound ICMP connection // TrackOutbound records an outbound ICMP connection
func (t *ICMPTracker) TrackOutbound( func (t *ICMPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, size int) {
srcIP netip.Addr,
dstIP netip.Addr,
id uint16,
typecode layers.ICMPv4TypeCode,
payload []byte,
size int,
) {
if _, exists := t.updateIfExists(dstIP, srcIP, id, nftypes.Egress, size); !exists { if _, exists := t.updateIfExists(dstIP, srcIP, id, nftypes.Egress, size); !exists {
// if (inverted direction) conn is not tracked, track this direction // if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, id, typecode, nftypes.Egress, nil, payload, size) t.track(srcIP, dstIP, id, typecode, nftypes.Egress, nil, size)
} }
} }
// TrackInbound records an inbound ICMP Echo Request // TrackInbound records an inbound ICMP Echo Request
func (t *ICMPTracker) TrackInbound( func (t *ICMPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, ruleId []byte, size int) {
srcIP netip.Addr, t.track(srcIP, dstIP, id, typecode, nftypes.Ingress, ruleId, size)
dstIP netip.Addr,
id uint16,
typecode layers.ICMPv4TypeCode,
ruleId []byte,
payload []byte,
size int,
) {
t.track(srcIP, dstIP, id, typecode, nftypes.Ingress, ruleId, payload, size)
} }
// track is the common implementation for tracking both inbound and outbound ICMP connections // track is the common implementation for tracking both inbound and outbound ICMP connections
func (t *ICMPTracker) track( func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction, ruleId []byte, size int) {
srcIP netip.Addr,
dstIP netip.Addr,
id uint16,
typecode layers.ICMPv4TypeCode,
direction nftypes.Direction,
ruleId []byte,
payload []byte,
size int,
) {
key, exists := t.updateIfExists(srcIP, dstIP, id, direction, size) key, exists := t.updateIfExists(srcIP, dstIP, id, direction, size)
if exists { if exists {
return return
} }
typ, code := typecode.Type(), typecode.Code() typ, code := typecode.Type(), typecode.Code()
icmpInfo := ICMPInfo{
TypeCode: typecode,
}
if len(payload) > 0 {
icmpInfo.PayloadLen = len(payload)
if icmpInfo.PayloadLen > MaxICMPPayloadLength {
icmpInfo.PayloadLen = MaxICMPPayloadLength
}
copy(icmpInfo.PayloadData[:], payload[:icmpInfo.PayloadLen])
}
// non echo requests don't need tracking // non echo requests don't need tracking
if typ != uint8(layers.ICMPv4TypeEchoRequest) { if typ != uint8(layers.ICMPv4TypeEchoRequest) {
t.logger.Trace("New %s ICMP connection %s - %s", direction, key, icmpInfo) t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size) t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size)
return return
} }
@@ -243,7 +138,7 @@ func (t *ICMPTracker) track(
t.connections[key] = conn t.connections[key] = conn
t.mutex.Unlock() t.mutex.Unlock()
t.logger.Trace("New %s ICMP connection %s - %s", direction, key, icmpInfo) t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
t.sendEvent(nftypes.TypeStart, conn, ruleId) t.sendEvent(nftypes.TypeStart, conn, ruleId)
} }

View File

@@ -15,7 +15,7 @@ func BenchmarkICMPTracker(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0, []byte{}, 0) tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0, 0)
} }
}) })
@@ -28,7 +28,7 @@ func BenchmarkICMPTracker(b *testing.B) {
// Pre-populate some connections // Pre-populate some connections
for i := 0; i < 1000; i++ { for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0, []byte{}, 0) tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0, 0)
} }
b.ResetTimer() b.ResetTimer()

View File

@@ -86,5 +86,5 @@ type epID stack.TransportEndpointID
func (i epID) String() string { func (i epID) String() string {
// src and remote is swapped // src and remote is swapped
return fmt.Sprintf("%s:%d %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort) return fmt.Sprintf("%s:%d -> %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort)
} }

View File

@@ -41,7 +41,7 @@ type Forwarder struct {
udpForwarder *udpForwarder udpForwarder *udpForwarder
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
ip tcpip.Address ip net.IP
netstack bool netstack bool
} }
@@ -71,11 +71,12 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
return nil, fmt.Errorf("failed to create NIC: %v", err) return nil, fmt.Errorf("failed to create NIC: %v", err)
} }
ones, _ := iface.Address().Network.Mask.Size()
protoAddr := tcpip.ProtocolAddress{ protoAddr := tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber, Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{ AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()), Address: tcpip.AddrFromSlice(iface.Address().IP.To4()),
PrefixLen: iface.Address().Network.Bits(), PrefixLen: ones,
}, },
} }
@@ -115,7 +116,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
netstack: netstack, netstack: netstack,
ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()), ip: iface.Address().IP,
} }
receiveWindow := defaultReceiveWindow receiveWindow := defaultReceiveWindow
@@ -166,7 +167,7 @@ func (f *Forwarder) Stop() {
} }
func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP { func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
if f.netstack && f.ip.Equal(addr) { if f.netstack && f.ip.Equal(addr.AsSlice()) {
return net.IPv4(127, 0, 0, 1) return net.IPv4(127, 0, 0, 1)
} }
return addr.AsSlice() return addr.AsSlice()
@@ -178,6 +179,7 @@ func (f *Forwarder) RegisterRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uin
} }
func (f *Forwarder) getRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) ([]byte, bool) { func (f *Forwarder) getRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) ([]byte, bool) {
if value, ok := f.ruleIdMap.Load(buildKey(srcIP, dstIP, srcPort, dstPort)); ok { if value, ok := f.ruleIdMap.Load(buildKey(srcIP, dstIP, srcPort, dstPort)); ok {
return value.([]byte), true return value.([]byte), true
} else if value, ok := f.ruleIdMap.Load(buildKey(dstIP, srcIP, dstPort, srcPort)); ok { } else if value, ok := f.ruleIdMap.Load(buildKey(dstIP, srcIP, dstPort, srcPort)); ok {

View File

@@ -111,12 +111,12 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
if errInToOut != nil { if errInToOut != nil {
if !isClosedError(errInToOut) { if !isClosedError(errInToOut) {
f.logger.Error("proxyTCP: copy error (in out) for %s: %v", epID(id), errInToOut) f.logger.Error("proxyTCP: copy error (in -> out): %v", errInToOut)
} }
} }
if errOutToIn != nil { if errOutToIn != nil {
if !isClosedError(errOutToIn) { if !isClosedError(errOutToIn) {
f.logger.Error("proxyTCP: copy error (out in) for %s: %v", epID(id), errOutToIn) f.logger.Error("proxyTCP: copy error (out -> in): %v", errOutToIn)
} }
} }

View File

@@ -250,10 +250,10 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
wg.Wait() wg.Wait()
if outboundErr != nil && !isClosedError(outboundErr) { if outboundErr != nil && !isClosedError(outboundErr) {
f.logger.Error("proxyUDP: copy error (outboundinbound) for %s: %v", epID(id), outboundErr) f.logger.Error("proxyUDP: copy error (outbound->inbound): %v", outboundErr)
} }
if inboundErr != nil && !isClosedError(inboundErr) { if inboundErr != nil && !isClosedError(inboundErr) {
f.logger.Error("proxyUDP: copy error (inboundoutbound) for %s: %v", epID(id), inboundErr) f.logger.Error("proxyUDP: copy error (inbound->outbound): %v", inboundErr)
} }
var rxPackets, txPackets uint64 var rxPackets, txPackets uint64

View File

@@ -45,26 +45,24 @@ func (m *localIPManager) setBitmapBit(ip net.IP) {
m.ipv4Bitmap[high].bitmap[index] |= 1 << bit m.ipv4Bitmap[high].bitmap[index] |= 1 << bit
} }
func (m *localIPManager) setBitInBitmap(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) { func (m *localIPManager) setBitInBitmap(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
if !ip.Is4() { if ipv4 := ip.To4(); ipv4 != nil {
return high := uint16(ipv4[0])
} low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
ipv4 := ip.AsSlice()
high := uint16(ipv4[0]) if bitmap[high] == nil {
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3]) bitmap[high] = &ipv4LowBitmap{}
}
if bitmap[high] == nil { index := low / 32
bitmap[high] = &ipv4LowBitmap{} bit := low % 32
} bitmap[high].bitmap[index] |= 1 << bit
index := low / 32 ipStr := ipv4.String()
bit := low % 32 if _, exists := ipv4Set[ipStr]; !exists {
bitmap[high].bitmap[index] |= 1 << bit ipv4Set[ipStr] = struct{}{}
*ipv4Addresses = append(*ipv4Addresses, ipStr)
if _, exists := ipv4Set[ip]; !exists { }
ipv4Set[ip] = struct{}{}
*ipv4Addresses = append(*ipv4Addresses, ip)
} }
} }
@@ -81,12 +79,12 @@ func (m *localIPManager) checkBitmapBit(ip []byte) bool {
return (m.ipv4Bitmap[high].bitmap[index] & (1 << bit)) != 0 return (m.ipv4Bitmap[high].bitmap[index] & (1 << bit)) != 0
} }
func (m *localIPManager) processIP(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) error { func (m *localIPManager) processIP(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error {
m.setBitInBitmap(ip, bitmap, ipv4Set, ipv4Addresses) m.setBitInBitmap(ip, bitmap, ipv4Set, ipv4Addresses)
return nil return nil
} }
func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) { func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
addrs, err := iface.Addrs() addrs, err := iface.Addrs()
if err != nil { if err != nil {
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err) log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
@@ -104,13 +102,7 @@ func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv
continue continue
} }
addr, ok := netip.AddrFromSlice(ip) if err := m.processIP(ip, bitmap, ipv4Set, ipv4Addresses); err != nil {
if !ok {
log.Warnf("invalid IP address %s in interface %s", ip.String(), iface.Name)
continue
}
if err := m.processIP(addr.Unmap(), bitmap, ipv4Set, ipv4Addresses); err != nil {
log.Debugf("process IP failed: %v", err) log.Debugf("process IP failed: %v", err)
} }
} }
@@ -124,8 +116,8 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
}() }()
var newIPv4Bitmap [256]*ipv4LowBitmap var newIPv4Bitmap [256]*ipv4LowBitmap
ipv4Set := make(map[netip.Addr]struct{}) ipv4Set := make(map[string]struct{})
var ipv4Addresses []netip.Addr var ipv4Addresses []string
// 127.0.0.0/8 // 127.0.0.0/8
newIPv4Bitmap[127] = &ipv4LowBitmap{} newIPv4Bitmap[127] = &ipv4LowBitmap{}

View File

@@ -20,8 +20,11 @@ func TestLocalIPManager(t *testing.T) {
{ {
name: "Localhost range", name: "Localhost range",
setupAddr: wgaddr.Address{ setupAddr: wgaddr.Address{
IP: netip.MustParseAddr("192.168.1.1"), IP: net.ParseIP("192.168.1.1"),
Network: netip.MustParsePrefix("192.168.1.0/24"), Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
}, },
testIP: netip.MustParseAddr("127.0.0.2"), testIP: netip.MustParseAddr("127.0.0.2"),
expected: true, expected: true,
@@ -29,8 +32,11 @@ func TestLocalIPManager(t *testing.T) {
{ {
name: "Localhost standard address", name: "Localhost standard address",
setupAddr: wgaddr.Address{ setupAddr: wgaddr.Address{
IP: netip.MustParseAddr("192.168.1.1"), IP: net.ParseIP("192.168.1.1"),
Network: netip.MustParsePrefix("192.168.1.0/24"), Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
}, },
testIP: netip.MustParseAddr("127.0.0.1"), testIP: netip.MustParseAddr("127.0.0.1"),
expected: true, expected: true,
@@ -38,8 +44,11 @@ func TestLocalIPManager(t *testing.T) {
{ {
name: "Localhost range edge", name: "Localhost range edge",
setupAddr: wgaddr.Address{ setupAddr: wgaddr.Address{
IP: netip.MustParseAddr("192.168.1.1"), IP: net.ParseIP("192.168.1.1"),
Network: netip.MustParsePrefix("192.168.1.0/24"), Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
}, },
testIP: netip.MustParseAddr("127.255.255.255"), testIP: netip.MustParseAddr("127.255.255.255"),
expected: true, expected: true,
@@ -47,8 +56,11 @@ func TestLocalIPManager(t *testing.T) {
{ {
name: "Local IP matches", name: "Local IP matches",
setupAddr: wgaddr.Address{ setupAddr: wgaddr.Address{
IP: netip.MustParseAddr("192.168.1.1"), IP: net.ParseIP("192.168.1.1"),
Network: netip.MustParsePrefix("192.168.1.0/24"), Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
}, },
testIP: netip.MustParseAddr("192.168.1.1"), testIP: netip.MustParseAddr("192.168.1.1"),
expected: true, expected: true,
@@ -56,8 +68,11 @@ func TestLocalIPManager(t *testing.T) {
{ {
name: "Local IP doesn't match", name: "Local IP doesn't match",
setupAddr: wgaddr.Address{ setupAddr: wgaddr.Address{
IP: netip.MustParseAddr("192.168.1.1"), IP: net.ParseIP("192.168.1.1"),
Network: netip.MustParsePrefix("192.168.1.0/24"), Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
}, },
testIP: netip.MustParseAddr("192.168.1.2"), testIP: netip.MustParseAddr("192.168.1.2"),
expected: false, expected: false,
@@ -65,8 +80,11 @@ func TestLocalIPManager(t *testing.T) {
{ {
name: "Local IP doesn't match - addresses 32 apart", name: "Local IP doesn't match - addresses 32 apart",
setupAddr: wgaddr.Address{ setupAddr: wgaddr.Address{
IP: netip.MustParseAddr("192.168.1.1"), IP: net.ParseIP("192.168.1.1"),
Network: netip.MustParsePrefix("192.168.1.0/24"), Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
}, },
testIP: netip.MustParseAddr("192.168.1.33"), testIP: netip.MustParseAddr("192.168.1.33"),
expected: false, expected: false,
@@ -74,8 +92,11 @@ func TestLocalIPManager(t *testing.T) {
{ {
name: "IPv6 address", name: "IPv6 address",
setupAddr: wgaddr.Address{ setupAddr: wgaddr.Address{
IP: netip.MustParseAddr("fe80::1"), IP: net.ParseIP("fe80::1"),
Network: netip.MustParsePrefix("192.168.1.0/24"), Network: &net.IPNet{
IP: net.ParseIP("fe80::"),
Mask: net.CIDRMask(64, 128),
},
}, },
testIP: netip.MustParseAddr("fe80::1"), testIP: netip.MustParseAddr("fe80::1"),
expected: false, expected: false,

View File

@@ -38,8 +38,11 @@ func TestTracePacket(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: netip.MustParseAddr("100.10.0.100"), IP: net.ParseIP("100.10.0.100"),
Network: netip.MustParsePrefix("100.10.0.0/16"), Network: &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
},
} }
}, },
} }

View File

@@ -39,12 +39,8 @@ const (
// EnvForceUserspaceRouter forces userspace routing even if native routing is available. // EnvForceUserspaceRouter forces userspace routing even if native routing is available.
EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER" EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER"
// EnvEnableLocalForwarding enables forwarding of local traffic to the native stack for internal (non-NetBird) interfaces. // EnvEnableNetstackLocalForwarding enables forwarding of local traffic to the native stack when running netstack
// Default off as it might be security risk because sockets listening on localhost only will become accessible. // Leaving this on by default introduces a security risk as sockets on listening on localhost only will be accessible
EnvEnableLocalForwarding = "NB_ENABLE_LOCAL_FORWARDING"
// EnvEnableNetstackLocalForwarding is an alias for EnvEnableLocalForwarding.
// In netstack mode, it enables forwarding of local traffic to the native stack for all interfaces.
EnvEnableNetstackLocalForwarding = "NB_ENABLE_NETSTACK_LOCAL_FORWARDING" EnvEnableNetstackLocalForwarding = "NB_ENABLE_NETSTACK_LOCAL_FORWARDING"
) )
@@ -75,6 +71,7 @@ type Manager struct {
// incomingRules is used for filtering and hooks // incomingRules is used for filtering and hooks
incomingRules map[netip.Addr]RuleSet incomingRules map[netip.Addr]RuleSet
routeRules RouteRules routeRules RouteRules
wgNetwork *net.IPNet
decoders sync.Pool decoders sync.Pool
wgIface common.IFaceMapper wgIface common.IFaceMapper
nativeFirewall firewall.Manager nativeFirewall firewall.Manager
@@ -151,11 +148,6 @@ func parseCreateEnv() (bool, bool) {
if err != nil { if err != nil {
log.Warnf("failed to parse %s: %v", EnvEnableNetstackLocalForwarding, err) log.Warnf("failed to parse %s: %v", EnvEnableNetstackLocalForwarding, err)
} }
} else if val := os.Getenv(EnvEnableLocalForwarding); val != "" {
enableLocalForwarding, err = strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", EnvEnableLocalForwarding, err)
}
} }
return disableConntrack, enableLocalForwarding return disableConntrack, enableLocalForwarding
@@ -277,7 +269,7 @@ func (m *Manager) determineRouting() error {
log.Info("userspace routing is forced") log.Info("userspace routing is forced")
case !m.netstack && m.nativeFirewall != nil: case !m.netstack && m.nativeFirewall != nil && m.nativeFirewall.IsServerRouteSupported():
// if the OS supports routing natively, then we don't need to filter/route ourselves // if the OS supports routing natively, then we don't need to filter/route ourselves
// netstack mode won't support native routing as there is no interface // netstack mode won't support native routing as there is no interface
@@ -334,10 +326,6 @@ func (m *Manager) IsServerRouteSupported() bool {
return true return true
} }
func (m *Manager) IsStateful() bool {
return m.stateful
}
func (m *Manager) AddNatRule(pair firewall.RouterPair) error { func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
if m.nativeRouter.Load() && m.nativeFirewall != nil { if m.nativeRouter.Load() && m.nativeFirewall != nil {
return m.nativeFirewall.AddNatRule(pair) return m.nativeFirewall.AddNatRule(pair)
@@ -618,8 +606,9 @@ func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool {
return true return true
} }
// for netflow we keep track even if the firewall is stateless if m.stateful {
m.trackOutbound(d, srcIP, dstIP, size) m.trackOutbound(d, srcIP, dstIP, size)
}
return false return false
} }
@@ -671,7 +660,7 @@ func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, size int) {
flags := getTCPFlags(&d.tcp) flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size) m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size)
case layers.LayerTypeICMPv4: case layers.LayerTypeICMPv4:
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size) m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, size)
} }
} }
@@ -684,7 +673,7 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
flags := getTCPFlags(&d.tcp) flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size) m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size)
case layers.LayerTypeICMPv4: case layers.LayerTypeICMPv4:
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size) m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, size)
} }
} }
@@ -788,10 +777,9 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
return true return true
} }
// If requested we pass local traffic to internal interfaces to the forwarder. // if running in netstack mode we need to pass this to the forwarder
// netstack doesn't have an interface to forward packets to the native stack so we always need to use the forwarder. if m.netstack && m.localForwarding {
if m.localForwarding && (m.netstack || dstIP != m.wgIface.Address().IP) { return m.handleNetstackLocalTraffic(packetData)
return m.handleForwardedLocalTraffic(packetData)
} }
// track inbound packets to get the correct direction and session id for flows // track inbound packets to get the correct direction and session id for flows
@@ -801,7 +789,8 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
return false return false
} }
func (m *Manager) handleForwardedLocalTraffic(packetData []byte) bool { func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool {
fwd := m.forwarder.Load() fwd := m.forwarder.Load()
if fwd == nil { if fwd == nil {
m.logger.Trace("Dropping local packet (forwarder not initialized)") m.logger.Trace("Dropping local packet (forwarder not initialized)")
@@ -1099,6 +1088,11 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot
return true return true
} }
// SetNetwork of the wireguard interface to which filtering applied
func (m *Manager) SetNetwork(network *net.IPNet) {
m.wgNetwork = network
}
// AddUDPPacketHook calls hook when UDP packet from given direction matched // AddUDPPacketHook calls hook when UDP packet from given direction matched
// //
// Hook function returns flag which indicates should be the matched package dropped or not // Hook function returns flag which indicates should be the matched package dropped or not

View File

@@ -174,6 +174,11 @@ func BenchmarkCoreFiltering(b *testing.B) {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
// Apply scenario-specific setup // Apply scenario-specific setup
sc.setupFunc(manager) sc.setupFunc(manager)
@@ -214,6 +219,11 @@ func BenchmarkStateScaling(b *testing.B) {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
// Pre-populate connection table // Pre-populate connection table
srcIPs := generateRandomIPs(count) srcIPs := generateRandomIPs(count)
dstIPs := generateRandomIPs(count) dstIPs := generateRandomIPs(count)
@@ -257,6 +267,11 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
srcIP := generateRandomIPs(1)[0] srcIP := generateRandomIPs(1)[0]
dstIP := generateRandomIPs(1)[0] dstIP := generateRandomIPs(1)[0]
outbound := generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP) outbound := generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP)
@@ -289,6 +304,10 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolTCP, proto: layers.IPProtocolTCP,
state: "new", state: "new",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
b.Setenv("NB_DISABLE_CONNTRACK", "1") b.Setenv("NB_DISABLE_CONNTRACK", "1")
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -302,6 +321,10 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolTCP, proto: layers.IPProtocolTCP,
state: "established", state: "established",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
b.Setenv("NB_DISABLE_CONNTRACK", "1") b.Setenv("NB_DISABLE_CONNTRACK", "1")
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -316,6 +339,10 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolUDP, proto: layers.IPProtocolUDP,
state: "new", state: "new",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
b.Setenv("NB_DISABLE_CONNTRACK", "1") b.Setenv("NB_DISABLE_CONNTRACK", "1")
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -329,6 +356,10 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolUDP, proto: layers.IPProtocolUDP,
state: "established", state: "established",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
b.Setenv("NB_DISABLE_CONNTRACK", "1") b.Setenv("NB_DISABLE_CONNTRACK", "1")
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -342,6 +373,10 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolTCP, proto: layers.IPProtocolTCP,
state: "new", state: "new",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -355,6 +390,10 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolTCP, proto: layers.IPProtocolTCP,
state: "established", state: "established",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -369,6 +408,10 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolTCP, proto: layers.IPProtocolTCP,
state: "post_handshake", state: "post_handshake",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -383,6 +426,10 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolUDP, proto: layers.IPProtocolUDP,
state: "new", state: "new",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -396,6 +443,10 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolUDP, proto: layers.IPProtocolUDP,
state: "established", state: "established",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -542,6 +593,11 @@ func BenchmarkLongLivedConnections(b *testing.B) {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.SetNetwork(&net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
})
// Setup initial state based on scenario // Setup initial state based on scenario
if sc.rules { if sc.rules {
// Single rule to allow all return traffic from port 80 // Single rule to allow all return traffic from port 80
@@ -625,6 +681,11 @@ func BenchmarkShortLivedConnections(b *testing.B) {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.SetNetwork(&net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
})
// Setup initial state based on scenario // Setup initial state based on scenario
if sc.rules { if sc.rules {
// Single rule to allow all return traffic from port 80 // Single rule to allow all return traffic from port 80
@@ -736,6 +797,11 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.SetNetwork(&net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
})
// Setup initial state based on scenario // Setup initial state based on scenario
if sc.rules { if sc.rules {
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "") _, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
@@ -816,6 +882,11 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.SetNetwork(&net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
})
if sc.rules { if sc.rules {
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "") _, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
require.NoError(b, err) require.NoError(b, err)
@@ -961,8 +1032,7 @@ func BenchmarkRouteACLs(b *testing.B) {
} }
for _, r := range rules { for _, r := range rules {
dst := fw.Network{Prefix: r.dest} _, err := manager.AddRouteFiltering(nil, r.sources, r.dest, r.proto, nil, r.port, fw.ActionAccept)
_, err := manager.AddRouteFiltering(nil, r.sources, dst, r.proto, nil, r.port, fw.ActionAccept)
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }

View File

@@ -19,8 +19,12 @@ import (
) )
func TestPeerACLFiltering(t *testing.T) { func TestPeerACLFiltering(t *testing.T) {
localIP := netip.MustParseAddr("100.10.0.100") localIP := net.ParseIP("100.10.0.100")
wgNet := netip.MustParsePrefix("100.10.0.0/16") wgNet := &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
}
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
@@ -39,6 +43,8 @@ func TestPeerACLFiltering(t *testing.T) {
require.NoError(t, manager.Close(nil)) require.NoError(t, manager.Close(nil))
}) })
manager.wgNetwork = wgNet
err = manager.UpdateLocalIPs() err = manager.UpdateLocalIPs()
require.NoError(t, err) require.NoError(t, err)
@@ -575,13 +581,14 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
dev := mocks.NewMockDevice(ctrl) dev := mocks.NewMockDevice(ctrl)
dev.EXPECT().MTU().Return(1500, nil).AnyTimes() dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
wgNet := netip.MustParsePrefix(network) localIP, wgNet, err := net.ParseCIDR(network)
require.NoError(tb, err)
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: wgNet.Addr(), IP: localIP,
Network: wgNet, Network: wgNet,
} }
}, },
@@ -1433,8 +1440,11 @@ func TestRouteACLSet(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: netip.MustParseAddr("100.10.0.100"), IP: net.ParseIP("100.10.0.100"),
Network: netip.MustParsePrefix("100.10.0.0/16"), Network: &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
},
} }
}, },
} }

View File

@@ -271,8 +271,11 @@ func TestNotMatchByIP(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: netip.MustParseAddr("100.10.0.100"), IP: net.ParseIP("100.10.0.100"),
Network: netip.MustParsePrefix("100.10.0.0/16"), Network: &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
},
} }
}, },
} }
@@ -282,6 +285,10 @@ func TestNotMatchByIP(t *testing.T) {
t.Errorf("failed to create Manager: %v", err) t.Errorf("failed to create Manager: %v", err)
return return
} }
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
}
ip := net.ParseIP("0.0.0.0") ip := net.ParseIP("0.0.0.0")
proto := fw.ProtocolUDP proto := fw.ProtocolUDP
@@ -389,6 +396,10 @@ func TestProcessOutgoingHooks(t *testing.T) {
}, false, flowLogger) }, false, flowLogger)
require.NoError(t, err) require.NoError(t, err)
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
}
manager.udpTracker.Close() manager.udpTracker.Close()
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger, flowLogger) manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger, flowLogger)
defer func() { defer func() {
@@ -498,6 +509,11 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
}, false, flowLogger) }, false, flowLogger)
require.NoError(t, err) require.NoError(t, err)
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
}
manager.udpTracker.Close() // Close the existing tracker manager.udpTracker.Close() // Close the existing tracker
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger, flowLogger) manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger, flowLogger)
manager.decoders = sync.Pool{ manager.decoders = sync.Pool{

View File

@@ -164,7 +164,7 @@ func (u *udpConn) performFilterCheck(addr net.Addr) error {
return nil return nil
} }
if u.address.Network.Contains(a) { if u.address.Network.Contains(a.AsSlice()) {
log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address) log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address)
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address) return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
} }

View File

@@ -1,17 +0,0 @@
package configurer
import (
"net"
"net/netip"
)
func prefixesToIPNets(prefixes []netip.Prefix) []net.IPNet {
ipNets := make([]net.IPNet, len(prefixes))
for i, prefix := range prefixes {
ipNets[i] = net.IPNet{
IP: prefix.Addr().AsSlice(), // Convert netip.Addr to net.IP
Mask: net.CIDRMask(prefix.Bits(), prefix.Addr().BitLen()), // Create subnet mask
}
}
return ipNets
}

View File

@@ -5,7 +5,6 @@ package configurer
import ( import (
"fmt" "fmt"
"net" "net"
"net/netip"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -13,8 +12,6 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
var zeroKey wgtypes.Key
type KernelConfigurer struct { type KernelConfigurer struct {
deviceName string deviceName string
} }
@@ -46,7 +43,7 @@ func (c *KernelConfigurer) ConfigureInterface(privateKey string, port int) error
return nil return nil
} }
func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey) peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil { if err != nil {
return err return err
@@ -55,7 +52,7 @@ func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
PublicKey: peerKeyParsed, PublicKey: peerKeyParsed,
ReplaceAllowedIPs: false, ReplaceAllowedIPs: false,
// don't replace allowed ips, wg will handle duplicated peer IP // don't replace allowed ips, wg will handle duplicated peer IP
AllowedIPs: prefixesToIPNets(allowedIps), AllowedIPs: allowedIps,
PersistentKeepaliveInterval: &keepAlive, PersistentKeepaliveInterval: &keepAlive,
Endpoint: endpoint, Endpoint: endpoint,
PresharedKey: preSharedKey, PresharedKey: preSharedKey,
@@ -92,10 +89,10 @@ func (c *KernelConfigurer) RemovePeer(peerKey string) error {
return nil return nil
} }
func (c *KernelConfigurer) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error { func (c *KernelConfigurer) AddAllowedIP(peerKey string, allowedIP string) error {
ipNet := net.IPNet{ _, ipNet, err := net.ParseCIDR(allowedIP)
IP: allowedIP.Addr().AsSlice(), if err != nil {
Mask: net.CIDRMask(allowedIP.Bits(), allowedIP.Addr().BitLen()), return err
} }
peerKeyParsed, err := wgtypes.ParseKey(peerKey) peerKeyParsed, err := wgtypes.ParseKey(peerKey)
@@ -106,7 +103,7 @@ func (c *KernelConfigurer) AddAllowedIP(peerKey string, allowedIP netip.Prefix)
PublicKey: peerKeyParsed, PublicKey: peerKeyParsed,
UpdateOnly: true, UpdateOnly: true,
ReplaceAllowedIPs: false, ReplaceAllowedIPs: false,
AllowedIPs: []net.IPNet{ipNet}, AllowedIPs: []net.IPNet{*ipNet},
} }
config := wgtypes.Config{ config := wgtypes.Config{
@@ -119,10 +116,10 @@ func (c *KernelConfigurer) AddAllowedIP(peerKey string, allowedIP netip.Prefix)
return nil return nil
} }
func (c *KernelConfigurer) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error { func (c *KernelConfigurer) RemoveAllowedIP(peerKey string, allowedIP string) error {
ipNet := net.IPNet{ _, ipNet, err := net.ParseCIDR(allowedIP)
IP: allowedIP.Addr().AsSlice(), if err != nil {
Mask: net.CIDRMask(allowedIP.Bits(), allowedIP.Addr().BitLen()), return fmt.Errorf("parse allowed IP: %w", err)
} }
peerKeyParsed, err := wgtypes.ParseKey(peerKey) peerKeyParsed, err := wgtypes.ParseKey(peerKey)
@@ -190,11 +187,7 @@ func (c *KernelConfigurer) configure(config wgtypes.Config) error {
if err != nil { if err != nil {
return err return err
} }
defer func() { defer wg.Close()
if err := wg.Close(); err != nil {
log.Errorf("Failed to close wgctrl client: %v", err)
}
}()
// validate if device with name exists // validate if device with name exists
_, err = wg.Device(c.deviceName) _, err = wg.Device(c.deviceName)
@@ -208,71 +201,14 @@ func (c *KernelConfigurer) configure(config wgtypes.Config) error {
func (c *KernelConfigurer) Close() { func (c *KernelConfigurer) Close() {
} }
func (c *KernelConfigurer) FullStats() (*Stats, error) { func (c *KernelConfigurer) GetStats(peerKey string) (WGStats, error) {
wg, err := wgctrl.New() peer, err := c.getPeer(c.deviceName, peerKey)
if err != nil { if err != nil {
return nil, fmt.Errorf("wgctl: %w", err) return WGStats{}, fmt.Errorf("get wireguard stats: %w", err)
} }
defer func() { return WGStats{
err = wg.Close() LastHandshake: peer.LastHandshakeTime,
if err != nil { TxBytes: peer.TransmitBytes,
log.Errorf("Got error while closing wgctl: %v", err) RxBytes: peer.ReceiveBytes,
} }, nil
}()
wgDevice, err := wg.Device(c.deviceName)
if err != nil {
return nil, fmt.Errorf("get device %s: %w", c.deviceName, err)
}
fullStats := &Stats{
DeviceName: wgDevice.Name,
PublicKey: wgDevice.PublicKey.String(),
ListenPort: wgDevice.ListenPort,
FWMark: wgDevice.FirewallMark,
Peers: []Peer{},
}
for _, p := range wgDevice.Peers {
peer := Peer{
PublicKey: p.PublicKey.String(),
AllowedIPs: p.AllowedIPs,
TxBytes: p.TransmitBytes,
RxBytes: p.ReceiveBytes,
LastHandshake: p.LastHandshakeTime,
PresharedKey: p.PresharedKey != zeroKey,
}
if p.Endpoint != nil {
peer.Endpoint = *p.Endpoint
}
fullStats.Peers = append(fullStats.Peers, peer)
}
return fullStats, nil
}
func (c *KernelConfigurer) GetStats() (map[string]WGStats, error) {
stats := make(map[string]WGStats)
wg, err := wgctrl.New()
if err != nil {
return nil, fmt.Errorf("wgctl: %w", err)
}
defer func() {
err = wg.Close()
if err != nil {
log.Errorf("Got error while closing wgctl: %v", err)
}
}()
wgDevice, err := wg.Device(c.deviceName)
if err != nil {
return nil, fmt.Errorf("get device %s: %w", c.deviceName, err)
}
for _, peer := range wgDevice.Peers {
stats[peer.PublicKey.String()] = WGStats{
LastHandshake: peer.LastHandshakeTime,
TxBytes: peer.TransmitBytes,
RxBytes: peer.ReceiveBytes,
}
}
return stats, nil
} }

View File

@@ -1,11 +1,9 @@
package configurer package configurer
import ( import (
"encoding/base64"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"net" "net"
"net/netip"
"os" "os"
"runtime" "runtime"
"strconv" "strconv"
@@ -19,20 +17,6 @@ import (
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
const (
privateKey = "private_key"
ipcKeyLastHandshakeTimeSec = "last_handshake_time_sec"
ipcKeyLastHandshakeTimeNsec = "last_handshake_time_nsec"
ipcKeyTxBytes = "tx_bytes"
ipcKeyRxBytes = "rx_bytes"
allowedIP = "allowed_ip"
endpoint = "endpoint"
fwmark = "fwmark"
listenPort = "listen_port"
publicKey = "public_key"
presharedKey = "preshared_key"
)
var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found") var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found")
type WGUSPConfigurer struct { type WGUSPConfigurer struct {
@@ -68,7 +52,7 @@ func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error
return c.device.IpcSet(toWgUserspaceString(config)) return c.device.IpcSet(toWgUserspaceString(config))
} }
func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey) peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil { if err != nil {
return err return err
@@ -77,7 +61,7 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
PublicKey: peerKeyParsed, PublicKey: peerKeyParsed,
ReplaceAllowedIPs: false, ReplaceAllowedIPs: false,
// don't replace allowed ips, wg will handle duplicated peer IP // don't replace allowed ips, wg will handle duplicated peer IP
AllowedIPs: prefixesToIPNets(allowedIps), AllowedIPs: allowedIps,
PersistentKeepaliveInterval: &keepAlive, PersistentKeepaliveInterval: &keepAlive,
PresharedKey: preSharedKey, PresharedKey: preSharedKey,
Endpoint: endpoint, Endpoint: endpoint,
@@ -107,10 +91,10 @@ func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
return c.device.IpcSet(toWgUserspaceString(config)) return c.device.IpcSet(toWgUserspaceString(config))
} }
func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error { func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP string) error {
ipNet := net.IPNet{ _, ipNet, err := net.ParseCIDR(allowedIP)
IP: allowedIP.Addr().AsSlice(), if err != nil {
Mask: net.CIDRMask(allowedIP.Bits(), allowedIP.Addr().BitLen()), return err
} }
peerKeyParsed, err := wgtypes.ParseKey(peerKey) peerKeyParsed, err := wgtypes.ParseKey(peerKey)
@@ -121,7 +105,7 @@ func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP netip.Prefix) e
PublicKey: peerKeyParsed, PublicKey: peerKeyParsed,
UpdateOnly: true, UpdateOnly: true,
ReplaceAllowedIPs: false, ReplaceAllowedIPs: false,
AllowedIPs: []net.IPNet{ipNet}, AllowedIPs: []net.IPNet{*ipNet},
} }
config := wgtypes.Config{ config := wgtypes.Config{
@@ -131,7 +115,7 @@ func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP netip.Prefix) e
return c.device.IpcSet(toWgUserspaceString(config)) return c.device.IpcSet(toWgUserspaceString(config))
} }
func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error { func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, ip string) error {
ipc, err := c.device.IpcGet() ipc, err := c.device.IpcGet()
if err != nil { if err != nil {
return err return err
@@ -154,8 +138,6 @@ func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix
foundPeer := false foundPeer := false
removedAllowedIP := false removedAllowedIP := false
ip := allowedIP.String()
for _, line := range lines { for _, line := range lines {
line = strings.TrimSpace(line) line = strings.TrimSpace(line)
@@ -178,8 +160,8 @@ func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix
// Append the line to the output string // Append the line to the output string
if foundPeer && strings.HasPrefix(line, "allowed_ip=") { if foundPeer && strings.HasPrefix(line, "allowed_ip=") {
allowedIPStr := strings.TrimPrefix(line, "allowed_ip=") allowedIP := strings.TrimPrefix(line, "allowed_ip=")
_, ipNet, err := net.ParseCIDR(allowedIPStr) _, ipNet, err := net.ParseCIDR(allowedIP)
if err != nil { if err != nil {
return err return err
} }
@@ -196,15 +178,6 @@ func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix
return c.device.IpcSet(toWgUserspaceString(config)) return c.device.IpcSet(toWgUserspaceString(config))
} }
func (c *WGUSPConfigurer) FullStats() (*Stats, error) {
ipcStr, err := c.device.IpcGet()
if err != nil {
return nil, fmt.Errorf("IpcGet failed: %w", err)
}
return parseStatus(c.deviceName, ipcStr)
}
// startUAPI starts the UAPI listener for managing the WireGuard interface via external tool // startUAPI starts the UAPI listener for managing the WireGuard interface via external tool
func (t *WGUSPConfigurer) startUAPI() { func (t *WGUSPConfigurer) startUAPI() {
var err error var err error
@@ -244,75 +217,91 @@ func (t *WGUSPConfigurer) Close() {
} }
} }
func (t *WGUSPConfigurer) GetStats() (map[string]WGStats, error) { func (t *WGUSPConfigurer) GetStats(peerKey string) (WGStats, error) {
ipc, err := t.device.IpcGet() ipc, err := t.device.IpcGet()
if err != nil { if err != nil {
return nil, fmt.Errorf("ipc get: %w", err) return WGStats{}, fmt.Errorf("ipc get: %w", err)
} }
return parseTransfers(ipc) stats, err := findPeerInfo(ipc, peerKey, []string{
"last_handshake_time_sec",
"last_handshake_time_nsec",
"tx_bytes",
"rx_bytes",
})
if err != nil {
return WGStats{}, fmt.Errorf("find peer info: %w", err)
}
sec, err := strconv.ParseInt(stats["last_handshake_time_sec"], 10, 64)
if err != nil {
return WGStats{}, fmt.Errorf("parse handshake sec: %w", err)
}
nsec, err := strconv.ParseInt(stats["last_handshake_time_nsec"], 10, 64)
if err != nil {
return WGStats{}, fmt.Errorf("parse handshake nsec: %w", err)
}
txBytes, err := strconv.ParseInt(stats["tx_bytes"], 10, 64)
if err != nil {
return WGStats{}, fmt.Errorf("parse tx_bytes: %w", err)
}
rxBytes, err := strconv.ParseInt(stats["rx_bytes"], 10, 64)
if err != nil {
return WGStats{}, fmt.Errorf("parse rx_bytes: %w", err)
}
return WGStats{
LastHandshake: time.Unix(sec, nsec),
TxBytes: txBytes,
RxBytes: rxBytes,
}, nil
} }
func parseTransfers(ipc string) (map[string]WGStats, error) { func findPeerInfo(ipcInput string, peerKey string, searchConfigKeys []string) (map[string]string, error) {
stats := make(map[string]WGStats) peerKeyParsed, err := wgtypes.ParseKey(peerKey)
var ( if err != nil {
currentKey string return nil, fmt.Errorf("parse key: %w", err)
currentStats WGStats }
hasPeer bool
) hexKey := hex.EncodeToString(peerKeyParsed[:])
lines := strings.Split(ipc, "\n")
lines := strings.Split(ipcInput, "\n")
configFound := map[string]string{}
foundPeer := false
for _, line := range lines { for _, line := range lines {
line = strings.TrimSpace(line) line = strings.TrimSpace(line)
// If we're within the details of the found peer and encounter another public key, // If we're within the details of the found peer and encounter another public key,
// this means we're starting another peer's details. So, stop. // this means we're starting another peer's details. So, stop.
if strings.HasPrefix(line, "public_key=") { if strings.HasPrefix(line, "public_key=") && foundPeer {
peerID := strings.TrimPrefix(line, "public_key=") break
h, err := hex.DecodeString(peerID)
if err != nil {
return nil, fmt.Errorf("decode peerID: %w", err)
}
currentKey = base64.StdEncoding.EncodeToString(h)
currentStats = WGStats{} // Reset stats for the new peer
hasPeer = true
stats[currentKey] = currentStats
continue
} }
if !hasPeer { // Identify the peer with the specific public key
continue if line == fmt.Sprintf("public_key=%s", hexKey) {
foundPeer = true
} }
key := strings.SplitN(line, "=", 2) for _, key := range searchConfigKeys {
if len(key) != 2 { if foundPeer && strings.HasPrefix(line, key+"=") {
continue v := strings.SplitN(line, "=", 2)
} configFound[v[0]] = v[1]
switch key[0] {
case ipcKeyLastHandshakeTimeSec:
hs, err := toLastHandshake(key[1])
if err != nil {
return nil, err
} }
currentStats.LastHandshake = hs
stats[currentKey] = currentStats
case ipcKeyRxBytes:
rxBytes, err := toBytes(key[1])
if err != nil {
return nil, fmt.Errorf("parse rx_bytes: %w", err)
}
currentStats.RxBytes = rxBytes
stats[currentKey] = currentStats
case ipcKeyTxBytes:
TxBytes, err := toBytes(key[1])
if err != nil {
return nil, fmt.Errorf("parse tx_bytes: %w", err)
}
currentStats.TxBytes = TxBytes
stats[currentKey] = currentStats
} }
} }
return stats, nil // todo: use multierr
for _, key := range searchConfigKeys {
if _, ok := configFound[key]; !ok {
return configFound, fmt.Errorf("config key not found: %s", key)
}
}
if !foundPeer {
return nil, fmt.Errorf("%w: %s", ErrPeerNotFound, peerKey)
}
return configFound, nil
} }
func toWgUserspaceString(wgCfg wgtypes.Config) string { func toWgUserspaceString(wgCfg wgtypes.Config) string {
@@ -366,154 +355,9 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
return sb.String() return sb.String()
} }
func toLastHandshake(stringVar string) (time.Time, error) {
sec, err := strconv.ParseInt(stringVar, 10, 64)
if err != nil {
return time.Time{}, fmt.Errorf("parse handshake sec: %w", err)
}
return time.Unix(sec, 0), nil
}
func toBytes(s string) (int64, error) {
return strconv.ParseInt(s, 10, 64)
}
func getFwmark() int { func getFwmark() int {
if nbnet.AdvancedRouting() { if nbnet.AdvancedRouting() {
return nbnet.ControlPlaneMark return nbnet.ControlPlaneMark
} }
return 0 return 0
} }
func hexToWireguardKey(hexKey string) (wgtypes.Key, error) {
// Decode hex string to bytes
keyBytes, err := hex.DecodeString(hexKey)
if err != nil {
return wgtypes.Key{}, fmt.Errorf("failed to decode hex key: %w", err)
}
// Check if we have the right number of bytes (WireGuard keys are 32 bytes)
if len(keyBytes) != 32 {
return wgtypes.Key{}, fmt.Errorf("invalid key length: expected 32 bytes, got %d", len(keyBytes))
}
// Convert to wgtypes.Key
var key wgtypes.Key
copy(key[:], keyBytes)
return key, nil
}
func parseStatus(deviceName, ipcStr string) (*Stats, error) {
stats := &Stats{DeviceName: deviceName}
var currentPeer *Peer
for _, line := range strings.Split(strings.TrimSpace(ipcStr), "\n") {
if line == "" {
continue
}
parts := strings.SplitN(line, "=", 2)
if len(parts) != 2 {
continue
}
key := parts[0]
val := parts[1]
switch key {
case privateKey:
key, err := hexToWireguardKey(val)
if err != nil {
log.Errorf("failed to parse private key: %v", err)
continue
}
stats.PublicKey = key.PublicKey().String()
case publicKey:
// Save previous peer
if currentPeer != nil {
stats.Peers = append(stats.Peers, *currentPeer)
}
key, err := hexToWireguardKey(val)
if err != nil {
log.Errorf("failed to parse public key: %v", err)
continue
}
currentPeer = &Peer{
PublicKey: key.String(),
}
case listenPort:
if port, err := strconv.Atoi(val); err == nil {
stats.ListenPort = port
}
case fwmark:
if fwmark, err := strconv.Atoi(val); err == nil {
stats.FWMark = fwmark
}
case endpoint:
if currentPeer == nil {
continue
}
host, portStr, err := net.SplitHostPort(strings.Trim(val, "[]"))
if err != nil {
log.Errorf("failed to parse endpoint: %v", err)
continue
}
port, err := strconv.Atoi(portStr)
if err != nil {
log.Errorf("failed to parse endpoint port: %v", err)
continue
}
currentPeer.Endpoint = net.UDPAddr{
IP: net.ParseIP(host),
Port: port,
}
case allowedIP:
if currentPeer == nil {
continue
}
_, ipnet, err := net.ParseCIDR(val)
if err == nil {
currentPeer.AllowedIPs = append(currentPeer.AllowedIPs, *ipnet)
}
case ipcKeyTxBytes:
if currentPeer == nil {
continue
}
rxBytes, err := toBytes(val)
if err != nil {
continue
}
currentPeer.TxBytes = rxBytes
case ipcKeyRxBytes:
if currentPeer == nil {
continue
}
rxBytes, err := toBytes(val)
if err != nil {
continue
}
currentPeer.RxBytes = rxBytes
case ipcKeyLastHandshakeTimeSec:
if currentPeer == nil {
continue
}
ts, err := toLastHandshake(val)
if err != nil {
continue
}
currentPeer.LastHandshake = ts
case presharedKey:
if currentPeer == nil {
continue
}
if val != "" {
currentPeer.PresharedKey = true
}
}
}
if currentPeer != nil {
stats.Peers = append(stats.Peers, *currentPeer)
}
return stats, nil
}

View File

@@ -2,8 +2,10 @@ package configurer
import ( import (
"encoding/hex" "encoding/hex"
"fmt"
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
@@ -32,35 +34,58 @@ errno=0
` `
func Test_parseTransfers(t *testing.T) { func Test_findPeerInfo(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
peerKey string peerKey string
want WGStats searchKeys []string
want map[string]string
wantErr bool
}{ }{
{ {
name: "single", name: "single",
peerKey: "b85996fecc9c7f1fc6d2572a76eda11d59bcd20be8e543b15ce4bd85a8e75a33", peerKey: "58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376",
want: WGStats{ searchKeys: []string{"tx_bytes"},
TxBytes: 0, want: map[string]string{
RxBytes: 0, "tx_bytes": "38333",
}, },
wantErr: false,
}, },
{ {
name: "multiple", name: "multiple",
peerKey: "58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376", peerKey: "58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376",
want: WGStats{ searchKeys: []string{"tx_bytes", "rx_bytes"},
TxBytes: 38333, want: map[string]string{
RxBytes: 2224, "tx_bytes": "38333",
"rx_bytes": "2224",
}, },
wantErr: false,
}, },
{ {
name: "lastpeer", name: "lastpeer",
peerKey: "662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58", peerKey: "662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58",
want: WGStats{ searchKeys: []string{"tx_bytes", "rx_bytes"},
TxBytes: 1212111, want: map[string]string{
RxBytes: 1929999999, "tx_bytes": "1212111",
"rx_bytes": "1929999999",
}, },
wantErr: false,
},
{
name: "peer not found",
peerKey: "1111111111111111111111111111111111111111111111111111111111111111",
searchKeys: nil,
want: nil,
wantErr: true,
},
{
name: "key not found",
peerKey: "662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58",
searchKeys: []string{"tx_bytes", "unknown_key"},
want: map[string]string{
"tx_bytes": "1212111",
},
wantErr: true,
}, },
} }
for _, tt := range tests { for _, tt := range tests {
@@ -71,19 +96,9 @@ func Test_parseTransfers(t *testing.T) {
key, err := wgtypes.NewKey(res) key, err := wgtypes.NewKey(res)
require.NoError(t, err) require.NoError(t, err)
stats, err := parseTransfers(ipcFixture) got, err := findPeerInfo(ipcFixture, key.String(), tt.searchKeys)
if err != nil { assert.Equalf(t, tt.wantErr, err != nil, fmt.Sprintf("findPeerInfo(%v, %v, %v)", ipcFixture, key.String(), tt.searchKeys))
require.NoError(t, err) assert.Equalf(t, tt.want, got, "findPeerInfo(%v, %v, %v)", ipcFixture, key.String(), tt.searchKeys)
return
}
stat, ok := stats[key.String()]
if !ok {
require.True(t, ok)
return
}
require.Equal(t, tt.want, stat)
}) })
} }
} }

View File

@@ -1,24 +0,0 @@
package configurer
import (
"net"
"time"
)
type Peer struct {
PublicKey string
Endpoint net.UDPAddr
AllowedIPs []net.IPNet
TxBytes int64
RxBytes int64
LastHandshake time.Time
PresharedKey bool
}
type Stats struct {
DeviceName string
PublicKey string
ListenPort int
FWMark int
Peers []Peer
}

View File

@@ -24,7 +24,6 @@ type WGTunDevice struct {
mtu int mtu int
iceBind *bind.ICEBind iceBind *bind.ICEBind
tunAdapter TunAdapter tunAdapter TunAdapter
disableDNS bool
name string name string
device *device.Device device *device.Device
@@ -33,7 +32,7 @@ type WGTunDevice struct {
configurer WGConfigurer configurer WGConfigurer
} }
func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter, disableDNS bool) *WGTunDevice { func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice {
return &WGTunDevice{ return &WGTunDevice{
address: address, address: address,
port: port, port: port,
@@ -41,7 +40,6 @@ func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind
mtu: mtu, mtu: mtu,
iceBind: iceBind, iceBind: iceBind,
tunAdapter: tunAdapter, tunAdapter: tunAdapter,
disableDNS: disableDNS,
} }
} }
@@ -51,13 +49,6 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
routesString := routesToString(routes) routesString := routesToString(routes)
searchDomainsToString := searchDomainsToString(searchDomains) searchDomainsToString := searchDomainsToString(searchDomains)
// Skip DNS configuration when DisableDNS is enabled
if t.disableDNS {
log.Info("DNS is disabled, skipping DNS and search domain configuration")
dns = ""
searchDomainsToString = ""
}
fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, dns, searchDomainsToString, routesString) fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, dns, searchDomainsToString, routesString)
if err != nil { if err != nil {
log.Errorf("failed to create Android interface: %s", err) log.Errorf("failed to create Android interface: %s", err)

View File

@@ -1,6 +1,7 @@
package device package device
import ( import (
"net"
"net/netip" "net/netip"
"sync" "sync"
@@ -23,6 +24,9 @@ type PacketFilter interface {
// RemovePacketHook removes hook by ID // RemovePacketHook removes hook by ID
RemovePacketHook(hookID string) error RemovePacketHook(hookID string) error
// SetNetwork of the wireguard interface to which filtering applied
SetNetwork(*net.IPNet)
} }
// FilteredDevice to override Read or Write of packets // FilteredDevice to override Read or Write of packets

View File

@@ -51,11 +51,7 @@ func (t *TunNetstackDevice) Create() (WGConfigurer, error) {
log.Info("create nbnetstack tun interface") log.Info("create nbnetstack tun interface")
// TODO: get from service listener runtime IP // TODO: get from service listener runtime IP
dnsAddr, err := nbnet.GetLastIPFromNetwork(t.address.Network, 1) dnsAddr := nbnet.GetLastIPFromNetwork(t.address.Network, 1)
if err != nil {
return nil, fmt.Errorf("last ip: %w", err)
}
log.Debugf("netstack using address: %s", t.address.IP) log.Debugf("netstack using address: %s", t.address.IP)
t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, t.mtu) t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, t.mtu)
log.Debugf("netstack using dns address: %s", dnsAddr) log.Debugf("netstack using dns address: %s", dnsAddr)

View File

@@ -2,7 +2,6 @@ package device
import ( import (
"net" "net"
"net/netip"
"time" "time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@@ -12,11 +11,10 @@ import (
type WGConfigurer interface { type WGConfigurer interface {
ConfigureInterface(privateKey string, port int) error ConfigureInterface(privateKey string, port int) error
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeer(peerKey string) error RemovePeer(peerKey string) error
AddAllowedIP(peerKey string, allowedIP netip.Prefix) error AddAllowedIP(peerKey string, allowedIP string) error
RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error RemoveAllowedIP(peerKey string, allowedIP string) error
Close() Close()
GetStats() (map[string]configurer.WGStats, error) GetStats(peerKey string) (configurer.WGStats, error)
FullStats() (*configurer.Stats, error)
} }

View File

@@ -64,15 +64,7 @@ func (l *wgLink) assignAddr(address wgaddr.Address) error {
} }
ip := address.IP.String() ip := address.IP.String()
mask := "0x" + address.Network.Mask.String()
// Convert prefix length to hex netmask
prefixLen := address.Network.Bits()
if !address.IP.Is4() {
return fmt.Errorf("IPv6 not supported for interface assignment")
}
maskBits := uint32(0xffffffff) << (32 - prefixLen)
mask := fmt.Sprintf("0x%08x", maskBits)
log.Infof("assign addr %s mask %s to %s interface", ip, mask, l.name) log.Infof("assign addr %s mask %s to %s interface", ip, mask, l.name)

View File

@@ -43,7 +43,6 @@ type WGIFaceOpts struct {
MobileArgs *device.MobileIFaceArguments MobileArgs *device.MobileIFaceArguments
TransportNet transport.Net TransportNet transport.Net
FilterFn bind.FilterFn FilterFn bind.FilterFn
DisableDNS bool
} }
// WGIface represents an interface instance // WGIface represents an interface instance
@@ -112,14 +111,14 @@ func (w *WGIface) UpdateAddr(newAddr string) error {
} }
// UpdatePeer updates existing Wireguard Peer or creates a new one if doesn't exist // UpdatePeer updates existing Wireguard Peer or creates a new one if doesn't exist
// Endpoint is optional. // Endpoint is optional
// If allowedIps is given it will be added to the existing ones.
func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
log.Debugf("updating interface %s peer %s, endpoint %s, allowedIPs %v", w.tun.DeviceName(), peerKey, endpoint, allowedIps) netIPNets := prefixesToIPNets(allowedIps)
return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey) log.Debugf("updating interface %s peer %s, endpoint %s", w.tun.DeviceName(), peerKey, endpoint)
return w.configurer.UpdatePeer(peerKey, netIPNets, keepAlive, endpoint, preSharedKey)
} }
// RemovePeer removes a Wireguard Peer from the interface iface // RemovePeer removes a Wireguard Peer from the interface iface
@@ -132,7 +131,7 @@ func (w *WGIface) RemovePeer(peerKey string) error {
} }
// AddAllowedIP adds a prefix to the allowed IPs list of peer // AddAllowedIP adds a prefix to the allowed IPs list of peer
func (w *WGIface) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error { func (w *WGIface) AddAllowedIP(peerKey string, allowedIP string) error {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
@@ -141,7 +140,7 @@ func (w *WGIface) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
} }
// RemoveAllowedIP removes a prefix from the allowed IPs list of peer // RemoveAllowedIP removes a prefix from the allowed IPs list of peer
func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error { func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP string) error {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
@@ -186,6 +185,7 @@ func (w *WGIface) SetFilter(filter device.PacketFilter) error {
} }
w.filter = filter w.filter = filter
w.filter.SetNetwork(w.tun.WgAddress().Network)
w.tun.FilteredDevice().SetFilter(filter) w.tun.FilteredDevice().SetFilter(filter)
return nil return nil
@@ -212,13 +212,9 @@ func (w *WGIface) GetWGDevice() *wgdevice.Device {
return w.tun.Device() return w.tun.Device()
} }
// GetStats returns the last handshake time, rx and tx bytes // GetStats returns the last handshake time, rx and tx bytes for the given peer
func (w *WGIface) GetStats() (map[string]configurer.WGStats, error) { func (w *WGIface) GetStats(peerKey string) (configurer.WGStats, error) {
return w.configurer.GetStats() return w.configurer.GetStats(peerKey)
}
func (w *WGIface) FullStats() (*configurer.Stats, error) {
return w.configurer.FullStats()
} }
func (w *WGIface) waitUntilRemoved() error { func (w *WGIface) waitUntilRemoved() error {
@@ -255,3 +251,14 @@ func (w *WGIface) GetNet() *netstack.Net {
return w.tun.GetNet() return w.tun.GetNet()
} }
func prefixesToIPNets(prefixes []netip.Prefix) []net.IPNet {
ipNets := make([]net.IPNet, len(prefixes))
for i, prefix := range prefixes {
ipNets[i] = net.IPNet{
IP: net.IP(prefix.Addr().AsSlice()), // Convert netip.Addr to net.IP
Mask: net.CIDRMask(prefix.Bits(), prefix.Addr().BitLen()), // Create subnet mask
}
}
return ipNets
}

View File

@@ -18,7 +18,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgIFace := &WGIface{ wgIFace := &WGIface{
userspaceBind: true, userspaceBind: true,
tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter, opts.DisableDNS), tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter),
wgProxyFactory: wgproxy.NewUSPFactory(iceBind), wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
} }
return wgIFace, nil return wgIFace, nil

View File

@@ -5,6 +5,7 @@
package mocks package mocks
import ( import (
net "net"
"net/netip" "net/netip"
reflect "reflect" reflect "reflect"
@@ -89,3 +90,15 @@ func (mr *MockPacketFilterMockRecorder) RemovePacketHook(arg0 interface{}) *gomo
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemovePacketHook", reflect.TypeOf((*MockPacketFilter)(nil).RemovePacketHook), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemovePacketHook", reflect.TypeOf((*MockPacketFilter)(nil).RemovePacketHook), arg0)
} }
// SetNetwork mocks base method.
func (m *MockPacketFilter) SetNetwork(arg0 *net.IPNet) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetNetwork", arg0)
}
// SetNetwork indicates an expected call of SetNetwork.
func (mr *MockPacketFilterMockRecorder) SetNetwork(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetwork", reflect.TypeOf((*MockPacketFilter)(nil).SetNetwork), arg0)
}

View File

@@ -1,6 +1,8 @@
package netstack package netstack
import ( import (
"fmt"
"net"
"net/netip" "net/netip"
"os" "os"
"strconv" "strconv"
@@ -13,8 +15,8 @@ import (
const EnvSkipProxy = "NB_NETSTACK_SKIP_PROXY" const EnvSkipProxy = "NB_NETSTACK_SKIP_PROXY"
type NetStackTun struct { //nolint:revive type NetStackTun struct { //nolint:revive
address netip.Addr address net.IP
dnsAddress netip.Addr dnsAddress net.IP
mtu int mtu int
listenAddress string listenAddress string
@@ -22,7 +24,7 @@ type NetStackTun struct { //nolint:revive
tundev tun.Device tundev tun.Device
} }
func NewNetStackTun(listenAddress string, address netip.Addr, dnsAddress netip.Addr, mtu int) *NetStackTun { func NewNetStackTun(listenAddress string, address net.IP, dnsAddress net.IP, mtu int) *NetStackTun {
return &NetStackTun{ return &NetStackTun{
address: address, address: address,
dnsAddress: dnsAddress, dnsAddress: dnsAddress,
@@ -32,9 +34,19 @@ func NewNetStackTun(listenAddress string, address netip.Addr, dnsAddress netip.A
} }
func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) { func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
addr, ok := netip.AddrFromSlice(t.address)
if !ok {
return nil, nil, fmt.Errorf("convert address to netip.Addr: %v", t.address)
}
dnsAddr, ok := netip.AddrFromSlice(t.dnsAddress)
if !ok {
return nil, nil, fmt.Errorf("convert dns address to netip.Addr: %v", t.dnsAddress)
}
nsTunDev, tunNet, err := netstack.CreateNetTUN( nsTunDev, tunNet, err := netstack.CreateNetTUN(
[]netip.Addr{t.address}, []netip.Addr{addr.Unmap()},
[]netip.Addr{t.dnsAddress}, []netip.Addr{dnsAddr.Unmap()},
t.mtu) t.mtu)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err

View File

@@ -2,27 +2,28 @@ package wgaddr
import ( import (
"fmt" "fmt"
"net/netip" "net"
) )
// Address WireGuard parsed address // Address WireGuard parsed address
type Address struct { type Address struct {
IP netip.Addr IP net.IP
Network netip.Prefix Network *net.IPNet
} }
// ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address // ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address
func ParseWGAddress(address string) (Address, error) { func ParseWGAddress(address string) (Address, error) {
prefix, err := netip.ParsePrefix(address) ip, network, err := net.ParseCIDR(address)
if err != nil { if err != nil {
return Address{}, err return Address{}, err
} }
return Address{ return Address{
IP: prefix.Addr().Unmap(), IP: ip,
Network: prefix.Masked(), Network: network,
}, nil }, nil
} }
func (addr Address) String() string { func (addr Address) String() string {
return fmt.Sprintf("%s/%d", addr.IP.String(), addr.Network.Bits()) maskSize, _ := addr.Network.Mask.Size()
return fmt.Sprintf("%s/%d", addr.IP.String(), maskSize)
} }

View File

@@ -24,8 +24,6 @@
!define AUTOSTART_REG_KEY "Software\Microsoft\Windows\CurrentVersion\Run" !define AUTOSTART_REG_KEY "Software\Microsoft\Windows\CurrentVersion\Run"
!define NETBIRD_DATA_DIR "$COMMONPROGRAMDATA\Netbird"
Unicode True Unicode True
###################################################################### ######################################################################
@@ -51,10 +49,6 @@ ShowInstDetails Show
###################################################################### ######################################################################
!include "MUI2.nsh"
!include LogicLib.nsh
!include "nsDialogs.nsh"
!define MUI_ICON "${ICON}" !define MUI_ICON "${ICON}"
!define MUI_UNICON "${ICON}" !define MUI_UNICON "${ICON}"
!define MUI_WELCOMEFINISHPAGE_BITMAP "${BANNER}" !define MUI_WELCOMEFINISHPAGE_BITMAP "${BANNER}"
@@ -64,6 +58,9 @@ ShowInstDetails Show
!define MUI_FINISHPAGE_RUN_FUNCTION "LaunchLink" !define MUI_FINISHPAGE_RUN_FUNCTION "LaunchLink"
###################################################################### ######################################################################
!include "MUI2.nsh"
!include LogicLib.nsh
!define MUI_ABORTWARNING !define MUI_ABORTWARNING
!define MUI_UNABORTWARNING !define MUI_UNABORTWARNING
@@ -73,16 +70,13 @@ ShowInstDetails Show
!insertmacro MUI_PAGE_DIRECTORY !insertmacro MUI_PAGE_DIRECTORY
; Custom page for autostart checkbox
Page custom AutostartPage AutostartPageLeave Page custom AutostartPage AutostartPageLeave
!insertmacro MUI_PAGE_INSTFILES !insertmacro MUI_PAGE_INSTFILES
!insertmacro MUI_PAGE_FINISH !insertmacro MUI_PAGE_FINISH
!insertmacro MUI_UNPAGE_WELCOME
UninstPage custom un.DeleteDataPage un.DeleteDataPageLeave
!insertmacro MUI_UNPAGE_CONFIRM !insertmacro MUI_UNPAGE_CONFIRM
!insertmacro MUI_UNPAGE_INSTFILES !insertmacro MUI_UNPAGE_INSTFILES
@@ -95,10 +89,6 @@ UninstPage custom un.DeleteDataPage un.DeleteDataPageLeave
Var AutostartCheckbox Var AutostartCheckbox
Var AutostartEnabled Var AutostartEnabled
; Variables for uninstall data deletion option
Var DeleteDataCheckbox
Var DeleteDataEnabled
###################################################################### ######################################################################
; Function to create the autostart options page ; Function to create the autostart options page
@@ -114,8 +104,8 @@ Function AutostartPage
${NSD_CreateCheckbox} 0 20u 100% 10u "Start ${APP_NAME} UI automatically when Windows starts" ${NSD_CreateCheckbox} 0 20u 100% 10u "Start ${APP_NAME} UI automatically when Windows starts"
Pop $AutostartCheckbox Pop $AutostartCheckbox
${NSD_Check} $AutostartCheckbox ${NSD_Check} $AutostartCheckbox ; Default to checked
StrCpy $AutostartEnabled "1" StrCpy $AutostartEnabled "1" ; Default to enabled
nsDialogs::Show nsDialogs::Show
FunctionEnd FunctionEnd
@@ -125,30 +115,6 @@ Function AutostartPageLeave
${NSD_GetState} $AutostartCheckbox $AutostartEnabled ${NSD_GetState} $AutostartCheckbox $AutostartEnabled
FunctionEnd FunctionEnd
; Function to create the uninstall data deletion page
Function un.DeleteDataPage
!insertmacro MUI_HEADER_TEXT "Uninstall Options" "Choose whether to delete ${APP_NAME} data."
nsDialogs::Create 1018
Pop $0
${If} $0 == error
Abort
${EndIf}
${NSD_CreateCheckbox} 0 20u 100% 10u "Delete all ${APP_NAME} configuration and state data (${NETBIRD_DATA_DIR})"
Pop $DeleteDataCheckbox
${NSD_Uncheck} $DeleteDataCheckbox
StrCpy $DeleteDataEnabled "0"
nsDialogs::Show
FunctionEnd
; Function to handle leaving the data deletion page
Function un.DeleteDataPageLeave
${NSD_GetState} $DeleteDataCheckbox $DeleteDataEnabled
FunctionEnd
Function GetAppFromCommand Function GetAppFromCommand
Exch $1 Exch $1
Push $2 Push $2
@@ -210,10 +176,10 @@ ${EndIf}
FunctionEnd FunctionEnd
###################################################################### ######################################################################
Section -MainProgram Section -MainProgram
${INSTALL_TYPE} ${INSTALL_TYPE}
# SetOverwrite ifnewer # SetOverwrite ifnewer
SetOutPath "$INSTDIR" SetOutPath "$INSTDIR"
File /r "..\\dist\\netbird_windows_amd64\\" File /r "..\\dist\\netbird_windows_amd64\\"
SectionEnd SectionEnd
###################################################################### ######################################################################
@@ -259,58 +225,31 @@ SectionEnd
Section Uninstall Section Uninstall
${INSTALL_TYPE} ${INSTALL_TYPE}
DetailPrint "Stopping Netbird service..."
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service stop' ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service stop'
DetailPrint "Uninstalling Netbird service..."
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall' ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall'
DetailPrint "Terminating Netbird UI process..." # kill ui client
ExecWait `taskkill /im ${UI_APP_EXE}.exe /f` ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
; Remove autostart registry entry ; Remove autostart registry entry
DetailPrint "Removing autostart registry entry if exists..."
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}" DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
; Handle data deletion based on checkbox
DetailPrint "Checking if user requested data deletion..."
${If} $DeleteDataEnabled == "1"
DetailPrint "User opted to delete Netbird data. Removing ${NETBIRD_DATA_DIR}..."
ClearErrors
RMDir /r "${NETBIRD_DATA_DIR}"
IfErrors 0 +2 ; If no errors, jump over the message
DetailPrint "Error deleting Netbird data directory. It might be in use or already removed."
DetailPrint "Netbird data directory removal complete."
${Else}
DetailPrint "User did not opt to delete Netbird data."
${EndIf}
# wait the service uninstall take unblock the executable # wait the service uninstall take unblock the executable
DetailPrint "Waiting for service handle to be released..."
Sleep 3000 Sleep 3000
DetailPrint "Deleting application files..."
Delete "$INSTDIR\${UI_APP_EXE}" Delete "$INSTDIR\${UI_APP_EXE}"
Delete "$INSTDIR\${MAIN_APP_EXE}" Delete "$INSTDIR\${MAIN_APP_EXE}"
Delete "$INSTDIR\wintun.dll" Delete "$INSTDIR\wintun.dll"
Delete "$INSTDIR\opengl32.dll" Delete "$INSTDIR\opengl32.dll"
DetailPrint "Removing application directory..."
RmDir /r "$INSTDIR" RmDir /r "$INSTDIR"
DetailPrint "Removing shortcuts..."
SetShellVarContext all SetShellVarContext all
Delete "$DESKTOP\${APP_NAME}.lnk" Delete "$DESKTOP\${APP_NAME}.lnk"
Delete "$SMPROGRAMS\${APP_NAME}.lnk" Delete "$SMPROGRAMS\${APP_NAME}.lnk"
DetailPrint "Removing registry keys..."
DeleteRegKey ${REG_ROOT} "${REG_APP_PATH}" DeleteRegKey ${REG_ROOT} "${REG_APP_PATH}"
DeleteRegKey ${REG_ROOT} "${UNINSTALL_PATH}" DeleteRegKey ${REG_ROOT} "${UNINSTALL_PATH}"
DeleteRegKey ${REG_ROOT} "${UI_REG_APP_PATH}"
DetailPrint "Removing application directory from PATH..."
EnVar::SetHKLM EnVar::SetHKLM
EnVar::DeleteValue "path" "$INSTDIR" EnVar::DeleteValue "path" "$INSTDIR"
DetailPrint "Uninstallation finished."
SectionEnd SectionEnd

View File

@@ -58,11 +58,6 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
d.mutex.Lock() d.mutex.Lock()
defer d.mutex.Unlock() defer d.mutex.Unlock()
if d.firewall == nil {
log.Debug("firewall manager is not supported, skipping firewall rules")
return
}
start := time.Now() start := time.Now()
defer func() { defer func() {
total := 0 total := 0
@@ -74,8 +69,20 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
time.Since(start), total) time.Since(start), total)
}() }()
if d.firewall == nil {
log.Debug("firewall manager is not supported, skipping firewall rules")
return
}
d.applyPeerACLs(networkMap) d.applyPeerACLs(networkMap)
// If we got empty rules list but management did not set the networkMap.FirewallRulesIsEmpty flag,
// then the mgmt server is older than the client, and we need to allow all traffic for routes
isLegacy := len(networkMap.RoutesFirewallRules) == 0 && !networkMap.RoutesFirewallRulesIsEmpty
if err := d.firewall.SetLegacyManagement(isLegacy); err != nil {
log.Errorf("failed to set legacy management flag: %v", err)
}
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag); err != nil { if err := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag); err != nil {
log.Errorf("Failed to apply route ACLs: %v", err) log.Errorf("Failed to apply route ACLs: %v", err)
} }
@@ -284,10 +291,8 @@ func (d *DefaultManager) protoRuleToFirewallRule(
case mgmProto.RuleDirection_IN: case mgmProto.RuleDirection_IN:
rules, err = d.addInRules(r.PolicyID, ip, protocol, port, action, ipsetName) rules, err = d.addInRules(r.PolicyID, ip, protocol, port, action, ipsetName)
case mgmProto.RuleDirection_OUT: case mgmProto.RuleDirection_OUT:
if d.firewall.IsStateful() { // TODO: Remove this soon. Outbound rules are obsolete.
return "", nil, nil // We only maintain this for return traffic (inbound dir) which is now handled by the stateful firewall already
}
// return traffic for outbound connections if firewall is stateless
rules, err = d.addOutRules(r.PolicyID, ip, protocol, port, action, ipsetName) rules, err = d.addOutRules(r.PolicyID, ip, protocol, port, action, ipsetName)
default: default:
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule") return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
@@ -398,15 +403,11 @@ func (d *DefaultManager) squashAcceptRules(
// //
// We zeroed this to notify squash function that this protocol can't be squashed. // We zeroed this to notify squash function that this protocol can't be squashed.
addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols map[mgmProto.RuleProtocol]*protoMatch) { addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols map[mgmProto.RuleProtocol]*protoMatch) {
hasPortRestrictions := r.Action == mgmProto.RuleAction_DROP || drop := r.Action == mgmProto.RuleAction_DROP || r.Port != ""
r.Port != "" || !portInfoEmpty(r.PortInfo) if drop {
if hasPortRestrictions {
// Don't squash rules with port restrictions
protocols[r.Protocol] = &protoMatch{ips: map[string]int{}} protocols[r.Protocol] = &protoMatch{ips: map[string]int{}}
return return
} }
if _, ok := protocols[r.Protocol]; !ok { if _, ok := protocols[r.Protocol]; !ok {
protocols[r.Protocol] = &protoMatch{ protocols[r.Protocol] = &protoMatch{
ips: map[string]int{}, ips: map[string]int{},

View File

@@ -1,14 +1,13 @@
package acl package acl
import ( import (
"net/netip" "net"
"testing" "testing"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/firewall" "github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/acl/mocks" "github.com/netbirdio/netbird/client/internal/acl/mocks"
"github.com/netbirdio/netbird/client/internal/netflow" "github.com/netbirdio/netbird/client/internal/netflow"
@@ -43,31 +42,35 @@ func TestDefaultManager(t *testing.T) {
ifaceMock := mocks.NewMockIFaceMapper(ctrl) ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes() ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any()) ifaceMock.EXPECT().SetFilter(gomock.Any())
network := netip.MustParsePrefix("172.0.0.1/32") ip, network, err := net.ParseCIDR("172.0.0.1/32")
if err != nil {
t.Fatalf("failed to parse IP address: %v", err)
}
ifaceMock.EXPECT().Name().Return("lo").AnyTimes() ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{ ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: network.Addr(), IP: ip,
Network: network, Network: network,
}).AnyTimes() }).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
// we receive one rule from the management so for testing purposes ignore it
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false) fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
require.NoError(t, err) if err != nil {
defer func() { t.Errorf("create firewall: %v", err)
err = fw.Close(nil) return
require.NoError(t, err) }
}() defer func(fw manager.Manager) {
_ = fw.Close(nil)
}(fw)
acl := NewDefaultManager(fw) acl := NewDefaultManager(fw)
t.Run("apply firewall rules", func(t *testing.T) { t.Run("apply firewall rules", func(t *testing.T) {
acl.ApplyFiltering(networkMap, false) acl.ApplyFiltering(networkMap, false)
if fw.IsStateful() { if len(acl.peerRulesPairs) != 2 {
assert.Equal(t, 0, len(acl.peerRulesPairs)) t.Errorf("firewall rules not applied: %v", acl.peerRulesPairs)
} else { return
assert.Equal(t, 2, len(acl.peerRulesPairs))
} }
}) })
@@ -91,13 +94,12 @@ func TestDefaultManager(t *testing.T) {
acl.ApplyFiltering(networkMap, false) acl.ApplyFiltering(networkMap, false)
expectedRules := 2 // we should have one old and one new rule in the existed rules
if fw.IsStateful() { if len(acl.peerRulesPairs) != 2 {
expectedRules = 1 // only the inbound rule t.Errorf("firewall rules not applied")
return
} }
assert.Equal(t, expectedRules, len(acl.peerRulesPairs))
// check that old rule was removed // check that old rule was removed
previousCount := 0 previousCount := 0
for id := range acl.peerRulesPairs { for id := range acl.peerRulesPairs {
@@ -105,86 +107,26 @@ func TestDefaultManager(t *testing.T) {
previousCount++ previousCount++
} }
} }
if previousCount != 1 {
expectedPreviousCount := 0 t.Errorf("old rule was not removed")
if !fw.IsStateful() {
expectedPreviousCount = 1
} }
assert.Equal(t, expectedPreviousCount, previousCount)
}) })
t.Run("handle default rules", func(t *testing.T) { t.Run("handle default rules", func(t *testing.T) {
networkMap.FirewallRules = networkMap.FirewallRules[:0] networkMap.FirewallRules = networkMap.FirewallRules[:0]
networkMap.FirewallRulesIsEmpty = true networkMap.FirewallRulesIsEmpty = true
acl.ApplyFiltering(networkMap, false) if acl.ApplyFiltering(networkMap, false); len(acl.peerRulesPairs) != 0 {
assert.Equal(t, 0, len(acl.peerRulesPairs)) t.Errorf("rules should be empty if FirewallRulesIsEmpty is set, got: %v", len(acl.peerRulesPairs))
return
}
networkMap.FirewallRulesIsEmpty = false networkMap.FirewallRulesIsEmpty = false
acl.ApplyFiltering(networkMap, false) acl.ApplyFiltering(networkMap, false)
if len(acl.peerRulesPairs) != 1 {
expectedRules := 1 t.Errorf("rules should contain 1 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs))
if fw.IsStateful() { return
expectedRules = 1 // only inbound allow-all rule
} }
assert.Equal(t, expectedRules, len(acl.peerRulesPairs))
})
}
func TestDefaultManagerStateless(t *testing.T) {
// stateless currently only in userspace, so we have to disable kernel
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
t.Setenv("NB_DISABLE_CONNTRACK", "true")
networkMap := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "80",
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
Port: "53",
},
},
}
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any())
network := netip.MustParsePrefix("172.0.0.1/32")
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: network.Addr(),
Network: network,
}).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
require.NoError(t, err)
defer func() {
err = fw.Close(nil)
require.NoError(t, err)
}()
acl := NewDefaultManager(fw)
t.Run("stateless firewall creates outbound rules", func(t *testing.T) {
acl.ApplyFiltering(networkMap, false)
// In stateless mode, we should have both inbound and outbound rules
assert.False(t, fw.IsStateful())
assert.Equal(t, 2, len(acl.peerRulesPairs))
}) })
} }
@@ -250,19 +192,42 @@ func TestDefaultManagerSquashRules(t *testing.T) {
manager := &DefaultManager{} manager := &DefaultManager{}
rules, _ := manager.squashAcceptRules(networkMap) rules, _ := manager.squashAcceptRules(networkMap)
assert.Equal(t, 2, len(rules)) if len(rules) != 2 {
t.Errorf("rules should contain 2, got: %v", rules)
return
}
r := rules[0] r := rules[0]
assert.Equal(t, "0.0.0.0", r.PeerIP) switch {
assert.Equal(t, mgmProto.RuleDirection_IN, r.Direction) case r.PeerIP != "0.0.0.0":
assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol) t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP)
assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action) return
case r.Direction != mgmProto.RuleDirection_IN:
t.Errorf("direction should be IN, got: %v", r.Direction)
return
case r.Protocol != mgmProto.RuleProtocol_ALL:
t.Errorf("protocol should be ALL, got: %v", r.Protocol)
return
case r.Action != mgmProto.RuleAction_ACCEPT:
t.Errorf("action should be ACCEPT, got: %v", r.Action)
return
}
r = rules[1] r = rules[1]
assert.Equal(t, "0.0.0.0", r.PeerIP) switch {
assert.Equal(t, mgmProto.RuleDirection_OUT, r.Direction) case r.PeerIP != "0.0.0.0":
assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol) t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP)
assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action) return
case r.Direction != mgmProto.RuleDirection_OUT:
t.Errorf("direction should be OUT, got: %v", r.Direction)
return
case r.Protocol != mgmProto.RuleProtocol_ALL:
t.Errorf("protocol should be ALL, got: %v", r.Protocol)
return
case r.Action != mgmProto.RuleAction_ACCEPT:
t.Errorf("action should be ACCEPT, got: %v", r.Action)
return
}
} }
func TestDefaultManagerSquashRulesNoAffect(t *testing.T) { func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
@@ -326,435 +291,8 @@ func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
} }
manager := &DefaultManager{} manager := &DefaultManager{}
rules, _ := manager.squashAcceptRules(networkMap) if rules, _ := manager.squashAcceptRules(networkMap); len(rules) != len(networkMap.FirewallRules) {
assert.Equal(t, len(networkMap.FirewallRules), len(rules)) t.Errorf("we should get the same amount of rules as output, got %v", len(rules))
}
func TestDefaultManagerSquashRulesWithPortRestrictions(t *testing.T) {
tests := []struct {
name string
rules []*mgmProto.FirewallRule
expectedCount int
description string
}{
{
name: "should not squash rules with port ranges",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
},
expectedCount: 4,
description: "Rules with port ranges should not be squashed even if they cover all peers",
},
{
name: "should not squash rules with specific ports",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
},
expectedCount: 4,
description: "Rules with specific ports should not be squashed even if they cover all peers",
},
{
name: "should not squash rules with legacy port field",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
},
expectedCount: 4,
description: "Rules with legacy port field should not be squashed",
},
{
name: "should not squash rules with DROP action",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
},
expectedCount: 4,
description: "Rules with DROP action should not be squashed",
},
{
name: "should squash rules without port restrictions",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
},
expectedCount: 1,
description: "Rules without port restrictions should be squashed into a single 0.0.0.0 rule",
},
{
name: "mixed rules should not squash protocol with port restrictions",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
},
expectedCount: 4,
description: "TCP should not be squashed because one rule has port restrictions",
},
{
name: "should squash UDP but not TCP when TCP has port restrictions",
rules: []*mgmProto.FirewallRule{
// TCP rules with port restrictions - should NOT be squashed
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
// UDP rules without port restrictions - SHOULD be squashed
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
},
expectedCount: 5, // 4 TCP rules + 1 squashed UDP rule (0.0.0.0)
description: "UDP should be squashed to 0.0.0.0 rule, but TCP should remain as individual rules due to port restrictions",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
networkMap := &mgmProto.NetworkMap{
RemotePeers: []*mgmProto.RemotePeerConfig{
{AllowedIps: []string{"10.93.0.1"}},
{AllowedIps: []string{"10.93.0.2"}},
{AllowedIps: []string{"10.93.0.3"}},
{AllowedIps: []string{"10.93.0.4"}},
},
FirewallRules: tt.rules,
}
manager := &DefaultManager{}
rules, _ := manager.squashAcceptRules(networkMap)
assert.Equal(t, tt.expectedCount, len(rules), tt.description)
// For squashed rules, verify we get the expected 0.0.0.0 rule
if tt.expectedCount == 1 {
assert.Equal(t, "0.0.0.0", rules[0].PeerIP)
assert.Equal(t, mgmProto.RuleDirection_IN, rules[0].Direction)
assert.Equal(t, mgmProto.RuleAction_ACCEPT, rules[0].Action)
}
})
}
}
func TestPortInfoEmpty(t *testing.T) {
tests := []struct {
name string
portInfo *mgmProto.PortInfo
expected bool
}{
{
name: "nil PortInfo should be empty",
portInfo: nil,
expected: true,
},
{
name: "PortInfo with zero port should be empty",
portInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 0,
},
},
expected: true,
},
{
name: "PortInfo with valid port should not be empty",
portInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
expected: false,
},
{
name: "PortInfo with nil range should be empty",
portInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: nil,
},
},
expected: true,
},
{
name: "PortInfo with zero start range should be empty",
portInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 0,
End: 100,
},
},
},
expected: true,
},
{
name: "PortInfo with zero end range should be empty",
portInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 80,
End: 0,
},
},
},
expected: true,
},
{
name: "PortInfo with valid range should not be empty",
portInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := portInfoEmpty(tt.portInfo)
assert.Equal(t, tt.expected, result)
})
} }
} }
@@ -798,29 +336,33 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
ifaceMock := mocks.NewMockIFaceMapper(ctrl) ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes() ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any()) ifaceMock.EXPECT().SetFilter(gomock.Any())
network := netip.MustParsePrefix("172.0.0.1/32") ip, network, err := net.ParseCIDR("172.0.0.1/32")
if err != nil {
t.Fatalf("failed to parse IP address: %v", err)
}
ifaceMock.EXPECT().Name().Return("lo").AnyTimes() ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{ ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: network.Addr(), IP: ip,
Network: network, Network: network,
}).AnyTimes() }).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
// we receive one rule from the management so for testing purposes ignore it
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false) fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
require.NoError(t, err) if err != nil {
defer func() { t.Errorf("create firewall: %v", err)
err = fw.Close(nil) return
require.NoError(t, err) }
}() defer func(fw manager.Manager) {
_ = fw.Close(nil)
}(fw)
acl := NewDefaultManager(fw) acl := NewDefaultManager(fw)
acl.ApplyFiltering(networkMap, false) acl.ApplyFiltering(networkMap, false)
expectedRules := 3 if len(acl.peerRulesPairs) != 3 {
if fw.IsStateful() { t.Errorf("expect 3 rules (last must be SSH), got: %d", len(acl.peerRulesPairs))
expectedRules = 3 // 2 inbound rules + SSH rule return
} }
assert.Equal(t, expectedRules, len(acl.peerRulesPairs))
} }

View File

@@ -64,8 +64,13 @@ func (t TokenInfo) GetTokenToUse() string {
// and if that also fails, the authentication process is deemed unsuccessful // and if that also fails, the authentication process is deemed unsuccessful
// //
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow // On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
func NewOAuthFlow(ctx context.Context, config *internal.Config, isUnixDesktopClient bool) (OAuthFlow, error) { func NewOAuthFlow(ctx context.Context, config *internal.Config, isLinuxDesktopClient bool) (OAuthFlow, error) {
if (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient { if runtime.GOOS == "linux" && !isLinuxDesktopClient {
return authenticateWithDeviceCodeFlow(ctx, config)
}
// On FreeBSD we currently do not support desktop environments and offer only Device Code Flow (#2384)
if runtime.GOOS == "freebsd" {
return authenticateWithDeviceCodeFlow(ctx, config) return authenticateWithDeviceCodeFlow(ctx, config)
} }

View File

@@ -101,12 +101,7 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
oauth2.SetAuthURLParam("audience", p.providerConfig.Audience), oauth2.SetAuthURLParam("audience", p.providerConfig.Audience),
} }
if !p.providerConfig.DisablePromptLogin { if !p.providerConfig.DisablePromptLogin {
if p.providerConfig.LoginFlag.IsPromptLogin() { params = append(params, oauth2.SetAuthURLParam("prompt", "login"))
params = append(params, oauth2.SetAuthURLParam("prompt", "login"))
}
if p.providerConfig.LoginFlag.IsMaxAge0Login() {
params = append(params, oauth2.SetAuthURLParam("max_age", "0"))
}
} }
authURL := p.oAuthConfig.AuthCodeURL(state, params...) authURL := p.oAuthConfig.AuthCodeURL(state, params...)

View File

@@ -7,36 +7,15 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
mgm "github.com/netbirdio/netbird/management/client/common"
) )
func TestPromptLogin(t *testing.T) { func TestPromptLogin(t *testing.T) {
const (
promptLogin = "prompt=login"
maxAge0 = "max_age=0"
)
tt := []struct { tt := []struct {
name string name string
loginFlag mgm.LoginFlag prompt bool
disablePromptLogin bool
expect string
}{ }{
{ {"PromptLogin", true},
name: "Prompt login", {"NoPromptLogin", false},
loginFlag: mgm.LoginFlagPrompt,
expect: promptLogin,
},
{
name: "Max age 0 login",
loginFlag: mgm.LoginFlagMaxAge0,
expect: maxAge0,
},
{
name: "Disable prompt login",
loginFlag: mgm.LoginFlagPrompt,
disablePromptLogin: true,
},
} }
for _, tc := range tt { for _, tc := range tt {
@@ -49,7 +28,7 @@ func TestPromptLogin(t *testing.T) {
AuthorizationEndpoint: "https://test-auth-endpoint.com/authorize", AuthorizationEndpoint: "https://test-auth-endpoint.com/authorize",
RedirectURLs: []string{"http://127.0.0.1:33992/"}, RedirectURLs: []string{"http://127.0.0.1:33992/"},
UseIDToken: true, UseIDToken: true,
LoginFlag: tc.loginFlag, DisablePromptLogin: !tc.prompt,
} }
pkce, err := NewPKCEAuthorizationFlow(config) pkce, err := NewPKCEAuthorizationFlow(config)
if err != nil { if err != nil {
@@ -59,12 +38,11 @@ func TestPromptLogin(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Failed to request auth info: %v", err) t.Fatalf("Failed to request auth info: %v", err)
} }
pattern := "prompt=login"
if !tc.disablePromptLogin { if tc.prompt {
require.Contains(t, authInfo.VerificationURIComplete, tc.expect) require.Contains(t, authInfo.VerificationURIComplete, pattern)
} else { } else {
require.Contains(t, authInfo.VerificationURIComplete, promptLogin) require.NotContains(t, authInfo.VerificationURIComplete, pattern)
require.NotContains(t, authInfo.VerificationURIComplete, maxAge0)
} }
}) })
} }

View File

@@ -68,14 +68,12 @@ type ConfigInput struct {
DisableServerRoutes *bool DisableServerRoutes *bool
DisableDNS *bool DisableDNS *bool
DisableFirewall *bool DisableFirewall *bool
BlockLANAccess *bool
BlockInbound *bool BlockLANAccess *bool
DisableNotifications *bool DisableNotifications *bool
DNSLabels domain.List DNSLabels domain.List
LazyConnectionEnabled *bool
} }
// Config Configuration type // Config Configuration type
@@ -98,8 +96,8 @@ type Config struct {
DisableServerRoutes bool DisableServerRoutes bool
DisableDNS bool DisableDNS bool
DisableFirewall bool DisableFirewall bool
BlockLANAccess bool
BlockInbound bool BlockLANAccess bool
DisableNotifications *bool DisableNotifications *bool
@@ -140,8 +138,6 @@ type Config struct {
ClientCertKeyPath string ClientCertKeyPath string
ClientCertKeyPair *tls.Certificate `json:"-"` ClientCertKeyPair *tls.Certificate `json:"-"`
LazyConnectionEnabled bool
} }
// ReadConfig read config file and return with Config. If it is not exists create a new with default values // ReadConfig read config file and return with Config. If it is not exists create a new with default values
@@ -223,8 +219,6 @@ func createNewConfig(input ConfigInput) (*Config, error) {
config := &Config{ config := &Config{
// defaults to false only for new (post 0.26) configurations // defaults to false only for new (post 0.26) configurations
ServerSSHAllowed: util.False(), ServerSSHAllowed: util.False(),
// default to disabling server routes on Android for security
DisableServerRoutes: runtime.GOOS == "android",
} }
if _, err := config.apply(input); err != nil { if _, err := config.apply(input); err != nil {
@@ -319,6 +313,10 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
*input.WireguardPort, config.WgPort) *input.WireguardPort, config.WgPort)
config.WgPort = *input.WireguardPort config.WgPort = *input.WireguardPort
updated = true updated = true
} else if config.WgPort == 0 {
config.WgPort = iface.DefaultWgPort
log.Infof("using default Wireguard port %d", config.WgPort)
updated = true
} }
if input.InterfaceName != nil && *input.InterfaceName != config.WgIface { if input.InterfaceName != nil && *input.InterfaceName != config.WgIface {
@@ -414,15 +412,9 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
config.ServerSSHAllowed = input.ServerSSHAllowed config.ServerSSHAllowed = input.ServerSSHAllowed
updated = true updated = true
} else if config.ServerSSHAllowed == nil { } else if config.ServerSSHAllowed == nil {
if runtime.GOOS == "android" { // enables SSH for configs from old versions to preserve backwards compatibility
// default to disabled SSH on Android for security log.Infof("falling back to enabled SSH server for pre-existing configuration")
log.Infof("setting SSH server to false by default on Android") config.ServerSSHAllowed = util.True()
config.ServerSSHAllowed = util.False()
} else {
// enables SSH for configs from old versions to preserve backwards compatibility
log.Infof("falling back to enabled SSH server for pre-existing configuration")
config.ServerSSHAllowed = util.True()
}
updated = true updated = true
} }
@@ -487,16 +479,6 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true updated = true
} }
if input.BlockInbound != nil && *input.BlockInbound != config.BlockInbound {
if *input.BlockInbound {
log.Infof("blocking inbound connections")
} else {
log.Infof("allowing inbound connections")
}
config.BlockInbound = *input.BlockInbound
updated = true
}
if input.DisableNotifications != nil && input.DisableNotifications != config.DisableNotifications { if input.DisableNotifications != nil && input.DisableNotifications != config.DisableNotifications {
if *input.DisableNotifications { if *input.DisableNotifications {
log.Infof("disabling notifications") log.Infof("disabling notifications")
@@ -542,12 +524,6 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true updated = true
} }
if input.LazyConnectionEnabled != nil && *input.LazyConnectionEnabled != config.LazyConnectionEnabled {
log.Infof("switching lazy connection to %t", *input.LazyConnectionEnabled)
config.LazyConnectionEnabled = *input.LazyConnectionEnabled
updated = true
}
return updated, nil return updated, nil
} }

View File

@@ -1,312 +0,0 @@
package internal
import (
"context"
"os"
"strconv"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/lazyconn"
"github.com/netbirdio/netbird/client/internal/lazyconn/manager"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/route"
)
// ConnMgr coordinates both lazy connections (established on-demand) and permanent peer connections.
//
// The connection manager is responsible for:
// - Managing lazy connections via the lazyConnManager
// - Maintaining a list of excluded peers that should always have permanent connections
// - Handling connection establishment based on peer signaling
//
// The implementation is not thread-safe; it is protected by engine.syncMsgMux.
type ConnMgr struct {
peerStore *peerstore.Store
statusRecorder *peer.Status
iface lazyconn.WGIface
dispatcher *dispatcher.ConnectionDispatcher
enabledLocally bool
lazyConnMgr *manager.Manager
wg sync.WaitGroup
lazyCtx context.Context
lazyCtxCancel context.CancelFunc
}
func NewConnMgr(engineConfig *EngineConfig, statusRecorder *peer.Status, peerStore *peerstore.Store, iface lazyconn.WGIface, dispatcher *dispatcher.ConnectionDispatcher) *ConnMgr {
e := &ConnMgr{
peerStore: peerStore,
statusRecorder: statusRecorder,
iface: iface,
dispatcher: dispatcher,
}
if engineConfig.LazyConnectionEnabled || lazyconn.IsLazyConnEnabledByEnv() {
e.enabledLocally = true
}
return e
}
// Start initializes the connection manager and starts the lazy connection manager if enabled by env var or cmd line option.
func (e *ConnMgr) Start(ctx context.Context) {
if e.lazyConnMgr != nil {
log.Errorf("lazy connection manager is already started")
return
}
if !e.enabledLocally {
log.Infof("lazy connection manager is disabled")
return
}
e.initLazyManager(ctx)
e.statusRecorder.UpdateLazyConnection(true)
}
// UpdatedRemoteFeatureFlag is called when the remote feature flag is updated.
// If enabled, it initializes the lazy connection manager and start it. Do not need to call Start() again.
// If disabled, then it closes the lazy connection manager and open the connections to all peers.
func (e *ConnMgr) UpdatedRemoteFeatureFlag(ctx context.Context, enabled bool) error {
// do not disable lazy connection manager if it was enabled by env var
if e.enabledLocally {
return nil
}
if enabled {
// if the lazy connection manager is already started, do not start it again
if e.lazyConnMgr != nil {
return nil
}
log.Infof("lazy connection manager is enabled by management feature flag")
e.initLazyManager(ctx)
e.statusRecorder.UpdateLazyConnection(true)
return e.addPeersToLazyConnManager()
} else {
if e.lazyConnMgr == nil {
return nil
}
log.Infof("lazy connection manager is disabled by management feature flag")
e.closeManager(ctx)
e.statusRecorder.UpdateLazyConnection(false)
return nil
}
}
// UpdateRouteHAMap updates the route HA mappings in the lazy connection manager
func (e *ConnMgr) UpdateRouteHAMap(haMap route.HAMap) {
if !e.isStartedWithLazyMgr() {
log.Debugf("lazy connection manager is not started, skipping UpdateRouteHAMap")
return
}
e.lazyConnMgr.UpdateRouteHAMap(haMap)
}
// SetExcludeList sets the list of peer IDs that should always have permanent connections.
func (e *ConnMgr) SetExcludeList(ctx context.Context, peerIDs map[string]bool) {
if e.lazyConnMgr == nil {
return
}
excludedPeers := make([]lazyconn.PeerConfig, 0, len(peerIDs))
for peerID := range peerIDs {
var peerConn *peer.Conn
var exists bool
if peerConn, exists = e.peerStore.PeerConn(peerID); !exists {
log.Warnf("failed to find peer conn for peerID: %s", peerID)
continue
}
lazyPeerCfg := lazyconn.PeerConfig{
PublicKey: peerID,
AllowedIPs: peerConn.WgConfig().AllowedIps,
PeerConnID: peerConn.ConnID(),
Log: peerConn.Log,
}
excludedPeers = append(excludedPeers, lazyPeerCfg)
}
added := e.lazyConnMgr.ExcludePeer(e.lazyCtx, excludedPeers)
for _, peerID := range added {
var peerConn *peer.Conn
var exists bool
if peerConn, exists = e.peerStore.PeerConn(peerID); !exists {
// if the peer not exist in the store, it means that the engine will call the AddPeerConn in next step
continue
}
peerConn.Log.Infof("peer has been added to lazy connection exclude list, opening permanent connection")
if err := peerConn.Open(ctx); err != nil {
peerConn.Log.Errorf("failed to open connection: %v", err)
}
}
}
func (e *ConnMgr) AddPeerConn(ctx context.Context, peerKey string, conn *peer.Conn) (exists bool) {
if success := e.peerStore.AddPeerConn(peerKey, conn); !success {
return true
}
if !e.isStartedWithLazyMgr() {
if err := conn.Open(ctx); err != nil {
conn.Log.Errorf("failed to open connection: %v", err)
}
return
}
if !lazyconn.IsSupported(conn.AgentVersionString()) {
conn.Log.Warnf("peer does not support lazy connection (%s), open permanent connection", conn.AgentVersionString())
if err := conn.Open(ctx); err != nil {
conn.Log.Errorf("failed to open connection: %v", err)
}
return
}
lazyPeerCfg := lazyconn.PeerConfig{
PublicKey: peerKey,
AllowedIPs: conn.WgConfig().AllowedIps,
PeerConnID: conn.ConnID(),
Log: conn.Log,
}
excluded, err := e.lazyConnMgr.AddPeer(e.lazyCtx, lazyPeerCfg)
if err != nil {
conn.Log.Errorf("failed to add peer to lazyconn manager: %v", err)
if err := conn.Open(ctx); err != nil {
conn.Log.Errorf("failed to open connection: %v", err)
}
return
}
if excluded {
conn.Log.Infof("peer is on lazy conn manager exclude list, opening connection")
if err := conn.Open(ctx); err != nil {
conn.Log.Errorf("failed to open connection: %v", err)
}
return
}
conn.Log.Infof("peer added to lazy conn manager")
return
}
func (e *ConnMgr) RemovePeerConn(peerKey string) {
conn, ok := e.peerStore.Remove(peerKey)
if !ok {
return
}
defer conn.Close()
if !e.isStartedWithLazyMgr() {
return
}
e.lazyConnMgr.RemovePeer(peerKey)
conn.Log.Infof("removed peer from lazy conn manager")
}
func (e *ConnMgr) OnSignalMsg(ctx context.Context, peerKey string) (*peer.Conn, bool) {
conn, ok := e.peerStore.PeerConn(peerKey)
if !ok {
return nil, false
}
if !e.isStartedWithLazyMgr() {
return conn, true
}
if found := e.lazyConnMgr.ActivatePeer(e.lazyCtx, peerKey); found {
conn.Log.Infof("activated peer from inactive state")
if err := conn.Open(ctx); err != nil {
conn.Log.Errorf("failed to open connection: %v", err)
}
}
return conn, true
}
func (e *ConnMgr) Close() {
if !e.isStartedWithLazyMgr() {
return
}
e.lazyCtxCancel()
e.wg.Wait()
e.lazyConnMgr = nil
}
func (e *ConnMgr) initLazyManager(engineCtx context.Context) {
cfg := manager.Config{
InactivityThreshold: inactivityThresholdEnv(),
}
e.lazyConnMgr = manager.NewManager(cfg, engineCtx, e.peerStore, e.iface, e.dispatcher)
e.lazyCtx, e.lazyCtxCancel = context.WithCancel(engineCtx)
e.wg.Add(1)
go func() {
defer e.wg.Done()
e.lazyConnMgr.Start(e.lazyCtx)
}()
}
func (e *ConnMgr) addPeersToLazyConnManager() error {
peers := e.peerStore.PeersPubKey()
lazyPeerCfgs := make([]lazyconn.PeerConfig, 0, len(peers))
for _, peerID := range peers {
var peerConn *peer.Conn
var exists bool
if peerConn, exists = e.peerStore.PeerConn(peerID); !exists {
log.Warnf("failed to find peer conn for peerID: %s", peerID)
continue
}
lazyPeerCfg := lazyconn.PeerConfig{
PublicKey: peerID,
AllowedIPs: peerConn.WgConfig().AllowedIps,
PeerConnID: peerConn.ConnID(),
Log: peerConn.Log,
}
lazyPeerCfgs = append(lazyPeerCfgs, lazyPeerCfg)
}
return e.lazyConnMgr.AddActivePeers(e.lazyCtx, lazyPeerCfgs)
}
func (e *ConnMgr) closeManager(ctx context.Context) {
if e.lazyConnMgr == nil {
return
}
e.lazyCtxCancel()
e.wg.Wait()
e.lazyConnMgr = nil
for _, peerID := range e.peerStore.PeersPubKey() {
e.peerStore.PeerConnOpen(ctx, peerID)
}
}
func (e *ConnMgr) isStartedWithLazyMgr() bool {
return e.lazyConnMgr != nil && e.lazyCtxCancel != nil
}
func inactivityThresholdEnv() *time.Duration {
envValue := os.Getenv(lazyconn.EnvInactivityThreshold)
if envValue == "" {
return nil
}
parsedMinutes, err := strconv.Atoi(envValue)
if err != nil || parsedMinutes <= 0 {
return nil
}
d := time.Duration(parsedMinutes) * time.Minute
return &d
}

View File

@@ -17,6 +17,7 @@ import (
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status" gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/listener"
@@ -435,13 +436,11 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
DNSRouteInterval: config.DNSRouteInterval, DNSRouteInterval: config.DNSRouteInterval,
DisableClientRoutes: config.DisableClientRoutes, DisableClientRoutes: config.DisableClientRoutes,
DisableServerRoutes: config.DisableServerRoutes || config.BlockInbound, DisableServerRoutes: config.DisableServerRoutes,
DisableDNS: config.DisableDNS, DisableDNS: config.DisableDNS,
DisableFirewall: config.DisableFirewall, DisableFirewall: config.DisableFirewall,
BlockLANAccess: config.BlockLANAccess,
BlockInbound: config.BlockInbound,
LazyConnectionEnabled: config.LazyConnectionEnabled, BlockLANAccess: config.BlockLANAccess,
} }
if config.PreSharedKey != "" { if config.PreSharedKey != "" {
@@ -482,7 +481,7 @@ func connectToSignal(ctx context.Context, wtConfig *mgmProto.NetbirdConfig, ourP
return signalClient, nil return signalClient, nil
} }
// loginToManagement creates Management ServiceDependencies client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc) // loginToManagement creates Management Services client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc)
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *Config) (*mgmProto.LoginResponse, error) { func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *Config) (*mgmProto.LoginResponse, error) {
serverPublicKey, err := client.GetServerPublicKey() serverPublicKey, err := client.GetServerPublicKey()
@@ -499,9 +498,6 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
config.DisableServerRoutes, config.DisableServerRoutes,
config.DisableDNS, config.DisableDNS,
config.DisableFirewall, config.DisableFirewall,
config.BlockLANAccess,
config.BlockInbound,
config.LazyConnectionEnabled,
) )
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey, config.DNSLabels) loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey, config.DNSLabels)
if err != nil { if err != nil {
@@ -525,13 +521,17 @@ func statusRecorderToSignalConnStateNotifier(statusRecorder *peer.Status) signal
// freePort attempts to determine if the provided port is available, if not it will ask the system for a free port. // freePort attempts to determine if the provided port is available, if not it will ask the system for a free port.
func freePort(initPort int) (int, error) { func freePort(initPort int) (int, error) {
addr := net.UDPAddr{Port: initPort} addr := net.UDPAddr{}
if initPort == 0 {
initPort = iface.DefaultWgPort
}
addr.Port = initPort
conn, err := net.ListenUDP("udp", &addr) conn, err := net.ListenUDP("udp", &addr)
if err == nil { if err == nil {
returnPort := conn.LocalAddr().(*net.UDPAddr).Port
closeConnWithLog(conn) closeConnWithLog(conn)
return returnPort, nil return initPort, nil
} }
// if the port is already in use, ask the system for a free port // if the port is already in use, ask the system for a free port

View File

@@ -13,10 +13,10 @@ func Test_freePort(t *testing.T) {
shouldMatch bool shouldMatch bool
}{ }{
{ {
name: "when port is 0 use random port", name: "not provided, fallback to default",
port: 0, port: 0,
want: 0, want: 51820,
shouldMatch: false, shouldMatch: true,
}, },
{ {
name: "provided and available", name: "provided and available",
@@ -31,7 +31,7 @@ func Test_freePort(t *testing.T) {
shouldMatch: false, shouldMatch: false,
}, },
} }
c1, err := net.ListenUDP("udp", &net.UDPAddr{Port: 0}) c1, err := net.ListenUDP("udp", &net.UDPAddr{Port: 51830})
if err != nil { if err != nil {
t.Errorf("freePort error = %v", err) t.Errorf("freePort error = %v", err)
} }
@@ -39,14 +39,6 @@ func Test_freePort(t *testing.T) {
_ = c1.Close() _ = c1.Close()
}(c1) }(c1)
if tests[1].port == c1.LocalAddr().(*net.UDPAddr).Port {
tests[1].port++
tests[1].want++
}
tests[2].port = c1.LocalAddr().(*net.UDPAddr).Port
tests[2].want = c1.LocalAddr().(*net.UDPAddr).Port
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {

View File

@@ -4,7 +4,6 @@ import (
"archive/zip" "archive/zip"
"bufio" "bufio"
"bytes" "bytes"
"compress/gzip"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@@ -270,21 +269,11 @@ func (g *BundleGenerator) createArchive() error {
log.Errorf("Failed to add corrupted state files to debug bundle: %v", err) log.Errorf("Failed to add corrupted state files to debug bundle: %v", err)
} }
if err := g.addWgShow(); err != nil { if g.logFile != "console" {
log.Errorf("Failed to add wg show output: %v", err)
}
if g.logFile != "console" && g.logFile != "" {
if err := g.addLogfile(); err != nil { if err := g.addLogfile(); err != nil {
log.Errorf("Failed to add log file to debug bundle: %v", err) return fmt.Errorf("add log file: %w", err)
if err := g.trySystemdLogFallback(); err != nil {
log.Errorf("Failed to add systemd logs as fallback: %v", err)
}
} }
} else if err := g.trySystemdLogFallback(); err != nil {
log.Errorf("Failed to add systemd logs: %v", err)
} }
return nil return nil
} }
@@ -376,34 +365,17 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
configContent.WriteString(fmt.Sprintf("RosenpassEnabled: %v\n", g.internalConfig.RosenpassEnabled)) configContent.WriteString(fmt.Sprintf("RosenpassEnabled: %v\n", g.internalConfig.RosenpassEnabled))
configContent.WriteString(fmt.Sprintf("RosenpassPermissive: %v\n", g.internalConfig.RosenpassPermissive)) configContent.WriteString(fmt.Sprintf("RosenpassPermissive: %v\n", g.internalConfig.RosenpassPermissive))
if g.internalConfig.ServerSSHAllowed != nil { if g.internalConfig.ServerSSHAllowed != nil {
configContent.WriteString(fmt.Sprintf("ServerSSHAllowed: %v\n", *g.internalConfig.ServerSSHAllowed)) configContent.WriteString(fmt.Sprintf("BundleGeneratorSSHAllowed: %v\n", *g.internalConfig.ServerSSHAllowed))
} }
configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes))
configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes))
configContent.WriteString(fmt.Sprintf("DisableDNS: %v\n", g.internalConfig.DisableDNS))
configContent.WriteString(fmt.Sprintf("DisableFirewall: %v\n", g.internalConfig.DisableFirewall))
configContent.WriteString(fmt.Sprintf("BlockLANAccess: %v\n", g.internalConfig.BlockLANAccess))
configContent.WriteString(fmt.Sprintf("BlockInbound: %v\n", g.internalConfig.BlockInbound))
if g.internalConfig.DisableNotifications != nil {
configContent.WriteString(fmt.Sprintf("DisableNotifications: %v\n", *g.internalConfig.DisableNotifications))
}
configContent.WriteString(fmt.Sprintf("DNSLabels: %v\n", g.internalConfig.DNSLabels))
configContent.WriteString(fmt.Sprintf("DisableAutoConnect: %v\n", g.internalConfig.DisableAutoConnect)) configContent.WriteString(fmt.Sprintf("DisableAutoConnect: %v\n", g.internalConfig.DisableAutoConnect))
configContent.WriteString(fmt.Sprintf("DNSRouteInterval: %s\n", g.internalConfig.DNSRouteInterval)) configContent.WriteString(fmt.Sprintf("DNSRouteInterval: %s\n", g.internalConfig.DNSRouteInterval))
if g.internalConfig.ClientCertPath != "" { configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes))
configContent.WriteString(fmt.Sprintf("ClientCertPath: %s\n", g.internalConfig.ClientCertPath)) configContent.WriteString(fmt.Sprintf("DisableBundleGeneratorRoutes: %v\n", g.internalConfig.DisableServerRoutes))
} configContent.WriteString(fmt.Sprintf("DisableDNS: %v\n", g.internalConfig.DisableDNS))
if g.internalConfig.ClientCertKeyPath != "" { configContent.WriteString(fmt.Sprintf("DisableFirewall: %v\n", g.internalConfig.DisableFirewall))
configContent.WriteString(fmt.Sprintf("ClientCertKeyPath: %s\n", g.internalConfig.ClientCertKeyPath))
}
configContent.WriteString(fmt.Sprintf("LazyConnectionEnabled: %v\n", g.internalConfig.LazyConnectionEnabled)) configContent.WriteString(fmt.Sprintf("BlockLANAccess: %v\n", g.internalConfig.BlockLANAccess))
} }
func (g *BundleGenerator) addProf() (err error) { func (g *BundleGenerator) addProf() (err error) {
@@ -561,33 +533,6 @@ func (g *BundleGenerator) addLogfile() error {
return fmt.Errorf("add client log file to zip: %w", err) return fmt.Errorf("add client log file to zip: %w", err)
} }
// add latest rotated log file
pattern := filepath.Join(logDir, "client-*.log.gz")
files, err := filepath.Glob(pattern)
if err != nil {
log.Warnf("failed to glob rotated logs: %v", err)
} else if len(files) > 0 {
// pick the file with the latest ModTime
sort.Slice(files, func(i, j int) bool {
fi, err := os.Stat(files[i])
if err != nil {
log.Warnf("failed to stat rotated log %s: %v", files[i], err)
return false
}
fj, err := os.Stat(files[j])
if err != nil {
log.Warnf("failed to stat rotated log %s: %v", files[j], err)
return false
}
return fi.ModTime().Before(fj.ModTime())
})
latest := files[len(files)-1]
name := filepath.Base(latest)
if err := g.addSingleLogFileGz(latest, name); err != nil {
log.Warnf("failed to add rotated log %s: %v", name, err)
}
}
stdErrLogPath := filepath.Join(logDir, errorLogFile) stdErrLogPath := filepath.Join(logDir, errorLogFile)
stdoutLogPath := filepath.Join(logDir, stdoutLogFile) stdoutLogPath := filepath.Join(logDir, stdoutLogFile)
if runtime.GOOS == "darwin" { if runtime.GOOS == "darwin" {
@@ -618,13 +563,16 @@ func (g *BundleGenerator) addSingleLogfile(logPath, targetName string) error {
} }
}() }()
var logReader io.Reader = logFile var logReader io.Reader
if g.anonymize { if g.anonymize {
var writer *io.PipeWriter var writer *io.PipeWriter
logReader, writer = io.Pipe() logReader, writer = io.Pipe()
go anonymizeLog(logFile, writer, g.anonymizer) go anonymizeLog(logFile, writer, g.anonymizer)
} else {
logReader = logFile
} }
if err := g.addFileToZip(logReader, targetName); err != nil { if err := g.addFileToZip(logReader, targetName); err != nil {
return fmt.Errorf("add %s to zip: %w", targetName, err) return fmt.Errorf("add %s to zip: %w", targetName, err)
} }
@@ -632,44 +580,6 @@ func (g *BundleGenerator) addSingleLogfile(logPath, targetName string) error {
return nil return nil
} }
// addSingleLogFileGz adds a single gzipped log file to the archive
func (g *BundleGenerator) addSingleLogFileGz(logPath, targetName string) error {
f, err := os.Open(logPath)
if err != nil {
return fmt.Errorf("open gz log file %s: %w", targetName, err)
}
defer f.Close()
gzr, err := gzip.NewReader(f)
if err != nil {
return fmt.Errorf("create gzip reader: %w", err)
}
defer gzr.Close()
var logReader io.Reader = gzr
if g.anonymize {
var pw *io.PipeWriter
logReader, pw = io.Pipe()
go anonymizeLog(gzr, pw, g.anonymizer)
}
var buf bytes.Buffer
gw := gzip.NewWriter(&buf)
if _, err := io.Copy(gw, logReader); err != nil {
return fmt.Errorf("re-gzip: %w", err)
}
if err := gw.Close(); err != nil {
return fmt.Errorf("close gzip writer: %w", err)
}
if err := g.addFileToZip(&buf, targetName); err != nil {
return fmt.Errorf("add anonymized gz: %w", err)
}
return nil
}
func (g *BundleGenerator) addFileToZip(reader io.Reader, filename string) error { func (g *BundleGenerator) addFileToZip(reader io.Reader, filename string) error {
header := &zip.FileHeader{ header := &zip.FileHeader{
Name: filename, Name: filename,

View File

@@ -4,104 +4,17 @@ package debug
import ( import (
"bytes" "bytes"
"context"
"encoding/binary" "encoding/binary"
"errors"
"fmt" "fmt"
"os"
"os/exec" "os/exec"
"sort" "sort"
"strings" "strings"
"time"
"github.com/google/nftables" "github.com/google/nftables"
"github.com/google/nftables/expr" "github.com/google/nftables/expr"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
const (
maxLogEntries = 100000
maxLogAge = 7 * 24 * time.Hour // Last 7 days
)
// trySystemdLogFallback attempts to get logs from systemd journal as fallback
func (g *BundleGenerator) trySystemdLogFallback() error {
log.Debug("Attempting to collect systemd journal logs")
serviceName := getServiceName()
journalLogs, err := getSystemdLogs(serviceName)
if err != nil {
return fmt.Errorf("get systemd logs for %s: %w", serviceName, err)
}
if strings.Contains(journalLogs, "No recent log entries found") {
log.Debug("No recent log entries found in systemd journal")
return nil
}
if g.anonymize {
journalLogs = g.anonymizer.AnonymizeString(journalLogs)
}
logReader := strings.NewReader(journalLogs)
fileName := fmt.Sprintf("systemd-%s.log", serviceName)
if err := g.addFileToZip(logReader, fileName); err != nil {
return fmt.Errorf("add systemd logs to bundle: %w", err)
}
log.Infof("Added systemd journal logs for %s to debug bundle", serviceName)
return nil
}
// getServiceName gets the service name from environment or defaults to netbird
func getServiceName() string {
if unitName := os.Getenv("SYSTEMD_UNIT"); unitName != "" {
log.Debugf("Detected SYSTEMD_UNIT environment variable: %s", unitName)
return unitName
}
return "netbird"
}
// getSystemdLogs retrieves logs from systemd journal for a specific service using journalctl
func getSystemdLogs(serviceName string) (string, error) {
args := []string{
"-u", fmt.Sprintf("%s.service", serviceName),
"--since", fmt.Sprintf("-%s", maxLogAge.String()),
"--lines", fmt.Sprintf("%d", maxLogEntries),
"--no-pager",
"--output", "short-iso",
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, "journalctl", args...)
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
return "", fmt.Errorf("journalctl command timed out after 30 seconds")
}
if strings.Contains(err.Error(), "executable file not found") {
return "", fmt.Errorf("journalctl command not found: %w", err)
}
return "", fmt.Errorf("execute journalctl: %w (stderr: %s)", err, stderr.String())
}
logs := stdout.String()
if strings.TrimSpace(logs) == "" {
return "No recent log entries found in systemd journal", nil
}
header := fmt.Sprintf("=== Systemd Journal Logs for %s.service (last %d entries, max %s) ===\n",
serviceName, maxLogEntries, maxLogAge.String())
return header + logs, nil
}
// addFirewallRules collects and adds firewall rules to the archive // addFirewallRules collects and adds firewall rules to the archive
func (g *BundleGenerator) addFirewallRules() error { func (g *BundleGenerator) addFirewallRules() error {
log.Info("Collecting firewall rules") log.Info("Collecting firewall rules")
@@ -568,7 +481,7 @@ func formatExpr(exp expr.Any) string {
case *expr.Fib: case *expr.Fib:
return formatFib(e) return formatFib(e)
case *expr.Target: case *expr.Target:
return fmt.Sprintf("jump %s", e.Name) return fmt.Sprintf("jump %s", e.Name) // Properly format jump targets
case *expr.Immediate: case *expr.Immediate:
if e.Register == 1 { if e.Register == 1 {
return formatImmediateData(e.Data) return formatImmediateData(e.Data)

View File

@@ -6,9 +6,3 @@ package debug
func (g *BundleGenerator) addFirewallRules() error { func (g *BundleGenerator) addFirewallRules() error {
return nil return nil
} }
func (g *BundleGenerator) trySystemdLogFallback() error {
// Systemd is only available on Linux
// TODO: Add BSD support
return nil
}

View File

@@ -1,66 +0,0 @@
package debug
import (
"bytes"
"fmt"
"strings"
"time"
"github.com/netbirdio/netbird/client/iface/configurer"
)
type WGIface interface {
FullStats() (*configurer.Stats, error)
}
func (g *BundleGenerator) addWgShow() error {
result, err := g.statusRecorder.PeersStatus()
if err != nil {
return err
}
output := g.toWGShowFormat(result)
reader := bytes.NewReader([]byte(output))
if err := g.addFileToZip(reader, "wgshow.txt"); err != nil {
return fmt.Errorf("add wg show to zip: %w", err)
}
return nil
}
func (g *BundleGenerator) toWGShowFormat(s *configurer.Stats) string {
var sb strings.Builder
sb.WriteString(fmt.Sprintf("interface: %s\n", s.DeviceName))
sb.WriteString(fmt.Sprintf(" public key: %s\n", s.PublicKey))
sb.WriteString(fmt.Sprintf(" listen port: %d\n", s.ListenPort))
if s.FWMark != 0 {
sb.WriteString(fmt.Sprintf(" fwmark: %#x\n", s.FWMark))
}
for _, peer := range s.Peers {
sb.WriteString(fmt.Sprintf("\npeer: %s\n", peer.PublicKey))
if peer.Endpoint.IP != nil {
if g.anonymize {
anonEndpoint := g.anonymizer.AnonymizeUDPAddr(peer.Endpoint)
sb.WriteString(fmt.Sprintf(" endpoint: %s\n", anonEndpoint.String()))
} else {
sb.WriteString(fmt.Sprintf(" endpoint: %s\n", peer.Endpoint.String()))
}
}
if len(peer.AllowedIPs) > 0 {
var ipStrings []string
for _, ipnet := range peer.AllowedIPs {
ipStrings = append(ipStrings, ipnet.String())
}
sb.WriteString(fmt.Sprintf(" allowed ips: %s\n", strings.Join(ipStrings, ", ")))
}
sb.WriteString(fmt.Sprintf(" latest handshake: %s\n", peer.LastHandshake.Format(time.RFC1123)))
sb.WriteString(fmt.Sprintf(" transfer: %d B received, %d B sent\n", peer.RxBytes, peer.TxBytes))
if peer.PresharedKey {
sb.WriteString(" preshared key: (hidden)\n")
}
}
return sb.String()
}

View File

@@ -2,7 +2,7 @@ package internal
import ( import (
"fmt" "fmt"
"net/netip" "net"
"slices" "slices"
"strings" "strings"
@@ -12,14 +12,13 @@ import (
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
) )
func createPTRRecord(aRecord nbdns.SimpleRecord, prefix netip.Prefix) (nbdns.SimpleRecord, bool) { func createPTRRecord(aRecord nbdns.SimpleRecord, ipNet *net.IPNet) (nbdns.SimpleRecord, bool) {
ip, err := netip.ParseAddr(aRecord.RData) ip := net.ParseIP(aRecord.RData)
if err != nil { if ip == nil || ip.To4() == nil {
log.Warnf("failed to parse IP address %s: %v", aRecord.RData, err)
return nbdns.SimpleRecord{}, false return nbdns.SimpleRecord{}, false
} }
if !prefix.Contains(ip) { if !ipNet.Contains(ip) {
return nbdns.SimpleRecord{}, false return nbdns.SimpleRecord{}, false
} }
@@ -37,19 +36,16 @@ func createPTRRecord(aRecord nbdns.SimpleRecord, prefix netip.Prefix) (nbdns.Sim
} }
// generateReverseZoneName creates the reverse DNS zone name for a given network // generateReverseZoneName creates the reverse DNS zone name for a given network
func generateReverseZoneName(network netip.Prefix) (string, error) { func generateReverseZoneName(ipNet *net.IPNet) (string, error) {
networkIP := network.Masked().Addr() networkIP := ipNet.IP.Mask(ipNet.Mask)
maskOnes, _ := ipNet.Mask.Size()
if !networkIP.Is4() {
return "", fmt.Errorf("reverse DNS is only supported for IPv4 networks, got: %s", networkIP)
}
// round up to nearest byte // round up to nearest byte
octetsToUse := (network.Bits() + 7) / 8 octetsToUse := (maskOnes + 7) / 8
octets := strings.Split(networkIP.String(), ".") octets := strings.Split(networkIP.String(), ".")
if octetsToUse > len(octets) { if octetsToUse > len(octets) {
return "", fmt.Errorf("invalid network mask size for reverse DNS: %d", network.Bits()) return "", fmt.Errorf("invalid network mask size for reverse DNS: %d", maskOnes)
} }
reverseOctets := make([]string, octetsToUse) reverseOctets := make([]string, octetsToUse)
@@ -72,7 +68,7 @@ func zoneExists(config *nbdns.Config, zoneName string) bool {
} }
// collectPTRRecords gathers all PTR records for the given network from A records // collectPTRRecords gathers all PTR records for the given network from A records
func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.SimpleRecord { func collectPTRRecords(config *nbdns.Config, ipNet *net.IPNet) []nbdns.SimpleRecord {
var records []nbdns.SimpleRecord var records []nbdns.SimpleRecord
for _, zone := range config.CustomZones { for _, zone := range config.CustomZones {
@@ -81,7 +77,7 @@ func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.Simple
continue continue
} }
if ptrRecord, ok := createPTRRecord(record, prefix); ok { if ptrRecord, ok := createPTRRecord(record, ipNet); ok {
records = append(records, ptrRecord) records = append(records, ptrRecord)
} }
} }
@@ -91,8 +87,8 @@ func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.Simple
} }
// addReverseZone adds a reverse DNS zone to the configuration for the given network // addReverseZone adds a reverse DNS zone to the configuration for the given network
func addReverseZone(config *nbdns.Config, network netip.Prefix) { func addReverseZone(config *nbdns.Config, ipNet *net.IPNet) {
zoneName, err := generateReverseZoneName(network) zoneName, err := generateReverseZoneName(ipNet)
if err != nil { if err != nil {
log.Warn(err) log.Warn(err)
return return
@@ -103,7 +99,7 @@ func addReverseZone(config *nbdns.Config, network netip.Prefix) {
return return
} }
records := collectPTRRecords(config, network) records := collectPTRRecords(config, ipNet)
reverseZone := nbdns.CustomZone{ reverseZone := nbdns.CustomZone{
Domain: zoneName, Domain: zoneName,

View File

@@ -1,7 +1,6 @@
package dns package dns
import ( import (
"fmt"
"slices" "slices"
"strings" "strings"
"sync" "sync"
@@ -11,10 +10,9 @@ import (
) )
const ( const (
PriorityLocal = 100 PriorityDNSRoute = 100
PriorityDNSRoute = 75 PriorityMatchDomain = 50
PriorityUpstream = 50 PriorityDefault = 1
PriorityDefault = 1
) )
type SubdomainMatcher interface { type SubdomainMatcher interface {
@@ -150,42 +148,61 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
} }
qname := strings.ToLower(r.Question[0].Name) qname := strings.ToLower(r.Question[0].Name)
log.Tracef("handling DNS request for domain=%s", qname)
c.mu.RLock() c.mu.RLock()
handlers := slices.Clone(c.handlers) handlers := slices.Clone(c.handlers)
c.mu.RUnlock() c.mu.RUnlock()
if log.IsLevelEnabled(log.TraceLevel) { if log.IsLevelEnabled(log.TraceLevel) {
var b strings.Builder log.Tracef("current handlers (%d):", len(handlers))
b.WriteString(fmt.Sprintf("DNS request domain=%s, handlers (%d):\n", qname, len(handlers)))
for _, h := range handlers { for _, h := range handlers {
b.WriteString(fmt.Sprintf(" - pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d\n", log.Tracef(" - pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d",
h.Pattern, h.OrigPattern, h.IsWildcard, h.MatchSubdomains, h.Priority)) h.Pattern, h.OrigPattern, h.IsWildcard, h.MatchSubdomains, h.Priority)
} }
log.Trace(strings.TrimSuffix(b.String(), "\n"))
} }
// Try handlers in priority order // Try handlers in priority order
for _, entry := range handlers { for _, entry := range handlers {
matched := c.isHandlerMatch(qname, entry) var matched bool
switch {
if matched { case entry.Pattern == ".":
log.Tracef("handler matched: domain=%s -> pattern=%s wildcard=%v match_subdomain=%v priority=%d", matched = true
qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority) case entry.IsWildcard:
parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern), ".")
chainWriter := &ResponseWriterChain{ matched = len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern)
ResponseWriter: w, default:
origPattern: entry.OrigPattern, // For non-wildcard patterns:
// If handler wants subdomain matching, allow suffix match
// Otherwise require exact match
if entry.MatchSubdomains {
matched = strings.EqualFold(qname, entry.Pattern) || strings.HasSuffix(qname, "."+entry.Pattern)
} else {
matched = strings.EqualFold(qname, entry.Pattern)
} }
entry.Handler.ServeDNS(chainWriter, r)
// If handler wants to continue, try next handler
if chainWriter.shouldContinue {
log.Tracef("handler requested continue to next handler for domain=%s", qname)
continue
}
return
} }
if !matched {
log.Tracef("trying domain match: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v priority=%d matched=false",
qname, entry.OrigPattern, entry.MatchSubdomains, entry.IsWildcard, entry.Priority)
continue
}
log.Tracef("handler matched: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v priority=%d",
qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority)
chainWriter := &ResponseWriterChain{
ResponseWriter: w,
origPattern: entry.OrigPattern,
}
entry.Handler.ServeDNS(chainWriter, r)
// If handler wants to continue, try next handler
if chainWriter.shouldContinue {
log.Tracef("handler requested continue to next handler")
continue
}
return
} }
// No handler matched or all handlers passed // No handler matched or all handlers passed
@@ -196,22 +213,3 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
log.Errorf("failed to write DNS response: %v", err) log.Errorf("failed to write DNS response: %v", err)
} }
} }
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
switch {
case entry.Pattern == ".":
return true
case entry.IsWildcard:
parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern), ".")
return len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern)
default:
// For non-wildcard patterns:
// If handler wants subdomain matching, allow suffix match
// Otherwise require exact match
if entry.MatchSubdomains {
return strings.EqualFold(qname, entry.Pattern) || strings.HasSuffix(qname, "."+entry.Pattern)
} else {
return strings.EqualFold(qname, entry.Pattern)
}
}
}

View File

@@ -22,7 +22,7 @@ func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
// Setup handlers with different priorities // Setup handlers with different priorities
chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault) chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault)
chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityUpstream) chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain)
chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute) chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute)
// Create test request // Create test request
@@ -200,7 +200,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
priority int priority int
}{ }{
{pattern: "*.example.com.", priority: nbdns.PriorityDefault}, {pattern: "*.example.com.", priority: nbdns.PriorityDefault},
{pattern: "*.example.com.", priority: nbdns.PriorityUpstream}, {pattern: "*.example.com.", priority: nbdns.PriorityMatchDomain},
{pattern: "*.example.com.", priority: nbdns.PriorityDNSRoute}, {pattern: "*.example.com.", priority: nbdns.PriorityDNSRoute},
}, },
queryDomain: "test.example.com.", queryDomain: "test.example.com.",
@@ -214,7 +214,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
priority int priority int
}{ }{
{pattern: "*.example.com.", priority: nbdns.PriorityDefault}, {pattern: "*.example.com.", priority: nbdns.PriorityDefault},
{pattern: "test.example.com.", priority: nbdns.PriorityUpstream}, {pattern: "test.example.com.", priority: nbdns.PriorityMatchDomain},
{pattern: "*.test.example.com.", priority: nbdns.PriorityDNSRoute}, {pattern: "*.test.example.com.", priority: nbdns.PriorityDNSRoute},
}, },
queryDomain: "sub.test.example.com.", queryDomain: "sub.test.example.com.",
@@ -281,7 +281,7 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
// Add handlers in priority order // Add handlers in priority order
chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute) chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute)
chain.AddHandler("example.com.", handler2, nbdns.PriorityUpstream) chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain)
chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault) chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault)
// Create test request // Create test request
@@ -344,13 +344,13 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
priority int priority int
}{ }{
{"add", "example.com.", nbdns.PriorityDNSRoute}, {"add", "example.com.", nbdns.PriorityDNSRoute},
{"add", "example.com.", nbdns.PriorityUpstream}, {"add", "example.com.", nbdns.PriorityMatchDomain},
{"remove", "example.com.", nbdns.PriorityDNSRoute}, {"remove", "example.com.", nbdns.PriorityDNSRoute},
}, },
query: "example.com.", query: "example.com.",
expectedCalls: map[int]bool{ expectedCalls: map[int]bool{
nbdns.PriorityDNSRoute: false, nbdns.PriorityDNSRoute: false,
nbdns.PriorityUpstream: true, nbdns.PriorityMatchDomain: true,
}, },
}, },
{ {
@@ -361,13 +361,13 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
priority int priority int
}{ }{
{"add", "example.com.", nbdns.PriorityDNSRoute}, {"add", "example.com.", nbdns.PriorityDNSRoute},
{"add", "example.com.", nbdns.PriorityUpstream}, {"add", "example.com.", nbdns.PriorityMatchDomain},
{"remove", "example.com.", nbdns.PriorityUpstream}, {"remove", "example.com.", nbdns.PriorityMatchDomain},
}, },
query: "example.com.", query: "example.com.",
expectedCalls: map[int]bool{ expectedCalls: map[int]bool{
nbdns.PriorityDNSRoute: true, nbdns.PriorityDNSRoute: true,
nbdns.PriorityUpstream: false, nbdns.PriorityMatchDomain: false,
}, },
}, },
{ {
@@ -378,16 +378,16 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
priority int priority int
}{ }{
{"add", "example.com.", nbdns.PriorityDNSRoute}, {"add", "example.com.", nbdns.PriorityDNSRoute},
{"add", "example.com.", nbdns.PriorityUpstream}, {"add", "example.com.", nbdns.PriorityMatchDomain},
{"add", "example.com.", nbdns.PriorityDefault}, {"add", "example.com.", nbdns.PriorityDefault},
{"remove", "example.com.", nbdns.PriorityDNSRoute}, {"remove", "example.com.", nbdns.PriorityDNSRoute},
{"remove", "example.com.", nbdns.PriorityUpstream}, {"remove", "example.com.", nbdns.PriorityMatchDomain},
}, },
query: "example.com.", query: "example.com.",
expectedCalls: map[int]bool{ expectedCalls: map[int]bool{
nbdns.PriorityDNSRoute: false, nbdns.PriorityDNSRoute: false,
nbdns.PriorityUpstream: false, nbdns.PriorityMatchDomain: false,
nbdns.PriorityDefault: true, nbdns.PriorityDefault: true,
}, },
}, },
} }
@@ -454,7 +454,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
// Add handlers in mixed order // Add handlers in mixed order
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault) chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault)
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute) chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute)
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityUpstream) chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain)
// Test 1: Initial state // Test 1: Initial state
w1 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} w1 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
@@ -490,7 +490,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
defaultHandler.Calls = nil defaultHandler.Calls = nil
// Test 3: Remove middle priority handler // Test 3: Remove middle priority handler
chain.RemoveHandler(testDomain, nbdns.PriorityUpstream) chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain)
w3 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} w3 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
// Now lowest priority handler (defaultHandler) should be called // Now lowest priority handler (defaultHandler) should be called
@@ -607,7 +607,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
shouldMatch bool shouldMatch bool
}{ }{
{"EXAMPLE.COM.", nbdns.PriorityDefault, false, false}, {"EXAMPLE.COM.", nbdns.PriorityDefault, false, false},
{"example.com.", nbdns.PriorityUpstream, false, false}, {"example.com.", nbdns.PriorityMatchDomain, false, false},
{"Example.Com.", nbdns.PriorityDNSRoute, false, true}, {"Example.Com.", nbdns.PriorityDNSRoute, false, true},
}, },
query: "example.com.", query: "example.com.",
@@ -702,8 +702,8 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
priority int priority int
subdomain bool subdomain bool
}{ }{
{"add", "example.com.", nbdns.PriorityUpstream, true}, {"add", "example.com.", nbdns.PriorityMatchDomain, true},
{"add", "sub.example.com.", nbdns.PriorityUpstream, false}, {"add", "sub.example.com.", nbdns.PriorityMatchDomain, false},
}, },
query: "sub.example.com.", query: "sub.example.com.",
expectedMatch: "sub.example.com.", expectedMatch: "sub.example.com.",
@@ -717,8 +717,8 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
priority int priority int
subdomain bool subdomain bool
}{ }{
{"add", "example.com.", nbdns.PriorityUpstream, true}, {"add", "example.com.", nbdns.PriorityMatchDomain, true},
{"add", "sub.example.com.", nbdns.PriorityUpstream, true}, {"add", "sub.example.com.", nbdns.PriorityMatchDomain, true},
}, },
query: "sub.example.com.", query: "sub.example.com.",
expectedMatch: "sub.example.com.", expectedMatch: "sub.example.com.",
@@ -732,10 +732,10 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
priority int priority int
subdomain bool subdomain bool
}{ }{
{"add", "example.com.", nbdns.PriorityUpstream, true}, {"add", "example.com.", nbdns.PriorityMatchDomain, true},
{"add", "sub.example.com.", nbdns.PriorityUpstream, true}, {"add", "sub.example.com.", nbdns.PriorityMatchDomain, true},
{"add", "test.sub.example.com.", nbdns.PriorityUpstream, false}, {"add", "test.sub.example.com.", nbdns.PriorityMatchDomain, false},
{"remove", "test.sub.example.com.", nbdns.PriorityUpstream, false}, {"remove", "test.sub.example.com.", nbdns.PriorityMatchDomain, false},
}, },
query: "test.sub.example.com.", query: "test.sub.example.com.",
expectedMatch: "sub.example.com.", expectedMatch: "sub.example.com.",
@@ -749,7 +749,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
priority int priority int
subdomain bool subdomain bool
}{ }{
{"add", "sub.example.com.", nbdns.PriorityUpstream, false}, {"add", "sub.example.com.", nbdns.PriorityMatchDomain, false},
{"add", "example.com.", nbdns.PriorityDNSRoute, true}, {"add", "example.com.", nbdns.PriorityDNSRoute, true},
}, },
query: "sub.example.com.", query: "sub.example.com.",
@@ -764,9 +764,9 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
priority int priority int
subdomain bool subdomain bool
}{ }{
{"add", "example.com.", nbdns.PriorityUpstream, true}, {"add", "example.com.", nbdns.PriorityMatchDomain, true},
{"add", "other.example.com.", nbdns.PriorityUpstream, true}, {"add", "other.example.com.", nbdns.PriorityMatchDomain, true},
{"add", "sub.example.com.", nbdns.PriorityUpstream, false}, {"add", "sub.example.com.", nbdns.PriorityMatchDomain, false},
}, },
query: "sub.example.com.", query: "sub.example.com.",
expectedMatch: "sub.example.com.", expectedMatch: "sub.example.com.",

View File

@@ -1,14 +1,11 @@
package dns package dns
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"os/exec"
"strings" "strings"
"syscall" "syscall"
"time"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -44,20 +41,6 @@ const (
interfaceConfigNameServerKey = "NameServer" interfaceConfigNameServerKey = "NameServer"
interfaceConfigSearchListKey = "SearchList" interfaceConfigSearchListKey = "SearchList"
// Network interface DNS registration settings
disableDynamicUpdateKey = "DisableDynamicUpdate"
registrationEnabledKey = "RegistrationEnabled"
maxNumberOfAddressesToRegisterKey = "MaxNumberOfAddressesToRegister"
// NetBIOS/WINS settings
netbtInterfacePath = `SYSTEM\CurrentControlSet\Services\NetBT\Parameters\Interfaces`
netbiosOptionsKey = "NetbiosOptions"
// NetBIOS option values: 0 = from DHCP, 1 = enabled, 2 = disabled
netbiosFromDHCP = 0
netbiosEnabled = 1
netbiosDisabled = 2
// RP_FORCE: Reapply all policies even if no policy change was detected // RP_FORCE: Reapply all policies even if no policy change was detected
rpForce = 0x1 rpForce = 0x1
) )
@@ -84,85 +67,16 @@ func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
log.Infof("detected GPO DNS policy configuration, using policy store") log.Infof("detected GPO DNS policy configuration, using policy store")
} }
configurator := &registryConfigurator{ return &registryConfigurator{
guid: guid, guid: guid,
gpo: useGPO, gpo: useGPO,
} }, nil
if err := configurator.configureInterface(); err != nil {
log.Errorf("failed to configure interface settings: %v", err)
}
return configurator, nil
} }
func (r *registryConfigurator) supportCustomPort() bool { func (r *registryConfigurator) supportCustomPort() bool {
return false return false
} }
func (r *registryConfigurator) configureInterface() error {
var merr *multierror.Error
if err := r.disableDNSRegistrationForInterface(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("disable DNS registration: %w", err))
}
if err := r.disableWINSForInterface(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("disable WINS: %w", err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *registryConfigurator) disableDNSRegistrationForInterface() error {
regKey, err := r.getInterfaceRegistryKey()
if err != nil {
return fmt.Errorf("get interface registry key: %w", err)
}
defer closer(regKey)
var merr *multierror.Error
if err := regKey.SetDWordValue(disableDynamicUpdateKey, 1); err != nil {
merr = multierror.Append(merr, fmt.Errorf("set %s: %w", disableDynamicUpdateKey, err))
}
if err := regKey.SetDWordValue(registrationEnabledKey, 0); err != nil {
merr = multierror.Append(merr, fmt.Errorf("set %s: %w", registrationEnabledKey, err))
}
if err := regKey.SetDWordValue(maxNumberOfAddressesToRegisterKey, 0); err != nil {
merr = multierror.Append(merr, fmt.Errorf("set %s: %w", maxNumberOfAddressesToRegisterKey, err))
}
if merr == nil || len(merr.Errors) == 0 {
log.Infof("disabled DNS registration for interface %s", r.guid)
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *registryConfigurator) disableWINSForInterface() error {
netbtKeyPath := fmt.Sprintf(`%s\Tcpip_%s`, netbtInterfacePath, r.guid)
regKey, err := registry.OpenKey(registry.LOCAL_MACHINE, netbtKeyPath, registry.SET_VALUE)
if err != nil {
regKey, _, err = registry.CreateKey(registry.LOCAL_MACHINE, netbtKeyPath, registry.SET_VALUE)
if err != nil {
return fmt.Errorf("create NetBT interface key %s: %w", netbtKeyPath, err)
}
}
defer closer(regKey)
// NetbiosOptions: 2 = disabled
if err := regKey.SetDWordValue(netbiosOptionsKey, netbiosDisabled); err != nil {
return fmt.Errorf("set %s: %w", netbiosOptionsKey, err)
}
log.Infof("disabled WINS/NetBIOS for interface %s", r.guid)
return nil
}
func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
if config.RouteAll { if config.RouteAll {
if err := r.addDNSSetupForAll(config.ServerIP); err != nil { if err := r.addDNSSetupForAll(config.ServerIP); err != nil {
@@ -205,7 +119,9 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
return fmt.Errorf("update search domains: %w", err) return fmt.Errorf("update search domains: %w", err)
} }
go r.flushDNSCache() if err := r.flushDNSCache(); err != nil {
log.Errorf("failed to flush DNS cache: %v", err)
}
return nil return nil
} }
@@ -275,25 +191,7 @@ func (r *registryConfigurator) string() string {
return "registry" return "registry"
} }
func (r *registryConfigurator) registerDNS() { func (r *registryConfigurator) flushDNSCache() error {
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
// nolint:misspell
cmd := exec.CommandContext(ctx, "ipconfig", "/registerdns")
out, err := cmd.CombinedOutput()
if err != nil {
log.Errorf("failed to register DNS: %v, output: %s", err, out)
return
}
log.Info("registered DNS names")
}
func (r *registryConfigurator) flushDNSCache() {
r.registerDNS()
// dnsFlushResolverCacheFn.Call() may panic if the func is not found // dnsFlushResolverCacheFn.Call() may panic if the func is not found
defer func() { defer func() {
if rec := recover(); rec != nil { if rec := recover(); rec != nil {
@@ -304,14 +202,13 @@ func (r *registryConfigurator) flushDNSCache() {
ret, _, err := dnsFlushResolverCacheFn.Call() ret, _, err := dnsFlushResolverCacheFn.Call()
if ret == 0 { if ret == 0 {
if err != nil && !errors.Is(err, syscall.Errno(0)) { if err != nil && !errors.Is(err, syscall.Errno(0)) {
log.Errorf("DnsFlushResolverCache failed: %v", err) return fmt.Errorf("DnsFlushResolverCache failed: %w", err)
return
} }
log.Errorf("DnsFlushResolverCache failed") return fmt.Errorf("DnsFlushResolverCache failed")
return
} }
log.Info("flushed DNS cache") log.Info("flushed DNS cache")
return nil
} }
func (r *registryConfigurator) updateSearchDomains(domains []string) error { func (r *registryConfigurator) updateSearchDomains(domains []string) error {
@@ -366,7 +263,9 @@ func (r *registryConfigurator) restoreHostDNS() error {
return fmt.Errorf("remove interface registry key: %w", err) return fmt.Errorf("remove interface registry key: %w", err)
} }
go r.flushDNSCache() if err := r.flushDNSCache(); err != nil {
log.Errorf("failed to flush DNS cache: %v", err)
}
return nil return nil
} }

View File

@@ -12,19 +12,16 @@ import (
"github.com/netbirdio/netbird/client/internal/dns/types" "github.com/netbirdio/netbird/client/internal/dns/types"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
) )
type Resolver struct { type Resolver struct {
mu sync.RWMutex mu sync.RWMutex
records map[dns.Question][]dns.RR records map[dns.Question][]dns.RR
domains map[domain.Domain]struct{}
} }
func NewResolver() *Resolver { func NewResolver() *Resolver {
return &Resolver{ return &Resolver{
records: make(map[dns.Question][]dns.RR), records: make(map[dns.Question][]dns.RR),
domains: make(map[domain.Domain]struct{}),
} }
} }
@@ -67,12 +64,8 @@ func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
replyMessage.Rcode = dns.RcodeSuccess replyMessage.Rcode = dns.RcodeSuccess
replyMessage.Answer = append(replyMessage.Answer, records...) replyMessage.Answer = append(replyMessage.Answer, records...)
} else { } else {
// Check if we have any records for this domain name with different types // TODO: return success if we have a different record type for the same name, relevant for search domains
if d.hasRecordsForDomain(domain.Domain(question.Name)) { replyMessage.Rcode = dns.RcodeNameError
replyMessage.Rcode = dns.RcodeSuccess // NOERROR with 0 records
} else {
replyMessage.Rcode = dns.RcodeNameError // NXDOMAIN
}
} }
if err := w.WriteMsg(replyMessage); err != nil { if err := w.WriteMsg(replyMessage); err != nil {
@@ -80,15 +73,6 @@ func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
} }
} }
// hasRecordsForDomain checks if any records exist for the given domain name regardless of type
func (d *Resolver) hasRecordsForDomain(domainName domain.Domain) bool {
d.mu.RLock()
defer d.mu.RUnlock()
_, exists := d.domains[domainName]
return exists
}
// lookupRecords fetches *all* DNS records matching the first question in r. // lookupRecords fetches *all* DNS records matching the first question in r.
func (d *Resolver) lookupRecords(question dns.Question) []dns.RR { func (d *Resolver) lookupRecords(question dns.Question) []dns.RR {
d.mu.RLock() d.mu.RLock()
@@ -127,7 +111,6 @@ func (d *Resolver) Update(update []nbdns.SimpleRecord) {
defer d.mu.Unlock() defer d.mu.Unlock()
maps.Clear(d.records) maps.Clear(d.records)
maps.Clear(d.domains)
for _, rec := range update { for _, rec := range update {
if err := d.registerRecord(rec); err != nil { if err := d.registerRecord(rec); err != nil {
@@ -161,7 +144,6 @@ func (d *Resolver) registerRecord(record nbdns.SimpleRecord) error {
} }
d.records[q] = append(d.records[q], rr) d.records[q] = append(d.records[q], rr)
d.domains[domain.Domain(q.Name)] = struct{}{}
return nil return nil
} }

View File

@@ -470,115 +470,3 @@ func TestLocalResolver_CNAMEFallback(t *testing.T) {
}) })
} }
} }
// TestLocalResolver_NoErrorWithDifferentRecordType verifies that querying for a record type
// that doesn't exist but where other record types exist for the same domain returns NOERROR
// with 0 records instead of NXDOMAIN
func TestLocalResolver_NoErrorWithDifferentRecordType(t *testing.T) {
resolver := NewResolver()
recordA := nbdns.SimpleRecord{
Name: "example.netbird.cloud.",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "192.168.1.100",
}
recordCNAME := nbdns.SimpleRecord{
Name: "alias.netbird.cloud.",
Type: int(dns.TypeCNAME),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "target.example.com.",
}
resolver.Update([]nbdns.SimpleRecord{recordA, recordCNAME})
testCases := []struct {
name string
queryName string
queryType uint16
expectedRcode int
shouldHaveData bool
}{
{
name: "Query A record that exists",
queryName: "example.netbird.cloud.",
queryType: dns.TypeA,
expectedRcode: dns.RcodeSuccess,
shouldHaveData: true,
},
{
name: "Query AAAA for domain with only A record",
queryName: "example.netbird.cloud.",
queryType: dns.TypeAAAA,
expectedRcode: dns.RcodeSuccess,
shouldHaveData: false,
},
{
name: "Query other record with different case and non-fqdn",
queryName: "EXAMPLE.netbird.cloud",
queryType: dns.TypeAAAA,
expectedRcode: dns.RcodeSuccess,
shouldHaveData: false,
},
{
name: "Query TXT for domain with only A record",
queryName: "example.netbird.cloud.",
queryType: dns.TypeTXT,
expectedRcode: dns.RcodeSuccess,
shouldHaveData: false,
},
{
name: "Query A for domain with only CNAME record",
queryName: "alias.netbird.cloud.",
queryType: dns.TypeA,
expectedRcode: dns.RcodeSuccess,
shouldHaveData: true,
},
{
name: "Query AAAA for domain with only CNAME record",
queryName: "alias.netbird.cloud.",
queryType: dns.TypeAAAA,
expectedRcode: dns.RcodeSuccess,
shouldHaveData: true,
},
{
name: "Query for completely non-existent domain",
queryName: "nonexistent.netbird.cloud.",
queryType: dns.TypeA,
expectedRcode: dns.RcodeNameError,
shouldHaveData: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var responseMSG *dns.Msg
msg := new(dns.Msg).SetQuestion(tc.queryName, tc.queryType)
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
resolver.ServeDNS(responseWriter, msg)
require.NotNil(t, responseMSG, "Should have received a response message")
assert.Equal(t, tc.expectedRcode, responseMSG.Rcode,
"Response code should be %d (%s)",
tc.expectedRcode, dns.RcodeToString[tc.expectedRcode])
if tc.shouldHaveData {
assert.Greater(t, len(responseMSG.Answer), 0, "Response should contain answers")
} else {
assert.Equal(t, 0, len(responseMSG.Answer), "Response should contain no answers")
}
})
}
}

View File

@@ -489,7 +489,7 @@ func (s *DefaultServer) applyHostConfig() {
} }
} }
log.Debugf("extra match domains: %v", maps.Keys(s.extraDomains)) log.Debugf("extra match domains: %v", s.extraDomains)
if err := s.hostManager.applyDNSConfig(config, s.stateManager); err != nil { if err := s.hostManager.applyDNSConfig(config, s.stateManager); err != nil {
log.Errorf("failed to apply DNS host manager update: %v", err) log.Errorf("failed to apply DNS host manager update: %v", err)
@@ -527,7 +527,7 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
muxUpdates = append(muxUpdates, handlerWrapper{ muxUpdates = append(muxUpdates, handlerWrapper{
domain: customZone.Domain, domain: customZone.Domain,
handler: s.localResolver, handler: s.localResolver,
priority: PriorityLocal, priority: PriorityMatchDomain,
}) })
for _, record := range customZone.Records { for _, record := range customZone.Records {
@@ -566,7 +566,7 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
groupedNS := groupNSGroupsByDomain(nameServerGroups) groupedNS := groupNSGroupsByDomain(nameServerGroups)
for _, domainGroup := range groupedNS { for _, domainGroup := range groupedNS {
basePriority := PriorityUpstream basePriority := PriorityMatchDomain
if domainGroup.domain == nbdns.RootZone { if domainGroup.domain == nbdns.RootZone {
basePriority = PriorityDefault basePriority = PriorityDefault
} }
@@ -588,14 +588,10 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
// Decrement priority by handler index (0, 1, 2, ...) to avoid conflicts // Decrement priority by handler index (0, 1, 2, ...) to avoid conflicts
priority := basePriority - i priority := basePriority - i
// Check if we're about to overlap with the next priority tier. // Check if we're about to overlap with the next priority tier
// This boundary check ensures that the priority of upstream handlers does not conflict if basePriority == PriorityMatchDomain && priority <= PriorityDefault {
// with the default priority tier. By decrementing the priority for each handler, we avoid
// overlaps, but if the calculated priority falls into the default tier, we skip the remaining
// handlers to maintain the integrity of the priority system.
if basePriority == PriorityUpstream && priority <= PriorityDefault {
log.Warnf("too many handlers for domain=%s, would overlap with default priority tier (diff=%d). Skipping remaining handlers", log.Warnf("too many handlers for domain=%s, would overlap with default priority tier (diff=%d). Skipping remaining handlers",
domainGroup.domain, PriorityUpstream-PriorityDefault) domainGroup.domain, PriorityMatchDomain-PriorityDefault)
break break
} }

View File

@@ -46,9 +46,10 @@ func (w *mocWGIface) Name() string {
} }
func (w *mocWGIface) Address() wgaddr.Address { func (w *mocWGIface) Address() wgaddr.Address {
ip, network, _ := net.ParseCIDR("100.66.100.0/24")
return wgaddr.Address{ return wgaddr.Address{
IP: netip.MustParseAddr("100.66.100.1"), IP: ip,
Network: netip.MustParsePrefix("100.66.100.0/24"), Network: network,
} }
} }
@@ -164,12 +165,12 @@ func TestUpdateDNSServer(t *testing.T) {
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{ generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
domain: "netbird.io", domain: "netbird.io",
handler: dummyHandler, handler: dummyHandler,
priority: PriorityUpstream, priority: PriorityMatchDomain,
}, },
dummyHandler.ID(): handlerWrapper{ dummyHandler.ID(): handlerWrapper{
domain: "netbird.cloud", domain: "netbird.cloud",
handler: dummyHandler, handler: dummyHandler,
priority: PriorityLocal, priority: PriorityMatchDomain,
}, },
generateDummyHandler(".", nameServers).ID(): handlerWrapper{ generateDummyHandler(".", nameServers).ID(): handlerWrapper{
domain: nbdns.RootZone, domain: nbdns.RootZone,
@@ -186,7 +187,7 @@ func TestUpdateDNSServer(t *testing.T) {
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{ generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: "netbird.cloud", domain: "netbird.cloud",
handler: dummyHandler, handler: dummyHandler,
priority: PriorityUpstream, priority: PriorityMatchDomain,
}, },
}, },
initSerial: 0, initSerial: 0,
@@ -210,12 +211,12 @@ func TestUpdateDNSServer(t *testing.T) {
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{ generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
domain: "netbird.io", domain: "netbird.io",
handler: dummyHandler, handler: dummyHandler,
priority: PriorityUpstream, priority: PriorityMatchDomain,
}, },
"local-resolver": handlerWrapper{ "local-resolver": handlerWrapper{
domain: "netbird.cloud", domain: "netbird.cloud",
handler: dummyHandler, handler: dummyHandler,
priority: PriorityLocal, priority: PriorityMatchDomain,
}, },
}, },
expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}}, expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}},
@@ -305,7 +306,7 @@ func TestUpdateDNSServer(t *testing.T) {
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{ generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: zoneRecords[0].Name, domain: zoneRecords[0].Name,
handler: dummyHandler, handler: dummyHandler,
priority: PriorityUpstream, priority: PriorityMatchDomain,
}, },
}, },
initSerial: 0, initSerial: 0,
@@ -321,7 +322,7 @@ func TestUpdateDNSServer(t *testing.T) {
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{ generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: zoneRecords[0].Name, domain: zoneRecords[0].Name,
handler: dummyHandler, handler: dummyHandler,
priority: PriorityUpstream, priority: PriorityMatchDomain,
}, },
}, },
initSerial: 0, initSerial: 0,
@@ -463,10 +464,17 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
defer ctrl.Finish() defer ctrl.Finish()
_, ipNet, err := net.ParseCIDR("100.66.100.1/32")
if err != nil {
t.Errorf("parse CIDR: %v", err)
return
}
packetfilter := pfmock.NewMockPacketFilter(ctrl) packetfilter := pfmock.NewMockPacketFilter(ctrl)
packetfilter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).AnyTimes() packetfilter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).AnyTimes()
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
packetfilter.EXPECT().RemovePacketHook(gomock.Any()) packetfilter.EXPECT().RemovePacketHook(gomock.Any())
packetfilter.EXPECT().SetNetwork(ipNet)
if err := wgIface.SetFilter(packetfilter); err != nil { if err := wgIface.SetFilter(packetfilter); err != nil {
t.Errorf("set packet filter: %v", err) t.Errorf("set packet filter: %v", err)
@@ -495,7 +503,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
"id1": handlerWrapper{ "id1": handlerWrapper{
domain: zoneRecords[0].Name, domain: zoneRecords[0].Name,
handler: &local.Resolver{}, handler: &local.Resolver{},
priority: PriorityUpstream, priority: PriorityMatchDomain,
}, },
} }
//dnsServer.localResolver.RegisteredMap = local.RegistrationMap{local.BuildRecordKey("netbird.cloud", dns.ClassINET, dns.TypeA): struct{}{}} //dnsServer.localResolver.RegisteredMap = local.RegistrationMap{local.BuildRecordKey("netbird.cloud", dns.ClassINET, dns.TypeA): struct{}{}}
@@ -978,7 +986,7 @@ func TestHandlerChain_DomainPriorities(t *testing.T) {
} }
chain.AddHandler("example.com.", dnsRouteHandler, PriorityDNSRoute) chain.AddHandler("example.com.", dnsRouteHandler, PriorityDNSRoute)
chain.AddHandler("example.com.", upstreamHandler, PriorityUpstream) chain.AddHandler("example.com.", upstreamHandler, PriorityMatchDomain)
testCases := []struct { testCases := []struct {
name string name string
@@ -1059,14 +1067,14 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group1", Id: "upstream-group1",
}, },
priority: PriorityUpstream, priority: PriorityMatchDomain,
}, },
"upstream-group2": { "upstream-group2": {
domain: "example.com", domain: "example.com",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group2", Id: "upstream-group2",
}, },
priority: PriorityUpstream - 1, priority: PriorityMatchDomain - 1,
}, },
} }
@@ -1093,21 +1101,21 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group1", Id: "upstream-group1",
}, },
priority: PriorityUpstream, priority: PriorityMatchDomain,
}, },
"upstream-group2": { "upstream-group2": {
domain: "example.com", domain: "example.com",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group2", Id: "upstream-group2",
}, },
priority: PriorityUpstream - 1, priority: PriorityMatchDomain - 1,
}, },
"upstream-other": { "upstream-other": {
domain: "other.com", domain: "other.com",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-other", Id: "upstream-other",
}, },
priority: PriorityUpstream, priority: PriorityMatchDomain,
}, },
} }
@@ -1128,7 +1136,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group2", Id: "upstream-group2",
}, },
priority: PriorityUpstream - 1, priority: PriorityMatchDomain - 1,
}, },
}, },
expectedHandlers: map[string]string{ expectedHandlers: map[string]string{
@@ -1146,7 +1154,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group1", Id: "upstream-group1",
}, },
priority: PriorityUpstream, priority: PriorityMatchDomain,
}, },
}, },
expectedHandlers: map[string]string{ expectedHandlers: map[string]string{
@@ -1164,7 +1172,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group3", Id: "upstream-group3",
}, },
priority: PriorityUpstream + 1, priority: PriorityMatchDomain + 1,
}, },
// Keep existing groups with their original priorities // Keep existing groups with their original priorities
{ {
@@ -1172,14 +1180,14 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group1", Id: "upstream-group1",
}, },
priority: PriorityUpstream, priority: PriorityMatchDomain,
}, },
{ {
domain: "example.com", domain: "example.com",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group2", Id: "upstream-group2",
}, },
priority: PriorityUpstream - 1, priority: PriorityMatchDomain - 1,
}, },
}, },
expectedHandlers: map[string]string{ expectedHandlers: map[string]string{
@@ -1199,14 +1207,14 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group1", Id: "upstream-group1",
}, },
priority: PriorityUpstream, priority: PriorityMatchDomain,
}, },
{ {
domain: "example.com", domain: "example.com",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group2", Id: "upstream-group2",
}, },
priority: PriorityUpstream - 1, priority: PriorityMatchDomain - 1,
}, },
// Add group3 with lowest priority // Add group3 with lowest priority
{ {
@@ -1214,7 +1222,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group3", Id: "upstream-group3",
}, },
priority: PriorityUpstream - 2, priority: PriorityMatchDomain - 2,
}, },
}, },
expectedHandlers: map[string]string{ expectedHandlers: map[string]string{
@@ -1335,14 +1343,14 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group1", Id: "upstream-group1",
}, },
priority: PriorityUpstream, priority: PriorityMatchDomain,
}, },
{ {
domain: "other.com", domain: "other.com",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-other", Id: "upstream-other",
}, },
priority: PriorityUpstream, priority: PriorityMatchDomain,
}, },
}, },
expectedHandlers: map[string]string{ expectedHandlers: map[string]string{
@@ -1360,28 +1368,28 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group1", Id: "upstream-group1",
}, },
priority: PriorityUpstream, priority: PriorityMatchDomain,
}, },
{ {
domain: "example.com", domain: "example.com",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group2", Id: "upstream-group2",
}, },
priority: PriorityUpstream - 1, priority: PriorityMatchDomain - 1,
}, },
{ {
domain: "other.com", domain: "other.com",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-other", Id: "upstream-other",
}, },
priority: PriorityUpstream, priority: PriorityMatchDomain,
}, },
{ {
domain: "new.com", domain: "new.com",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-new", Id: "upstream-new",
}, },
priority: PriorityUpstream, priority: PriorityMatchDomain,
}, },
}, },
expectedHandlers: map[string]string{ expectedHandlers: map[string]string{
@@ -1791,14 +1799,14 @@ func TestExtraDomainsRefCounting(t *testing.T) {
// Register domains from different handlers with same domain // Register domains from different handlers with same domain
server.RegisterHandler(domain.List{"*.shared.example.com"}, &MockHandler{}, PriorityDNSRoute) server.RegisterHandler(domain.List{"*.shared.example.com"}, &MockHandler{}, PriorityDNSRoute)
server.RegisterHandler(domain.List{"shared.example.com."}, &MockHandler{}, PriorityUpstream) server.RegisterHandler(domain.List{"shared.example.com."}, &MockHandler{}, PriorityMatchDomain)
// Verify refcount is 2 // Verify refcount is 2
zoneKey := toZone("shared.example.com") zoneKey := toZone("shared.example.com")
assert.Equal(t, 2, server.extraDomains[zoneKey], "Refcount should be 2 after registering same domain twice") assert.Equal(t, 2, server.extraDomains[zoneKey], "Refcount should be 2 after registering same domain twice")
// Deregister one handler // Deregister one handler
server.DeregisterHandler(domain.List{"shared.example.com"}, PriorityUpstream) server.DeregisterHandler(domain.List{"shared.example.com"}, PriorityMatchDomain)
// Verify refcount is 1 // Verify refcount is 1
assert.Equal(t, 1, server.extraDomains[zoneKey], "Refcount should be 1 after deregistering one handler") assert.Equal(t, 1, server.extraDomains[zoneKey], "Refcount should be 1 after deregistering one handler")
@@ -1925,7 +1933,7 @@ func TestDomainCaseHandling(t *testing.T) {
} }
server.RegisterHandler(domain.List{"MIXED.example.com"}, &MockHandler{}, PriorityDefault) server.RegisterHandler(domain.List{"MIXED.example.com"}, &MockHandler{}, PriorityDefault)
server.RegisterHandler(domain.List{"mixed.EXAMPLE.com"}, &MockHandler{}, PriorityUpstream) server.RegisterHandler(domain.List{"mixed.EXAMPLE.com"}, &MockHandler{}, PriorityMatchDomain)
assert.Equal(t, 1, len(server.extraDomains), "Case differences should be normalized") assert.Equal(t, 1, len(server.extraDomains), "Case differences should be normalized")
@@ -1945,111 +1953,3 @@ func TestDomainCaseHandling(t *testing.T) {
assert.Contains(t, domains, "config.example.com.", "Mixed case domain should be normalized and pre.sent") assert.Contains(t, domains, "config.example.com.", "Mixed case domain should be normalized and pre.sent")
assert.Contains(t, domains, "mixed.example.com.", "Mixed case domain should be normalized and present") assert.Contains(t, domains, "mixed.example.com.", "Mixed case domain should be normalized and present")
} }
func TestLocalResolverPriorityInServer(t *testing.T) {
server := &DefaultServer{
ctx: context.Background(),
wgInterface: &mocWGIface{},
handlerChain: NewHandlerChain(),
localResolver: local.NewResolver(),
service: &mockService{},
extraDomains: make(map[domain.Domain]int),
}
config := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "local.example.com",
Records: []nbdns.SimpleRecord{
{
Name: "test.local.example.com",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "192.168.1.100",
},
},
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
Domains: []string{"local.example.com"}, // Same domain as local records
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
},
},
},
}
localMuxUpdates, _, err := server.buildLocalHandlerUpdate(config.CustomZones)
assert.NoError(t, err)
upstreamMuxUpdates, err := server.buildUpstreamHandlerUpdate(config.NameServerGroups)
assert.NoError(t, err)
// Verify that local handler has higher priority than upstream for same domain
var localPriority, upstreamPriority int
localFound, upstreamFound := false, false
for _, update := range localMuxUpdates {
if update.domain == "local.example.com" {
localPriority = update.priority
localFound = true
}
}
for _, update := range upstreamMuxUpdates {
if update.domain == "local.example.com" {
upstreamPriority = update.priority
upstreamFound = true
}
}
assert.True(t, localFound, "Local handler should be found")
assert.True(t, upstreamFound, "Upstream handler should be found")
assert.Greater(t, localPriority, upstreamPriority,
"Local handler priority (%d) should be higher than upstream priority (%d)",
localPriority, upstreamPriority)
assert.Equal(t, PriorityLocal, localPriority, "Local handler should use PriorityLocal")
assert.Equal(t, PriorityUpstream, upstreamPriority, "Upstream handler should use PriorityUpstream")
}
func TestLocalResolverPriorityConstants(t *testing.T) {
// Test that priority constants are ordered correctly
assert.Greater(t, PriorityLocal, PriorityDNSRoute, "Local priority should be higher than DNS route")
assert.Greater(t, PriorityLocal, PriorityUpstream, "Local priority should be higher than upstream")
assert.Greater(t, PriorityUpstream, PriorityDefault, "Upstream priority should be higher than default")
// Test that local resolver uses the correct priority
server := &DefaultServer{
localResolver: local.NewResolver(),
}
config := nbdns.Config{
CustomZones: []nbdns.CustomZone{
{
Domain: "local.example.com",
Records: []nbdns.SimpleRecord{
{
Name: "test.local.example.com",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "192.168.1.100",
},
},
},
},
}
localMuxUpdates, _, err := server.buildLocalHandlerUpdate(config.CustomZones)
assert.NoError(t, err)
assert.Len(t, localMuxUpdates, 1)
assert.Equal(t, PriorityLocal, localMuxUpdates[0].priority, "Local handler should use PriorityLocal")
assert.Equal(t, "local.example.com", localMuxUpdates[0].domain)
}

View File

@@ -24,15 +24,11 @@ type ServiceViaMemory struct {
} }
func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory { func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory {
lastIP, err := nbnet.GetLastIPFromNetwork(wgIface.Address().Network, 1)
if err != nil {
log.Errorf("get last ip from network: %v", err)
}
s := &ServiceViaMemory{ s := &ServiceViaMemory{
wgInterface: wgIface, wgInterface: wgIface,
dnsMux: dns.NewServeMux(), dnsMux: dns.NewServeMux(),
runtimeIP: lastIP.String(), runtimeIP: nbnet.GetLastIPFromNetwork(wgIface.Address().Network, 1).String(),
runtimePort: defaultPort, runtimePort: defaultPort,
} }
return s return s
@@ -95,7 +91,7 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
} }
firstLayerDecoder := layers.LayerTypeIPv4 firstLayerDecoder := layers.LayerTypeIPv4
if s.wgInterface.Address().IP.Is6() { if s.wgInterface.Address().Network.IP.To4() == nil {
firstLayerDecoder = layers.LayerTypeIPv6 firstLayerDecoder = layers.LayerTypeIPv6
} }

View File

@@ -0,0 +1,33 @@
package dns
import (
"net"
"testing"
nbnet "github.com/netbirdio/netbird/util/net"
)
func TestGetLastIPFromNetwork(t *testing.T) {
tests := []struct {
addr string
ip string
}{
{"2001:db8::/32", "2001:db8:ffff:ffff:ffff:ffff:ffff:fffe"},
{"192.168.0.0/30", "192.168.0.2"},
{"192.168.0.0/16", "192.168.255.254"},
{"192.168.0.0/24", "192.168.0.254"},
}
for _, tt := range tests {
_, ipnet, err := net.ParseCIDR(tt.addr)
if err != nil {
t.Errorf("Error parsing CIDR: %v", err)
return
}
lastIP := nbnet.GetLastIPFromNetwork(ipnet, 1).String()
if lastIP != tt.ip {
t.Errorf("wrong IP address, expected %s: got %s", tt.ip, lastIP)
}
}
}

View File

@@ -30,12 +30,9 @@ const (
systemdDbusSetDNSMethodSuffix = systemdDbusLinkInterface + ".SetDNS" systemdDbusSetDNSMethodSuffix = systemdDbusLinkInterface + ".SetDNS"
systemdDbusSetDefaultRouteMethodSuffix = systemdDbusLinkInterface + ".SetDefaultRoute" systemdDbusSetDefaultRouteMethodSuffix = systemdDbusLinkInterface + ".SetDefaultRoute"
systemdDbusSetDomainsMethodSuffix = systemdDbusLinkInterface + ".SetDomains" systemdDbusSetDomainsMethodSuffix = systemdDbusLinkInterface + ".SetDomains"
systemdDbusSetDNSSECMethodSuffix = systemdDbusLinkInterface + ".SetDNSSEC"
systemdDbusResolvConfModeForeign = "foreign" systemdDbusResolvConfModeForeign = "foreign"
dbusErrorUnknownObject = "org.freedesktop.DBus.Error.UnknownObject" dbusErrorUnknownObject = "org.freedesktop.DBus.Error.UnknownObject"
dnsSecDisabled = "no"
) )
type systemdDbusConfigurator struct { type systemdDbusConfigurator struct {
@@ -98,13 +95,9 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana
Family: unix.AF_INET, Family: unix.AF_INET,
Address: ipAs4[:], Address: ipAs4[:],
} }
if err = s.callLinkMethod(systemdDbusSetDNSMethodSuffix, []systemdDbusDNSInput{defaultLinkInput}); err != nil { err = s.callLinkMethod(systemdDbusSetDNSMethodSuffix, []systemdDbusDNSInput{defaultLinkInput})
return fmt.Errorf("set interface DNS server %s:%d: %w", config.ServerIP, config.ServerPort, err) if err != nil {
} return fmt.Errorf("setting the interface DNS server %s:%d failed with error: %w", config.ServerIP, config.ServerPort, err)
// We don't support dnssec. On some machines this is default on so we explicitly set it to off
if err = s.callLinkMethod(systemdDbusSetDNSSECMethodSuffix, dnsSecDisabled); err != nil {
log.Warnf("failed to set DNSSEC to 'no': %v", err)
} }
var ( var (

View File

@@ -2,7 +2,6 @@ package dns
import ( import (
"context" "context"
"crypto/rand"
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
"errors" "errors"
@@ -104,21 +103,19 @@ func (u *upstreamResolverBase) Stop() {
// ServeDNS handles a DNS request // ServeDNS handles a DNS request
func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
requestID := GenerateRequestID()
logger := log.WithField("request_id", requestID)
var err error var err error
defer func() { defer func() {
u.checkUpstreamFails(err) u.checkUpstreamFails(err)
}() }()
logger.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) log.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
if r.Extra == nil { if r.Extra == nil {
r.MsgHdr.AuthenticatedData = true r.MsgHdr.AuthenticatedData = true
} }
select { select {
case <-u.ctx.Done(): case <-u.ctx.Done():
logger.Tracef("%s has been stopped", u) log.Tracef("%s has been stopped", u)
return return
default: default:
} }
@@ -135,35 +132,35 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if err != nil { if err != nil {
if errors.Is(err, context.DeadlineExceeded) || isTimeout(err) { if errors.Is(err, context.DeadlineExceeded) || isTimeout(err) {
logger.Warnf("upstream %s timed out for question domain=%s", upstream, r.Question[0].Name) log.Warnf("upstream %s timed out for question domain=%s", upstream, r.Question[0].Name)
continue continue
} }
logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, r.Question[0].Name, err) log.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, r.Question[0].Name, err)
continue continue
} }
if rm == nil || !rm.Response { if rm == nil || !rm.Response {
logger.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name) log.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name)
continue continue
} }
u.successCount.Add(1) u.successCount.Add(1)
logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, r.Question[0].Name) log.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, r.Question[0].Name)
if err = w.WriteMsg(rm); err != nil { if err = w.WriteMsg(rm); err != nil {
logger.Errorf("failed to write DNS response for question domain=%s: %s", r.Question[0].Name, err) log.Errorf("failed to write DNS response for question domain=%s: %s", r.Question[0].Name, err)
} }
// count the fails only if they happen sequentially // count the fails only if they happen sequentially
u.failsCount.Store(0) u.failsCount.Store(0)
return return
} }
u.failsCount.Add(1) u.failsCount.Add(1)
logger.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name) log.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name)
m := new(dns.Msg) m := new(dns.Msg)
m.SetRcode(r, dns.RcodeServerFailure) m.SetRcode(r, dns.RcodeServerFailure)
if err := w.WriteMsg(m); err != nil { if err := w.WriteMsg(m); err != nil {
logger.Errorf("failed to write error response for %s for question domain=%s: %s", u, r.Question[0].Name, err) log.Errorf("failed to write error response for %s for question domain=%s: %s", u, r.Question[0].Name, err)
} }
} }
@@ -388,13 +385,3 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
return rm, t, nil return rm, t, nil
} }
func GenerateRequestID() string {
bytes := make([]byte, 4)
_, err := rand.Read(bytes)
if err != nil {
log.Errorf("failed to generate request ID: %v", err)
return ""
}
return hex.EncodeToString(bytes)
}

View File

@@ -3,7 +3,6 @@ package dns
import ( import (
"context" "context"
"net" "net"
"net/netip"
"syscall" "syscall"
"time" "time"
@@ -24,8 +23,8 @@ type upstreamResolver struct {
func newUpstreamResolver( func newUpstreamResolver(
ctx context.Context, ctx context.Context,
_ string, _ string,
_ netip.Addr, _ net.IP,
_ netip.Prefix, _ *net.IPNet,
statusRecorder *peer.Status, statusRecorder *peer.Status,
hostsDNSHolder *hostsDNSHolder, hostsDNSHolder *hostsDNSHolder,
domain string, domain string,
@@ -84,10 +83,3 @@ func (u *upstreamResolver) isLocalResolver(upstream string) bool {
} }
return false return false
} }
func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
return &dns.Client{
Timeout: dialTimeout,
Net: "udp",
}, nil
}

View File

@@ -4,7 +4,7 @@ package dns
import ( import (
"context" "context"
"net/netip" "net"
"time" "time"
"github.com/miekg/dns" "github.com/miekg/dns"
@@ -19,8 +19,8 @@ type upstreamResolver struct {
func newUpstreamResolver( func newUpstreamResolver(
ctx context.Context, ctx context.Context,
_ string, _ string,
_ netip.Addr, _ net.IP,
_ netip.Prefix, _ *net.IPNet,
statusRecorder *peer.Status, statusRecorder *peer.Status,
_ *hostsDNSHolder, _ *hostsDNSHolder,
domain string, domain string,
@@ -36,10 +36,3 @@ func newUpstreamResolver(
func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) { func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
return ExchangeWithFallback(ctx, &dns.Client{}, r, upstream) return ExchangeWithFallback(ctx, &dns.Client{}, r, upstream)
} }
func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
return &dns.Client{
Timeout: dialTimeout,
Net: "udp",
}, nil
}

View File

@@ -6,7 +6,6 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"net/netip"
"syscall" "syscall"
"time" "time"
@@ -19,16 +18,16 @@ import (
type upstreamResolverIOS struct { type upstreamResolverIOS struct {
*upstreamResolverBase *upstreamResolverBase
lIP netip.Addr lIP net.IP
lNet netip.Prefix lNet *net.IPNet
interfaceName string interfaceName string
} }
func newUpstreamResolver( func newUpstreamResolver(
ctx context.Context, ctx context.Context,
interfaceName string, interfaceName string,
ip netip.Addr, ip net.IP,
net netip.Prefix, net *net.IPNet,
statusRecorder *peer.Status, statusRecorder *peer.Status,
_ *hostsDNSHolder, _ *hostsDNSHolder,
domain string, domain string,
@@ -59,11 +58,8 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
} }
client.DialTimeout = timeout client.DialTimeout = timeout
upstreamIP, err := netip.ParseAddr(upstreamHost) upstreamIP := net.ParseIP(upstreamHost)
if err != nil { if u.lNet.Contains(upstreamIP) || net.IP.IsPrivate(upstreamIP) {
log.Warnf("failed to parse upstream host %s: %s", upstreamHost, err)
}
if u.lNet.Contains(upstreamIP) || upstreamIP.IsPrivate() {
log.Debugf("using private client to query upstream: %s", upstream) log.Debugf("using private client to query upstream: %s", upstream)
client, err = GetClientPrivate(u.lIP, u.interfaceName, timeout) client, err = GetClientPrivate(u.lIP, u.interfaceName, timeout)
if err != nil { if err != nil {
@@ -77,7 +73,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
// GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface // GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface
// This method is needed for iOS // This method is needed for iOS
func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) { func GetClientPrivate(ip net.IP, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
index, err := getInterfaceIndex(interfaceName) index, err := getInterfaceIndex(interfaceName)
if err != nil { if err != nil {
log.Debugf("unable to get interface index for %s: %s", interfaceName, err) log.Debugf("unable to get interface index for %s: %s", interfaceName, err)
@@ -86,7 +82,7 @@ func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Dura
dialer := &net.Dialer{ dialer := &net.Dialer{
LocalAddr: &net.UDPAddr{ LocalAddr: &net.UDPAddr{
IP: ip.AsSlice(), IP: ip,
Port: 0, // Let the OS pick a free port Port: 0, // Let the OS pick a free port
}, },
Timeout: dialTimeout, Timeout: dialTimeout,

View File

@@ -2,7 +2,7 @@ package dns
import ( import (
"context" "context"
"net/netip" "net"
"strings" "strings"
"testing" "testing"
"time" "time"
@@ -58,7 +58,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
for _, testCase := range testCases { for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO()) ctx, cancel := context.WithCancel(context.TODO())
resolver, _ := newUpstreamResolver(ctx, "", netip.Addr{}, netip.Prefix{}, nil, nil, ".") resolver, _ := newUpstreamResolver(ctx, "", net.IP{}, &net.IPNet{}, nil, nil, ".")
resolver.upstreamServers = testCase.InputServers resolver.upstreamServers = testCase.InputServers
resolver.upstreamTimeout = testCase.timeout resolver.upstreamTimeout = testCase.timeout
if testCase.cancelCTX { if testCase.cancelCTX {

View File

@@ -5,6 +5,7 @@ package dns
import ( import (
"net" "net"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
@@ -17,4 +18,5 @@ type WGIface interface {
IsUserspaceBind() bool IsUserspaceBind() bool
GetFilter() device.PacketFilter GetFilter() device.PacketFilter
GetDevice() *device.FilteredDevice GetDevice() *device.FilteredDevice
GetStats(peerKey string) (configurer.WGStats, error)
} }

View File

@@ -1,6 +1,7 @@
package dns package dns
import ( import (
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
@@ -12,5 +13,6 @@ type WGIface interface {
IsUserspaceBind() bool IsUserspaceBind() bool
GetFilter() device.PacketFilter GetFilter() device.PacketFilter
GetDevice() *device.FilteredDevice GetDevice() *device.FilteredDevice
GetStats(peerKey string) (configurer.WGStats, error)
GetInterfaceGUIDString() (string, error) GetInterfaceGUIDString() (string, error)
} }

View File

@@ -18,20 +18,14 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
const errResolveFailed = "failed to resolve query for domain=%s: %v" const errResolveFailed = "failed to resolve query for domain=%s: %v"
const upstreamTimeout = 15 * time.Second const upstreamTimeout = 15 * time.Second
type resolver interface {
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
}
type firewaller interface {
UpdateSet(set firewall.Set, prefixes []netip.Prefix) error
}
type DNSForwarder struct { type DNSForwarder struct {
listenAddress string listenAddress string
ttl uint32 ttl uint32
@@ -44,18 +38,16 @@ type DNSForwarder struct {
mutex sync.RWMutex mutex sync.RWMutex
fwdEntries []*ForwarderEntry fwdEntries []*ForwarderEntry
firewall firewaller firewall firewall.Manager
resolver resolver
} }
func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder { func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewall.Manager, statusRecorder *peer.Status) *DNSForwarder {
log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl) log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl)
return &DNSForwarder{ return &DNSForwarder{
listenAddress: listenAddress, listenAddress: listenAddress,
ttl: ttl, ttl: ttl,
firewall: firewall, firewall: firewall,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
resolver: net.DefaultResolver,
} }
} }
@@ -65,17 +57,14 @@ func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error {
// UDP server // UDP server
mux := dns.NewServeMux() mux := dns.NewServeMux()
f.mux = mux f.mux = mux
mux.HandleFunc(".", f.handleDNSQueryUDP)
f.dnsServer = &dns.Server{ f.dnsServer = &dns.Server{
Addr: f.listenAddress, Addr: f.listenAddress,
Net: "udp", Net: "udp",
Handler: mux, Handler: mux,
} }
// TCP server // TCP server
tcpMux := dns.NewServeMux() tcpMux := dns.NewServeMux()
f.tcpMux = tcpMux f.tcpMux = tcpMux
tcpMux.HandleFunc(".", f.handleDNSQueryTCP)
f.tcpServer = &dns.Server{ f.tcpServer = &dns.Server{
Addr: f.listenAddress, Addr: f.listenAddress,
Net: "tcp", Net: "tcp",
@@ -98,13 +87,30 @@ func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error {
// return the first error we get (e.g. bind failure or shutdown) // return the first error we get (e.g. bind failure or shutdown)
return <-errCh return <-errCh
} }
func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) { func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) {
f.mutex.Lock() f.mutex.Lock()
defer f.mutex.Unlock() defer f.mutex.Unlock()
if f.mux == nil {
log.Debug("DNS mux is nil, skipping domain update")
f.fwdEntries = entries
return
}
oldDomains := filterDomains(f.fwdEntries)
for _, d := range oldDomains {
f.mux.HandleRemove(d.PunycodeString())
f.tcpMux.HandleRemove(d.PunycodeString())
}
newDomains := filterDomains(entries)
for _, d := range newDomains {
f.mux.HandleFunc(d.PunycodeString(), f.handleDNSQueryUDP)
f.tcpMux.HandleFunc(d.PunycodeString(), f.handleDNSQueryTCP)
}
f.fwdEntries = entries f.fwdEntries = entries
log.Debugf("Updated DNS forwarder with %d domains", len(entries)) log.Debugf("Updated domains from %v to %v", oldDomains, newDomains)
} }
func (f *DNSForwarder) Close(ctx context.Context) error { func (f *DNSForwarder) Close(ctx context.Context) error {
@@ -151,31 +157,22 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
return nil return nil
} }
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, "."))
// query doesn't match any configured domain
if mostSpecificResId == "" {
resp.Rcode = dns.RcodeRefused
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS response: %v", err)
}
return nil
}
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout) ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
defer cancel() defer cancel()
ips, err := f.resolver.LookupNetIP(ctx, network, domain) ips, err := net.DefaultResolver.LookupNetIP(ctx, network, domain)
if err != nil { if err != nil {
f.handleDNSError(w, query, resp, domain, err) f.handleDNSError(w, query, resp, domain, err)
return nil return nil
} }
f.updateInternalState(ips, mostSpecificResId, matchingEntries) f.updateInternalState(domain, ips)
f.addIPsToResponse(resp, domain, ips) f.addIPsToResponse(resp, domain, ips)
return resp return resp
} }
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) { func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
resp := f.handleDNSQuery(w, query) resp := f.handleDNSQuery(w, query)
if resp == nil { if resp == nil {
return return
@@ -209,8 +206,9 @@ func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
} }
} }
func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) { func (f *DNSForwarder) updateInternalState(domain string, ips []netip.Addr) {
var prefixes []netip.Prefix var prefixes []netip.Prefix
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, "."))
if mostSpecificResId != "" { if mostSpecificResId != "" {
for _, ip := range ips { for _, ip := range ips {
var prefix netip.Prefix var prefix netip.Prefix
@@ -341,3 +339,16 @@ func (f *DNSForwarder) getMatchingEntries(domain string) (route.ResID, []*Forwar
return selectedResId, matches return selectedResId, matches
} }
// filterDomains returns a list of normalized domains
func filterDomains(entries []*ForwarderEntry) domain.List {
newDomains := make(domain.List, 0, len(entries))
for _, d := range entries {
if d.Domain == "" {
log.Warn("empty domain in DNS forwarder")
continue
}
newDomains = append(newDomains, domain.Domain(nbdns.NormalizeZone(d.Domain.PunycodeString())))
}
return newDomains
}

View File

@@ -1,21 +1,11 @@
package dnsfwd package dnsfwd
import ( import (
"context"
"fmt"
"net/netip"
"strings"
"testing" "testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/dns/test"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@@ -23,7 +13,7 @@ import (
func Test_getMatchingEntries(t *testing.T) { func Test_getMatchingEntries(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
storedMappings map[string]route.ResID storedMappings map[string]route.ResID // key: domain pattern, value: resId
queryDomain string queryDomain string
expectedResId route.ResID expectedResId route.ResID
}{ }{
@@ -54,7 +44,7 @@ func Test_getMatchingEntries(t *testing.T) {
{ {
name: "Wildcard pattern does not match different domain", name: "Wildcard pattern does not match different domain",
storedMappings: map[string]route.ResID{"*.example.com": "res4"}, storedMappings: map[string]route.ResID{"*.example.com": "res4"},
queryDomain: "foo.example.org", queryDomain: "foo.notexample.com",
expectedResId: "", expectedResId: "",
}, },
{ {
@@ -111,619 +101,3 @@ func Test_getMatchingEntries(t *testing.T) {
}) })
} }
} }
type MockFirewall struct {
mock.Mock
}
func (m *MockFirewall) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
args := m.Called(set, prefixes)
return args.Error(0)
}
type MockResolver struct {
mock.Mock
}
func (m *MockResolver) LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) {
args := m.Called(ctx, network, host)
return args.Get(0).([]netip.Addr), args.Error(1)
}
func TestDNSForwarder_SubdomainAccessLogic(t *testing.T) {
tests := []struct {
name string
configuredDomain string
queryDomain string
shouldMatch bool
expectedResID route.ResID
description string
}{
{
name: "exact domain match should be allowed",
configuredDomain: "example.com",
queryDomain: "example.com",
shouldMatch: true,
expectedResID: "test-res-id",
description: "Direct match to configured domain should work",
},
{
name: "subdomain access should be restricted",
configuredDomain: "example.com",
queryDomain: "mail.example.com",
shouldMatch: false,
expectedResID: "",
description: "Subdomain should not be accessible unless explicitly configured",
},
{
name: "wildcard should allow subdomains",
configuredDomain: "*.example.com",
queryDomain: "mail.example.com",
shouldMatch: true,
expectedResID: "test-res-id",
description: "Wildcard domains should allow subdomain access",
},
{
name: "wildcard should allow base domain",
configuredDomain: "*.example.com",
queryDomain: "example.com",
shouldMatch: true,
expectedResID: "test-res-id",
description: "Wildcard should also match the base domain",
},
{
name: "deep subdomain should be restricted",
configuredDomain: "example.com",
queryDomain: "deep.mail.example.com",
shouldMatch: false,
expectedResID: "",
description: "Deep subdomains should not be accessible",
},
{
name: "wildcard allows deep subdomains",
configuredDomain: "*.example.com",
queryDomain: "deep.mail.example.com",
shouldMatch: true,
expectedResID: "test-res-id",
description: "Wildcard should allow deep subdomains",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
forwarder := &DNSForwarder{}
d, err := domain.FromString(tt.configuredDomain)
require.NoError(t, err)
entries := []*ForwarderEntry{
{
Domain: d,
ResID: "test-res-id",
},
}
forwarder.UpdateDomains(entries)
resID, matchingEntries := forwarder.getMatchingEntries(tt.queryDomain)
if tt.shouldMatch {
assert.Equal(t, tt.expectedResID, resID, "Expected matching ResID")
assert.NotEmpty(t, matchingEntries, "Expected matching entries")
t.Logf("✓ Domain %s correctly matches pattern %s", tt.queryDomain, tt.configuredDomain)
} else {
assert.Equal(t, tt.expectedResID, resID, "Expected no ResID match")
assert.Empty(t, matchingEntries, "Expected no matching entries")
t.Logf("✓ Domain %s correctly does NOT match pattern %s", tt.queryDomain, tt.configuredDomain)
}
})
}
}
func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
tests := []struct {
name string
configuredDomain string
queryDomain string
shouldResolve bool
description string
}{
{
name: "configured exact domain resolves",
configuredDomain: "example.com",
queryDomain: "example.com",
shouldResolve: true,
description: "Exact match should resolve",
},
{
name: "unauthorized subdomain blocked",
configuredDomain: "example.com",
queryDomain: "mail.example.com",
shouldResolve: false,
description: "Subdomain should be blocked without wildcard",
},
{
name: "wildcard allows subdomain",
configuredDomain: "*.example.com",
queryDomain: "mail.example.com",
shouldResolve: true,
description: "Wildcard should allow subdomain",
},
{
name: "wildcard allows base domain",
configuredDomain: "*.example.com",
queryDomain: "example.com",
shouldResolve: true,
description: "Wildcard should allow base domain",
},
{
name: "unrelated domain blocked",
configuredDomain: "example.com",
queryDomain: "example.org",
shouldResolve: false,
description: "Unrelated domain should be blocked",
},
{
name: "deep subdomain blocked",
configuredDomain: "example.com",
queryDomain: "deep.mail.example.com",
shouldResolve: false,
description: "Deep subdomain should be blocked",
},
{
name: "wildcard allows deep subdomain",
configuredDomain: "*.example.com",
queryDomain: "deep.mail.example.com",
shouldResolve: true,
description: "Wildcard should allow deep subdomain",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockFirewall := &MockFirewall{}
mockResolver := &MockResolver{}
if tt.shouldResolve {
mockFirewall.On("UpdateSet", mock.AnythingOfType("manager.Set"), mock.AnythingOfType("[]netip.Prefix")).Return(nil)
// Mock successful DNS resolution
fakeIP := netip.MustParseAddr("1.2.3.4")
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(tt.queryDomain)).Return([]netip.Addr{fakeIP}, nil)
}
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
forwarder.resolver = mockResolver
d, err := domain.FromString(tt.configuredDomain)
require.NoError(t, err)
entries := []*ForwarderEntry{
{
Domain: d,
ResID: "test-res-id",
Set: firewall.NewDomainSet([]domain.Domain{d}),
},
}
forwarder.UpdateDomains(entries)
query := &dns.Msg{}
query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(mockWriter, query)
if tt.shouldResolve {
require.NotNil(t, resp, "Expected response for authorized domain")
require.Equal(t, dns.RcodeSuccess, resp.Rcode, "Expected successful response")
assert.NotEmpty(t, resp.Answer, "Expected DNS answer records")
time.Sleep(10 * time.Millisecond)
mockFirewall.AssertExpectations(t)
mockResolver.AssertExpectations(t)
} else {
if resp != nil {
assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess,
"Unauthorized domain should not return successful answers")
}
mockFirewall.AssertNotCalled(t, "UpdateSet")
mockResolver.AssertNotCalled(t, "LookupNetIP")
}
})
}
}
func TestDNSForwarder_FirewallSetUpdates(t *testing.T) {
tests := []struct {
name string
configuredDomains []string
query string
mockIP string
shouldResolve bool
expectedSetCount int // How many sets should be updated
description string
}{
{
name: "exact domain gets firewall update",
configuredDomains: []string{"example.com"},
query: "example.com",
mockIP: "1.1.1.1",
shouldResolve: true,
expectedSetCount: 1,
description: "Single exact match updates one set",
},
{
name: "wildcard domain gets firewall update",
configuredDomains: []string{"*.example.com"},
query: "mail.example.com",
mockIP: "1.1.1.2",
shouldResolve: true,
expectedSetCount: 1,
description: "Wildcard match updates one set",
},
{
name: "overlapping exact and wildcard both get updates",
configuredDomains: []string{"*.example.com", "mail.example.com"},
query: "mail.example.com",
mockIP: "1.1.1.3",
shouldResolve: true,
expectedSetCount: 2,
description: "Both exact and wildcard sets should be updated",
},
{
name: "unauthorized domain gets no firewall update",
configuredDomains: []string{"example.com"},
query: "mail.example.com",
mockIP: "1.1.1.4",
shouldResolve: false,
expectedSetCount: 0,
description: "No firewall update for unauthorized domains",
},
{
name: "multiple wildcards matching get all updated",
configuredDomains: []string{"*.example.com", "*.sub.example.com"},
query: "test.sub.example.com",
mockIP: "1.1.1.5",
shouldResolve: true,
expectedSetCount: 2,
description: "All matching wildcard sets should be updated",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockFirewall := &MockFirewall{}
mockResolver := &MockResolver{}
// Set up forwarder
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
forwarder.resolver = mockResolver
// Create entries and track sets
var entries []*ForwarderEntry
sets := make([]firewall.Set, 0)
for i, configDomain := range tt.configuredDomains {
d, err := domain.FromString(configDomain)
require.NoError(t, err)
set := firewall.NewDomainSet([]domain.Domain{d})
sets = append(sets, set)
entries = append(entries, &ForwarderEntry{
Domain: d,
ResID: route.ResID(fmt.Sprintf("res-%d", i)),
Set: set,
})
}
forwarder.UpdateDomains(entries)
// Set up mocks
if tt.shouldResolve {
fakeIP := netip.MustParseAddr(tt.mockIP)
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(tt.query)).
Return([]netip.Addr{fakeIP}, nil).Once()
expectedPrefixes := []netip.Prefix{netip.PrefixFrom(fakeIP, 32)}
// Count how many sets should actually match
updateCount := 0
for i, entry := range entries {
domain := strings.ToLower(tt.query)
pattern := entry.Domain.PunycodeString()
matches := false
if strings.HasPrefix(pattern, "*.") {
baseDomain := strings.TrimPrefix(pattern, "*.")
if domain == baseDomain || strings.HasSuffix(domain, "."+baseDomain) {
matches = true
}
} else if domain == pattern {
matches = true
}
if matches {
mockFirewall.On("UpdateSet", sets[i], expectedPrefixes).Return(nil).Once()
updateCount++
}
}
assert.Equal(t, tt.expectedSetCount, updateCount,
"Expected %d sets to be updated, but mock expects %d",
tt.expectedSetCount, updateCount)
}
// Execute query
dnsQuery := &dns.Msg{}
dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(mockWriter, dnsQuery)
// Verify response
if tt.shouldResolve {
require.NotNil(t, resp, "Expected response for authorized domain")
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.NotEmpty(t, resp.Answer)
} else if resp != nil {
assert.True(t, resp.Rcode == dns.RcodeRefused || len(resp.Answer) == 0,
"Unauthorized domain should be refused or have no answers")
}
// Verify all mock expectations were met
mockFirewall.AssertExpectations(t)
mockResolver.AssertExpectations(t)
})
}
}
// Test to verify that multiple IPs for one domain result in all prefixes being sent together
func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
mockFirewall := &MockFirewall{}
mockResolver := &MockResolver{}
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
forwarder.resolver = mockResolver
// Configure a single domain
d, err := domain.FromString("example.com")
require.NoError(t, err)
set := firewall.NewDomainSet([]domain.Domain{d})
entries := []*ForwarderEntry{{
Domain: d,
ResID: "test-res",
Set: set,
}}
forwarder.UpdateDomains(entries)
// Mock resolver returns multiple IPs
ips := []netip.Addr{
netip.MustParseAddr("1.1.1.1"),
netip.MustParseAddr("1.1.1.2"),
netip.MustParseAddr("1.1.1.3"),
}
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
Return(ips, nil).Once()
// Expect ONE UpdateSet call with ALL prefixes
expectedPrefixes := []netip.Prefix{
netip.PrefixFrom(ips[0], 32),
netip.PrefixFrom(ips[1], 32),
netip.PrefixFrom(ips[2], 32),
}
mockFirewall.On("UpdateSet", set, expectedPrefixes).Return(nil).Once()
// Execute query
query := &dns.Msg{}
query.SetQuestion("example.com.", dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(mockWriter, query)
// Verify response contains all IPs
require.NotNil(t, resp)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 3, "Should have 3 answer records")
// Verify mocks
mockFirewall.AssertExpectations(t)
mockResolver.AssertExpectations(t)
}
func TestDNSForwarder_ResponseCodes(t *testing.T) {
tests := []struct {
name string
queryType uint16
queryDomain string
configured string
expectedCode int
description string
}{
{
name: "unauthorized domain returns REFUSED",
queryType: dns.TypeA,
queryDomain: "evil.com",
configured: "example.com",
expectedCode: dns.RcodeRefused,
description: "RFC compliant REFUSED for unauthorized queries",
},
{
name: "unsupported query type returns NOTIMP",
queryType: dns.TypeMX,
queryDomain: "example.com",
configured: "example.com",
expectedCode: dns.RcodeNotImplemented,
description: "RFC compliant NOTIMP for unsupported types",
},
{
name: "CNAME query returns NOTIMP",
queryType: dns.TypeCNAME,
queryDomain: "example.com",
configured: "example.com",
expectedCode: dns.RcodeNotImplemented,
description: "CNAME queries not supported",
},
{
name: "TXT query returns NOTIMP",
queryType: dns.TypeTXT,
queryDomain: "example.com",
configured: "example.com",
expectedCode: dns.RcodeNotImplemented,
description: "TXT queries not supported",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
d, err := domain.FromString(tt.configured)
require.NoError(t, err)
entries := []*ForwarderEntry{{Domain: d, ResID: "test-res"}}
forwarder.UpdateDomains(entries)
query := &dns.Msg{}
query.SetQuestion(dns.Fqdn(tt.queryDomain), tt.queryType)
// Capture the written response
var writtenResp *dns.Msg
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
writtenResp = m
return nil
},
}
_ = forwarder.handleDNSQuery(mockWriter, query)
// Check the response written to the writer
require.NotNil(t, writtenResp, "Expected response to be written")
assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description)
})
}
}
func TestDNSForwarder_TCPTruncation(t *testing.T) {
// Test that large UDP responses are truncated with TC bit set
mockResolver := &MockResolver{}
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
forwarder.resolver = mockResolver
d, _ := domain.FromString("example.com")
entries := []*ForwarderEntry{{Domain: d, ResID: "test-res"}}
forwarder.UpdateDomains(entries)
// Mock many IPs to create a large response
var manyIPs []netip.Addr
for i := 0; i < 100; i++ {
manyIPs = append(manyIPs, netip.MustParseAddr(fmt.Sprintf("1.1.1.%d", i%256)))
}
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").Return(manyIPs, nil)
// Query without EDNS0
query := &dns.Msg{}
query.SetQuestion("example.com.", dns.TypeA)
var writtenResp *dns.Msg
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
writtenResp = m
return nil
},
}
forwarder.handleDNSQueryUDP(mockWriter, query)
require.NotNil(t, writtenResp)
assert.True(t, writtenResp.Truncated, "Large response should be truncated")
assert.LessOrEqual(t, writtenResp.Len(), dns.MinMsgSize, "Response should fit in minimum UDP size")
}
func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
// Test complex overlapping pattern scenarios
mockFirewall := &MockFirewall{}
mockResolver := &MockResolver{}
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
forwarder.resolver = mockResolver
// Set up complex overlapping patterns
patterns := []string{
"*.example.com", // Matches all subdomains
"*.mail.example.com", // More specific wildcard
"smtp.mail.example.com", // Exact match
"example.com", // Base domain
}
var entries []*ForwarderEntry
sets := make(map[string]firewall.Set)
for _, pattern := range patterns {
d, _ := domain.FromString(pattern)
set := firewall.NewDomainSet([]domain.Domain{d})
sets[pattern] = set
entries = append(entries, &ForwarderEntry{
Domain: d,
ResID: route.ResID("res-" + pattern),
Set: set,
})
}
forwarder.UpdateDomains(entries)
// Test smtp.mail.example.com - should match 3 patterns
fakeIP := netip.MustParseAddr("1.2.3.4")
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "smtp.mail.example.com.").Return([]netip.Addr{fakeIP}, nil)
expectedPrefix := netip.PrefixFrom(fakeIP, 32)
// All three matching patterns should get firewall updates
mockFirewall.On("UpdateSet", sets["smtp.mail.example.com"], []netip.Prefix{expectedPrefix}).Return(nil)
mockFirewall.On("UpdateSet", sets["*.mail.example.com"], []netip.Prefix{expectedPrefix}).Return(nil)
mockFirewall.On("UpdateSet", sets["*.example.com"], []netip.Prefix{expectedPrefix}).Return(nil)
query := &dns.Msg{}
query.SetQuestion("smtp.mail.example.com.", dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(mockWriter, query)
require.NotNil(t, resp)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
// Verify all three sets were updated
mockFirewall.AssertExpectations(t)
// Verify the most specific ResID was selected
// (exact match should win over wildcards)
resID, matches := forwarder.getMatchingEntries("smtp.mail.example.com")
assert.Equal(t, route.ResID("res-smtp.mail.example.com"), resID)
assert.Len(t, matches, 3, "Should match 3 patterns")
}
func TestDNSForwarder_EmptyQuery(t *testing.T) {
// Test handling of malformed query with no questions
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
query := &dns.Msg{}
// Don't set any question
writeCalled := false
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
writeCalled = true
return nil
},
}
resp := forwarder.handleDNSQuery(mockWriter, query)
assert.Nil(t, resp, "Should return nil for empty query")
assert.False(t, writeCalled, "Should not write response for empty query")
}

Some files were not shown because too many files have changed in this diff Show More