bookdata/ids/
index.rs

1//! Data structure for mapping string keys to numeric identifiers.
2use std::borrow::Borrow;
3use std::fs::File;
4use std::hash::Hash;
5use std::path::Path;
6use std::sync::Arc;
7
8use arrow::{
9    array::{Int32Array, RecordBatch, StringArray},
10    datatypes::{DataType, Field, Schema},
11};
12use hashbrown::hash_map::{HashMap, Keys};
13use parquet::arrow::{arrow_reader::ParquetRecordBatchReaderBuilder, ArrowWriter, ProjectionMask};
14
15use anyhow::{anyhow, Result};
16use log::*;
17use thiserror::Error;
18
19#[cfg(test)]
20use quickcheck::{Arbitrary, Gen};
21#[cfg(test)]
22use tempfile::tempdir;
23
24use crate::arrow::writer::parquet_writer_defaults;
25
26/// The type of index identifiers.
27pub type Id = i32;
28
29#[derive(Error, Debug)]
30pub enum IndexError {
31    #[error("key not present in frozen index")]
32    KeyNotPresent,
33}
34
35/// Index identifiers from a data type
36pub struct IdIndex<K> {
37    map: HashMap<K, Id>,
38    frozen: bool,
39}
40
41impl<K> IdIndex<K>
42where
43    K: Eq + Hash,
44{
45    /// Create a new index.
46    pub fn new() -> IdIndex<K> {
47        IdIndex {
48            map: HashMap::new(),
49            frozen: false,
50        }
51    }
52
53    /// Freeze the index so no new items can be added.
54    #[allow(dead_code)]
55    pub fn freeze(self) -> IdIndex<K> {
56        IdIndex {
57            map: self.map,
58            frozen: true,
59        }
60    }
61
62    /// Get the index length
63    pub fn len(&self) -> usize {
64        self.map.len()
65    }
66
67    /// Get the ID for a key, adding it to the index if needed.
68    pub fn intern<Q>(&mut self, key: &Q) -> Result<Id, IndexError>
69    where
70        K: Borrow<Q>,
71        Q: Hash + Eq + ToOwned<Owned = K> + ?Sized,
72    {
73        let n = self.map.len() as Id;
74        if self.frozen {
75            self.lookup(key).ok_or(IndexError::KeyNotPresent)
76        } else {
77            // use Hashbrown's raw-entry API to minimize cloning
78            let eb = self.map.raw_entry_mut();
79            let e = eb.from_key(key);
80            let (_, v) = e.or_insert_with(|| (key.to_owned(), n + 1));
81            Ok(*v)
82        }
83    }
84
85    /// Get the ID for a key, adding it to the index if needed and transferring ownership.
86    pub fn intern_owned(&mut self, key: K) -> Result<Id, IndexError> {
87        let n = self.map.len() as Id;
88        if self.frozen {
89            self.lookup(&key).ok_or(IndexError::KeyNotPresent)
90        } else {
91            Ok(*self.map.entry(key).or_insert(n + 1))
92        }
93    }
94
95    /// Look up the ID for a key if it is present.
96    #[allow(dead_code)]
97    pub fn lookup<Q>(&self, key: &Q) -> Option<Id>
98    where
99        K: Borrow<Q>,
100        Q: Hash + Eq + ?Sized,
101    {
102        self.map.get(key).map(|i| *i)
103    }
104
105    /// Iterate over keys (see [std::collections::HashMap::keys]).
106    #[allow(dead_code)]
107    pub fn keys(&self) -> Keys<'_, K, Id> {
108        self.map.keys()
109    }
110}
111
112impl IdIndex<String> {
113    /// Get the keys in order.
114    pub fn key_vec(&self) -> Vec<&str> {
115        let mut vec = Vec::with_capacity(self.len());
116        vec.resize(self.len(), None);
117        for (k, n) in self.map.iter() {
118            let i = (n - 1) as usize;
119            assert!(vec[i].is_none());
120            vec[i] = Some(k);
121        }
122
123        let vec = vec.iter().map(|ro| ro.unwrap().as_str()).collect();
124        vec
125    }
126
127    /// Convert this ID index into an Arrow RecordBatch, with columns for ID and key.
128    pub fn record_batch(&self, id_col: &str, key_col: &str) -> Result<RecordBatch> {
129        debug!("preparing data frame for index");
130        let n = self.map.len() as i32;
131        let ids: Vec<_> = (0..n).collect();
132        let ids = Int32Array::from(ids);
133        let keys = StringArray::from(self.key_vec());
134
135        let schema = Schema::new(vec![
136            Field::new(id_col, DataType::Int32, false),
137            Field::new(key_col, DataType::Utf8, false),
138        ]);
139
140        let rb = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(ids), Arc::new(keys)])?;
141        Ok(rb)
142    }
143
144    /// Load from a Parquet file, with a standard configuration.
145    ///
146    /// This assumes the Parquet file has the following columns:
147    ///
148    /// - `key`, of type `String`, storing the keys
149    /// - `id`, of type `i32`, storing the IDs
150    #[cfg(test)]
151    pub fn load_standard<P: AsRef<Path>>(path: P) -> Result<IdIndex<String>> {
152        IdIndex::load(path, "id", "key")
153    }
154
155    /// Load from a Parquet file.
156    ///
157    /// This loads two columns from a Parquet file.  The ID column is expected to
158    /// have type `UInt32` (or a type projectable to it), and the key column should
159    /// be `Utf8`.
160    pub fn load<P: AsRef<Path>>(path: P, id_col: &str, key_col: &str) -> Result<IdIndex<String>> {
161        let path_str = path.as_ref().to_string_lossy();
162        info!("reading index from file {}", path_str);
163        let file = File::open(path.as_ref())?;
164        let builder = ParquetRecordBatchReaderBuilder::try_new(file)?;
165        let project = ProjectionMask::columns(builder.parquet_schema(), [id_col, key_col]);
166        let reader = builder.with_projection(project).build()?;
167
168        let mut map = HashMap::new();
169
170        debug!("reading file contents");
171        for batch in reader {
172            let batch = batch?;
173            assert_eq!(batch.schema().field(0).name(), id_col);
174            assert_eq!(batch.schema().field(1).name(), key_col);
175            let ic = batch
176                .column(0)
177                .as_any()
178                .downcast_ref::<Int32Array>()
179                .ok_or_else(|| {
180                    anyhow!(
181                        "invalid id column type {}",
182                        batch.schema().field(0).data_type()
183                    )
184                })?;
185            let kc = batch
186                .column(1)
187                .as_any()
188                .downcast_ref::<StringArray>()
189                .ok_or_else(|| {
190                    anyhow!(
191                        "invalid id column type {}",
192                        batch.schema().field(0).data_type()
193                    )
194                })?;
195            for pair in ic.into_iter().zip(kc.into_iter()) {
196                if let (Some(id), Some(key)) = pair {
197                    map.insert(key.to_string(), id);
198                }
199            }
200        }
201
202        info!("read {} keys from {}", map.len(), path_str);
203
204        Ok(IdIndex { map, frozen: false })
205    }
206
207    /// Save to a Parquet file with the standard configuration.
208    #[cfg(test)]
209    pub fn save_standard<P: AsRef<Path>>(&self, path: P) -> Result<()> {
210        self.save(path, "id", "key")
211    }
212
213    /// Save to a Parquet file with the standard configuration.
214    pub fn save<P: AsRef<Path>>(&self, path: P, id_col: &str, key_col: &str) -> Result<()> {
215        let frame = self.record_batch(id_col, key_col)?;
216
217        let path = path.as_ref();
218        info!("saving index to {:?}", path);
219        let schema = Schema::new(vec![
220            Field::new(id_col, DataType::Int32, false),
221            Field::new(key_col, DataType::Utf8, false),
222        ]);
223        let schema = Arc::new(schema);
224        let file = File::create(path)?;
225        let props = parquet_writer_defaults().build();
226        let mut writer = ArrowWriter::try_new(file, schema, Some(props))?;
227        writer.write(&frame)?;
228        writer.finish()?;
229
230        Ok(())
231    }
232}
233
234#[test]
235fn test_index_empty() {
236    let index: IdIndex<String> = IdIndex::new();
237    assert_eq!(index.len(), 0);
238    assert!(index.lookup("bob").is_none());
239}
240
241#[test]
242fn test_index_intern_one() {
243    let mut index: IdIndex<String> = IdIndex::new();
244    assert!(index.lookup("hackem muche").is_none());
245    let id = index.intern("hackem muche").expect("intern failure");
246    assert_eq!(id, 1);
247    assert_eq!(index.lookup("hackem muche").unwrap(), 1);
248}
249
250#[test]
251fn test_index_intern_two() {
252    let mut index: IdIndex<String> = IdIndex::new();
253    assert!(index.lookup("hackem muche").is_none());
254    let id = index.intern("hackem muche");
255    assert_eq!(id.expect("intern failure"), 1);
256    let id2 = index.intern("readme");
257    assert_eq!(id2.expect("intern failure"), 2);
258    assert_eq!(index.lookup("hackem muche").unwrap(), 1);
259}
260
261#[test]
262fn test_index_intern_twice() {
263    let mut index: IdIndex<String> = IdIndex::new();
264    assert!(index.lookup("hackem muche").is_none());
265    let id = index.intern("hackem muche");
266    assert_eq!(id.expect("intern failure"), 1);
267    let id2 = index.intern("hackem muche");
268    assert_eq!(id2.expect("intern failure"), 1);
269    assert_eq!(index.len(), 1);
270}
271
272#[test]
273fn test_index_intern_twice_owned() {
274    let mut index: IdIndex<String> = IdIndex::new();
275    assert!(index.lookup("hackem muche").is_none());
276    let id = index.intern_owned("hackem muche".to_owned());
277    assert!(id.is_ok());
278    assert_eq!(id.expect("intern failure"), 1);
279    let id2 = index.intern_owned("hackem muche".to_owned());
280    assert!(id2.is_ok());
281    assert_eq!(id2.expect("intern failure"), 1);
282    assert_eq!(index.len(), 1);
283}
284
285#[cfg(test)]
286#[test_log::test]
287fn test_index_save() -> Result<()> {
288    let mut index: IdIndex<String> = IdIndex::new();
289    let mut gen = Gen::new(100);
290    for _i in 0..10000 {
291        let key = String::arbitrary(&mut gen);
292        let prev = index.lookup(&key);
293        let id = index.intern(&key).expect("intern failure");
294        if let Some(i) = prev {
295            assert_eq!(id, i)
296        } else {
297            assert_eq!(id as usize, index.len())
298        }
299    }
300
301    let dir = tempdir()?;
302    let pq = dir.path().join("index.parquet");
303    index.save_standard(&pq).expect("save error");
304
305    let i2 = IdIndex::load_standard(&pq).expect("load error");
306    assert_eq!(i2.len(), index.len());
307    for (k, v) in &index.map {
308        let v2 = i2.lookup(k);
309        assert!(v2.is_some());
310        assert_eq!(v2.unwrap(), *v);
311    }
312
313    Ok(())
314}
315
316#[test]
317fn test_index_freeze() {
318    let mut index: IdIndex<String> = IdIndex::new();
319    assert!(index.lookup("hackem muche").is_none());
320    let id = index.intern("hackem muche");
321    assert!(id.is_ok());
322    assert_eq!(id.expect("intern failure"), 1);
323
324    let mut index = index.freeze();
325
326    let id = index.intern("hackem muche");
327    assert!(id.is_ok());
328    assert_eq!(id.expect("intern failure"), 1);
329
330    let id2 = index.intern("foobie bletch");
331    assert!(id2.is_err());
332}