1use std::borrow::Borrow;
3use std::fs::File;
4use std::hash::Hash;
5use std::path::Path;
6use std::sync::Arc;
7
8use arrow::{
9 array::{Int32Array, RecordBatch, StringArray},
10 datatypes::{DataType, Field, Schema},
11};
12use hashbrown::hash_map::{HashMap, Keys};
13use parquet::arrow::{arrow_reader::ParquetRecordBatchReaderBuilder, ArrowWriter, ProjectionMask};
14
15use anyhow::{anyhow, Result};
16use log::*;
17use thiserror::Error;
18
19#[cfg(test)]
20use quickcheck::{Arbitrary, Gen};
21#[cfg(test)]
22use tempfile::tempdir;
23
24use crate::arrow::writer::parquet_writer_defaults;
25
26pub type Id = i32;
28
29#[derive(Error, Debug)]
30pub enum IndexError {
31 #[error("key not present in frozen index")]
32 KeyNotPresent,
33}
34
35pub struct IdIndex<K> {
37 map: HashMap<K, Id>,
38 frozen: bool,
39}
40
41impl<K> IdIndex<K>
42where
43 K: Eq + Hash,
44{
45 pub fn new() -> IdIndex<K> {
47 IdIndex {
48 map: HashMap::new(),
49 frozen: false,
50 }
51 }
52
53 #[allow(dead_code)]
55 pub fn freeze(self) -> IdIndex<K> {
56 IdIndex {
57 map: self.map,
58 frozen: true,
59 }
60 }
61
62 pub fn len(&self) -> usize {
64 self.map.len()
65 }
66
67 pub fn intern<Q>(&mut self, key: &Q) -> Result<Id, IndexError>
69 where
70 K: Borrow<Q>,
71 Q: Hash + Eq + ToOwned<Owned = K> + ?Sized,
72 {
73 let n = self.map.len() as Id;
74 if self.frozen {
75 self.lookup(key).ok_or(IndexError::KeyNotPresent)
76 } else {
77 let eb = self.map.raw_entry_mut();
79 let e = eb.from_key(key);
80 let (_, v) = e.or_insert_with(|| (key.to_owned(), n + 1));
81 Ok(*v)
82 }
83 }
84
85 pub fn intern_owned(&mut self, key: K) -> Result<Id, IndexError> {
87 let n = self.map.len() as Id;
88 if self.frozen {
89 self.lookup(&key).ok_or(IndexError::KeyNotPresent)
90 } else {
91 Ok(*self.map.entry(key).or_insert(n + 1))
92 }
93 }
94
95 #[allow(dead_code)]
97 pub fn lookup<Q>(&self, key: &Q) -> Option<Id>
98 where
99 K: Borrow<Q>,
100 Q: Hash + Eq + ?Sized,
101 {
102 self.map.get(key).map(|i| *i)
103 }
104
105 #[allow(dead_code)]
107 pub fn keys(&self) -> Keys<'_, K, Id> {
108 self.map.keys()
109 }
110}
111
112impl IdIndex<String> {
113 pub fn key_vec(&self) -> Vec<&str> {
115 let mut vec = Vec::with_capacity(self.len());
116 vec.resize(self.len(), None);
117 for (k, n) in self.map.iter() {
118 let i = (n - 1) as usize;
119 assert!(vec[i].is_none());
120 vec[i] = Some(k);
121 }
122
123 let vec = vec.iter().map(|ro| ro.unwrap().as_str()).collect();
124 vec
125 }
126
127 pub fn record_batch(&self, id_col: &str, key_col: &str) -> Result<RecordBatch> {
129 debug!("preparing data frame for index");
130 let n = self.map.len() as i32;
131 let ids: Vec<_> = (0..n).collect();
132 let ids = Int32Array::from(ids);
133 let keys = StringArray::from(self.key_vec());
134
135 let schema = Schema::new(vec![
136 Field::new(id_col, DataType::Int32, false),
137 Field::new(key_col, DataType::Utf8, false),
138 ]);
139
140 let rb = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(ids), Arc::new(keys)])?;
141 Ok(rb)
142 }
143
144 #[cfg(test)]
151 pub fn load_standard<P: AsRef<Path>>(path: P) -> Result<IdIndex<String>> {
152 IdIndex::load(path, "id", "key")
153 }
154
155 pub fn load<P: AsRef<Path>>(path: P, id_col: &str, key_col: &str) -> Result<IdIndex<String>> {
161 let path_str = path.as_ref().to_string_lossy();
162 info!("reading index from file {}", path_str);
163 let file = File::open(path.as_ref())?;
164 let builder = ParquetRecordBatchReaderBuilder::try_new(file)?;
165 let project = ProjectionMask::columns(builder.parquet_schema(), [id_col, key_col]);
166 let reader = builder.with_projection(project).build()?;
167
168 let mut map = HashMap::new();
169
170 debug!("reading file contents");
171 for batch in reader {
172 let batch = batch?;
173 assert_eq!(batch.schema().field(0).name(), id_col);
174 assert_eq!(batch.schema().field(1).name(), key_col);
175 let ic = batch
176 .column(0)
177 .as_any()
178 .downcast_ref::<Int32Array>()
179 .ok_or_else(|| {
180 anyhow!(
181 "invalid id column type {}",
182 batch.schema().field(0).data_type()
183 )
184 })?;
185 let kc = batch
186 .column(1)
187 .as_any()
188 .downcast_ref::<StringArray>()
189 .ok_or_else(|| {
190 anyhow!(
191 "invalid id column type {}",
192 batch.schema().field(0).data_type()
193 )
194 })?;
195 for pair in ic.into_iter().zip(kc.into_iter()) {
196 if let (Some(id), Some(key)) = pair {
197 map.insert(key.to_string(), id);
198 }
199 }
200 }
201
202 info!("read {} keys from {}", map.len(), path_str);
203
204 Ok(IdIndex { map, frozen: false })
205 }
206
207 #[cfg(test)]
209 pub fn save_standard<P: AsRef<Path>>(&self, path: P) -> Result<()> {
210 self.save(path, "id", "key")
211 }
212
213 pub fn save<P: AsRef<Path>>(&self, path: P, id_col: &str, key_col: &str) -> Result<()> {
215 let frame = self.record_batch(id_col, key_col)?;
216
217 let path = path.as_ref();
218 info!("saving index to {:?}", path);
219 let schema = Schema::new(vec![
220 Field::new(id_col, DataType::Int32, false),
221 Field::new(key_col, DataType::Utf8, false),
222 ]);
223 let schema = Arc::new(schema);
224 let file = File::create(path)?;
225 let props = parquet_writer_defaults().build();
226 let mut writer = ArrowWriter::try_new(file, schema, Some(props))?;
227 writer.write(&frame)?;
228 writer.finish()?;
229
230 Ok(())
231 }
232}
233
234#[test]
235fn test_index_empty() {
236 let index: IdIndex<String> = IdIndex::new();
237 assert_eq!(index.len(), 0);
238 assert!(index.lookup("bob").is_none());
239}
240
241#[test]
242fn test_index_intern_one() {
243 let mut index: IdIndex<String> = IdIndex::new();
244 assert!(index.lookup("hackem muche").is_none());
245 let id = index.intern("hackem muche").expect("intern failure");
246 assert_eq!(id, 1);
247 assert_eq!(index.lookup("hackem muche").unwrap(), 1);
248}
249
250#[test]
251fn test_index_intern_two() {
252 let mut index: IdIndex<String> = IdIndex::new();
253 assert!(index.lookup("hackem muche").is_none());
254 let id = index.intern("hackem muche");
255 assert_eq!(id.expect("intern failure"), 1);
256 let id2 = index.intern("readme");
257 assert_eq!(id2.expect("intern failure"), 2);
258 assert_eq!(index.lookup("hackem muche").unwrap(), 1);
259}
260
261#[test]
262fn test_index_intern_twice() {
263 let mut index: IdIndex<String> = IdIndex::new();
264 assert!(index.lookup("hackem muche").is_none());
265 let id = index.intern("hackem muche");
266 assert_eq!(id.expect("intern failure"), 1);
267 let id2 = index.intern("hackem muche");
268 assert_eq!(id2.expect("intern failure"), 1);
269 assert_eq!(index.len(), 1);
270}
271
272#[test]
273fn test_index_intern_twice_owned() {
274 let mut index: IdIndex<String> = IdIndex::new();
275 assert!(index.lookup("hackem muche").is_none());
276 let id = index.intern_owned("hackem muche".to_owned());
277 assert!(id.is_ok());
278 assert_eq!(id.expect("intern failure"), 1);
279 let id2 = index.intern_owned("hackem muche".to_owned());
280 assert!(id2.is_ok());
281 assert_eq!(id2.expect("intern failure"), 1);
282 assert_eq!(index.len(), 1);
283}
284
285#[cfg(test)]
286#[test_log::test]
287fn test_index_save() -> Result<()> {
288 let mut index: IdIndex<String> = IdIndex::new();
289 let mut gen = Gen::new(100);
290 for _i in 0..10000 {
291 let key = String::arbitrary(&mut gen);
292 let prev = index.lookup(&key);
293 let id = index.intern(&key).expect("intern failure");
294 if let Some(i) = prev {
295 assert_eq!(id, i)
296 } else {
297 assert_eq!(id as usize, index.len())
298 }
299 }
300
301 let dir = tempdir()?;
302 let pq = dir.path().join("index.parquet");
303 index.save_standard(&pq).expect("save error");
304
305 let i2 = IdIndex::load_standard(&pq).expect("load error");
306 assert_eq!(i2.len(), index.len());
307 for (k, v) in &index.map {
308 let v2 = i2.lookup(k);
309 assert!(v2.is_some());
310 assert_eq!(v2.unwrap(), *v);
311 }
312
313 Ok(())
314}
315
316#[test]
317fn test_index_freeze() {
318 let mut index: IdIndex<String> = IdIndex::new();
319 assert!(index.lookup("hackem muche").is_none());
320 let id = index.intern("hackem muche");
321 assert!(id.is_ok());
322 assert_eq!(id.expect("intern failure"), 1);
323
324 let mut index = index.freeze();
325
326 let id = index.intern("hackem muche");
327 assert!(id.is_ok());
328 assert_eq!(id.expect("intern failure"), 1);
329
330 let id2 = index.intern("foobie bletch");
331 assert!(id2.is_err());
332}