diff --git a/.idea/dictionaries/project.xml b/.idea/dictionaries/project.xml
index 93dd824d..ecd06d91 100644
--- a/.idea/dictionaries/project.xml
+++ b/.idea/dictionaries/project.xml
@@ -5,6 +5,7 @@
endpointstats
gochecknoglobals
luid
+ operationoptions
setupapi
spdx
textfile
diff --git a/internal/mi/callbacks.go b/internal/mi/callbacks.go
deleted file mode 100644
index f9ad7de0..00000000
--- a/internal/mi/callbacks.go
+++ /dev/null
@@ -1,181 +0,0 @@
-// SPDX-License-Identifier: Apache-2.0
-//
-// Copyright The Prometheus Authors
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-//go:build windows
-
-package mi
-
-import (
- "errors"
- "fmt"
- "math"
- "reflect"
- "sync"
- "time"
- "unsafe"
-
- "golang.org/x/sys/windows"
-)
-
-// operationUnmarshalCallbacksInstanceResult registers a global callback function.
-// The amount of system callbacks is limited to 2000.
-//
-//nolint:gochecknoglobals
-var operationUnmarshalCallbacksInstanceResult = sync.OnceValue[uintptr](func() uintptr {
- // Workaround for a deadlock issue in go.
- // Ref: https://github.com/golang/go/issues/55015
- go time.Sleep(time.Duration(math.MaxInt64))
-
- return windows.NewCallback(func(
- operation *Operation,
- callbacks *OperationUnmarshalCallbacks,
- instance *Instance,
- moreResults Boolean,
- instanceResult ResultError,
- errorMessageUTF16 *uint16,
- errorDetails *Instance,
- _ uintptr,
- ) uintptr {
- if moreResults == False {
- defer operation.Close()
- }
-
- return callbacks.InstanceResult(operation, instance, moreResults, instanceResult, errorMessageUTF16, errorDetails)
- })
-})
-
-type OperationUnmarshalCallbacks struct {
- dst any
- dv reflect.Value
- errCh chan<- error
-
- elemType reflect.Type
- elemValue reflect.Value
-}
-
-func NewUnmarshalOperationsCallbacks(dst any, errCh chan<- error) (*OperationCallbacks[OperationUnmarshalCallbacks], error) {
- dv := reflect.ValueOf(dst)
- if dv.Kind() != reflect.Ptr || dv.IsNil() {
- return nil, ErrInvalidEntityType
- }
-
- dv = dv.Elem()
-
- elemType := dv.Type().Elem()
- elemValue := reflect.ValueOf(reflect.New(elemType).Interface()).Elem()
-
- if dv.Kind() != reflect.Slice || elemType.Kind() != reflect.Struct {
- return nil, ErrInvalidEntityType
- }
-
- dv.Set(reflect.MakeSlice(dv.Type(), 0, 0))
-
- return &OperationCallbacks[OperationUnmarshalCallbacks]{
- CallbackContext: &OperationUnmarshalCallbacks{
- errCh: errCh,
- dst: dst,
- dv: dv,
- elemType: elemType,
- elemValue: elemValue,
- },
- InstanceResult: operationUnmarshalCallbacksInstanceResult(),
- }, nil
-}
-
-func (o *OperationUnmarshalCallbacks) InstanceResult(
- _ *Operation,
- instance *Instance,
- moreResults Boolean,
- instanceResult ResultError,
- errorMessageUTF16 *uint16,
- _ *Instance,
-) uintptr {
- defer func() {
- if moreResults == False {
- close(o.errCh)
- }
- }()
-
- if !errors.Is(instanceResult, MI_RESULT_OK) {
- o.errCh <- fmt.Errorf("%w: %s", instanceResult, windows.UTF16PtrToString(errorMessageUTF16))
-
- return 0
- }
-
- if instance == nil {
- return 0
- }
-
- counter, err := instance.GetElementCount()
- if err != nil {
- o.errCh <- fmt.Errorf("failed to get element count: %w", err)
-
- return 0
- }
-
- if counter == 0 {
- return 0
- }
-
- for i := range o.elemType.NumField() {
- field := o.elemValue.Field(i)
-
- // Check if the field has an `mi` tag
- miTag := o.elemType.Field(i).Tag.Get("mi")
- if miTag == "" {
- continue
- }
-
- element, err := instance.GetElement(miTag)
- if err != nil {
- if errors.Is(err, MI_RESULT_NO_SUCH_PROPERTY) {
- continue
- }
-
- o.errCh <- fmt.Errorf("failed to get element %s: %w", miTag, err)
-
- return 0
- }
-
- switch element.valueType {
- case ValueTypeBOOLEAN:
- field.SetBool(element.value == 1)
- case ValueTypeUINT8, ValueTypeUINT16, ValueTypeUINT32, ValueTypeUINT64:
- field.SetUint(uint64(element.value))
- case ValueTypeSINT8, ValueTypeSINT16, ValueTypeSINT32, ValueTypeSINT64:
- field.SetInt(int64(element.value))
- case ValueTypeSTRING:
- if element.value == 0 {
- // value is null
- continue
- }
-
- // Convert the UTF-16 string to a Go string
- stringValue := windows.UTF16PtrToString((*uint16)(unsafe.Pointer(element.value)))
-
- field.SetString(stringValue)
- case ValueTypeREAL32, ValueTypeREAL64:
- field.SetFloat(float64(element.value))
- default:
- o.errCh <- fmt.Errorf("unsupported value type: %d", element.valueType)
-
- return 0
- }
- }
-
- o.dv.Set(reflect.Append(o.dv, o.elemValue))
-
- return 0
-}
diff --git a/internal/mi/mi_test.go b/internal/mi/mi_test.go
index 7024c63b..1d6427b5 100644
--- a/internal/mi/mi_test.go
+++ b/internal/mi/mi_test.go
@@ -22,7 +22,9 @@ import (
"time"
"github.com/prometheus-community/windows_exporter/internal/mi"
+ "github.com/prometheus-community/windows_exporter/internal/utils/testutils"
"github.com/stretchr/testify/require"
+ "golang.org/x/sys/windows"
)
type win32Process struct {
@@ -234,3 +236,54 @@ func Test_MI_Query_Unmarshal(t *testing.T) {
err = application.Close()
require.NoError(t, err)
}
+
+func Test_MI_FD_Leak(t *testing.T) {
+ application, err := mi.ApplicationInitialize()
+ require.NoError(t, err)
+ require.NotEmpty(t, application)
+
+ session, err := application.NewSession(nil)
+ require.NoError(t, err)
+ require.NotEmpty(t, session)
+
+ currentFileHandle, err := testutils.GetProcessHandleCount(windows.CurrentProcess())
+ require.NoError(t, err)
+
+ t.Log("Current File Handle Count: ", currentFileHandle)
+
+ queryPrinter, err := mi.NewQuery("SELECT Name FROM Win32_Process")
+ require.NoError(t, err)
+
+ for range 300 {
+ var processes []win32Process
+
+ err := session.Query(&processes, mi.NamespaceRootCIMv2, queryPrinter)
+ require.NoError(t, err)
+
+ currentFileHandle, err = testutils.GetProcessHandleCount(windows.CurrentProcess())
+ require.NoError(t, err)
+
+ t.Log("Current File Handle Count: ", currentFileHandle)
+ }
+
+ currentFileHandle, err = testutils.GetProcessHandleCount(windows.CurrentProcess())
+ require.NoError(t, err)
+
+ t.Log("Current File Handle Count: ", currentFileHandle)
+
+ err = session.Close()
+ require.NoError(t, err)
+
+ currentFileHandle, err = testutils.GetProcessHandleCount(windows.CurrentProcess())
+ require.NoError(t, err)
+
+ t.Log("Current File Handle Count: ", currentFileHandle)
+
+ err = application.Close()
+ require.NoError(t, err)
+
+ currentFileHandle, err = testutils.GetProcessHandleCount(windows.CurrentProcess())
+ require.NoError(t, err)
+
+ t.Log("Current File Handle Count: ", currentFileHandle)
+}
diff --git a/internal/mi/operation.go b/internal/mi/operation.go
index f90e9946..c187625b 100644
--- a/internal/mi/operation.go
+++ b/internal/mi/operation.go
@@ -41,11 +41,7 @@ var OperationOptionsTimeout = UTF16PtrFromString[*uint16]("__MI_OPERATIONOPTIONS
type OperationFlags uint32
const (
- OperationFlagsDefaultRTTI OperationFlags = 0x0000
- OperationFlagsBasicRTTI OperationFlags = 0x0002
- OperationFlagsNoRTTI OperationFlags = 0x0400
OperationFlagsStandardRTTI OperationFlags = 0x0800
- OperationFlagsFullRTTI OperationFlags = 0x0004
)
// Operation represents an operation.
@@ -123,7 +119,7 @@ func (o *Operation) Cancel() error {
return ErrNotInitialized
}
- r0, _, _ := syscall.SyscallN(o.ft.Close, uintptr(unsafe.Pointer(o)), 0)
+ r0, _, _ := syscall.SyscallN(o.ft.Cancel, uintptr(unsafe.Pointer(o)), 0)
if result := ResultError(r0); !errors.Is(result, MI_RESULT_OK) {
return result
@@ -229,10 +225,14 @@ func (o *Operation) Unmarshal(dst any) error {
field.SetInt(int64(element.value))
case ValueTypeSTRING:
if element.value == 0 {
- return fmt.Errorf("%s: invalid pointer: value is nil", miTag)
+ field.SetString("") // Set empty string for nil values
+
+ continue
}
- // Convert the UTF-16 string to a Go string
+ // Convert uintptr to *uint16 for Windows UTF-16 string
+ // This is safe because element.value comes directly from Windows MI API
+ //goland:noinspection GoVetUnsafePointer
stringValue := windows.UTF16PtrToString((*uint16)(unsafe.Pointer(element.value)))
field.SetString(stringValue)
diff --git a/internal/mi/session.go b/internal/mi/session.go
index 33fb7eb7..f77f87d3 100644
--- a/internal/mi/session.go
+++ b/internal/mi/session.go
@@ -20,7 +20,7 @@ package mi
import (
"errors"
"fmt"
- "runtime"
+ "reflect"
"syscall"
"unsafe"
@@ -200,13 +200,22 @@ func (s *Session) QueryUnmarshal(dst any,
operationOptions = s.defaultOperationOptions
}
- errCh := make(chan error, 1)
-
- operationCallbacks, err := NewUnmarshalOperationsCallbacks(dst, errCh)
- if err != nil {
- return err
+ dv := reflect.ValueOf(dst)
+ if dv.Kind() != reflect.Ptr || dv.IsNil() {
+ return ErrInvalidEntityType
}
+ dv = dv.Elem()
+
+ elemType := dv.Type().Elem()
+ elemValue := reflect.ValueOf(reflect.New(elemType).Interface()).Elem()
+
+ if dv.Kind() != reflect.Slice || elemType.Kind() != reflect.Struct {
+ return ErrInvalidEntityType
+ }
+
+ dv.Set(reflect.MakeSlice(dv.Type(), 0, 0))
+
r0, _, _ := syscall.SyscallN(
s.ft.QueryInstances,
uintptr(unsafe.Pointer(s)),
@@ -215,7 +224,7 @@ func (s *Session) QueryUnmarshal(dst any,
uintptr(unsafe.Pointer(namespaceName)),
uintptr(unsafe.Pointer(queryDialect)),
uintptr(unsafe.Pointer(queryExpression)),
- uintptr(unsafe.Pointer(operationCallbacks)),
+ 0,
uintptr(unsafe.Pointer(operation)),
)
@@ -223,25 +232,79 @@ func (s *Session) QueryUnmarshal(dst any,
return result
}
- errs := make([]error, 0)
-
- // We need an active go routine to prevent a
- // fatal error: all goroutines are asleep - deadlock!
- // ref: https://github.com/golang/go/issues/55015
- // go time.Sleep(5 * time.Second)
+ defer func() {
+ _ = operation.Close()
+ }()
for {
- if err, ok := <-errCh; err != nil {
- errs = append(errs, err)
- } else if !ok {
+ instance, moreResults, err := operation.GetInstance()
+ if err != nil {
+ return fmt.Errorf("failed to get instance: %w", err)
+ }
+
+ if instance == nil {
+ break
+ }
+
+ counter, err := instance.GetElementCount()
+ if err != nil {
+ return fmt.Errorf("failed to get element count: %w", err)
+ }
+
+ if counter == 0 {
+ break
+ }
+
+ for i := range elemType.NumField() {
+ field := elemValue.Field(i)
+
+ // Check if the field has an `mi` tag
+ miTag := elemType.Field(i).Tag.Get("mi")
+ if miTag == "" {
+ continue
+ }
+
+ element, err := instance.GetElement(miTag)
+ if err != nil {
+ if errors.Is(err, MI_RESULT_NO_SUCH_PROPERTY) {
+ continue
+ }
+
+ return fmt.Errorf("failed to get element %s: %w", miTag, err)
+ }
+
+ switch element.valueType {
+ case ValueTypeBOOLEAN:
+ field.SetBool(element.value == 1)
+ case ValueTypeUINT8, ValueTypeUINT16, ValueTypeUINT32, ValueTypeUINT64:
+ field.SetUint(uint64(element.value))
+ case ValueTypeSINT8, ValueTypeSINT16, ValueTypeSINT32, ValueTypeSINT64:
+ field.SetInt(int64(element.value))
+ case ValueTypeSTRING:
+ if element.value == 0 {
+ // value is null
+ continue
+ }
+
+ // Convert the UTF-16 string to a Go string
+ stringValue := windows.UTF16PtrToString((*uint16)(unsafe.Pointer(element.value)))
+
+ field.SetString(stringValue)
+ case ValueTypeREAL32, ValueTypeREAL64:
+ field.SetFloat(float64(element.value))
+ default:
+ return fmt.Errorf("unsupported value type: %d", element.valueType)
+ }
+ }
+
+ dv.Set(reflect.Append(dv, elemValue))
+
+ if !moreResults {
break
}
}
- // KeepAlive is used to ensure that the callbacks are not garbage collected before the operation is closed.
- runtime.KeepAlive(operationCallbacks.CallbackContext)
-
- return errors.Join(errs...)
+ return nil
}
// Query queries for a set of instances based on a query expression.
diff --git a/internal/utils/testutils/handle.go b/internal/utils/testutils/handle.go
new file mode 100644
index 00000000..dff4182a
--- /dev/null
+++ b/internal/utils/testutils/handle.go
@@ -0,0 +1,44 @@
+// SPDX-License-Identifier: Apache-2.0
+//
+// Copyright The Prometheus Authors
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package testutils
+
+import (
+ "unsafe"
+
+ "golang.org/x/sys/windows"
+)
+
+//nolint:gochecknoglobals
+var (
+ modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
+
+ procGetProcessHandleCount = modkernel32.NewProc("GetProcessHandleCount")
+)
+
+func GetProcessHandleCount(handle windows.Handle) (uint32, error) {
+ var count uint32
+
+ r1, _, err := procGetProcessHandleCount.Call(
+ uintptr(handle),
+ uintptr(unsafe.Pointer(&count)),
+ )
+
+ if r1 != 1 {
+ return 0, err
+ }
+
+ return count, nil
+}