From 71cedbc4d03db64bb5d67a0d597323f3e9617fed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan-Otto=20Kr=C3=B6pke?= Date: Tue, 26 Aug 2025 21:04:56 +0200 Subject: [PATCH] mi: remove callbacks (#2188) --- .idea/dictionaries/project.xml | 1 + internal/mi/callbacks.go | 181 ----------------------------- internal/mi/mi_test.go | 53 +++++++++ internal/mi/operation.go | 14 +-- internal/mi/session.go | 103 ++++++++++++---- internal/utils/testutils/handle.go | 44 +++++++ 6 files changed, 188 insertions(+), 208 deletions(-) delete mode 100644 internal/mi/callbacks.go create mode 100644 internal/utils/testutils/handle.go diff --git a/.idea/dictionaries/project.xml b/.idea/dictionaries/project.xml index 93dd824d..ecd06d91 100644 --- a/.idea/dictionaries/project.xml +++ b/.idea/dictionaries/project.xml @@ -5,6 +5,7 @@ endpointstats gochecknoglobals luid + operationoptions setupapi spdx textfile diff --git a/internal/mi/callbacks.go b/internal/mi/callbacks.go deleted file mode 100644 index f9ad7de0..00000000 --- a/internal/mi/callbacks.go +++ /dev/null @@ -1,181 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// -// Copyright The Prometheus Authors -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build windows - -package mi - -import ( - "errors" - "fmt" - "math" - "reflect" - "sync" - "time" - "unsafe" - - "golang.org/x/sys/windows" -) - -// operationUnmarshalCallbacksInstanceResult registers a global callback function. -// The amount of system callbacks is limited to 2000. -// -//nolint:gochecknoglobals -var operationUnmarshalCallbacksInstanceResult = sync.OnceValue[uintptr](func() uintptr { - // Workaround for a deadlock issue in go. - // Ref: https://github.com/golang/go/issues/55015 - go time.Sleep(time.Duration(math.MaxInt64)) - - return windows.NewCallback(func( - operation *Operation, - callbacks *OperationUnmarshalCallbacks, - instance *Instance, - moreResults Boolean, - instanceResult ResultError, - errorMessageUTF16 *uint16, - errorDetails *Instance, - _ uintptr, - ) uintptr { - if moreResults == False { - defer operation.Close() - } - - return callbacks.InstanceResult(operation, instance, moreResults, instanceResult, errorMessageUTF16, errorDetails) - }) -}) - -type OperationUnmarshalCallbacks struct { - dst any - dv reflect.Value - errCh chan<- error - - elemType reflect.Type - elemValue reflect.Value -} - -func NewUnmarshalOperationsCallbacks(dst any, errCh chan<- error) (*OperationCallbacks[OperationUnmarshalCallbacks], error) { - dv := reflect.ValueOf(dst) - if dv.Kind() != reflect.Ptr || dv.IsNil() { - return nil, ErrInvalidEntityType - } - - dv = dv.Elem() - - elemType := dv.Type().Elem() - elemValue := reflect.ValueOf(reflect.New(elemType).Interface()).Elem() - - if dv.Kind() != reflect.Slice || elemType.Kind() != reflect.Struct { - return nil, ErrInvalidEntityType - } - - dv.Set(reflect.MakeSlice(dv.Type(), 0, 0)) - - return &OperationCallbacks[OperationUnmarshalCallbacks]{ - CallbackContext: &OperationUnmarshalCallbacks{ - errCh: errCh, - dst: dst, - dv: dv, - elemType: elemType, - elemValue: elemValue, - }, - InstanceResult: operationUnmarshalCallbacksInstanceResult(), - }, nil -} - -func (o *OperationUnmarshalCallbacks) InstanceResult( - _ *Operation, - instance *Instance, - moreResults Boolean, - instanceResult ResultError, - errorMessageUTF16 *uint16, - _ *Instance, -) uintptr { - defer func() { - if moreResults == False { - close(o.errCh) - } - }() - - if !errors.Is(instanceResult, MI_RESULT_OK) { - o.errCh <- fmt.Errorf("%w: %s", instanceResult, windows.UTF16PtrToString(errorMessageUTF16)) - - return 0 - } - - if instance == nil { - return 0 - } - - counter, err := instance.GetElementCount() - if err != nil { - o.errCh <- fmt.Errorf("failed to get element count: %w", err) - - return 0 - } - - if counter == 0 { - return 0 - } - - for i := range o.elemType.NumField() { - field := o.elemValue.Field(i) - - // Check if the field has an `mi` tag - miTag := o.elemType.Field(i).Tag.Get("mi") - if miTag == "" { - continue - } - - element, err := instance.GetElement(miTag) - if err != nil { - if errors.Is(err, MI_RESULT_NO_SUCH_PROPERTY) { - continue - } - - o.errCh <- fmt.Errorf("failed to get element %s: %w", miTag, err) - - return 0 - } - - switch element.valueType { - case ValueTypeBOOLEAN: - field.SetBool(element.value == 1) - case ValueTypeUINT8, ValueTypeUINT16, ValueTypeUINT32, ValueTypeUINT64: - field.SetUint(uint64(element.value)) - case ValueTypeSINT8, ValueTypeSINT16, ValueTypeSINT32, ValueTypeSINT64: - field.SetInt(int64(element.value)) - case ValueTypeSTRING: - if element.value == 0 { - // value is null - continue - } - - // Convert the UTF-16 string to a Go string - stringValue := windows.UTF16PtrToString((*uint16)(unsafe.Pointer(element.value))) - - field.SetString(stringValue) - case ValueTypeREAL32, ValueTypeREAL64: - field.SetFloat(float64(element.value)) - default: - o.errCh <- fmt.Errorf("unsupported value type: %d", element.valueType) - - return 0 - } - } - - o.dv.Set(reflect.Append(o.dv, o.elemValue)) - - return 0 -} diff --git a/internal/mi/mi_test.go b/internal/mi/mi_test.go index 7024c63b..1d6427b5 100644 --- a/internal/mi/mi_test.go +++ b/internal/mi/mi_test.go @@ -22,7 +22,9 @@ import ( "time" "github.com/prometheus-community/windows_exporter/internal/mi" + "github.com/prometheus-community/windows_exporter/internal/utils/testutils" "github.com/stretchr/testify/require" + "golang.org/x/sys/windows" ) type win32Process struct { @@ -234,3 +236,54 @@ func Test_MI_Query_Unmarshal(t *testing.T) { err = application.Close() require.NoError(t, err) } + +func Test_MI_FD_Leak(t *testing.T) { + application, err := mi.ApplicationInitialize() + require.NoError(t, err) + require.NotEmpty(t, application) + + session, err := application.NewSession(nil) + require.NoError(t, err) + require.NotEmpty(t, session) + + currentFileHandle, err := testutils.GetProcessHandleCount(windows.CurrentProcess()) + require.NoError(t, err) + + t.Log("Current File Handle Count: ", currentFileHandle) + + queryPrinter, err := mi.NewQuery("SELECT Name FROM Win32_Process") + require.NoError(t, err) + + for range 300 { + var processes []win32Process + + err := session.Query(&processes, mi.NamespaceRootCIMv2, queryPrinter) + require.NoError(t, err) + + currentFileHandle, err = testutils.GetProcessHandleCount(windows.CurrentProcess()) + require.NoError(t, err) + + t.Log("Current File Handle Count: ", currentFileHandle) + } + + currentFileHandle, err = testutils.GetProcessHandleCount(windows.CurrentProcess()) + require.NoError(t, err) + + t.Log("Current File Handle Count: ", currentFileHandle) + + err = session.Close() + require.NoError(t, err) + + currentFileHandle, err = testutils.GetProcessHandleCount(windows.CurrentProcess()) + require.NoError(t, err) + + t.Log("Current File Handle Count: ", currentFileHandle) + + err = application.Close() + require.NoError(t, err) + + currentFileHandle, err = testutils.GetProcessHandleCount(windows.CurrentProcess()) + require.NoError(t, err) + + t.Log("Current File Handle Count: ", currentFileHandle) +} diff --git a/internal/mi/operation.go b/internal/mi/operation.go index f90e9946..c187625b 100644 --- a/internal/mi/operation.go +++ b/internal/mi/operation.go @@ -41,11 +41,7 @@ var OperationOptionsTimeout = UTF16PtrFromString[*uint16]("__MI_OPERATIONOPTIONS type OperationFlags uint32 const ( - OperationFlagsDefaultRTTI OperationFlags = 0x0000 - OperationFlagsBasicRTTI OperationFlags = 0x0002 - OperationFlagsNoRTTI OperationFlags = 0x0400 OperationFlagsStandardRTTI OperationFlags = 0x0800 - OperationFlagsFullRTTI OperationFlags = 0x0004 ) // Operation represents an operation. @@ -123,7 +119,7 @@ func (o *Operation) Cancel() error { return ErrNotInitialized } - r0, _, _ := syscall.SyscallN(o.ft.Close, uintptr(unsafe.Pointer(o)), 0) + r0, _, _ := syscall.SyscallN(o.ft.Cancel, uintptr(unsafe.Pointer(o)), 0) if result := ResultError(r0); !errors.Is(result, MI_RESULT_OK) { return result @@ -229,10 +225,14 @@ func (o *Operation) Unmarshal(dst any) error { field.SetInt(int64(element.value)) case ValueTypeSTRING: if element.value == 0 { - return fmt.Errorf("%s: invalid pointer: value is nil", miTag) + field.SetString("") // Set empty string for nil values + + continue } - // Convert the UTF-16 string to a Go string + // Convert uintptr to *uint16 for Windows UTF-16 string + // This is safe because element.value comes directly from Windows MI API + //goland:noinspection GoVetUnsafePointer stringValue := windows.UTF16PtrToString((*uint16)(unsafe.Pointer(element.value))) field.SetString(stringValue) diff --git a/internal/mi/session.go b/internal/mi/session.go index 33fb7eb7..f77f87d3 100644 --- a/internal/mi/session.go +++ b/internal/mi/session.go @@ -20,7 +20,7 @@ package mi import ( "errors" "fmt" - "runtime" + "reflect" "syscall" "unsafe" @@ -200,13 +200,22 @@ func (s *Session) QueryUnmarshal(dst any, operationOptions = s.defaultOperationOptions } - errCh := make(chan error, 1) - - operationCallbacks, err := NewUnmarshalOperationsCallbacks(dst, errCh) - if err != nil { - return err + dv := reflect.ValueOf(dst) + if dv.Kind() != reflect.Ptr || dv.IsNil() { + return ErrInvalidEntityType } + dv = dv.Elem() + + elemType := dv.Type().Elem() + elemValue := reflect.ValueOf(reflect.New(elemType).Interface()).Elem() + + if dv.Kind() != reflect.Slice || elemType.Kind() != reflect.Struct { + return ErrInvalidEntityType + } + + dv.Set(reflect.MakeSlice(dv.Type(), 0, 0)) + r0, _, _ := syscall.SyscallN( s.ft.QueryInstances, uintptr(unsafe.Pointer(s)), @@ -215,7 +224,7 @@ func (s *Session) QueryUnmarshal(dst any, uintptr(unsafe.Pointer(namespaceName)), uintptr(unsafe.Pointer(queryDialect)), uintptr(unsafe.Pointer(queryExpression)), - uintptr(unsafe.Pointer(operationCallbacks)), + 0, uintptr(unsafe.Pointer(operation)), ) @@ -223,25 +232,79 @@ func (s *Session) QueryUnmarshal(dst any, return result } - errs := make([]error, 0) - - // We need an active go routine to prevent a - // fatal error: all goroutines are asleep - deadlock! - // ref: https://github.com/golang/go/issues/55015 - // go time.Sleep(5 * time.Second) + defer func() { + _ = operation.Close() + }() for { - if err, ok := <-errCh; err != nil { - errs = append(errs, err) - } else if !ok { + instance, moreResults, err := operation.GetInstance() + if err != nil { + return fmt.Errorf("failed to get instance: %w", err) + } + + if instance == nil { + break + } + + counter, err := instance.GetElementCount() + if err != nil { + return fmt.Errorf("failed to get element count: %w", err) + } + + if counter == 0 { + break + } + + for i := range elemType.NumField() { + field := elemValue.Field(i) + + // Check if the field has an `mi` tag + miTag := elemType.Field(i).Tag.Get("mi") + if miTag == "" { + continue + } + + element, err := instance.GetElement(miTag) + if err != nil { + if errors.Is(err, MI_RESULT_NO_SUCH_PROPERTY) { + continue + } + + return fmt.Errorf("failed to get element %s: %w", miTag, err) + } + + switch element.valueType { + case ValueTypeBOOLEAN: + field.SetBool(element.value == 1) + case ValueTypeUINT8, ValueTypeUINT16, ValueTypeUINT32, ValueTypeUINT64: + field.SetUint(uint64(element.value)) + case ValueTypeSINT8, ValueTypeSINT16, ValueTypeSINT32, ValueTypeSINT64: + field.SetInt(int64(element.value)) + case ValueTypeSTRING: + if element.value == 0 { + // value is null + continue + } + + // Convert the UTF-16 string to a Go string + stringValue := windows.UTF16PtrToString((*uint16)(unsafe.Pointer(element.value))) + + field.SetString(stringValue) + case ValueTypeREAL32, ValueTypeREAL64: + field.SetFloat(float64(element.value)) + default: + return fmt.Errorf("unsupported value type: %d", element.valueType) + } + } + + dv.Set(reflect.Append(dv, elemValue)) + + if !moreResults { break } } - // KeepAlive is used to ensure that the callbacks are not garbage collected before the operation is closed. - runtime.KeepAlive(operationCallbacks.CallbackContext) - - return errors.Join(errs...) + return nil } // Query queries for a set of instances based on a query expression. diff --git a/internal/utils/testutils/handle.go b/internal/utils/testutils/handle.go new file mode 100644 index 00000000..dff4182a --- /dev/null +++ b/internal/utils/testutils/handle.go @@ -0,0 +1,44 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright The Prometheus Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testutils + +import ( + "unsafe" + + "golang.org/x/sys/windows" +) + +//nolint:gochecknoglobals +var ( + modkernel32 = windows.NewLazySystemDLL("kernel32.dll") + + procGetProcessHandleCount = modkernel32.NewProc("GetProcessHandleCount") +) + +func GetProcessHandleCount(handle windows.Handle) (uint32, error) { + var count uint32 + + r1, _, err := procGetProcessHandleCount.Call( + uintptr(handle), + uintptr(unsafe.Pointer(&count)), + ) + + if r1 != 1 { + return 0, err + } + + return count, nil +}