Serge Bazanski | cc25bdf | 2018-10-25 14:02:58 +0200 | [diff] [blame] | 1 | // +build windows |
| 2 | |
| 3 | /* |
| 4 | Package wmi provides a WQL interface for WMI on Windows. |
| 5 | |
| 6 | Example code to print names of running processes: |
| 7 | |
| 8 | type Win32_Process struct { |
| 9 | Name string |
| 10 | } |
| 11 | |
| 12 | func main() { |
| 13 | var dst []Win32_Process |
| 14 | q := wmi.CreateQuery(&dst, "") |
| 15 | err := wmi.Query(q, &dst) |
| 16 | if err != nil { |
| 17 | log.Fatal(err) |
| 18 | } |
| 19 | for i, v := range dst { |
| 20 | println(i, v.Name) |
| 21 | } |
| 22 | } |
| 23 | |
| 24 | */ |
| 25 | package wmi |
| 26 | |
| 27 | import ( |
| 28 | "bytes" |
| 29 | "errors" |
| 30 | "fmt" |
| 31 | "log" |
| 32 | "os" |
| 33 | "reflect" |
| 34 | "runtime" |
| 35 | "strconv" |
| 36 | "strings" |
| 37 | "sync" |
| 38 | "time" |
| 39 | |
| 40 | "github.com/go-ole/go-ole" |
| 41 | "github.com/go-ole/go-ole/oleutil" |
| 42 | ) |
| 43 | |
| 44 | var l = log.New(os.Stdout, "", log.LstdFlags) |
| 45 | |
| 46 | var ( |
| 47 | ErrInvalidEntityType = errors.New("wmi: invalid entity type") |
| 48 | // ErrNilCreateObject is the error returned if CreateObject returns nil even |
| 49 | // if the error was nil. |
| 50 | ErrNilCreateObject = errors.New("wmi: create object returned nil") |
| 51 | lock sync.Mutex |
| 52 | ) |
| 53 | |
| 54 | // S_FALSE is returned by CoInitializeEx if it was already called on this thread. |
| 55 | const S_FALSE = 0x00000001 |
| 56 | |
| 57 | // QueryNamespace invokes Query with the given namespace on the local machine. |
| 58 | func QueryNamespace(query string, dst interface{}, namespace string) error { |
| 59 | return Query(query, dst, nil, namespace) |
| 60 | } |
| 61 | |
| 62 | // Query runs the WQL query and appends the values to dst. |
| 63 | // |
| 64 | // dst must have type *[]S or *[]*S, for some struct type S. Fields selected in |
| 65 | // the query must have the same name in dst. Supported types are all signed and |
| 66 | // unsigned integers, time.Time, string, bool, or a pointer to one of those. |
| 67 | // Array types are not supported. |
| 68 | // |
| 69 | // By default, the local machine and default namespace are used. These can be |
| 70 | // changed using connectServerArgs. See |
| 71 | // http://msdn.microsoft.com/en-us/library/aa393720.aspx for details. |
| 72 | // |
| 73 | // Query is a wrapper around DefaultClient.Query. |
| 74 | func Query(query string, dst interface{}, connectServerArgs ...interface{}) error { |
| 75 | if DefaultClient.SWbemServicesClient == nil { |
| 76 | return DefaultClient.Query(query, dst, connectServerArgs...) |
| 77 | } |
| 78 | return DefaultClient.SWbemServicesClient.Query(query, dst, connectServerArgs...) |
| 79 | } |
| 80 | |
| 81 | // A Client is an WMI query client. |
| 82 | // |
| 83 | // Its zero value (DefaultClient) is a usable client. |
| 84 | type Client struct { |
| 85 | // NonePtrZero specifies if nil values for fields which aren't pointers |
| 86 | // should be returned as the field types zero value. |
| 87 | // |
| 88 | // Setting this to true allows stucts without pointer fields to be used |
| 89 | // without the risk failure should a nil value returned from WMI. |
| 90 | NonePtrZero bool |
| 91 | |
| 92 | // PtrNil specifies if nil values for pointer fields should be returned |
| 93 | // as nil. |
| 94 | // |
| 95 | // Setting this to true will set pointer fields to nil where WMI |
| 96 | // returned nil, otherwise the types zero value will be returned. |
| 97 | PtrNil bool |
| 98 | |
| 99 | // AllowMissingFields specifies that struct fields not present in the |
| 100 | // query result should not result in an error. |
| 101 | // |
| 102 | // Setting this to true allows custom queries to be used with full |
| 103 | // struct definitions instead of having to define multiple structs. |
| 104 | AllowMissingFields bool |
| 105 | |
| 106 | // SWbemServiceClient is an optional SWbemServices object that can be |
| 107 | // initialized and then reused across multiple queries. If it is null |
| 108 | // then the method will initialize a new temporary client each time. |
| 109 | SWbemServicesClient *SWbemServices |
| 110 | } |
| 111 | |
| 112 | // DefaultClient is the default Client and is used by Query, QueryNamespace |
| 113 | var DefaultClient = &Client{} |
| 114 | |
| 115 | // Query runs the WQL query and appends the values to dst. |
| 116 | // |
| 117 | // dst must have type *[]S or *[]*S, for some struct type S. Fields selected in |
| 118 | // the query must have the same name in dst. Supported types are all signed and |
| 119 | // unsigned integers, time.Time, string, bool, or a pointer to one of those. |
| 120 | // Array types are not supported. |
| 121 | // |
| 122 | // By default, the local machine and default namespace are used. These can be |
| 123 | // changed using connectServerArgs. See |
| 124 | // http://msdn.microsoft.com/en-us/library/aa393720.aspx for details. |
| 125 | func (c *Client) Query(query string, dst interface{}, connectServerArgs ...interface{}) error { |
| 126 | dv := reflect.ValueOf(dst) |
| 127 | if dv.Kind() != reflect.Ptr || dv.IsNil() { |
| 128 | return ErrInvalidEntityType |
| 129 | } |
| 130 | dv = dv.Elem() |
| 131 | mat, elemType := checkMultiArg(dv) |
| 132 | if mat == multiArgTypeInvalid { |
| 133 | return ErrInvalidEntityType |
| 134 | } |
| 135 | |
| 136 | lock.Lock() |
| 137 | defer lock.Unlock() |
| 138 | runtime.LockOSThread() |
| 139 | defer runtime.UnlockOSThread() |
| 140 | |
| 141 | err := ole.CoInitializeEx(0, ole.COINIT_MULTITHREADED) |
| 142 | if err != nil { |
| 143 | oleCode := err.(*ole.OleError).Code() |
| 144 | if oleCode != ole.S_OK && oleCode != S_FALSE { |
| 145 | return err |
| 146 | } |
| 147 | } |
| 148 | defer ole.CoUninitialize() |
| 149 | |
| 150 | unknown, err := oleutil.CreateObject("WbemScripting.SWbemLocator") |
| 151 | if err != nil { |
| 152 | return err |
| 153 | } else if unknown == nil { |
| 154 | return ErrNilCreateObject |
| 155 | } |
| 156 | defer unknown.Release() |
| 157 | |
| 158 | wmi, err := unknown.QueryInterface(ole.IID_IDispatch) |
| 159 | if err != nil { |
| 160 | return err |
| 161 | } |
| 162 | defer wmi.Release() |
| 163 | |
| 164 | // service is a SWbemServices |
| 165 | serviceRaw, err := oleutil.CallMethod(wmi, "ConnectServer", connectServerArgs...) |
| 166 | if err != nil { |
| 167 | return err |
| 168 | } |
| 169 | service := serviceRaw.ToIDispatch() |
| 170 | defer serviceRaw.Clear() |
| 171 | |
| 172 | // result is a SWBemObjectSet |
| 173 | resultRaw, err := oleutil.CallMethod(service, "ExecQuery", query) |
| 174 | if err != nil { |
| 175 | return err |
| 176 | } |
| 177 | result := resultRaw.ToIDispatch() |
| 178 | defer resultRaw.Clear() |
| 179 | |
| 180 | count, err := oleInt64(result, "Count") |
| 181 | if err != nil { |
| 182 | return err |
| 183 | } |
| 184 | |
| 185 | enumProperty, err := result.GetProperty("_NewEnum") |
| 186 | if err != nil { |
| 187 | return err |
| 188 | } |
| 189 | defer enumProperty.Clear() |
| 190 | |
| 191 | enum, err := enumProperty.ToIUnknown().IEnumVARIANT(ole.IID_IEnumVariant) |
| 192 | if err != nil { |
| 193 | return err |
| 194 | } |
| 195 | if enum == nil { |
| 196 | return fmt.Errorf("can't get IEnumVARIANT, enum is nil") |
| 197 | } |
| 198 | defer enum.Release() |
| 199 | |
| 200 | // Initialize a slice with Count capacity |
| 201 | dv.Set(reflect.MakeSlice(dv.Type(), 0, int(count))) |
| 202 | |
| 203 | var errFieldMismatch error |
| 204 | for itemRaw, length, err := enum.Next(1); length > 0; itemRaw, length, err = enum.Next(1) { |
| 205 | if err != nil { |
| 206 | return err |
| 207 | } |
| 208 | |
| 209 | err := func() error { |
| 210 | // item is a SWbemObject, but really a Win32_Process |
| 211 | item := itemRaw.ToIDispatch() |
| 212 | defer item.Release() |
| 213 | |
| 214 | ev := reflect.New(elemType) |
| 215 | if err = c.loadEntity(ev.Interface(), item); err != nil { |
| 216 | if _, ok := err.(*ErrFieldMismatch); ok { |
| 217 | // We continue loading entities even in the face of field mismatch errors. |
| 218 | // If we encounter any other error, that other error is returned. Otherwise, |
| 219 | // an ErrFieldMismatch is returned. |
| 220 | errFieldMismatch = err |
| 221 | } else { |
| 222 | return err |
| 223 | } |
| 224 | } |
| 225 | if mat != multiArgTypeStructPtr { |
| 226 | ev = ev.Elem() |
| 227 | } |
| 228 | dv.Set(reflect.Append(dv, ev)) |
| 229 | return nil |
| 230 | }() |
| 231 | if err != nil { |
| 232 | return err |
| 233 | } |
| 234 | } |
| 235 | return errFieldMismatch |
| 236 | } |
| 237 | |
| 238 | // ErrFieldMismatch is returned when a field is to be loaded into a different |
| 239 | // type than the one it was stored from, or when a field is missing or |
| 240 | // unexported in the destination struct. |
| 241 | // StructType is the type of the struct pointed to by the destination argument. |
| 242 | type ErrFieldMismatch struct { |
| 243 | StructType reflect.Type |
| 244 | FieldName string |
| 245 | Reason string |
| 246 | } |
| 247 | |
| 248 | func (e *ErrFieldMismatch) Error() string { |
| 249 | return fmt.Sprintf("wmi: cannot load field %q into a %q: %s", |
| 250 | e.FieldName, e.StructType, e.Reason) |
| 251 | } |
| 252 | |
| 253 | var timeType = reflect.TypeOf(time.Time{}) |
| 254 | |
| 255 | // loadEntity loads a SWbemObject into a struct pointer. |
| 256 | func (c *Client) loadEntity(dst interface{}, src *ole.IDispatch) (errFieldMismatch error) { |
| 257 | v := reflect.ValueOf(dst).Elem() |
| 258 | for i := 0; i < v.NumField(); i++ { |
| 259 | f := v.Field(i) |
| 260 | of := f |
| 261 | isPtr := f.Kind() == reflect.Ptr |
| 262 | if isPtr { |
| 263 | ptr := reflect.New(f.Type().Elem()) |
| 264 | f.Set(ptr) |
| 265 | f = f.Elem() |
| 266 | } |
| 267 | n := v.Type().Field(i).Name |
| 268 | if !f.CanSet() { |
| 269 | return &ErrFieldMismatch{ |
| 270 | StructType: of.Type(), |
| 271 | FieldName: n, |
| 272 | Reason: "CanSet() is false", |
| 273 | } |
| 274 | } |
| 275 | prop, err := oleutil.GetProperty(src, n) |
| 276 | if err != nil { |
| 277 | if !c.AllowMissingFields { |
| 278 | errFieldMismatch = &ErrFieldMismatch{ |
| 279 | StructType: of.Type(), |
| 280 | FieldName: n, |
| 281 | Reason: "no such struct field", |
| 282 | } |
| 283 | } |
| 284 | continue |
| 285 | } |
| 286 | defer prop.Clear() |
| 287 | |
| 288 | switch val := prop.Value().(type) { |
| 289 | case int8, int16, int32, int64, int: |
| 290 | v := reflect.ValueOf(val).Int() |
| 291 | switch f.Kind() { |
| 292 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: |
| 293 | f.SetInt(v) |
| 294 | case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: |
| 295 | f.SetUint(uint64(v)) |
| 296 | default: |
| 297 | return &ErrFieldMismatch{ |
| 298 | StructType: of.Type(), |
| 299 | FieldName: n, |
| 300 | Reason: "not an integer class", |
| 301 | } |
| 302 | } |
| 303 | case uint8, uint16, uint32, uint64: |
| 304 | v := reflect.ValueOf(val).Uint() |
| 305 | switch f.Kind() { |
| 306 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: |
| 307 | f.SetInt(int64(v)) |
| 308 | case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: |
| 309 | f.SetUint(v) |
| 310 | default: |
| 311 | return &ErrFieldMismatch{ |
| 312 | StructType: of.Type(), |
| 313 | FieldName: n, |
| 314 | Reason: "not an integer class", |
| 315 | } |
| 316 | } |
| 317 | case string: |
| 318 | switch f.Kind() { |
| 319 | case reflect.String: |
| 320 | f.SetString(val) |
| 321 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: |
| 322 | iv, err := strconv.ParseInt(val, 10, 64) |
| 323 | if err != nil { |
| 324 | return err |
| 325 | } |
| 326 | f.SetInt(iv) |
| 327 | case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: |
| 328 | uv, err := strconv.ParseUint(val, 10, 64) |
| 329 | if err != nil { |
| 330 | return err |
| 331 | } |
| 332 | f.SetUint(uv) |
| 333 | case reflect.Struct: |
| 334 | switch f.Type() { |
| 335 | case timeType: |
| 336 | if len(val) == 25 { |
| 337 | mins, err := strconv.Atoi(val[22:]) |
| 338 | if err != nil { |
| 339 | return err |
| 340 | } |
| 341 | val = val[:22] + fmt.Sprintf("%02d%02d", mins/60, mins%60) |
| 342 | } |
| 343 | t, err := time.Parse("20060102150405.000000-0700", val) |
| 344 | if err != nil { |
| 345 | return err |
| 346 | } |
| 347 | f.Set(reflect.ValueOf(t)) |
| 348 | } |
| 349 | } |
| 350 | case bool: |
| 351 | switch f.Kind() { |
| 352 | case reflect.Bool: |
| 353 | f.SetBool(val) |
| 354 | default: |
| 355 | return &ErrFieldMismatch{ |
| 356 | StructType: of.Type(), |
| 357 | FieldName: n, |
| 358 | Reason: "not a bool", |
| 359 | } |
| 360 | } |
| 361 | case float32: |
| 362 | switch f.Kind() { |
| 363 | case reflect.Float32: |
| 364 | f.SetFloat(float64(val)) |
| 365 | default: |
| 366 | return &ErrFieldMismatch{ |
| 367 | StructType: of.Type(), |
| 368 | FieldName: n, |
| 369 | Reason: "not a Float32", |
| 370 | } |
| 371 | } |
| 372 | default: |
| 373 | if f.Kind() == reflect.Slice { |
| 374 | switch f.Type().Elem().Kind() { |
| 375 | case reflect.String: |
| 376 | safeArray := prop.ToArray() |
| 377 | if safeArray != nil { |
| 378 | arr := safeArray.ToValueArray() |
| 379 | fArr := reflect.MakeSlice(f.Type(), len(arr), len(arr)) |
| 380 | for i, v := range arr { |
| 381 | s := fArr.Index(i) |
| 382 | s.SetString(v.(string)) |
| 383 | } |
| 384 | f.Set(fArr) |
| 385 | } |
| 386 | case reflect.Uint8: |
| 387 | safeArray := prop.ToArray() |
| 388 | if safeArray != nil { |
| 389 | arr := safeArray.ToValueArray() |
| 390 | fArr := reflect.MakeSlice(f.Type(), len(arr), len(arr)) |
| 391 | for i, v := range arr { |
| 392 | s := fArr.Index(i) |
| 393 | s.SetUint(reflect.ValueOf(v).Uint()) |
| 394 | } |
| 395 | f.Set(fArr) |
| 396 | } |
| 397 | default: |
| 398 | return &ErrFieldMismatch{ |
| 399 | StructType: of.Type(), |
| 400 | FieldName: n, |
| 401 | Reason: fmt.Sprintf("unsupported slice type (%T)", val), |
| 402 | } |
| 403 | } |
| 404 | } else { |
| 405 | typeof := reflect.TypeOf(val) |
| 406 | if typeof == nil && (isPtr || c.NonePtrZero) { |
| 407 | if (isPtr && c.PtrNil) || (!isPtr && c.NonePtrZero) { |
| 408 | of.Set(reflect.Zero(of.Type())) |
| 409 | } |
| 410 | break |
| 411 | } |
| 412 | return &ErrFieldMismatch{ |
| 413 | StructType: of.Type(), |
| 414 | FieldName: n, |
| 415 | Reason: fmt.Sprintf("unsupported type (%T)", val), |
| 416 | } |
| 417 | } |
| 418 | } |
| 419 | } |
| 420 | return errFieldMismatch |
| 421 | } |
| 422 | |
| 423 | type multiArgType int |
| 424 | |
| 425 | const ( |
| 426 | multiArgTypeInvalid multiArgType = iota |
| 427 | multiArgTypeStruct |
| 428 | multiArgTypeStructPtr |
| 429 | ) |
| 430 | |
| 431 | // checkMultiArg checks that v has type []S, []*S for some struct type S. |
| 432 | // |
| 433 | // It returns what category the slice's elements are, and the reflect.Type |
| 434 | // that represents S. |
| 435 | func checkMultiArg(v reflect.Value) (m multiArgType, elemType reflect.Type) { |
| 436 | if v.Kind() != reflect.Slice { |
| 437 | return multiArgTypeInvalid, nil |
| 438 | } |
| 439 | elemType = v.Type().Elem() |
| 440 | switch elemType.Kind() { |
| 441 | case reflect.Struct: |
| 442 | return multiArgTypeStruct, elemType |
| 443 | case reflect.Ptr: |
| 444 | elemType = elemType.Elem() |
| 445 | if elemType.Kind() == reflect.Struct { |
| 446 | return multiArgTypeStructPtr, elemType |
| 447 | } |
| 448 | } |
| 449 | return multiArgTypeInvalid, nil |
| 450 | } |
| 451 | |
| 452 | func oleInt64(item *ole.IDispatch, prop string) (int64, error) { |
| 453 | v, err := oleutil.GetProperty(item, prop) |
| 454 | if err != nil { |
| 455 | return 0, err |
| 456 | } |
| 457 | defer v.Clear() |
| 458 | |
| 459 | i := int64(v.Val) |
| 460 | return i, nil |
| 461 | } |
| 462 | |
| 463 | // CreateQuery returns a WQL query string that queries all columns of src. where |
| 464 | // is an optional string that is appended to the query, to be used with WHERE |
| 465 | // clauses. In such a case, the "WHERE" string should appear at the beginning. |
| 466 | func CreateQuery(src interface{}, where string) string { |
| 467 | var b bytes.Buffer |
| 468 | b.WriteString("SELECT ") |
| 469 | s := reflect.Indirect(reflect.ValueOf(src)) |
| 470 | t := s.Type() |
| 471 | if s.Kind() == reflect.Slice { |
| 472 | t = t.Elem() |
| 473 | } |
| 474 | if t.Kind() != reflect.Struct { |
| 475 | return "" |
| 476 | } |
| 477 | var fields []string |
| 478 | for i := 0; i < t.NumField(); i++ { |
| 479 | fields = append(fields, t.Field(i).Name) |
| 480 | } |
| 481 | b.WriteString(strings.Join(fields, ", ")) |
| 482 | b.WriteString(" FROM ") |
| 483 | b.WriteString(t.Name()) |
| 484 | b.WriteString(" " + where) |
| 485 | return b.String() |
| 486 | } |