chore: make wireguard work over webrtc

This commit is contained in:
braginini
2021-11-16 14:36:24 +01:00
parent f0a0888046
commit b35dcd21df
8 changed files with 113 additions and 80 deletions

View File

@@ -14,6 +14,8 @@ import (
"time"
)
const initDataChannelName = "wiretrustee-init"
func (*WebRTCBind) makeReceive(dcConn net.Conn) conn.ReceiveFunc {
return func(buff []byte) (int, conn.Endpoint, error) {
log.Printf("receiving from endpoint %s", dcConn.RemoteAddr().String())
@@ -21,8 +23,8 @@ func (*WebRTCBind) makeReceive(dcConn net.Conn) conn.ReceiveFunc {
if err != nil {
return 0, nil, err
}
addr := dcConn.RemoteAddr().(*DataChannelAddr)
return n, (*WebRTCEndpoint)(addr), err
//addr := dcConn.RemoteAddr().(DataChannelAddr)
return n, &WebRTCEndpoint{}, err
}
}
@@ -59,6 +61,9 @@ func NewWebRTCBind(id string, signal signal.Client, pubKey string, remotePubKey
// blocks until channel was successfully opened
func (bind *WebRTCBind) acceptDC() (stream net.Conn, err error) {
for dc := range bind.incoming {
if dc.Label() == initDataChannelName {
continue
}
stream, err := WrapDataChannel(dc)
if err != nil {
dc.Close()
@@ -125,13 +130,14 @@ func (bind *WebRTCBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort ui
})
bind.pc.OnDataChannel(func(dc *webrtc.DataChannel) {
log.Printf("received channel %s %v", dc.Label(), dc)
bind.incoming <- dc
})
controlling := bind.key < bind.remoteKey
// decision who is creating an offer
if controlling {
_, err = bind.pc.CreateDataChannel(bind.id, nil)
_, err = bind.pc.CreateDataChannel(initDataChannelName, nil)
if err != nil {
return nil, 0, err
}
@@ -152,7 +158,6 @@ func (bind *WebRTCBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort ui
case <-bind.closeCond.C:
return nil, 0, fmt.Errorf("closed while waiting for WebRTC candidates")
}
log.Printf("candidates gathered")
err = bind.signal.Send(&proto.Message{
@@ -236,7 +241,7 @@ func (bind *WebRTCBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort ui
}
select {
case <-time.After(30 * time.Second):
case <-time.After(10 * time.Second):
return nil, 0, fmt.Errorf("failed to connect in time: %w", err)
case <-connected.C:
}
@@ -257,7 +262,7 @@ func (bind *WebRTCBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort ui
}
bind.conn = dcConn
fns = append(fns, bind.makeReceive(bind.conn))
return fns, 38676, nil
return fns, 0, nil
}
@@ -303,10 +308,11 @@ func (*WebRTCBind) SetMark(mark uint32) error {
}
func (bind *WebRTCBind) Send(b []byte, ep conn.Endpoint) error {
_, err := bind.conn.Write(b)
n, err := bind.conn.Write(b)
if err != nil {
return err
}
log.Printf("wrote %d bytes", n)
return nil
}

View File

@@ -7,6 +7,7 @@ import (
"errors"
"github.com/pion/webrtc/v3"
"io"
"log"
"net"
"time"
)
@@ -56,6 +57,7 @@ func WrapDataChannel(rtcDataChannel *webrtc.DataChannel) (*DataChannelConn, erro
conn.openCond.Signal()
})
conn.dc.OnMessage(func(msg webrtc.DataChannelMessage) {
log.Printf("received message from data channel %d", len(msg.Data))
if rw != nil {
_, err := rw.Write(msg.Data)
if err != nil {
@@ -84,6 +86,7 @@ func (dc *DataChannelConn) Read(b []byte) (n int, err error) {
func (dc *DataChannelConn) Write(b []byte) (n int, err error) {
err = dc.dc.Send(b)
log.Printf("writing to channel %s %v", dc.dc.Label(), dc.dc)
if err != nil {
return 0, err
}
@@ -165,6 +168,7 @@ func (cr ContextReadCloser) SetReadDeadline(t time.Time) error {
}
func (cr ContextReadCloser) Read(p []byte) (n int, err error) {
log.Printf("reading bytes ro buf of len %d", len(p))
done := make(chan struct{})
go func() {
n, err = cr.ReadCloser.Read(p)