diff --git a/client/internal/connect.go b/client/internal/connect.go index 83909dfdd..d34d0aab0 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "net" "runtime" "runtime/debug" "strings" @@ -330,6 +331,15 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe engineConf.PreSharedKey = &preSharedKey } + port, err := freePort(config.WgPort) + if err != nil { + return nil, err + } + if port != config.WgPort { + log.Infof("using %d as wireguard port: %d is in use", port, config.WgPort) + } + engineConf.WgPort = port + return engineConf, nil } @@ -379,3 +389,20 @@ func statusRecorderToSignalConnStateNotifier(statusRecorder *peer.Status) signal notifier, _ := sri.(signal.ConnStateNotifier) return notifier } + +func freePort(start int) (int, error) { + addr := net.UDPAddr{} + if start == 0 { + start = iface.DefaultWgPort + } + for x := start; x <= 65535; x++ { + addr.Port = x + conn, err := net.ListenUDP("udp", &addr) + if err != nil { + continue + } + conn.Close() + return x, nil + } + return 0, errors.New("no free ports") +} diff --git a/client/internal/connect_test.go b/client/internal/connect_test.go new file mode 100644 index 000000000..6f4a6bbb7 --- /dev/null +++ b/client/internal/connect_test.go @@ -0,0 +1,57 @@ +package internal + +import ( + "net" + "testing" +) + +func Test_freePort(t *testing.T) { + tests := []struct { + name string + port int + want int + wantErr bool + }{ + { + name: "available", + port: 51820, + want: 51820, + wantErr: false, + }, + { + name: "notavailable", + port: 51830, + want: 51831, + wantErr: false, + }, + { + name: "noports", + port: 65535, + want: 0, + wantErr: true, + }, + } + for _, tt := range tests { + + c1, err := net.ListenUDP("udp", &net.UDPAddr{Port: 51830}) + if err != nil { + t.Errorf("freePort error = %v", err) + } + c2, err := net.ListenUDP("udp", &net.UDPAddr{Port: 65535}) + if err != nil { + t.Errorf("freePort error = %v", err) + } + t.Run(tt.name, func(t *testing.T) { + got, err := freePort(tt.port) + if (err != nil) != tt.wantErr { + t.Errorf("freePort() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("freePort() = %v, want %v", got, tt.want) + } + }) + c1.Close() + c2.Close() + } +}