Buscar..


Filtrado de datos con una matriz booleana

Cuando solo se proporciona un único argumento a numpy where funciona, devuelve los índices de la matriz de entrada (la condition ) que se evalúan como verdaderas (el mismo comportamiento que numpy.nonzero ). Esto se puede usar para extraer los índices de una matriz que satisfacen una condición dada.

import numpy as np

a = np.arange(20).reshape(2,10)
# a = array([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9],
#           [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]])

# Generate boolean array indicating which values in a are both greater than 7 and less than 13
condition = np.bitwise_and(a>7, a<13)
# condition = array([[False, False, False, False, False, False, False, False,  True, True],
#                    [True,  True,  True, False, False, False, False, False, False, False]], dtype=bool)

# Get the indices of a where the condition is True
ind = np.where(condition)
# ind = (array([0, 0, 1, 1, 1]), array([8, 9, 0, 1, 2]))

keep = a[ind]
# keep = [ 8  9 10 11 12]

Si no necesita los índices, esto se puede lograr en un solo paso usando extract , donde nuevamente especifique la condition como primer argumento, pero le da a la array que devuelva los valores desde donde la condición es verdadera como segundo argumento.

# np.extract(condition, array)
keep = np.extract(condition, a)
# keep = [ 8  9 10 11 12]

Se pueden proporcionar dos argumentos adicionales, x e y a where , en cuyo caso, la salida contendrá los valores de x donde la condición es True y los valores de y donde la condición es False .

# Set elements of a which are NOT greater than 7 and less than 13 to zero, np.where(condition, x, y)
a = np.where(condition, a, a*0)
print(a)
# Out: array([[ 0,  0,  0,  0,  0,  0,  0,  0,  8,  9],
#            [10, 11, 12,  0,  0,  0,  0,  0,  0,  0]])

Índices de filtrado directo.

Para casos simples, puede filtrar datos directamente.

a = np.random.normal(size=10)
print(a)
#[-1.19423121  1.10481873  0.26332982 -0.53300387 -0.04809928  1.77107775
# 1.16741359  0.17699948 -0.06342169 -1.74213078]
b = a[a>0]
print(b)
#[ 1.10481873  0.26332982  1.77107775  1.16741359  0.17699948]


Modified text is an extract of the original Stack Overflow Documentation
Licenciado bajo CC BY-SA 3.0
No afiliado a Stack Overflow