package bind import ( "context" "errors" "fmt" "net" "net/netip" "sync" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bufsize" "github.com/netbirdio/netbird/client/iface/wgproxy/listener" ) type Bind interface { SetEndpoint(addr netip.Addr, conn net.Conn) RemoveEndpoint(addr netip.Addr) ReceiveFromEndpoint(ctx context.Context, ep *bind.Endpoint, buf []byte) } type ProxyBind struct { bind Bind // wgRelayedEndpoint is a fake address that generated by the Bind.SetEndpoint based on the remote NetBird peer address wgRelayedEndpoint *bind.Endpoint wgCurrentUsed *bind.Endpoint remoteConn net.Conn ctx context.Context cancel context.CancelFunc closeMu sync.Mutex closed bool paused bool pausedCond *sync.Cond isStarted bool closeListener *listener.CloseListener mtu uint16 } func NewProxyBind(bind Bind, mtu uint16) *ProxyBind { p := &ProxyBind{ bind: bind, closeListener: listener.NewCloseListener(), pausedCond: sync.NewCond(&sync.Mutex{}), mtu: mtu + bufsize.WGBufferOverhead, } return p } // AddTurnConn adds a new connection to the bind. // endpoint is the NetBird address of the remote peer. The SetEndpoint return with the address what will be used in the // WireGuard configuration. // // Parameters: // - ctx: Context is used for proxyToLocal to avoid unnecessary error messages // - nbAddr: The NetBird UDP address of the remote peer, it required to generate fake address // - remoteConn: The established TURN connection to the remote peer func (p *ProxyBind) AddTurnConn(ctx context.Context, nbAddr *net.UDPAddr, remoteConn net.Conn) error { fakeNetIP, err := fakeAddress(nbAddr) if err != nil { return err } p.wgRelayedEndpoint = &bind.Endpoint{AddrPort: *fakeNetIP} p.remoteConn = remoteConn p.ctx, p.cancel = context.WithCancel(ctx) return nil } func (p *ProxyBind) EndpointAddr() *net.UDPAddr { return bind.EndpointToUDPAddr(*p.wgRelayedEndpoint) } func (p *ProxyBind) SetDisconnectListener(disconnected func()) { p.closeListener.SetCloseListener(disconnected) } func (p *ProxyBind) Work() { if p.remoteConn == nil { return } p.bind.SetEndpoint(p.wgRelayedEndpoint.Addr(), p.remoteConn) p.pausedCond.L.Lock() p.paused = false p.wgCurrentUsed = p.wgRelayedEndpoint // Start the proxy only once if !p.isStarted { p.isStarted = true go p.proxyToLocal(p.ctx) } p.pausedCond.Signal() p.pausedCond.L.Unlock() } func (p *ProxyBind) Pause() { if p.remoteConn == nil { return } p.pausedCond.L.Lock() p.paused = true p.pausedCond.L.Unlock() } func (p *ProxyBind) RedirectAs(endpoint *net.UDPAddr) { ep, err := addrToEndpoint(endpoint) if err != nil { log.Errorf("failed to start package redirection: %v", err) return } p.pausedCond.L.Lock() p.paused = false p.wgCurrentUsed = ep p.pausedCond.Signal() p.pausedCond.L.Unlock() } func (p *ProxyBind) CloseConn() error { if p.cancel == nil { return fmt.Errorf("proxy not started") } return p.close() } func (p *ProxyBind) close() error { if p.remoteConn == nil { return nil } p.closeMu.Lock() defer p.closeMu.Unlock() if p.closed { return nil } p.closeListener.SetCloseListener(nil) p.closed = true p.cancel() p.pausedCond.L.Lock() p.paused = false p.pausedCond.Signal() p.pausedCond.L.Unlock() p.bind.RemoveEndpoint(p.wgRelayedEndpoint.Addr()) if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) { return rErr } return nil } func (p *ProxyBind) proxyToLocal(ctx context.Context) { defer func() { if err := p.close(); err != nil { log.Warnf("failed to close remote conn: %s", err) } }() for { buf := make([]byte, p.mtu) n, err := p.remoteConn.Read(buf) if err != nil { if ctx.Err() != nil { return } p.closeListener.Notify() log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err) return } p.pausedCond.L.Lock() for p.paused { p.pausedCond.Wait() } p.bind.ReceiveFromEndpoint(ctx, p.wgCurrentUsed, buf[:n]) p.pausedCond.L.Unlock() } } // fakeAddress returns a fake address that is used as an identifier for the peer. // The fake address is in the format of 127.1.x.x where x.x is derived from the // last two bytes of the peer address (works for both IPv4 and IPv6). func fakeAddress(peerAddress *net.UDPAddr) (*netip.AddrPort, error) { if peerAddress == nil { return nil, fmt.Errorf("nil peer address") } addr, ok := netip.AddrFromSlice(peerAddress.IP) if !ok { return nil, fmt.Errorf("invalid IP format") } addr = addr.Unmap() raw := addr.As16() fakeIP := netip.AddrFrom4([4]byte{127, 1, raw[14], raw[15]}) netipAddr := netip.AddrPortFrom(fakeIP, uint16(peerAddress.Port)) return &netipAddr, nil } func addrToEndpoint(addr *net.UDPAddr) (*bind.Endpoint, error) { if addr == nil { return nil, fmt.Errorf("invalid address") } ip, ok := netip.AddrFromSlice(addr.IP) if !ok { return nil, fmt.Errorf("convert %s to netip.Addr", addr) } addrPort := netip.AddrPortFrom(ip.Unmap(), uint16(addr.Port)) return &bind.Endpoint{AddrPort: addrPort}, nil }