How do I share a mutable object between threads using Arc?
Arc
's documentation says:
Shared references in Rust disallow mutation by default, and
Arc
is no exception: you cannot generally obtain a mutable reference to something inside anArc
. If you need to mutate through anArc
, useMutex
,RwLock
, or one of theAtomic
types.
You will likely want a Mutex
combined with an Arc
:
use std::{ sync::{Arc, Mutex}, thread,};struct Stats;impl Stats { fn add_stats(&mut self, _other: &Stats) {}}fn main() { let shared_stats = Arc::new(Mutex::new(Stats)); let threads = 5; for _ in 0..threads { let my_stats = shared_stats.clone(); thread::spawn(move || { let mut shared = my_stats.lock().unwrap(); shared.add_stats(&Stats); }); // Note: Immediately joining, no multithreading happening! // THIS WAS A LIE, see below }}
This is largely cribbed from the Mutex
documentation.
How can I use shared_stats after the for? (I'm talking about the Stats object). It seems that the shared_stats cannot be easily converted to Stats.
As of Rust 1.15, it's possible to get the value back. See my additional answer for another solution as well.
[A comment in the example] says that there is no multithreading. Why?
Because I got confused! :-)
In the example code, the result of thread::spawn
(a JoinHandle
) is immediately dropped because it's not stored anywhere. When the handle is dropped, the thread is detached and may or may not ever finish. I was confusing it with JoinGuard
, a old, removed API that joined when it is dropped. Sorry for the confusion!
For a bit of editorial, I suggest avoiding mutability completely:
use std::{ops::Add, thread};#[derive(Debug)]struct Stats(u64);// Implement addition on our typeimpl Add for Stats { type Output = Stats; fn add(self, other: Stats) -> Stats { Stats(self.0 + other.0) }}fn main() { let threads = 5; // Start threads to do computation let threads: Vec<_> = (0..threads).map(|_| thread::spawn(|| Stats(4))).collect(); // Join all the threads, fail if any of them failed let result: Result<Vec<_>, _> = threads.into_iter().map(|t| t.join()).collect(); let result = result.unwrap(); // Add up all the results let sum = result.into_iter().fold(Stats(0), |i, sum| sum + i); println!("{:?}", sum);}
Here, we keep a reference to the JoinHandle
and then wait for all the threads to finish. We then collect the results and add them all up. This is the common map-reduce pattern. Note that no thread needs any mutability, it all happens in the master thread.