mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-02 07:06:41 +00:00
chore: make wireguard work over webrtc
This commit is contained in:
@@ -14,6 +14,8 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
const initDataChannelName = "wiretrustee-init"
|
||||
|
||||
func (*WebRTCBind) makeReceive(dcConn net.Conn) conn.ReceiveFunc {
|
||||
return func(buff []byte) (int, conn.Endpoint, error) {
|
||||
log.Printf("receiving from endpoint %s", dcConn.RemoteAddr().String())
|
||||
@@ -21,8 +23,8 @@ func (*WebRTCBind) makeReceive(dcConn net.Conn) conn.ReceiveFunc {
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
addr := dcConn.RemoteAddr().(*DataChannelAddr)
|
||||
return n, (*WebRTCEndpoint)(addr), err
|
||||
//addr := dcConn.RemoteAddr().(DataChannelAddr)
|
||||
return n, &WebRTCEndpoint{}, err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -59,6 +61,9 @@ func NewWebRTCBind(id string, signal signal.Client, pubKey string, remotePubKey
|
||||
// blocks until channel was successfully opened
|
||||
func (bind *WebRTCBind) acceptDC() (stream net.Conn, err error) {
|
||||
for dc := range bind.incoming {
|
||||
if dc.Label() == initDataChannelName {
|
||||
continue
|
||||
}
|
||||
stream, err := WrapDataChannel(dc)
|
||||
if err != nil {
|
||||
dc.Close()
|
||||
@@ -125,13 +130,14 @@ func (bind *WebRTCBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort ui
|
||||
})
|
||||
|
||||
bind.pc.OnDataChannel(func(dc *webrtc.DataChannel) {
|
||||
log.Printf("received channel %s %v", dc.Label(), dc)
|
||||
bind.incoming <- dc
|
||||
})
|
||||
|
||||
controlling := bind.key < bind.remoteKey
|
||||
// decision who is creating an offer
|
||||
if controlling {
|
||||
_, err = bind.pc.CreateDataChannel(bind.id, nil)
|
||||
_, err = bind.pc.CreateDataChannel(initDataChannelName, nil)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
@@ -152,7 +158,6 @@ func (bind *WebRTCBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort ui
|
||||
case <-bind.closeCond.C:
|
||||
return nil, 0, fmt.Errorf("closed while waiting for WebRTC candidates")
|
||||
}
|
||||
|
||||
log.Printf("candidates gathered")
|
||||
|
||||
err = bind.signal.Send(&proto.Message{
|
||||
@@ -236,7 +241,7 @@ func (bind *WebRTCBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort ui
|
||||
}
|
||||
|
||||
select {
|
||||
case <-time.After(30 * time.Second):
|
||||
case <-time.After(10 * time.Second):
|
||||
return nil, 0, fmt.Errorf("failed to connect in time: %w", err)
|
||||
case <-connected.C:
|
||||
}
|
||||
@@ -257,7 +262,7 @@ func (bind *WebRTCBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort ui
|
||||
}
|
||||
bind.conn = dcConn
|
||||
fns = append(fns, bind.makeReceive(bind.conn))
|
||||
return fns, 38676, nil
|
||||
return fns, 0, nil
|
||||
|
||||
}
|
||||
|
||||
@@ -303,10 +308,11 @@ func (*WebRTCBind) SetMark(mark uint32) error {
|
||||
}
|
||||
|
||||
func (bind *WebRTCBind) Send(b []byte, ep conn.Endpoint) error {
|
||||
_, err := bind.conn.Write(b)
|
||||
n, err := bind.conn.Write(b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Printf("wrote %d bytes", n)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"errors"
|
||||
"github.com/pion/webrtc/v3"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
@@ -56,6 +57,7 @@ func WrapDataChannel(rtcDataChannel *webrtc.DataChannel) (*DataChannelConn, erro
|
||||
conn.openCond.Signal()
|
||||
})
|
||||
conn.dc.OnMessage(func(msg webrtc.DataChannelMessage) {
|
||||
log.Printf("received message from data channel %d", len(msg.Data))
|
||||
if rw != nil {
|
||||
_, err := rw.Write(msg.Data)
|
||||
if err != nil {
|
||||
@@ -84,6 +86,7 @@ func (dc *DataChannelConn) Read(b []byte) (n int, err error) {
|
||||
|
||||
func (dc *DataChannelConn) Write(b []byte) (n int, err error) {
|
||||
err = dc.dc.Send(b)
|
||||
log.Printf("writing to channel %s %v", dc.dc.Label(), dc.dc)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -165,6 +168,7 @@ func (cr ContextReadCloser) SetReadDeadline(t time.Time) error {
|
||||
}
|
||||
|
||||
func (cr ContextReadCloser) Read(p []byte) (n int, err error) {
|
||||
log.Printf("reading bytes ro buf of len %d", len(p))
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
n, err = cr.ReadCloser.Read(p)
|
||||
|
||||
Reference in New Issue
Block a user