diff --git a/iface/dll.go b/iface/dll.go new file mode 100644 index 000000000..9f4481129 --- /dev/null +++ b/iface/dll.go @@ -0,0 +1,61 @@ +//go:build windows + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. + */ + +package iface + +import ( + "fmt" + "golang.zx2c4.com/wireguard/windows/driver/memmod" + "sync" + "sync/atomic" + "unsafe" + + "golang.org/x/sys/windows" +) + +type lazyDLL struct { + Name string + Base windows.Handle + mu sync.Mutex + module *memmod.Module + onLoad func(d *lazyDLL) +} + +func (d *lazyDLL) Load() error { + if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&d.module))) != nil { + return nil + } + d.mu.Lock() + defer d.mu.Unlock() + if d.module != nil { + return nil + } + + const ourModule windows.Handle = 0 + resInfo, err := windows.FindResource(ourModule, d.Name, windows.RT_RCDATA) + if err != nil { + return fmt.Errorf("Unable to find \"%v\" RCDATA resource: %w", d.Name, err) + } + data, err := windows.LoadResourceData(ourModule, resInfo) + if err != nil { + return fmt.Errorf("Unable to load resource: %w", err) + } + module, err := memmod.LoadLibrary(data) + if err != nil { + return fmt.Errorf("Unable to load library: %w", err) + } + d.Base = windows.Handle(module.BaseAddr()) + + atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&d.module)), unsafe.Pointer(module)) + if d.onLoad != nil { + d.onLoad(d) + } + return nil +} +func newLazyDLL(name string, onLoad func(d *lazyDLL)) *lazyDLL { + return &lazyDLL{Name: name, onLoad: onLoad} +} diff --git a/iface/tun_windows.go b/iface/tun_windows.go index 912529f5a..5c02a7ebc 100644 --- a/iface/tun_windows.go +++ b/iface/tun_windows.go @@ -37,11 +37,17 @@ func (c *tunDevice) Create() error { // createWithUserspace Creates a new WireGuard interface, using wireguard-go userspace implementation func (c *tunDevice) createWithUserspace() (NetInterface, error) { - dll := windows.NewLazyDLL("wintun.dll") + + dll := newLazyDLL("wintun.dll", func(d *lazyDLL) { + + }) + err := dll.Load() if err != nil { + log.Errorf("failed loading dll %v", err) return nil, err } + tunIface, err := tun.CreateTUN(c.name, c.mtu) if err != nil { return nil, err