Zoeken…


Gegevens filteren met een booleaanse array

Wanneer slechts een enkel argument wordt gegeven aan where functie van numpy's, geeft dit de indices van de numpy.nonzero (de condition ) terug die als waar evalueren (hetzelfde gedrag als numpy.nonzero ). Dit kan worden gebruikt om de indices van een array te extraheren die aan een bepaalde voorwaarde voldoen.

import numpy as np

a = np.arange(20).reshape(2,10)
# a = array([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9],
#           [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]])

# Generate boolean array indicating which values in a are both greater than 7 and less than 13
condition = np.bitwise_and(a>7, a<13)
# condition = array([[False, False, False, False, False, False, False, False,  True, True],
#                    [True,  True,  True, False, False, False, False, False, False, False]], dtype=bool)

# Get the indices of a where the condition is True
ind = np.where(condition)
# ind = (array([0, 0, 1, 1, 1]), array([8, 9, 0, 1, 2]))

keep = a[ind]
# keep = [ 8  9 10 11 12]

Als u de indices niet nodig hebt, kan dit in één stap worden bereikt met behulp van extract , waarbij u de condition als het eerste argument kunt opgeven, maar de array de waarden geeft waaruit de voorwaarde waar is als het tweede argument.

# np.extract(condition, array)
keep = np.extract(condition, a)
# keep = [ 8  9 10 11 12]

Twee verdere argumenten x en y kunnen worden gegeven aan where , in welk geval de uitvoer de waarden van x waar de voorwaarde True en de waarden van y waar de voorwaarde False .

# Set elements of a which are NOT greater than 7 and less than 13 to zero, np.where(condition, x, y)
a = np.where(condition, a, a*0)
print(a)
# Out: array([[ 0,  0,  0,  0,  0,  0,  0,  0,  8,  9],
#            [10, 11, 12,  0,  0,  0,  0,  0,  0,  0]])

Direct filteren van indices

Voor eenvoudige gevallen kunt u gegevens rechtstreeks filteren.

a = np.random.normal(size=10)
print(a)
#[-1.19423121  1.10481873  0.26332982 -0.53300387 -0.04809928  1.77107775
# 1.16741359  0.17699948 -0.06342169 -1.74213078]
b = a[a>0]
print(b)
#[ 1.10481873  0.26332982  1.77107775  1.16741359  0.17699948]


Modified text is an extract of the original Stack Overflow Documentation
Licentie onder CC BY-SA 3.0
Niet aangesloten bij Stack Overflow