1use std::collections::HashMap;
2use std::marker::PhantomData;
3use std::mem::take;
4use std::path::Path;
5
6use anyhow::{anyhow, Result};
7use log::*;
8use parquet::record::RecordWriter;
9use parquet_derive::ParquetRecordWriter;
10
11use super::{Dedup, Interaction, Key};
12use crate::arrow::*;
13use crate::io::{file_size, ObjectWriter};
14use crate::util::logging::item_progress;
15use crate::util::Timer;
16
17#[derive(ParquetRecordWriter, Debug)]
19pub struct TimestampRatingRecord {
20 pub user: i32,
21 pub item: i32,
22 pub rating: f32,
23 pub last_rating: f32,
24 pub timestamp: i64,
25 pub last_time: i64,
26 pub nratings: i32,
27}
28
29#[derive(ParquetRecordWriter, Debug)]
31pub struct TimelessRatingRecord {
32 pub user: i32,
33 pub item: i32,
34 pub rating: f32,
35 pub nratings: i32,
36}
37
38pub trait FromRatingSet {
40 fn create(user: i32, item: i32, ratings: Vec<(f32, i64)>) -> Self;
41}
42
43impl FromRatingSet for TimestampRatingRecord {
44 fn create(user: i32, item: i32, ratings: Vec<(f32, i64)>) -> Self {
45 let mut vec = ratings;
46 if vec.len() == 1 {
47 let (rating, timestamp) = vec[0];
49 TimestampRatingRecord {
50 user,
51 item,
52 rating,
53 timestamp,
54 last_rating: rating,
55 last_time: timestamp,
56 nratings: 1,
57 }
58 } else {
59 vec.sort_unstable_by_key(|(r, _ts)| (r * 10.0) as i32);
60 let (rating, timestamp) = if vec.len() % 2 == 0 {
61 let mp_up = vec.len() / 2;
62 let (r1, ts1) = vec[mp_up - 1];
64 let (r2, ts2) = vec[mp_up];
65 ((r1 + r2) * 0.5, (ts1 + ts2) / 2)
67 } else {
68 vec[vec.len() / 2]
69 };
70 vec.sort_unstable_by_key(|(_r, ts)| *ts);
71 let (last_rating, last_time) = vec[vec.len() - 1];
72
73 TimestampRatingRecord {
74 user,
75 item,
76 rating,
77 timestamp,
78 last_rating,
79 last_time,
80 nratings: vec.len() as i32,
81 }
82 }
83 }
84}
85
86impl FromRatingSet for TimelessRatingRecord {
87 fn create(user: i32, item: i32, ratings: Vec<(f32, i64)>) -> Self {
88 let mut vec = ratings;
89 if vec.len() == 1 {
90 let (rating, _ts) = vec[0];
92 TimelessRatingRecord {
93 user,
94 item,
95 rating,
96 nratings: 1,
97 }
98 } else {
99 vec.sort_unstable_by_key(|(r, _ts)| (r * 10.0) as i32);
100 let (rating, _ts) = if vec.len() % 2 == 0 {
101 let mp_up = vec.len() / 2;
102 let (r1, ts1) = vec[mp_up - 1];
104 let (r2, ts2) = vec[mp_up];
105 ((r1 + r2) * 0.5, (ts1 + ts2) / 2)
107 } else {
108 vec[vec.len() / 2]
109 };
110
111 TimelessRatingRecord {
112 user,
113 item,
114 rating,
115 nratings: vec.len() as i32,
116 }
117 }
118 }
119}
120
121pub struct RatingDedup<R>
123where
124 R: FromRatingSet,
125 for<'a> &'a [R]: RecordWriter<R>,
126{
127 _phantom: PhantomData<R>,
128 table: HashMap<Key, Vec<(f32, i64)>>,
129}
130
131impl<I: Interaction, R> Dedup<I> for RatingDedup<R>
132where
133 R: FromRatingSet + Send + Sync + 'static,
134 for<'a> &'a [R]: RecordWriter<R>,
135{
136 fn add_interaction(&mut self, act: I) -> Result<()> {
137 let rating = act
138 .get_rating()
139 .ok_or_else(|| anyhow!("rating deduplicator requires ratings"))?;
140 self.record(act.get_user(), act.get_item(), rating, act.get_timestamp());
141 Ok(())
142 }
143
144 fn save(&mut self, path: &Path) -> Result<usize> {
145 self.write_ratings(path)
146 }
147}
148
149impl<R> Default for RatingDedup<R>
150where
151 R: FromRatingSet + Send + Sync + 'static,
152 for<'a> &'a [R]: RecordWriter<R>,
153{
154 fn default() -> RatingDedup<R> {
155 RatingDedup {
156 _phantom: PhantomData,
157 table: HashMap::new(),
158 }
159 }
160}
161
162impl<R> RatingDedup<R>
163where
164 R: FromRatingSet + Send + Sync + 'static,
165 for<'a> &'a [R]: RecordWriter<R>,
166{
167 pub fn record(&mut self, user: i32, item: i32, rating: f32, timestamp: i64) {
169 let k = Key::new(user, item);
170 let vec = self.table.entry(k).or_insert_with(|| Vec::with_capacity(1));
172 vec.push((rating, timestamp));
174 }
175
176 pub fn write_ratings<P: AsRef<Path>>(&mut self, path: P) -> Result<usize> {
178 let path = path.as_ref();
179 info!(
180 "writing {} deduplicated ratings to {}",
181 friendly::scalar(self.table.len()),
182 path.display()
183 );
184 let mut writer = TableWriter::open(path)?;
185
186 let n = self.table.len() as u64;
187 let timer = Timer::new();
188 let pb = item_progress(n, "writing ratings");
189
190 let table = take(&mut self.table);
192 for (k, vec) in pb.wrap_iter(table.into_iter()) {
193 let record = R::create(k.user, k.item, vec);
194 writer.write_object(record)?;
195 }
196
197 let rv = writer.finish()?;
198 pb.finish_and_clear();
199
200 info!(
201 "wrote {} ratings in {}, file is {}",
202 friendly::scalar(n),
203 timer.human_elapsed(),
204 friendly::bytes(file_size(path)?)
205 );
206
207 Ok(rv)
208 }
209}