bookdata/cli/cluster_gender/
clusters.rs

1//! Read cluster information.
2use std::collections::HashMap;
3use std::convert::identity;
4use std::fs::File;
5use std::path::Path;
6
7use super::authors::AuthorTable;
8use crate::arrow::scan_parquet_file;
9use crate::gender::*;
10use crate::prelude::*;
11use crate::util::logging::item_progress;
12use anyhow::Result;
13use arrow::array::Int32Array;
14use arrow::compute::kernels;
15use arrow::datatypes::DataType;
16use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
17use parquet::arrow::ProjectionMask;
18use parquet_derive::ParquetRecordReader;
19
20/// Record for storing a cluster's gender statistics while aggregating.
21#[derive(Debug, Default)]
22pub struct ClusterStats {
23    pub n_book_authors: u32,
24    pub n_author_recs: u32,
25    pub genders: GenderBag,
26}
27
28/// Row struct for reading cluster author names.
29#[derive(Debug, ParquetRecordReader)]
30struct ClusterAuthor {
31    cluster: i32,
32    author_name: String,
33}
34
35pub type ClusterTable = HashMap<i32, ClusterStats>;
36
37/// Read cluster author names and resolve them to gender information.
38pub fn read_resolve(path: &Path, authors: &AuthorTable) -> Result<ClusterTable> {
39    let timer = Timer::new();
40    info!("reading cluster authors from {}", path.display());
41    let iter = scan_parquet_file(path)?;
42
43    let pb = item_progress(iter.remaining() as u64, "authors");
44
45    let mut table = ClusterTable::new();
46
47    for row in pb.wrap_iter(iter) {
48        let row: ClusterAuthor = row?;
49        let rec = table.entry(row.cluster).or_default();
50        rec.n_book_authors += 1;
51        if let Some(info) = authors.get(row.author_name.as_str()) {
52            rec.n_author_recs += info.n_author_recs;
53            rec.genders.merge_from(&info.genders);
54        }
55    }
56
57    info!(
58        "scanned genders for {} clusters in {}",
59        table.len(),
60        timer.human_elapsed()
61    );
62
63    Ok(table)
64}
65
66/// Read the full list of cluster IDs.
67pub fn all_clusters<P: AsRef<Path>>(path: P) -> Result<Vec<i32>> {
68    info!("reading cluster IDs from {}", path.as_ref().display());
69    let path = path.as_ref();
70    let file = File::open(&path)?;
71    let builder = ParquetRecordBatchReaderBuilder::try_new(file)?;
72    let meta = builder.metadata().clone();
73    let schema = builder.parquet_schema();
74    let project = ProjectionMask::columns(schema, ["cluster"]);
75
76    let reader = builder.with_projection(project).build()?;
77    let mut ids = Vec::with_capacity(meta.file_metadata().num_rows() as usize);
78
79    let pb = item_progress(meta.file_metadata().num_rows() as usize, "clusters");
80    for rb in reader {
81        let rb = rb?;
82        assert_eq!(rb.schema().field(0).name(), "cluster");
83        let col = rb.column(0);
84        let col = kernels::cast(col, &DataType::Int32)?;
85        let col = col
86            .as_any()
87            .downcast_ref::<Int32Array>()
88            .ok_or_else(|| anyhow!("invalid type for cluster field"))?;
89        ids.extend(col.iter().filter_map(identity));
90        pb.inc(col.len() as u64);
91    }
92    pb.finish();
93
94    info!("found {} cluster IDs", ids.len());
95    Ok(ids)
96}