bookdata/cli/
kcore.rs

1use std::fs::File;
2
3use chrono::NaiveDate;
4use clap::Args;
5
6use crate::prelude::*;
7use polars::prelude::*;
8
9/// Compute k-cores of interaction records.
10#[derive(Debug, Clone, Args)]
11#[command(name = "kcore")]
12pub struct Kcore {
13    /// The size of the k-core.
14    #[arg(short = 'k', long = "k", default_value_t = 5)]
15    k: u32,
16
17    /// The user rating count for a (ku,ki)-core
18    #[arg(short = 'U', long = "user-k")]
19    user_k: Option<u32>,
20
21    /// The item rating count for a (ku,ki)-core
22    #[arg(short = 'I', long = "item-k")]
23    item_k: Option<u32>,
24
25    /// Limit to ratings in a particular year.
26    #[arg(long = "year")]
27    year: Option<i32>,
28
29    /// Limit ratings to after a particular date (inclusive)
30    #[arg(long = "start-date")]
31    start: Option<NaiveDate>,
32
33    /// Limit ratings to before a particular date (exclusive)
34    #[arg(long = "end-date")]
35    end: Option<NaiveDate>,
36
37    /// The output file.
38    #[arg(short = 'o', long = "output", name = "FILE")]
39    output: PathBuf,
40
41    /// The input file
42    #[arg(name = "INPUT")]
43    input: PathBuf,
44}
45
46impl Command for Kcore {
47    fn exec(&self) -> Result<()> {
48        let uk = self.user_k.unwrap_or(self.k);
49        let ik = self.item_k.unwrap_or(self.k);
50        info!(
51            "computing ({},{})-core for {}",
52            uk,
53            ik,
54            self.input.display()
55        );
56
57        let file = File::open(&self.input)?;
58        let mut actions = ParquetReader::new(file).finish()?;
59        info!("loaded {} actions", friendly::scalar(actions.height()));
60
61        let start = self
62            .start
63            .or_else(|| self.year.map(|y| NaiveDate::from_ymd_opt(y, 1, 1).unwrap()));
64        let end = self.end.or_else(|| {
65            self.year
66                .map(|y| NaiveDate::from_ymd_opt(y + 1, 1, 1).unwrap())
67        });
68
69        if let Some(start) = start {
70            info!("removing actions before {}", start);
71            let start = start.and_hms_opt(0, 0, 0).unwrap().timestamp();
72            // currently hard-coded for goodreads
73            let col = actions.column("first_time")?;
74            let mask = col.gt_eq(start)?;
75            actions = actions.filter(&mask)?;
76            info!("filtered to {} actions", friendly::scalar(actions.height()));
77        }
78        if let Some(end) = end {
79            info!("removing actions after {}", end);
80            let end = end.and_hms_opt(0, 0, 0).unwrap().timestamp();
81            // currently hard-coded for goodreads
82            let col = actions.column("first_time")?;
83            let mask = col.lt(end)?;
84            actions = actions.filter(&mask)?;
85            info!("filtered to {} actions", friendly::scalar(actions.height()));
86        }
87
88        let n_initial = actions.height();
89        let mut n_last = 0;
90        let mut iters = 0;
91        // we proceed iteratively, alternating filtering users and items
92        // stop when a pass has left it unchanged
93        while actions.height() != n_last {
94            n_last = actions.height();
95            info!(
96                "pass {}: checking items of {} actions",
97                iters + 1,
98                friendly::scalar(actions.height())
99            );
100            actions = filter_counts(actions, "item_id", ik)?;
101
102            info!(
103                "pass {}: checking users of {} actions",
104                iters + 1,
105                friendly::scalar(actions.height())
106            );
107            actions = filter_counts(actions, "user_id", ik)?;
108
109            iters += 1;
110        }
111        info!(
112            "finished computing {}-core with {} of {} actions (imin: {}, umin: {})",
113            self.k,
114            friendly::scalar(actions.height()),
115            friendly::scalar(n_initial),
116            // re-compute this in case it changed
117            actions
118                .column("item_id")?
119                .value_counts(true, true)?
120                .column("count")?
121                .min::<u32>()?
122                .unwrap(),
123            actions
124                .column("user_id")?
125                .value_counts(true, true)?
126                .column("count")?
127                .min::<u32>()?
128                .unwrap(),
129        );
130
131        save_df_parquet(actions, &self.output)?;
132
133        Ok(())
134    }
135}
136
137fn filter_counts(actions: DataFrame, column: &'static str, k: u32) -> Result<DataFrame> {
138    let nstart = actions.height();
139    let counts = actions.column(column)?.value_counts(true, true)?;
140    debug!("value count schema: {:?}", counts.schema());
141    let min_count: u32 = counts
142        .column("count")?
143        .min()?
144        .ok_or_else(|| anyhow!("data frame is empty"))?;
145    if min_count < k {
146        info!("filtering {}s (smallest count: {})", column, min_count);
147        let ifilt = counts
148            .lazy()
149            .filter(col("count").gt_eq(lit(k)))
150            .select(&[col(column)]);
151        let afilt = actions.lazy().inner_join(ifilt, column, column);
152        let actions = afilt.collect()?;
153        info!(
154            "now have {} actions (removed {})",
155            friendly::scalar(actions.height()),
156            nstart - actions.height()
157        );
158        Ok(actions)
159    } else {
160        Ok(actions)
161    }
162}