The solution I will eventually go for is an array of ConcurrentHashMaps
instead of one ConcurrentHashMap
. This is ad hoc, but seems to be relevant for my usecase. I don't care about the second step being slow as it doesn't affect my code's performance. The solution is:
Object Creation:
- Create an array of size t of ConcurrentHashMaps, where t is a number of threads.
- Create an array of Runnables, also of size t.
Array Population (single threaded, not an issue):
- Create the keys and apply pre-hash function, which will return an int in the range 0 ... t-1. In my case simply modulo t.
- Put the key in the hashmap, by accessing appropriate entry in the array. E.g. if the pre-hash has resulted in index 4, then go for hashArray[4].put(key)
Array Iteration (nicely multithreaded, performance gain):
- Assign every thread from Runnables array a job of iterating over the hashmap with a corresponding index. This should give give a t times shorter iteration as opposed to single threaded.
To see the proof of concept code (as it's got some dependencies from the project I can't post it here) head towards my project on github
EDIT
Actually, implementing the above proof of concept for my system has proven to be time-consuming, bug-prone and grossly disappointing. Additionally I've discovered I would have missed many features of the standard library ConcurrentHashMap. The solution I have been exploring recently, which looks much less ad-hoc and much more promising is to use Scala, which produces bytecode that is fully interoperable with Java. The proof of concept relies on stunning library described in this paper and AFAIK it is currently IMPOSSIBLE to achieve a corresponding solution in vanilla Java without writing thousands lines of code, given the current state of the standard library and corresponding third-party libraries.
import scala.collection.parallel.mutable.ParHashMap
class Node(value: Int, id: Int){
var v = value
var i = id
override def toString(): String = v toString
}
object testParHashMap{
def visit(entry: Tuple2[Int, Node]){
entry._2.v += 1
}
def main(args: Array[String]){
val hm = new ParHashMap[Int, Node]()
for (i <- 1 to 10){
var node = new Node(0, i)
hm.put(node.i, node)
}
println("========== BEFORE ==========")
hm.foreach{println}
hm.foreach{visit}
println("========== AFTER ==========")
hm.foreach{println}
}
}