blob: a951b1258b04d6645df7bd66569cc48e9cf36183 [file] [log] [blame]
Serge Bazanskicc25bdf2018-10-25 14:02:58 +02001// +build windows
2
3/*
4Package wmi provides a WQL interface for WMI on Windows.
5
6Example code to print names of running processes:
7
8 type Win32_Process struct {
9 Name string
10 }
11
12 func main() {
13 var dst []Win32_Process
14 q := wmi.CreateQuery(&dst, "")
15 err := wmi.Query(q, &dst)
16 if err != nil {
17 log.Fatal(err)
18 }
19 for i, v := range dst {
20 println(i, v.Name)
21 }
22 }
23
24*/
25package wmi
26
27import (
28 "bytes"
29 "errors"
30 "fmt"
31 "log"
32 "os"
33 "reflect"
34 "runtime"
35 "strconv"
36 "strings"
37 "sync"
38 "time"
39
40 "github.com/go-ole/go-ole"
41 "github.com/go-ole/go-ole/oleutil"
42)
43
44var l = log.New(os.Stdout, "", log.LstdFlags)
45
46var (
47 ErrInvalidEntityType = errors.New("wmi: invalid entity type")
48 // ErrNilCreateObject is the error returned if CreateObject returns nil even
49 // if the error was nil.
50 ErrNilCreateObject = errors.New("wmi: create object returned nil")
51 lock sync.Mutex
52)
53
54// S_FALSE is returned by CoInitializeEx if it was already called on this thread.
55const S_FALSE = 0x00000001
56
57// QueryNamespace invokes Query with the given namespace on the local machine.
58func QueryNamespace(query string, dst interface{}, namespace string) error {
59 return Query(query, dst, nil, namespace)
60}
61
62// Query runs the WQL query and appends the values to dst.
63//
64// dst must have type *[]S or *[]*S, for some struct type S. Fields selected in
65// the query must have the same name in dst. Supported types are all signed and
66// unsigned integers, time.Time, string, bool, or a pointer to one of those.
67// Array types are not supported.
68//
69// By default, the local machine and default namespace are used. These can be
70// changed using connectServerArgs. See
71// http://msdn.microsoft.com/en-us/library/aa393720.aspx for details.
72//
73// Query is a wrapper around DefaultClient.Query.
74func Query(query string, dst interface{}, connectServerArgs ...interface{}) error {
75 if DefaultClient.SWbemServicesClient == nil {
76 return DefaultClient.Query(query, dst, connectServerArgs...)
77 }
78 return DefaultClient.SWbemServicesClient.Query(query, dst, connectServerArgs...)
79}
80
81// A Client is an WMI query client.
82//
83// Its zero value (DefaultClient) is a usable client.
84type Client struct {
85 // NonePtrZero specifies if nil values for fields which aren't pointers
86 // should be returned as the field types zero value.
87 //
88 // Setting this to true allows stucts without pointer fields to be used
89 // without the risk failure should a nil value returned from WMI.
90 NonePtrZero bool
91
92 // PtrNil specifies if nil values for pointer fields should be returned
93 // as nil.
94 //
95 // Setting this to true will set pointer fields to nil where WMI
96 // returned nil, otherwise the types zero value will be returned.
97 PtrNil bool
98
99 // AllowMissingFields specifies that struct fields not present in the
100 // query result should not result in an error.
101 //
102 // Setting this to true allows custom queries to be used with full
103 // struct definitions instead of having to define multiple structs.
104 AllowMissingFields bool
105
106 // SWbemServiceClient is an optional SWbemServices object that can be
107 // initialized and then reused across multiple queries. If it is null
108 // then the method will initialize a new temporary client each time.
109 SWbemServicesClient *SWbemServices
110}
111
112// DefaultClient is the default Client and is used by Query, QueryNamespace
113var DefaultClient = &Client{}
114
115// Query runs the WQL query and appends the values to dst.
116//
117// dst must have type *[]S or *[]*S, for some struct type S. Fields selected in
118// the query must have the same name in dst. Supported types are all signed and
119// unsigned integers, time.Time, string, bool, or a pointer to one of those.
120// Array types are not supported.
121//
122// By default, the local machine and default namespace are used. These can be
123// changed using connectServerArgs. See
124// http://msdn.microsoft.com/en-us/library/aa393720.aspx for details.
125func (c *Client) Query(query string, dst interface{}, connectServerArgs ...interface{}) error {
126 dv := reflect.ValueOf(dst)
127 if dv.Kind() != reflect.Ptr || dv.IsNil() {
128 return ErrInvalidEntityType
129 }
130 dv = dv.Elem()
131 mat, elemType := checkMultiArg(dv)
132 if mat == multiArgTypeInvalid {
133 return ErrInvalidEntityType
134 }
135
136 lock.Lock()
137 defer lock.Unlock()
138 runtime.LockOSThread()
139 defer runtime.UnlockOSThread()
140
141 err := ole.CoInitializeEx(0, ole.COINIT_MULTITHREADED)
142 if err != nil {
143 oleCode := err.(*ole.OleError).Code()
144 if oleCode != ole.S_OK && oleCode != S_FALSE {
145 return err
146 }
147 }
148 defer ole.CoUninitialize()
149
150 unknown, err := oleutil.CreateObject("WbemScripting.SWbemLocator")
151 if err != nil {
152 return err
153 } else if unknown == nil {
154 return ErrNilCreateObject
155 }
156 defer unknown.Release()
157
158 wmi, err := unknown.QueryInterface(ole.IID_IDispatch)
159 if err != nil {
160 return err
161 }
162 defer wmi.Release()
163
164 // service is a SWbemServices
165 serviceRaw, err := oleutil.CallMethod(wmi, "ConnectServer", connectServerArgs...)
166 if err != nil {
167 return err
168 }
169 service := serviceRaw.ToIDispatch()
170 defer serviceRaw.Clear()
171
172 // result is a SWBemObjectSet
173 resultRaw, err := oleutil.CallMethod(service, "ExecQuery", query)
174 if err != nil {
175 return err
176 }
177 result := resultRaw.ToIDispatch()
178 defer resultRaw.Clear()
179
180 count, err := oleInt64(result, "Count")
181 if err != nil {
182 return err
183 }
184
185 enumProperty, err := result.GetProperty("_NewEnum")
186 if err != nil {
187 return err
188 }
189 defer enumProperty.Clear()
190
191 enum, err := enumProperty.ToIUnknown().IEnumVARIANT(ole.IID_IEnumVariant)
192 if err != nil {
193 return err
194 }
195 if enum == nil {
196 return fmt.Errorf("can't get IEnumVARIANT, enum is nil")
197 }
198 defer enum.Release()
199
200 // Initialize a slice with Count capacity
201 dv.Set(reflect.MakeSlice(dv.Type(), 0, int(count)))
202
203 var errFieldMismatch error
204 for itemRaw, length, err := enum.Next(1); length > 0; itemRaw, length, err = enum.Next(1) {
205 if err != nil {
206 return err
207 }
208
209 err := func() error {
210 // item is a SWbemObject, but really a Win32_Process
211 item := itemRaw.ToIDispatch()
212 defer item.Release()
213
214 ev := reflect.New(elemType)
215 if err = c.loadEntity(ev.Interface(), item); err != nil {
216 if _, ok := err.(*ErrFieldMismatch); ok {
217 // We continue loading entities even in the face of field mismatch errors.
218 // If we encounter any other error, that other error is returned. Otherwise,
219 // an ErrFieldMismatch is returned.
220 errFieldMismatch = err
221 } else {
222 return err
223 }
224 }
225 if mat != multiArgTypeStructPtr {
226 ev = ev.Elem()
227 }
228 dv.Set(reflect.Append(dv, ev))
229 return nil
230 }()
231 if err != nil {
232 return err
233 }
234 }
235 return errFieldMismatch
236}
237
238// ErrFieldMismatch is returned when a field is to be loaded into a different
239// type than the one it was stored from, or when a field is missing or
240// unexported in the destination struct.
241// StructType is the type of the struct pointed to by the destination argument.
242type ErrFieldMismatch struct {
243 StructType reflect.Type
244 FieldName string
245 Reason string
246}
247
248func (e *ErrFieldMismatch) Error() string {
249 return fmt.Sprintf("wmi: cannot load field %q into a %q: %s",
250 e.FieldName, e.StructType, e.Reason)
251}
252
253var timeType = reflect.TypeOf(time.Time{})
254
255// loadEntity loads a SWbemObject into a struct pointer.
256func (c *Client) loadEntity(dst interface{}, src *ole.IDispatch) (errFieldMismatch error) {
257 v := reflect.ValueOf(dst).Elem()
258 for i := 0; i < v.NumField(); i++ {
259 f := v.Field(i)
260 of := f
261 isPtr := f.Kind() == reflect.Ptr
262 if isPtr {
263 ptr := reflect.New(f.Type().Elem())
264 f.Set(ptr)
265 f = f.Elem()
266 }
267 n := v.Type().Field(i).Name
268 if !f.CanSet() {
269 return &ErrFieldMismatch{
270 StructType: of.Type(),
271 FieldName: n,
272 Reason: "CanSet() is false",
273 }
274 }
275 prop, err := oleutil.GetProperty(src, n)
276 if err != nil {
277 if !c.AllowMissingFields {
278 errFieldMismatch = &ErrFieldMismatch{
279 StructType: of.Type(),
280 FieldName: n,
281 Reason: "no such struct field",
282 }
283 }
284 continue
285 }
286 defer prop.Clear()
287
288 switch val := prop.Value().(type) {
289 case int8, int16, int32, int64, int:
290 v := reflect.ValueOf(val).Int()
291 switch f.Kind() {
292 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
293 f.SetInt(v)
294 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
295 f.SetUint(uint64(v))
296 default:
297 return &ErrFieldMismatch{
298 StructType: of.Type(),
299 FieldName: n,
300 Reason: "not an integer class",
301 }
302 }
303 case uint8, uint16, uint32, uint64:
304 v := reflect.ValueOf(val).Uint()
305 switch f.Kind() {
306 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
307 f.SetInt(int64(v))
308 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
309 f.SetUint(v)
310 default:
311 return &ErrFieldMismatch{
312 StructType: of.Type(),
313 FieldName: n,
314 Reason: "not an integer class",
315 }
316 }
317 case string:
318 switch f.Kind() {
319 case reflect.String:
320 f.SetString(val)
321 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
322 iv, err := strconv.ParseInt(val, 10, 64)
323 if err != nil {
324 return err
325 }
326 f.SetInt(iv)
327 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
328 uv, err := strconv.ParseUint(val, 10, 64)
329 if err != nil {
330 return err
331 }
332 f.SetUint(uv)
333 case reflect.Struct:
334 switch f.Type() {
335 case timeType:
336 if len(val) == 25 {
337 mins, err := strconv.Atoi(val[22:])
338 if err != nil {
339 return err
340 }
341 val = val[:22] + fmt.Sprintf("%02d%02d", mins/60, mins%60)
342 }
343 t, err := time.Parse("20060102150405.000000-0700", val)
344 if err != nil {
345 return err
346 }
347 f.Set(reflect.ValueOf(t))
348 }
349 }
350 case bool:
351 switch f.Kind() {
352 case reflect.Bool:
353 f.SetBool(val)
354 default:
355 return &ErrFieldMismatch{
356 StructType: of.Type(),
357 FieldName: n,
358 Reason: "not a bool",
359 }
360 }
361 case float32:
362 switch f.Kind() {
363 case reflect.Float32:
364 f.SetFloat(float64(val))
365 default:
366 return &ErrFieldMismatch{
367 StructType: of.Type(),
368 FieldName: n,
369 Reason: "not a Float32",
370 }
371 }
372 default:
373 if f.Kind() == reflect.Slice {
374 switch f.Type().Elem().Kind() {
375 case reflect.String:
376 safeArray := prop.ToArray()
377 if safeArray != nil {
378 arr := safeArray.ToValueArray()
379 fArr := reflect.MakeSlice(f.Type(), len(arr), len(arr))
380 for i, v := range arr {
381 s := fArr.Index(i)
382 s.SetString(v.(string))
383 }
384 f.Set(fArr)
385 }
386 case reflect.Uint8:
387 safeArray := prop.ToArray()
388 if safeArray != nil {
389 arr := safeArray.ToValueArray()
390 fArr := reflect.MakeSlice(f.Type(), len(arr), len(arr))
391 for i, v := range arr {
392 s := fArr.Index(i)
393 s.SetUint(reflect.ValueOf(v).Uint())
394 }
395 f.Set(fArr)
396 }
397 default:
398 return &ErrFieldMismatch{
399 StructType: of.Type(),
400 FieldName: n,
401 Reason: fmt.Sprintf("unsupported slice type (%T)", val),
402 }
403 }
404 } else {
405 typeof := reflect.TypeOf(val)
406 if typeof == nil && (isPtr || c.NonePtrZero) {
407 if (isPtr && c.PtrNil) || (!isPtr && c.NonePtrZero) {
408 of.Set(reflect.Zero(of.Type()))
409 }
410 break
411 }
412 return &ErrFieldMismatch{
413 StructType: of.Type(),
414 FieldName: n,
415 Reason: fmt.Sprintf("unsupported type (%T)", val),
416 }
417 }
418 }
419 }
420 return errFieldMismatch
421}
422
423type multiArgType int
424
425const (
426 multiArgTypeInvalid multiArgType = iota
427 multiArgTypeStruct
428 multiArgTypeStructPtr
429)
430
431// checkMultiArg checks that v has type []S, []*S for some struct type S.
432//
433// It returns what category the slice's elements are, and the reflect.Type
434// that represents S.
435func checkMultiArg(v reflect.Value) (m multiArgType, elemType reflect.Type) {
436 if v.Kind() != reflect.Slice {
437 return multiArgTypeInvalid, nil
438 }
439 elemType = v.Type().Elem()
440 switch elemType.Kind() {
441 case reflect.Struct:
442 return multiArgTypeStruct, elemType
443 case reflect.Ptr:
444 elemType = elemType.Elem()
445 if elemType.Kind() == reflect.Struct {
446 return multiArgTypeStructPtr, elemType
447 }
448 }
449 return multiArgTypeInvalid, nil
450}
451
452func oleInt64(item *ole.IDispatch, prop string) (int64, error) {
453 v, err := oleutil.GetProperty(item, prop)
454 if err != nil {
455 return 0, err
456 }
457 defer v.Clear()
458
459 i := int64(v.Val)
460 return i, nil
461}
462
463// CreateQuery returns a WQL query string that queries all columns of src. where
464// is an optional string that is appended to the query, to be used with WHERE
465// clauses. In such a case, the "WHERE" string should appear at the beginning.
466func CreateQuery(src interface{}, where string) string {
467 var b bytes.Buffer
468 b.WriteString("SELECT ")
469 s := reflect.Indirect(reflect.ValueOf(src))
470 t := s.Type()
471 if s.Kind() == reflect.Slice {
472 t = t.Elem()
473 }
474 if t.Kind() != reflect.Struct {
475 return ""
476 }
477 var fields []string
478 for i := 0; i < t.NumField(); i++ {
479 fields = append(fields, t.Field(i).Name)
480 }
481 b.WriteString(strings.Join(fields, ", "))
482 b.WriteString(" FROM ")
483 b.WriteString(t.Name())
484 b.WriteString(" " + where)
485 return b.String()
486}