diff --git a/relay/client/client_test.go b/relay/client/client_test.go index f5d122276..278d46d08 100644 --- a/relay/client/client_test.go +++ b/relay/client/client_test.go @@ -23,10 +23,10 @@ func TestMain(m *testing.M) { func TestClient(t *testing.T) { ctx := context.Background() - addr := "localhost:1234" + srvCfg := server.Config{Address: "localhost:1234"} srv := server.NewServer() go func() { - err := srv.Listen(addr) + err := srv.Listen(srvCfg) if err != nil { t.Fatalf("failed to bind server: %s", err) } @@ -39,21 +39,21 @@ func TestClient(t *testing.T) { } }() - clientAlice := NewClient(ctx, addr, "alice") + clientAlice := NewClient(ctx, srvCfg.Address, "alice") err := clientAlice.Connect() if err != nil { t.Fatalf("failed to connect to server: %s", err) } defer clientAlice.Close() - clientPlaceHolder := NewClient(ctx, addr, "clientPlaceHolder") + clientPlaceHolder := NewClient(ctx, srvCfg.Address, "clientPlaceHolder") err = clientPlaceHolder.Connect() if err != nil { t.Fatalf("failed to connect to server: %s", err) } defer clientPlaceHolder.Close() - clientBob := NewClient(ctx, addr, "bob") + clientBob := NewClient(ctx, srvCfg.Address, "bob") err = clientBob.Connect() if err != nil { t.Fatalf("failed to connect to server: %s", err) @@ -91,16 +91,16 @@ func TestClient(t *testing.T) { func TestRegistration(t *testing.T) { ctx := context.Background() - addr := "localhost:1234" + srvCfg := server.Config{Address: "localhost:1234"} srv := server.NewServer() go func() { - err := srv.Listen(addr) + err := srv.Listen(srvCfg) if err != nil { t.Fatalf("failed to bind server: %s", err) } }() - clientAlice := NewClient(ctx, addr, "alice") + clientAlice := NewClient(ctx, srvCfg.Address, "alice") err := clientAlice.Connect() if err != nil { _ = srv.Close() @@ -156,10 +156,10 @@ func TestEcho(t *testing.T) { ctx := context.Background() idAlice := "alice" idBob := "bob" - addr := "localhost:1234" + srvCfg := server.Config{Address: "localhost:1234"} srv := server.NewServer() go func() { - err := srv.Listen(addr) + err := srv.Listen(srvCfg) if err != nil { t.Fatalf("failed to bind server: %s", err) } @@ -172,7 +172,7 @@ func TestEcho(t *testing.T) { } }() - clientAlice := NewClient(ctx, addr, idAlice) + clientAlice := NewClient(ctx, srvCfg.Address, idAlice) err := clientAlice.Connect() if err != nil { t.Fatalf("failed to connect to server: %s", err) @@ -184,7 +184,7 @@ func TestEcho(t *testing.T) { } }() - clientBob := NewClient(ctx, addr, idBob) + clientBob := NewClient(ctx, srvCfg.Address, idBob) err = clientBob.Connect() if err != nil { t.Fatalf("failed to connect to server: %s", err) @@ -236,10 +236,10 @@ func TestEcho(t *testing.T) { func TestBindToUnavailabePeer(t *testing.T) { ctx := context.Background() - addr := "localhost:1234" + srvCfg := server.Config{Address: "localhost:1234"} srv := server.NewServer() go func() { - err := srv.Listen(addr) + err := srv.Listen(srvCfg) if err != nil { t.Fatalf("failed to bind server: %s", err) } @@ -253,7 +253,7 @@ func TestBindToUnavailabePeer(t *testing.T) { } }() - clientAlice := NewClient(ctx, addr, "alice") + clientAlice := NewClient(ctx, srvCfg.Address, "alice") err := clientAlice.Connect() if err != nil { t.Errorf("failed to connect to server: %s", err) @@ -273,10 +273,10 @@ func TestBindToUnavailabePeer(t *testing.T) { func TestBindReconnect(t *testing.T) { ctx := context.Background() - addr := "localhost:1234" + srvCfg := server.Config{Address: "localhost:1234"} srv := server.NewServer() go func() { - err := srv.Listen(addr) + err := srv.Listen(srvCfg) if err != nil { t.Errorf("failed to bind server: %s", err) } @@ -290,7 +290,7 @@ func TestBindReconnect(t *testing.T) { } }() - clientAlice := NewClient(ctx, addr, "alice") + clientAlice := NewClient(ctx, srvCfg.Address, "alice") err := clientAlice.Connect() if err != nil { t.Errorf("failed to connect to server: %s", err) @@ -301,7 +301,7 @@ func TestBindReconnect(t *testing.T) { t.Errorf("failed to bind channel: %s", err) } - clientBob := NewClient(ctx, addr, "bob") + clientBob := NewClient(ctx, srvCfg.Address, "bob") err = clientBob.Connect() if err != nil { t.Errorf("failed to connect to server: %s", err) @@ -318,7 +318,7 @@ func TestBindReconnect(t *testing.T) { t.Errorf("failed to close client: %s", err) } - clientAlice = NewClient(ctx, addr, "alice") + clientAlice = NewClient(ctx, srvCfg.Address, "alice") err = clientAlice.Connect() if err != nil { t.Errorf("failed to connect to server: %s", err) @@ -355,10 +355,10 @@ func TestBindReconnect(t *testing.T) { func TestCloseConn(t *testing.T) { ctx := context.Background() - addr := "localhost:1234" + srvCfg := server.Config{Address: "localhost:1234"} srv := server.NewServer() go func() { - err := srv.Listen(addr) + err := srv.Listen(srvCfg) if err != nil { t.Errorf("failed to bind server: %s", err) } @@ -372,7 +372,7 @@ func TestCloseConn(t *testing.T) { } }() - clientAlice := NewClient(ctx, addr, "alice") + clientAlice := NewClient(ctx, srvCfg.Address, "alice") err := clientAlice.Connect() if err != nil { t.Errorf("failed to connect to server: %s", err) @@ -403,10 +403,10 @@ func TestCloseConn(t *testing.T) { func TestCloseRelayConn(t *testing.T) { ctx := context.Background() - addr := "localhost:1234" + srvCfg := server.Config{Address: "localhost:1234"} srv := server.NewServer() go func() { - err := srv.Listen(addr) + err := srv.Listen(srvCfg) if err != nil { t.Errorf("failed to bind server: %s", err) } @@ -419,7 +419,7 @@ func TestCloseRelayConn(t *testing.T) { } }() - clientAlice := NewClient(ctx, addr, "alice") + clientAlice := NewClient(ctx, srvCfg.Address, "alice") err := clientAlice.Connect() if err != nil { t.Fatalf("failed to connect to server: %s", err) @@ -446,10 +446,10 @@ func TestCloseRelayConn(t *testing.T) { func TestCloseByServer(t *testing.T) { ctx := context.Background() - addr1 := "localhost:1234" + srvCfg := server.Config{Address: "localhost:1234"} srv1 := server.NewServer() go func() { - err := srv1.Listen(addr1) + err := srv1.Listen(srvCfg) if err != nil { t.Fatalf("failed to bind server: %s", err) } @@ -457,7 +457,7 @@ func TestCloseByServer(t *testing.T) { idAlice := "alice" log.Debugf("connect by alice") - relayClient := NewClient(ctx, addr1, idAlice) + relayClient := NewClient(ctx, srvCfg.Address, idAlice) err := relayClient.Connect() if err != nil { log.Fatalf("failed to connect to server: %s", err) @@ -489,10 +489,10 @@ func TestCloseByServer(t *testing.T) { func TestCloseByClient(t *testing.T) { ctx := context.Background() - addr1 := "localhost:1234" + srvCfg := server.Config{Address: "localhost:1234"} srv := server.NewServer() go func() { - err := srv.Listen(addr1) + err := srv.Listen(srvCfg) if err != nil { t.Fatalf("failed to bind server: %s", err) } @@ -500,7 +500,7 @@ func TestCloseByClient(t *testing.T) { idAlice := "alice" log.Debugf("connect by alice") - relayClient := NewClient(ctx, addr1, idAlice) + relayClient := NewClient(ctx, srvCfg.Address, idAlice) err := relayClient.Connect() if err != nil { log.Fatalf("failed to connect to server: %s", err) diff --git a/relay/client/dialer/ws/conn.go b/relay/client/dialer/ws/conn.go index b0b8f5aba..44d86c1bb 100644 --- a/relay/client/dialer/ws/conn.go +++ b/relay/client/dialer/ws/conn.go @@ -12,14 +12,16 @@ import ( type Conn struct { ctx context.Context *websocket.Conn - srvAddr *net.TCPAddr + srvAddr net.Addr + localAddr net.Addr } -func NewConn(wsConn *websocket.Conn, srvAddr *net.TCPAddr) net.Conn { +func NewConn(wsConn *websocket.Conn, srvAddr, localAddr net.Addr) net.Conn { return &Conn{ - ctx: context.Background(), - Conn: wsConn, - srvAddr: srvAddr, + ctx: context.Background(), + Conn: wsConn, + srvAddr: srvAddr, + localAddr: localAddr, } } @@ -46,8 +48,7 @@ func (c *Conn) RemoteAddr() net.Addr { } func (c *Conn) LocalAddr() net.Addr { - // todo: implement me - return nil + return c.localAddr } func (c *Conn) SetReadDeadline(t time.Time) error { diff --git a/relay/client/dialer/ws/ws.go b/relay/client/dialer/ws/ws.go index 90175ebf9..070a02362 100644 --- a/relay/client/dialer/ws/ws.go +++ b/relay/client/dialer/ws/ws.go @@ -5,6 +5,7 @@ import ( "fmt" "net" "net/http" + "strings" log "github.com/sirupsen/logrus" "nhooyr.io/websocket" @@ -13,32 +14,48 @@ import ( ) func Dial(address string) (net.Conn, error) { - - hostName, _, err := net.SplitHostPort(address) - - addr, err := net.ResolveTCPAddr("tcp", address) + wsURL, err := prepareURL(address) if err != nil { - log.Errorf("failed to resolve address of Relay server: %s", address) return nil, err } - url := fmt.Sprintf("ws://%s:%d", addr.IP.String(), addr.Port) opts := &websocket.DialOptions{ - Host: hostName, HTTPClient: httpClientNbDialer(), } - wsConn, _, err := websocket.Dial(context.Background(), url, opts) + wsConn, _, err := websocket.Dial(context.Background(), wsURL, opts) if err != nil { - log.Errorf("failed to dial to Relay server '%s': %s", url, err) + log.Errorf("failed to dial to Relay server '%s': %s", wsURL, err) return nil, err } - conn := NewConn(wsConn, addr) + /* + response.Body.(net.Conn).LocalAddr() + unc, ok := response.Body.(net.Conn) + if !ok { + log.Errorf("failed to get local address: %s", err) + return nil, fmt.Errorf("failed to get local address") + } + */ + // todo figure out the proper address + dummy := &net.TCPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 8080, + } + + conn := NewConn(wsConn, dummy, dummy) return conn, nil } +func prepareURL(address string) (string, error) { + if !strings.HasPrefix(address, "rel") { + return "", fmt.Errorf("unsupported scheme: %s", address) + } + + return strings.Replace(address, "rel", "ws", 1), nil +} + func httpClientNbDialer() *http.Client { customDialer := nbnet.NewDialer() diff --git a/relay/client/manager_test.go b/relay/client/manager_test.go index 69539cc5e..71e4a416f 100644 --- a/relay/client/manager_test.go +++ b/relay/client/manager_test.go @@ -13,10 +13,12 @@ import ( func TestForeignConn(t *testing.T) { ctx := context.Background() - addr1 := "localhost:1234" + srvCfg1 := server.Config{ + Address: "localhost:1234", + } srv1 := server.NewServer() go func() { - err := srv1.Listen(addr1) + err := srv1.Listen(srvCfg1) if err != nil { t.Fatalf("failed to bind server: %s", err) } @@ -29,10 +31,12 @@ func TestForeignConn(t *testing.T) { } }() - addr2 := "localhost:2234" + srvCfg2 := server.Config{ + Address: "localhost:2234", + } srv2 := server.NewServer() go func() { - err := srv2.Listen(addr2) + err := srv2.Listen(srvCfg2) if err != nil { t.Fatalf("failed to bind server: %s", err) } @@ -49,12 +53,12 @@ func TestForeignConn(t *testing.T) { log.Debugf("connect by alice") mCtx, cancel := context.WithCancel(ctx) defer cancel() - clientAlice := NewManager(mCtx, addr1, idAlice) + clientAlice := NewManager(mCtx, srvCfg1.Address, idAlice) clientAlice.Serve() idBob := "bob" log.Debugf("connect by bob") - clientBob := NewManager(mCtx, addr2, idBob) + clientBob := NewManager(mCtx, srvCfg2.Address, idBob) clientBob.Serve() bobsSrvAddr, err := clientBob.RelayAddress() @@ -100,10 +104,12 @@ func TestForeignConn(t *testing.T) { func TestForeginConnClose(t *testing.T) { ctx := context.Background() - addr1 := "localhost:1234" + srvCfg1 := server.Config{ + Address: "localhost:1234", + } srv1 := server.NewServer() go func() { - err := srv1.Listen(addr1) + err := srv1.Listen(srvCfg1) if err != nil { t.Fatalf("failed to bind server: %s", err) } @@ -116,10 +122,12 @@ func TestForeginConnClose(t *testing.T) { } }() - addr2 := "localhost:2234" + srvCfg2 := server.Config{ + Address: "localhost:2234", + } srv2 := server.NewServer() go func() { - err := srv2.Listen(addr2) + err := srv2.Listen(srvCfg2) if err != nil { t.Fatalf("failed to bind server: %s", err) } @@ -136,10 +144,10 @@ func TestForeginConnClose(t *testing.T) { log.Debugf("connect by alice") mCtx, cancel := context.WithCancel(ctx) defer cancel() - mgr := NewManager(mCtx, addr1, idAlice) + mgr := NewManager(mCtx, srvCfg1.Address, idAlice) mgr.Serve() - conn, err := mgr.OpenConn(addr2, "anotherpeer", nil) + conn, err := mgr.OpenConn(srvCfg2.Address, "anotherpeer", nil) if err != nil { t.Fatalf("failed to bind channel: %s", err) } @@ -153,11 +161,13 @@ func TestForeginConnClose(t *testing.T) { func TestForeginAutoClose(t *testing.T) { ctx := context.Background() relayCleanupInterval = 1 * time.Second - addr1 := "localhost:1234" + srvCfg1 := server.Config{ + Address: "localhost:1234", + } srv1 := server.NewServer() go func() { t.Log("binding server 1.") - err := srv1.Listen(addr1) + err := srv1.Listen(srvCfg1) if err != nil { t.Fatalf("failed to bind server: %s", err) } @@ -172,11 +182,13 @@ func TestForeginAutoClose(t *testing.T) { t.Logf("server 1. closed") }() - addr2 := "localhost:2234" + srvCfg2 := server.Config{ + Address: "localhost:2234", + } srv2 := server.NewServer() go func() { t.Log("binding server 2.") - err := srv2.Listen(addr2) + err := srv2.Listen(srvCfg2) if err != nil { t.Fatalf("failed to bind server: %s", err) } @@ -194,11 +206,11 @@ func TestForeginAutoClose(t *testing.T) { t.Log("connect to server 1.") mCtx, cancel := context.WithCancel(ctx) defer cancel() - mgr := NewManager(mCtx, addr1, idAlice) + mgr := NewManager(mCtx, srvCfg1.Address, idAlice) mgr.Serve() t.Log("open connection to another peer") - conn, err := mgr.OpenConn(addr2, "anotherpeer", nil) + conn, err := mgr.OpenConn(srvCfg2.Address, "anotherpeer", nil) if err != nil { t.Fatalf("failed to bind channel: %s", err) } @@ -222,10 +234,12 @@ func TestAutoReconnect(t *testing.T) { ctx := context.Background() reconnectingTimeout = 2 * time.Second - addr := "localhost:1234" + srvCfg := server.Config{ + Address: "localhost:1234", + } srv := server.NewServer() go func() { - err := srv.Listen(addr) + err := srv.Listen(srvCfg) if err != nil { t.Errorf("failed to bind server: %s", err) } @@ -240,7 +254,7 @@ func TestAutoReconnect(t *testing.T) { mCtx, cancel := context.WithCancel(ctx) defer cancel() - clientAlice := NewManager(mCtx, addr, "alice") + clientAlice := NewManager(mCtx, srvCfg.Address, "alice") clientAlice.Serve() ra, err := clientAlice.RelayAddress() if err != nil { diff --git a/relay/cmd/main.go b/relay/cmd/main.go index eb7d87a1f..9d1802076 100644 --- a/relay/cmd/main.go +++ b/relay/cmd/main.go @@ -1,6 +1,8 @@ package main import ( + "crypto/tls" + "fmt" "os" "os/signal" "syscall" @@ -8,12 +10,15 @@ import ( log "github.com/sirupsen/logrus" "github.com/spf13/cobra" + "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/relay/server" "github.com/netbirdio/netbird/util" ) var ( - listenAddress string + listenAddress string + letsencryptDataDir string + letsencryptDomain string rootCmd = &cobra.Command{ Use: "relay", @@ -26,7 +31,8 @@ var ( func init() { _ = util.InitLog("trace", "console") rootCmd.PersistentFlags().StringVarP(&listenAddress, "listen-address", "l", ":1235", "listen address") - + rootCmd.PersistentFlags().StringVarP(&letsencryptDataDir, "letsencrypt-data-dir", "d", "", "a directory to store Let's Encrypt data. Required if Let's Encrypt is enabled.") + rootCmd.PersistentFlags().StringVarP(&letsencryptDomain, "letsencrypt-domain", "a", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS") } func waitForExitSignal() { @@ -36,8 +42,20 @@ func waitForExitSignal() { } func execute(cmd *cobra.Command, args []string) { + srvCfg := server.Config{ + Address: listenAddress, + } + if hasLetsEncrypt() { + tlscfg, err := setupTLS() + if err != nil { + log.Errorf("%s", err) + os.Exit(1) + } + srvCfg.TLSConfig = tlscfg + } + srv := server.NewServer() - err := srv.Listen(listenAddress) + err := srv.Listen(srvCfg) if err != nil { log.Errorf("failed to bind server: %s", err) os.Exit(1) @@ -52,6 +70,18 @@ func execute(cmd *cobra.Command, args []string) { } } +func hasLetsEncrypt() bool { + return letsencryptDataDir != "" && letsencryptDomain != "" +} + +func setupTLS() (*tls.Config, error) { + certManager, err := encryption.CreateCertManager(letsencryptDataDir, letsencryptDomain) + if err != nil { + return nil, fmt.Errorf("failed creating LetsEncrypt cert manager: %v", err) + } + return certManager.TLSConfig(), nil +} + func main() { err := rootCmd.Execute() if err != nil { diff --git a/relay/server/listener/ws/listener.go b/relay/server/listener/ws/listener.go index 26ba2276e..6b2f669f4 100644 --- a/relay/server/listener/ws/listener.go +++ b/relay/server/listener/ws/listener.go @@ -2,6 +2,7 @@ package ws import ( "context" + "crypto/tls" "errors" "fmt" "net" @@ -10,35 +11,37 @@ import ( log "github.com/sirupsen/logrus" "nhooyr.io/websocket" - - "github.com/netbirdio/netbird/relay/server/listener" ) type Listener struct { - address string + // Address is the address to listen on. + Address string + // TLSConfig is the TLS configuration for the server. + TLSConfig *tls.Config server *http.Server acceptFn func(conn net.Conn) } -func NewListener(address string) listener.Listener { - return &Listener{ - address: address, - } -} - func (l *Listener) Listen(acceptFn func(conn net.Conn)) error { l.acceptFn = acceptFn mux := http.NewServeMux() mux.HandleFunc("/", l.onAccept) l.server = &http.Server{ - Addr: l.address, - Handler: mux, + Addr: l.Address, + Handler: mux, + TLSConfig: l.TLSConfig, } - log.Infof("WS server is listening on address: %s", l.address) - err := l.server.ListenAndServe() + log.Infof("WS server is listening on address: %s", l.Address) + var err error + if l.TLSConfig != nil { + err = l.server.ListenAndServeTLS("", "") + + } else { + err = l.server.ListenAndServe() + } if errors.Is(err, http.ErrServerClosed) { return nil } diff --git a/relay/server/server.go b/relay/server/server.go index cf48e19e3..687486138 100644 --- a/relay/server/server.go +++ b/relay/server/server.go @@ -2,6 +2,7 @@ package server import ( "context" + "crypto/tls" "errors" "sync" "time" @@ -13,6 +14,11 @@ import ( "github.com/netbirdio/netbird/relay/server/listener/ws" ) +type Config struct { + Address string + TLSConfig *tls.Config +} + type Server struct { relay *Relay uDPListener listener.Listener @@ -25,11 +31,15 @@ func NewServer() *Server { } } -func (r *Server) Listen(address string) error { +func (r *Server) Listen(cfg Config) error { wg := sync.WaitGroup{} wg.Add(2) - r.wSListener = ws.NewListener(address) + r.wSListener = &ws.Listener{ + Address: cfg.Address, + TLSConfig: cfg.TLSConfig, + } + var wslErr error go func() { defer wg.Done() @@ -39,7 +49,7 @@ func (r *Server) Listen(address string) error { } }() - r.uDPListener = udp.NewListener(address) + r.uDPListener = udp.NewListener(cfg.Address) var udpLErr error go func() { defer wg.Done()