bookdata/cli/cluster/author_gender/
clusters.rs

1//! Read cluster information.
2use std::collections::HashMap;
3use std::convert::identity;
4use std::path::Path;
5
6use super::authors::AuthorTable;
7use crate::arrow::scan_parquet_file;
8use crate::gender::*;
9use crate::prelude::*;
10use crate::util::logging::item_progress;
11use anyhow::Result;
12use parquet_derive::ParquetRecordReader;
13use polars::prelude::*;
14
15/// Record for storing a cluster's gender statistics while aggregating.
16#[derive(Debug, Default)]
17pub struct ClusterStats {
18    pub n_book_authors: u32,
19    pub n_author_recs: u32,
20    pub genders: GenderBag,
21}
22
23/// Row struct for reading cluster author names.
24#[derive(Debug, ParquetRecordReader)]
25struct ClusterAuthor {
26    cluster: i32,
27    author_name: String,
28}
29
30pub type ClusterTable = HashMap<i32, ClusterStats>;
31
32/// Read cluster author names and resolve them to gender information.
33pub fn read_resolve(path: &Path, authors: &AuthorTable) -> Result<ClusterTable> {
34    let timer = Timer::new();
35    info!("reading cluster authors from {}", path.display());
36    let iter = scan_parquet_file(path)?;
37
38    let pb = item_progress(iter.remaining() as u64, "authors");
39
40    let mut table = ClusterTable::new();
41
42    for row in pb.wrap_iter(iter) {
43        let row: ClusterAuthor = row?;
44        let rec = table.entry(row.cluster).or_default();
45        rec.n_book_authors += 1;
46        if let Some(info) = authors.get(row.author_name.as_str()) {
47            rec.n_author_recs += info.n_author_recs;
48            rec.genders.merge_from(&info.genders);
49        }
50    }
51
52    info!(
53        "scanned genders for {} clusters in {}",
54        table.len(),
55        timer.human_elapsed()
56    );
57
58    Ok(table)
59}
60
61/// Read the full list of cluster IDs.
62pub fn all_clusters<P: AsRef<Path>>(path: P) -> Result<Vec<i32>> {
63    info!("reading cluster IDs from {}", path.as_ref().display());
64    let path = path
65        .as_ref()
66        .to_str()
67        .map(|s| s.to_string())
68        .ok_or(anyhow!("invalid unicode path"))?;
69    let df = LazyFrame::scan_parquet(path, Default::default())?;
70    let df = df.select([col("cluster")]);
71    let clusters = df.collect()?;
72    let ids = clusters.column("cluster")?.i32()?;
73
74    info!("found {} cluster IDs", ids.len());
75
76    Ok(ids.into_iter().filter_map(identity).collect())
77}