From 2b369cd28f2a2fd8c5d00a06d525bbe9af5e8677 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Zolt=C3=A1n=20Papp?= Date: Mon, 3 Jun 2024 20:17:43 +0200 Subject: [PATCH] Add quic transporter --- relay/client/dialer/quic/conn.go | 52 ++++++++++++ relay/client/dialer/quic/quic.go | 32 +++++++ relay/server/listener/quic/conn.go | 36 ++++++++ relay/server/listener/quic/listener.go | 111 +++++++++++++++++++++++++ 4 files changed, 231 insertions(+) create mode 100644 relay/client/dialer/quic/conn.go create mode 100644 relay/client/dialer/quic/quic.go create mode 100644 relay/server/listener/quic/conn.go create mode 100644 relay/server/listener/quic/listener.go diff --git a/relay/client/dialer/quic/conn.go b/relay/client/dialer/quic/conn.go new file mode 100644 index 000000000..2c8b29a3c --- /dev/null +++ b/relay/client/dialer/quic/conn.go @@ -0,0 +1,52 @@ +package quic + +import ( + "net" + + "github.com/quic-go/quic-go" + log "github.com/sirupsen/logrus" +) + +type Conn struct { + quic.Stream + qConn quic.Connection +} + +func NewConn(stream quic.Stream, qConn quic.Connection) net.Conn { + return &Conn{ + Stream: stream, + qConn: qConn, + } +} + +func (q *Conn) Write(b []byte) (n int, err error) { + log.Debugf("writing: %d, %x\n", len(b), b) + n, err = q.Stream.Write(b) + if n != len(b) { + log.Errorf("failed to write out the full message") + } + return +} + +func (q *Conn) Close() error { + err := q.Stream.Close() + if err != nil { + log.Errorf("failed to close stream: %s", err) + return err + } + err = q.qConn.CloseWithError(0, "") + if err != nil { + log.Errorf("failed to close connection: %s", err) + return err + + } + return err +} + +func (c *Conn) LocalAddr() net.Addr { + return c.qConn.LocalAddr() +} + +func (c *Conn) RemoteAddr() net.Addr { + return c.qConn.RemoteAddr() +} diff --git a/relay/client/dialer/quic/quic.go b/relay/client/dialer/quic/quic.go new file mode 100644 index 000000000..4863ed7bd --- /dev/null +++ b/relay/client/dialer/quic/quic.go @@ -0,0 +1,32 @@ +package quic + +import ( + "context" + "crypto/tls" + "net" + + "github.com/quic-go/quic-go" + log "github.com/sirupsen/logrus" +) + +func Dial(address string) (net.Conn, error) { + tlsConf := &tls.Config{ + InsecureSkipVerify: true, + NextProtos: []string{"quic-echo-example"}, + } + qConn, err := quic.DialAddr(context.Background(), address, tlsConf, &quic.Config{ + EnableDatagrams: true, + }) + if err != nil { + log.Errorf("dial quic address %s failed: %s", address, err) + return nil, err + } + + stream, err := qConn.OpenStreamSync(context.Background()) + if err != nil { + return nil, err + } + + conn := NewConn(stream, qConn) + return conn, nil +} diff --git a/relay/server/listener/quic/conn.go b/relay/server/listener/quic/conn.go new file mode 100644 index 000000000..12d22bb5e --- /dev/null +++ b/relay/server/listener/quic/conn.go @@ -0,0 +1,36 @@ +package quic + +import ( + "net" + + "github.com/quic-go/quic-go" + log "github.com/sirupsen/logrus" +) + +type QuicConn struct { + quic.Stream + qConn quic.Connection +} + +func NewConn(stream quic.Stream, qConn quic.Connection) net.Conn { + return &QuicConn{ + Stream: stream, + qConn: qConn, + } +} + +func (q QuicConn) Write(b []byte) (n int, err error) { + n, err = q.Stream.Write(b) + if n != len(b) { + log.Errorf("failed to write out the full message") + } + return +} + +func (q QuicConn) LocalAddr() net.Addr { + return q.qConn.LocalAddr() +} + +func (q QuicConn) RemoteAddr() net.Addr { + return q.qConn.RemoteAddr() +} diff --git a/relay/server/listener/quic/listener.go b/relay/server/listener/quic/listener.go new file mode 100644 index 000000000..8244bb2e2 --- /dev/null +++ b/relay/server/listener/quic/listener.go @@ -0,0 +1,111 @@ +package quic + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "encoding/pem" + "math/big" + "net" + "sync" + + "github.com/quic-go/quic-go" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/relay/server/listener" +) + +type Listener struct { + address string + onAcceptFn func(conn net.Conn) + + listener *quic.Listener + quit chan struct{} + wg sync.WaitGroup +} + +func NewListener(address string) listener.Listener { + return &Listener{ + address: address, + } +} + +func (l *Listener) Listen(onAcceptFn func(conn net.Conn)) error { + ql, err := quic.ListenAddr(l.address, generateTLSConfig(), &quic.Config{ + EnableDatagrams: true, + }) + if err != nil { + return err + } + l.listener = ql + l.quit = make(chan struct{}) + + log.Infof("quic server is listening on address: %s", l.address) + l.wg.Add(1) + go l.acceptLoop(onAcceptFn) + + <-l.quit + return nil +} + +func (l *Listener) Close() error { + close(l.quit) + err := l.listener.Close() + l.wg.Wait() + return err +} + +func (l *Listener) acceptLoop(acceptFn func(conn net.Conn)) { + defer l.wg.Done() + + for { + qConn, err := l.listener.Accept(context.Background()) + if err != nil { + select { + case <-l.quit: + return + default: + log.Errorf("failed to accept connection: %s", err) + continue + } + } + + log.Infof("new connection from: %s", qConn.RemoteAddr()) + + stream, err := qConn.AcceptStream(context.Background()) + if err != nil { + log.Errorf("failed to open stream: %s", err) + continue + } + + conn := NewConn(stream, qConn) + + go acceptFn(conn) + } +} + +// Setup a bare-bones TLS config for the server +func generateTLSConfig() *tls.Config { + key, err := rsa.GenerateKey(rand.Reader, 1024) + if err != nil { + panic(err) + } + template := x509.Certificate{SerialNumber: big.NewInt(1)} + certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key) + if err != nil { + panic(err) + } + keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}) + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) + + tlsCert, err := tls.X509KeyPair(certPEM, keyPEM) + if err != nil { + panic(err) + } + return &tls.Config{ + Certificates: []tls.Certificate{tlsCert}, + NextProtos: []string{"quic-echo-example"}, + } +}