Get top rows or random rows within groups in a DataFrame
# Import libraries
import pandas as pd
import seaborn as sns
# Load sample data in a DataFrame
df = (
sns.load_dataset('iris')
.sample(n=20, random_state=20)
.sort_values('species')
.reset_index(drop=True)
[['species', 'sepal_width']]
)
df
|
species |
sepal_width |
0 |
setosa |
3.2 |
1 |
setosa |
3.0 |
2 |
setosa |
3.4 |
3 |
setosa |
3.2 |
4 |
setosa |
3.7 |
5 |
setosa |
3.0 |
6 |
versicolor |
2.7 |
7 |
versicolor |
2.8 |
8 |
versicolor |
2.9 |
9 |
versicolor |
2.8 |
10 |
versicolor |
2.9 |
11 |
versicolor |
2.5 |
12 |
versicolor |
2.4 |
13 |
virginica |
3.2 |
14 |
virginica |
2.5 |
15 |
virginica |
2.8 |
16 |
virginica |
3.0 |
17 |
virginica |
2.2 |
18 |
virginica |
3.0 |
19 |
virginica |
2.8 |
Get top N rows of each group
An option is to sort values, then use groupby()
followed by head()
# Get top 3 rows for each group, sorted by decreasing sepal width
(
df
.sort_values(['species', 'sepal_width'], ascending=[True, False])
.groupby('species')
.head(3)
)
|
species |
sepal_width |
4 |
setosa |
3.7 |
2 |
setosa |
3.4 |
0 |
setosa |
3.2 |
8 |
versicolor |
2.9 |
10 |
versicolor |
2.9 |
7 |
versicolor |
2.8 |
13 |
virginica |
3.2 |
16 |
virginica |
3.0 |
18 |
virginica |
3.0 |
Get random rows within each group
To retrieve n random rows from each group, use sample(n)
after a groupby()
:
# Get 1 random row within each group of the DataFrame
(
df
.groupby('species')
.apply(lambda x: x.sample(1))
)
|
|
species |
sepal_width |
species |
|
|
|
setosa |
1 |
setosa |
3.0 |
versicolor |
10 |
versicolor |
2.9 |
virginica |
13 |
virginica |
3.2 |