Question

I am trying to memoize a recursive collatz sequence function in rust, however I need the hashmap of memoized values to keep its contents across separate function calls. Is there an elegant way to do this in rust, or do I have to declare the hashmap in main and pass it to the function each time? I believe the hashmap is being redeclared as an empty map each time I call the function. Here is my code:

fn collatz(n: int) -> int {
    let mut map = HashMap::<int, int>::new();
    if map.contains_key(&n) {return *map.get(&n);}
    if n == 1 { return 0; }
    map.insert(n, 
        match n % 2 {
            0 => { 1 + collatz(n/2) }
            _ => { 1 + collatz(n*3+1) }
        }
    );
    return *map.get(&n);
}

On a side note, why do I need to add all of the &'s and *'s when I am inserting and pulling items out of the HashMap? I just did it because the compiler was complaining and adding them fixed it but I'm not sure why. Can't I just pass by value? Thanks.

Was it helpful?

Solution

You can use thread_local for thread-local statics.

thread_local! (static COLLATZ_MEM: HashMap<i32, i32> = HashMap::new());
fn collatz(n: i32) -> i32 {
    COLLATZ_MEM.with (|collatz_mem| {
        0  // Your code here.
    })
}

P.S. There's also an excellent lazy-static macro which can be used for the truly global static caches. Here's an example.

OTHER TIPS

There are no "static" locals in Rust the way there are in C, no. Maybe make an object, put the hash in it, and make collatz a method of it.

You can't pass by value because that does either a copy (which might be expensive for complex keys) or a move (which would make you unable to use the key again). In this case your keys are just ints, but the API is meant to work for arbitrary types.

Using thread-local! (as suggested in another answer) to solve this was not that straightforward, so I include here a full solution:

use std::cell::RefCell;
use std::collections::HashMap;

fn main() {
    println!("thread-local demo for Collatz:");
    (2..20).for_each(|n| println!("{n}: {c}", n = n, c = collatz(n)));
}

thread_local! (static COLLATZ_CACHE: RefCell<HashMap<usize, usize>> = {
    let mut cache = HashMap::new();
    cache.insert(1, 0);
    RefCell::new(cache)
});

fn collatz(n: usize) -> usize {
    COLLATZ_CACHE.with(|cache| {
        let entry = cache.borrow().get(&n).copied();
        if let Some(v) = entry { v } else {
            let v = match n % 2 {
                0 => 1 + collatz(n / 2),
                1 => 1 + collatz(n * 3 + 1),
                _ => unreachable!(),
            };
            cache.borrow_mut().insert(n, v);
            *cache.borrow().get(&n).unwrap()
        }
    })
}

If thread-local storage is not global enough, then you can use the functionality of the once_cell crate -- which is also on its way into std (already in nightly) -- to initialize a static variable:

#![feature(once_cell)]
use std::collections::HashMap;
use std::lazy::SyncLazy;
use std::sync::Mutex;

fn main() {
    println!("once_cell demo for Collatz:");
    (2..20).for_each(|n| println!("{n}: {c}", n = n, c = collatz(n)));
}

static COLLATZ_CACHE: SyncLazy<Mutex<HashMap<usize, usize>>> = SyncLazy::new(|| {
    let mut cache = HashMap::new();
    cache.insert(1, 0);
    Mutex::new(cache)
});

fn collatz(n: usize) -> usize {
    let cache = &COLLATZ_CACHE;
    let entry = cache.lock().unwrap().get(&n).copied();
    if let Some(v) = entry {
        v
    } else {
        let v = match n % 2 {
            0 => 1 + collatz(n / 2),
            1 => 1 + collatz(n * 3 + 1),
            _ => unreachable!(),
        };
        cache.lock().unwrap().insert(n, v);
        *cache.lock().unwrap().get(&n).unwrap()
    }
}

playground

Licensed under: CC-BY-SA with attribution
Not affiliated with StackOverflow
scroll top