bookdata/cli/cluster_gender/
clusters.rs1use std::collections::HashMap;
3use std::convert::identity;
4use std::fs::File;
5use std::path::Path;
6
7use super::authors::AuthorTable;
8use crate::arrow::scan_parquet_file;
9use crate::gender::*;
10use crate::prelude::*;
11use crate::util::logging::item_progress;
12use anyhow::Result;
13use arrow::array::Int32Array;
14use arrow::compute::kernels;
15use arrow::datatypes::DataType;
16use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
17use parquet::arrow::ProjectionMask;
18use parquet_derive::ParquetRecordReader;
19
20#[derive(Debug, Default)]
22pub struct ClusterStats {
23 pub n_book_authors: u32,
24 pub n_author_recs: u32,
25 pub genders: GenderBag,
26}
27
28#[derive(Debug, ParquetRecordReader)]
30struct ClusterAuthor {
31 cluster: i32,
32 author_name: String,
33}
34
35pub type ClusterTable = HashMap<i32, ClusterStats>;
36
37pub fn read_resolve(path: &Path, authors: &AuthorTable) -> Result<ClusterTable> {
39 let timer = Timer::new();
40 info!("reading cluster authors from {}", path.display());
41 let iter = scan_parquet_file(path)?;
42
43 let pb = item_progress(iter.remaining() as u64, "authors");
44
45 let mut table = ClusterTable::new();
46
47 for row in pb.wrap_iter(iter) {
48 let row: ClusterAuthor = row?;
49 let rec = table.entry(row.cluster).or_default();
50 rec.n_book_authors += 1;
51 if let Some(info) = authors.get(row.author_name.as_str()) {
52 rec.n_author_recs += info.n_author_recs;
53 rec.genders.merge_from(&info.genders);
54 }
55 }
56
57 info!(
58 "scanned genders for {} clusters in {}",
59 table.len(),
60 timer.human_elapsed()
61 );
62
63 Ok(table)
64}
65
66pub fn all_clusters<P: AsRef<Path>>(path: P) -> Result<Vec<i32>> {
68 info!("reading cluster IDs from {}", path.as_ref().display());
69 let path = path.as_ref();
70 let file = File::open(&path)?;
71 let builder = ParquetRecordBatchReaderBuilder::try_new(file)?;
72 let meta = builder.metadata().clone();
73 let schema = builder.parquet_schema();
74 let project = ProjectionMask::columns(schema, ["cluster"]);
75
76 let reader = builder.with_projection(project).build()?;
77 let mut ids = Vec::with_capacity(meta.file_metadata().num_rows() as usize);
78
79 let pb = item_progress(meta.file_metadata().num_rows() as usize, "clusters");
80 for rb in reader {
81 let rb = rb?;
82 assert_eq!(rb.schema().field(0).name(), "cluster");
83 let col = rb.column(0);
84 let col = kernels::cast(col, &DataType::Int32)?;
85 let col = col
86 .as_any()
87 .downcast_ref::<Int32Array>()
88 .ok_or_else(|| anyhow!("invalid type for cluster field"))?;
89 ids.extend(col.iter().filter_map(identity));
90 pb.inc(col.len() as u64);
91 }
92 pb.finish();
93
94 info!("found {} cluster IDs", ids.len());
95 Ok(ids)
96}