mi: remove callbacks (#2188)

This commit is contained in:
Jan-Otto Kröpke
2025-08-26 21:04:56 +02:00
committed by GitHub
parent c8a4cb3806
commit 71cedbc4d0
6 changed files with 188 additions and 208 deletions

View File

@@ -5,6 +5,7 @@
<w>endpointstats</w>
<w>gochecknoglobals</w>
<w>luid</w>
<w>operationoptions</w>
<w>setupapi</w>
<w>spdx</w>
<w>textfile</w>

View File

@@ -1,181 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
//
// Copyright The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//go:build windows
package mi
import (
"errors"
"fmt"
"math"
"reflect"
"sync"
"time"
"unsafe"
"golang.org/x/sys/windows"
)
// operationUnmarshalCallbacksInstanceResult registers a global callback function.
// The amount of system callbacks is limited to 2000.
//
//nolint:gochecknoglobals
var operationUnmarshalCallbacksInstanceResult = sync.OnceValue[uintptr](func() uintptr {
// Workaround for a deadlock issue in go.
// Ref: https://github.com/golang/go/issues/55015
go time.Sleep(time.Duration(math.MaxInt64))
return windows.NewCallback(func(
operation *Operation,
callbacks *OperationUnmarshalCallbacks,
instance *Instance,
moreResults Boolean,
instanceResult ResultError,
errorMessageUTF16 *uint16,
errorDetails *Instance,
_ uintptr,
) uintptr {
if moreResults == False {
defer operation.Close()
}
return callbacks.InstanceResult(operation, instance, moreResults, instanceResult, errorMessageUTF16, errorDetails)
})
})
type OperationUnmarshalCallbacks struct {
dst any
dv reflect.Value
errCh chan<- error
elemType reflect.Type
elemValue reflect.Value
}
func NewUnmarshalOperationsCallbacks(dst any, errCh chan<- error) (*OperationCallbacks[OperationUnmarshalCallbacks], error) {
dv := reflect.ValueOf(dst)
if dv.Kind() != reflect.Ptr || dv.IsNil() {
return nil, ErrInvalidEntityType
}
dv = dv.Elem()
elemType := dv.Type().Elem()
elemValue := reflect.ValueOf(reflect.New(elemType).Interface()).Elem()
if dv.Kind() != reflect.Slice || elemType.Kind() != reflect.Struct {
return nil, ErrInvalidEntityType
}
dv.Set(reflect.MakeSlice(dv.Type(), 0, 0))
return &OperationCallbacks[OperationUnmarshalCallbacks]{
CallbackContext: &OperationUnmarshalCallbacks{
errCh: errCh,
dst: dst,
dv: dv,
elemType: elemType,
elemValue: elemValue,
},
InstanceResult: operationUnmarshalCallbacksInstanceResult(),
}, nil
}
func (o *OperationUnmarshalCallbacks) InstanceResult(
_ *Operation,
instance *Instance,
moreResults Boolean,
instanceResult ResultError,
errorMessageUTF16 *uint16,
_ *Instance,
) uintptr {
defer func() {
if moreResults == False {
close(o.errCh)
}
}()
if !errors.Is(instanceResult, MI_RESULT_OK) {
o.errCh <- fmt.Errorf("%w: %s", instanceResult, windows.UTF16PtrToString(errorMessageUTF16))
return 0
}
if instance == nil {
return 0
}
counter, err := instance.GetElementCount()
if err != nil {
o.errCh <- fmt.Errorf("failed to get element count: %w", err)
return 0
}
if counter == 0 {
return 0
}
for i := range o.elemType.NumField() {
field := o.elemValue.Field(i)
// Check if the field has an `mi` tag
miTag := o.elemType.Field(i).Tag.Get("mi")
if miTag == "" {
continue
}
element, err := instance.GetElement(miTag)
if err != nil {
if errors.Is(err, MI_RESULT_NO_SUCH_PROPERTY) {
continue
}
o.errCh <- fmt.Errorf("failed to get element %s: %w", miTag, err)
return 0
}
switch element.valueType {
case ValueTypeBOOLEAN:
field.SetBool(element.value == 1)
case ValueTypeUINT8, ValueTypeUINT16, ValueTypeUINT32, ValueTypeUINT64:
field.SetUint(uint64(element.value))
case ValueTypeSINT8, ValueTypeSINT16, ValueTypeSINT32, ValueTypeSINT64:
field.SetInt(int64(element.value))
case ValueTypeSTRING:
if element.value == 0 {
// value is null
continue
}
// Convert the UTF-16 string to a Go string
stringValue := windows.UTF16PtrToString((*uint16)(unsafe.Pointer(element.value)))
field.SetString(stringValue)
case ValueTypeREAL32, ValueTypeREAL64:
field.SetFloat(float64(element.value))
default:
o.errCh <- fmt.Errorf("unsupported value type: %d", element.valueType)
return 0
}
}
o.dv.Set(reflect.Append(o.dv, o.elemValue))
return 0
}

View File

@@ -22,7 +22,9 @@ import (
"time"
"github.com/prometheus-community/windows_exporter/internal/mi"
"github.com/prometheus-community/windows_exporter/internal/utils/testutils"
"github.com/stretchr/testify/require"
"golang.org/x/sys/windows"
)
type win32Process struct {
@@ -234,3 +236,54 @@ func Test_MI_Query_Unmarshal(t *testing.T) {
err = application.Close()
require.NoError(t, err)
}
func Test_MI_FD_Leak(t *testing.T) {
application, err := mi.ApplicationInitialize()
require.NoError(t, err)
require.NotEmpty(t, application)
session, err := application.NewSession(nil)
require.NoError(t, err)
require.NotEmpty(t, session)
currentFileHandle, err := testutils.GetProcessHandleCount(windows.CurrentProcess())
require.NoError(t, err)
t.Log("Current File Handle Count: ", currentFileHandle)
queryPrinter, err := mi.NewQuery("SELECT Name FROM Win32_Process")
require.NoError(t, err)
for range 300 {
var processes []win32Process
err := session.Query(&processes, mi.NamespaceRootCIMv2, queryPrinter)
require.NoError(t, err)
currentFileHandle, err = testutils.GetProcessHandleCount(windows.CurrentProcess())
require.NoError(t, err)
t.Log("Current File Handle Count: ", currentFileHandle)
}
currentFileHandle, err = testutils.GetProcessHandleCount(windows.CurrentProcess())
require.NoError(t, err)
t.Log("Current File Handle Count: ", currentFileHandle)
err = session.Close()
require.NoError(t, err)
currentFileHandle, err = testutils.GetProcessHandleCount(windows.CurrentProcess())
require.NoError(t, err)
t.Log("Current File Handle Count: ", currentFileHandle)
err = application.Close()
require.NoError(t, err)
currentFileHandle, err = testutils.GetProcessHandleCount(windows.CurrentProcess())
require.NoError(t, err)
t.Log("Current File Handle Count: ", currentFileHandle)
}

View File

@@ -41,11 +41,7 @@ var OperationOptionsTimeout = UTF16PtrFromString[*uint16]("__MI_OPERATIONOPTIONS
type OperationFlags uint32
const (
OperationFlagsDefaultRTTI OperationFlags = 0x0000
OperationFlagsBasicRTTI OperationFlags = 0x0002
OperationFlagsNoRTTI OperationFlags = 0x0400
OperationFlagsStandardRTTI OperationFlags = 0x0800
OperationFlagsFullRTTI OperationFlags = 0x0004
)
// Operation represents an operation.
@@ -123,7 +119,7 @@ func (o *Operation) Cancel() error {
return ErrNotInitialized
}
r0, _, _ := syscall.SyscallN(o.ft.Close, uintptr(unsafe.Pointer(o)), 0)
r0, _, _ := syscall.SyscallN(o.ft.Cancel, uintptr(unsafe.Pointer(o)), 0)
if result := ResultError(r0); !errors.Is(result, MI_RESULT_OK) {
return result
@@ -229,10 +225,14 @@ func (o *Operation) Unmarshal(dst any) error {
field.SetInt(int64(element.value))
case ValueTypeSTRING:
if element.value == 0 {
return fmt.Errorf("%s: invalid pointer: value is nil", miTag)
field.SetString("") // Set empty string for nil values
continue
}
// Convert the UTF-16 string to a Go string
// Convert uintptr to *uint16 for Windows UTF-16 string
// This is safe because element.value comes directly from Windows MI API
//goland:noinspection GoVetUnsafePointer
stringValue := windows.UTF16PtrToString((*uint16)(unsafe.Pointer(element.value)))
field.SetString(stringValue)

View File

@@ -20,7 +20,7 @@ package mi
import (
"errors"
"fmt"
"runtime"
"reflect"
"syscall"
"unsafe"
@@ -200,13 +200,22 @@ func (s *Session) QueryUnmarshal(dst any,
operationOptions = s.defaultOperationOptions
}
errCh := make(chan error, 1)
operationCallbacks, err := NewUnmarshalOperationsCallbacks(dst, errCh)
if err != nil {
return err
dv := reflect.ValueOf(dst)
if dv.Kind() != reflect.Ptr || dv.IsNil() {
return ErrInvalidEntityType
}
dv = dv.Elem()
elemType := dv.Type().Elem()
elemValue := reflect.ValueOf(reflect.New(elemType).Interface()).Elem()
if dv.Kind() != reflect.Slice || elemType.Kind() != reflect.Struct {
return ErrInvalidEntityType
}
dv.Set(reflect.MakeSlice(dv.Type(), 0, 0))
r0, _, _ := syscall.SyscallN(
s.ft.QueryInstances,
uintptr(unsafe.Pointer(s)),
@@ -215,7 +224,7 @@ func (s *Session) QueryUnmarshal(dst any,
uintptr(unsafe.Pointer(namespaceName)),
uintptr(unsafe.Pointer(queryDialect)),
uintptr(unsafe.Pointer(queryExpression)),
uintptr(unsafe.Pointer(operationCallbacks)),
0,
uintptr(unsafe.Pointer(operation)),
)
@@ -223,25 +232,79 @@ func (s *Session) QueryUnmarshal(dst any,
return result
}
errs := make([]error, 0)
// We need an active go routine to prevent a
// fatal error: all goroutines are asleep - deadlock!
// ref: https://github.com/golang/go/issues/55015
// go time.Sleep(5 * time.Second)
defer func() {
_ = operation.Close()
}()
for {
if err, ok := <-errCh; err != nil {
errs = append(errs, err)
} else if !ok {
instance, moreResults, err := operation.GetInstance()
if err != nil {
return fmt.Errorf("failed to get instance: %w", err)
}
if instance == nil {
break
}
counter, err := instance.GetElementCount()
if err != nil {
return fmt.Errorf("failed to get element count: %w", err)
}
if counter == 0 {
break
}
for i := range elemType.NumField() {
field := elemValue.Field(i)
// Check if the field has an `mi` tag
miTag := elemType.Field(i).Tag.Get("mi")
if miTag == "" {
continue
}
element, err := instance.GetElement(miTag)
if err != nil {
if errors.Is(err, MI_RESULT_NO_SUCH_PROPERTY) {
continue
}
return fmt.Errorf("failed to get element %s: %w", miTag, err)
}
switch element.valueType {
case ValueTypeBOOLEAN:
field.SetBool(element.value == 1)
case ValueTypeUINT8, ValueTypeUINT16, ValueTypeUINT32, ValueTypeUINT64:
field.SetUint(uint64(element.value))
case ValueTypeSINT8, ValueTypeSINT16, ValueTypeSINT32, ValueTypeSINT64:
field.SetInt(int64(element.value))
case ValueTypeSTRING:
if element.value == 0 {
// value is null
continue
}
// Convert the UTF-16 string to a Go string
stringValue := windows.UTF16PtrToString((*uint16)(unsafe.Pointer(element.value)))
field.SetString(stringValue)
case ValueTypeREAL32, ValueTypeREAL64:
field.SetFloat(float64(element.value))
default:
return fmt.Errorf("unsupported value type: %d", element.valueType)
}
}
dv.Set(reflect.Append(dv, elemValue))
if !moreResults {
break
}
}
// KeepAlive is used to ensure that the callbacks are not garbage collected before the operation is closed.
runtime.KeepAlive(operationCallbacks.CallbackContext)
return errors.Join(errs...)
return nil
}
// Query queries for a set of instances based on a query expression.

View File

@@ -0,0 +1,44 @@
// SPDX-License-Identifier: Apache-2.0
//
// Copyright The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package testutils
import (
"unsafe"
"golang.org/x/sys/windows"
)
//nolint:gochecknoglobals
var (
modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
procGetProcessHandleCount = modkernel32.NewProc("GetProcessHandleCount")
)
func GetProcessHandleCount(handle windows.Handle) (uint32, error) {
var count uint32
r1, _, err := procGetProcessHandleCount.Call(
uintptr(handle),
uintptr(unsafe.Pointer(&count)),
)
if r1 != 1 {
return 0, err
}
return count, nil
}