Merge lists that share common elements
You can see your list as a notation for a Graph, ie ['a','b','c']
is a graph with 3 nodes connected to each other. The problem you are trying to solve is finding connected components in this graph.
You can use NetworkX for this, which has the advantage that it's pretty much guaranteed to be correct:
l = [['a','b','c'],['b','d','e'],['k'],['o','p'],['e','f'],['p','a'],['d','g']]import networkx from networkx.algorithms.components.connected import connected_componentsdef to_graph(l): G = networkx.Graph() for part in l: # each sublist is a bunch of nodes G.add_nodes_from(part) # it also imlies a number of edges: G.add_edges_from(to_edges(part)) return Gdef to_edges(l): """ treat `l` as a Graph and returns it's edges to_edges(['a','b','c','d']) -> [(a,b), (b,c),(c,d)] """ it = iter(l) last = next(it) for current in it: yield last, current last = current G = to_graph(l)print connected_components(G)# prints [['a', 'c', 'b', 'e', 'd', 'g', 'f', 'o', 'p'], ['k']]
To solve this efficiently yourself you have to convert the list into something graph-ish anyways, so you might as well use networkX from the start.
Algorithm:
- take first set A from list
- for each other set B in the list do if B has common element(s) with A join B into A; remove B from list
- repeat 2. until no more overlap with A
- put A into outpup
- repeat 1. with rest of list
So you might want to use sets instead of list. The following program should do it.
l = [['a', 'b', 'c'], ['b', 'd', 'e'], ['k'], ['o', 'p'], ['e', 'f'], ['p', 'a'], ['d', 'g']]out = []while len(l)>0: first, *rest = l first = set(first) lf = -1 while len(first)>lf: lf = len(first) rest2 = [] for r in rest: if len(first.intersection(set(r)))>0: first |= set(r) else: rest2.append(r) rest = rest2 out.append(first) l = restprint(out)
I needed to perform the clustering technique described by the OP millions of times for rather large lists, and therefore wanted to determine which of the methods suggested above is both most accurate and most performant.
I ran 10 trials for input lists sized from 2^1 through 2^10 for each method above, using the same input list for each method, and measured the average runtime for each algorithm proposed above in milliseconds. Here are the results:
These results helped me see that of the methods that consistently return correct results, @jochen's is the fastest. Among those methods that don't consistently return correct results, mak's solution often does not include all of the input elements (i.e. list of list members are missing), and the solutions of braaksma, cmangla, and asterisk are not guaranteed to be maximally merged.
It's interesting that the two fastest, correct algorithms have the top two amount of upvotes to date, in properly ranked order.
Here's the code used to run the tests:
from networkx.algorithms.components.connected import connected_componentsfrom itertools import chainfrom random import randint, randomfrom collections import defaultdict, dequefrom copy import deepcopyfrom multiprocessing import Poolimport networkximport datetimeimport os### @mimomu##def mimomu(l): l = deepcopy(l) s = set(chain.from_iterable(l)) for i in s: components = [x for x in l if i in x] for j in components: l.remove(j) l += [list(set(chain.from_iterable(components)))] return l### @Howard##def howard(l): out = [] while len(l)>0: first, *rest = l first = set(first) lf = -1 while len(first)>lf: lf = len(first) rest2 = [] for r in rest: if len(first.intersection(set(r)))>0: first |= set(r) else: rest2.append(r) rest = rest2 out.append(first) l = rest return out### Nx @Jochen Ritzel##def jochen(l): l = deepcopy(l) def to_graph(l): G = networkx.Graph() for part in l: # each sublist is a bunch of nodes G.add_nodes_from(part) # it also imlies a number of edges: G.add_edges_from(to_edges(part)) return G def to_edges(l): """ treat `l` as a Graph and returns it's edges to_edges(['a','b','c','d']) -> [(a,b), (b,c),(c,d)] """ it = iter(l) last = next(it) for current in it: yield last, current last = current G = to_graph(l) return list(connected_components(G))### Merge all @MAK##def mak(l): l = deepcopy(l) taken=[False]*len(l) l=map(set,l) def dfs(node,index): taken[index]=True ret=node for i,item in enumerate(l): if not taken[i] and not ret.isdisjoint(item): ret.update(dfs(item,i)) return ret def merge_all(): ret=[] for i,node in enumerate(l): if not taken[i]: ret.append(list(dfs(node,i))) return ret result = list(merge_all()) return result### @cmangla##def cmangla(l): l = deepcopy(l) len_l = len(l) i = 0 while i < (len_l - 1): for j in range(i + 1, len_l): # i,j iterate over all pairs of l's elements including new # elements from merged pairs. We use len_l because len(l) # may change as we iterate i_set = set(l[i]) j_set = set(l[j]) if len(i_set.intersection(j_set)) > 0: # Remove these two from list l.pop(j) l.pop(i) # Merge them and append to the orig. list ij_union = list(i_set.union(j_set)) l.append(ij_union) # len(l) has changed len_l -= 1 # adjust 'i' because elements shifted i -= 1 # abort inner loop, continue with next l[i] break i += 1 return l### @pillmuncher##def pillmuncher(l): l = deepcopy(l) def connected_components(lists): neighbors = defaultdict(set) seen = set() for each in lists: for item in each: neighbors[item].update(each) def component(node, neighbors=neighbors, seen=seen, see=seen.add): nodes = set([node]) next_node = nodes.pop while nodes: node = next_node() see(node) nodes |= neighbors[node] - seen yield node for node in neighbors: if node not in seen: yield sorted(component(node)) return list(connected_components(l))### @NicholasBraaksma##def braaksma(l): l = deepcopy(l) lists = sorted([sorted(x) for x in l]) #Sorts lists in place so you dont miss things. Trust me, needs to be done. resultslist = [] #Create the empty result list. if len(lists) >= 1: # If your list is empty then you dont need to do anything. resultlist = [lists[0]] #Add the first item to your resultset if len(lists) > 1: #If there is only one list in your list then you dont need to do anything. for l in lists[1:]: #Loop through lists starting at list 1 listset = set(l) #Turn you list into a set merged = False #Trigger for index in range(len(resultlist)): #Use indexes of the list for speed. rset = set(resultlist[index]) #Get list from you resultset as a set if len(listset & rset) != 0: #If listset and rset have a common value then the len will be greater than 1 resultlist[index] = list(listset | rset) #Update the resultlist with the updated union of listset and rset merged = True #Turn trigger to True break #Because you found a match there is no need to continue the for loop. if not merged: #If there was no match then add the list to the resultset, so it doesnt get left out. resultlist.append(l) return resultlist### @Rumple Stiltskin##def stiltskin(l): l = deepcopy(l) hashdict = defaultdict(int) def hashit(x, y): for i in y: x[i] += 1 return x def merge(x, y): sums = sum([hashdict[i] for i in y]) if sums > len(y): x[0] = x[0].union(y) else: x[1] = x[1].union(y) return x hashdict = reduce(hashit, l, hashdict) sets = reduce(merge, l, [set(),set()]) return list(sets)### @Asterisk##def asterisk(l): l = deepcopy(l) results = {} for sm in ['min', 'max']: sort_method = min if sm == 'min' else max l = sorted(l, key=lambda x:sort_method(x)) queue = deque(l) grouped = [] while len(queue) >= 2: l1 = queue.popleft() l2 = queue.popleft() s1 = set(l1) s2 = set(l2) if s1 & s2: queue.appendleft(s1 | s2) else: grouped.append(s1) queue.appendleft(s2) if queue: grouped.append(queue.pop()) results[sm] = grouped if len(results['min']) < len(results['max']): return results['min'] return results['max']### Validate no more clusters can be merged##def validate(output, L): # validate all sublists are maximally merged d = defaultdict(list) for idx, i in enumerate(output): for j in i: d[j].append(i) if any([len(i) > 1 for i in d.values()]): return 'not maximally merged' # validate all items in L are accounted for all_items = set(chain.from_iterable(L)) accounted_items = set(chain.from_iterable(output)) if all_items != accounted_items: return 'missing items' # validate results are good return 'true'### Timers##def time(func, L): start = datetime.datetime.now() result = func(L) delta = datetime.datetime.now() - start return result, delta### Function runner##def run_func(args): func, L, input_size = args results, elapsed = time(func, L) validation_result = validate(results, L) return func.__name__, input_size, elapsed, validation_result### Main##all_results = defaultdict(lambda: defaultdict(list))funcs = [mimomu, howard, jochen, mak, cmangla, braaksma, asterisk]args = []for trial in range(10): for s in range(10): input_size = 2**s # get some random inputs to use for all trials at this size L = [] for i in range(input_size): sublist = [] for j in range(randint(5, 10)): sublist.append(randint(0, 2**24)) L.append(sublist) for i in funcs: args.append([i, L, input_size])pool = Pool()for result in pool.imap(run_func, args): func_name, input_size, elapsed, validation_result = result all_results[func_name][input_size].append({ 'time': elapsed, 'validation': validation_result, }) # show the running time for the function at this input size print(input_size, func_name, elapsed, validation_result)pool.close()pool.join()# write the average of time trials at each size for each functionwith open('times.tsv', 'w') as out: for func in all_results: validations = [i['validation'] for j in all_results[func] for i in all_results[func][j]] linetype = 'incorrect results' if any([i != 'true' for i in validations]) else 'correct results' for input_size in all_results[func]: all_times = [i['time'].microseconds for i in all_results[func][input_size]] avg_time = sum(all_times) / len(all_times) out.write(func + '\t' + str(input_size) + '\t' + \ str(avg_time) + '\t' + linetype + '\n')
And for plotting:
library(ggplot2)df <- read.table('times.tsv', sep='\t')p <- ggplot(df, aes(x=V2, y=V3, color=as.factor(V1))) + geom_line() + xlab('number of input lists') + ylab('runtime (ms)') + labs(color='') + scale_x_continuous(trans='log10') + facet_wrap(~V4, ncol=1)ggsave('runtimes.png')