Merge lists that share common elements Merge lists that share common elements python python

Merge lists that share common elements


You can see your list as a notation for a Graph, ie ['a','b','c'] is a graph with 3 nodes connected to each other. The problem you are trying to solve is finding connected components in this graph.

You can use NetworkX for this, which has the advantage that it's pretty much guaranteed to be correct:

l = [['a','b','c'],['b','d','e'],['k'],['o','p'],['e','f'],['p','a'],['d','g']]import networkx from networkx.algorithms.components.connected import connected_componentsdef to_graph(l):    G = networkx.Graph()    for part in l:        # each sublist is a bunch of nodes        G.add_nodes_from(part)        # it also imlies a number of edges:        G.add_edges_from(to_edges(part))    return Gdef to_edges(l):    """         treat `l` as a Graph and returns it's edges         to_edges(['a','b','c','d']) -> [(a,b), (b,c),(c,d)]    """    it = iter(l)    last = next(it)    for current in it:        yield last, current        last = current    G = to_graph(l)print connected_components(G)# prints [['a', 'c', 'b', 'e', 'd', 'g', 'f', 'o', 'p'], ['k']]

To solve this efficiently yourself you have to convert the list into something graph-ish anyways, so you might as well use networkX from the start.


Algorithm:

  1. take first set A from list
  2. for each other set B in the list do if B has common element(s) with A join B into A; remove B from list
  3. repeat 2. until no more overlap with A
  4. put A into outpup
  5. repeat 1. with rest of list

So you might want to use sets instead of list. The following program should do it.

l = [['a', 'b', 'c'], ['b', 'd', 'e'], ['k'], ['o', 'p'], ['e', 'f'], ['p', 'a'], ['d', 'g']]out = []while len(l)>0:    first, *rest = l    first = set(first)    lf = -1    while len(first)>lf:        lf = len(first)        rest2 = []        for r in rest:            if len(first.intersection(set(r)))>0:                first |= set(r)            else:                rest2.append(r)             rest = rest2    out.append(first)    l = restprint(out)


I needed to perform the clustering technique described by the OP millions of times for rather large lists, and therefore wanted to determine which of the methods suggested above is both most accurate and most performant.

I ran 10 trials for input lists sized from 2^1 through 2^10 for each method above, using the same input list for each method, and measured the average runtime for each algorithm proposed above in milliseconds. Here are the results:

enter image description here

These results helped me see that of the methods that consistently return correct results, @jochen's is the fastest. Among those methods that don't consistently return correct results, mak's solution often does not include all of the input elements (i.e. list of list members are missing), and the solutions of braaksma, cmangla, and asterisk are not guaranteed to be maximally merged.

It's interesting that the two fastest, correct algorithms have the top two amount of upvotes to date, in properly ranked order.

Here's the code used to run the tests:

from networkx.algorithms.components.connected import connected_componentsfrom itertools import chainfrom random import randint, randomfrom collections import defaultdict, dequefrom copy import deepcopyfrom multiprocessing import Poolimport networkximport datetimeimport os### @mimomu##def mimomu(l):  l = deepcopy(l)  s = set(chain.from_iterable(l))  for i in s:    components = [x for x in l if i in x]    for j in components:      l.remove(j)    l += [list(set(chain.from_iterable(components)))]  return l### @Howard##def howard(l):  out = []  while len(l)>0:      first, *rest = l      first = set(first)      lf = -1      while len(first)>lf:          lf = len(first)          rest2 = []          for r in rest:              if len(first.intersection(set(r)))>0:                  first |= set(r)              else:                  rest2.append(r)          rest = rest2      out.append(first)      l = rest  return out### Nx @Jochen Ritzel##def jochen(l):  l = deepcopy(l)  def to_graph(l):      G = networkx.Graph()      for part in l:          # each sublist is a bunch of nodes          G.add_nodes_from(part)          # it also imlies a number of edges:          G.add_edges_from(to_edges(part))      return G  def to_edges(l):      """          treat `l` as a Graph and returns it's edges          to_edges(['a','b','c','d']) -> [(a,b), (b,c),(c,d)]      """      it = iter(l)      last = next(it)      for current in it:          yield last, current          last = current  G = to_graph(l)  return list(connected_components(G))### Merge all @MAK##def mak(l):  l = deepcopy(l)  taken=[False]*len(l)  l=map(set,l)  def dfs(node,index):      taken[index]=True      ret=node      for i,item in enumerate(l):          if not taken[i] and not ret.isdisjoint(item):              ret.update(dfs(item,i))      return ret  def merge_all():      ret=[]      for i,node in enumerate(l):          if not taken[i]:              ret.append(list(dfs(node,i)))      return ret  result = list(merge_all())  return result### @cmangla##def cmangla(l):  l = deepcopy(l)  len_l = len(l)  i = 0  while i < (len_l - 1):    for j in range(i + 1, len_l):      # i,j iterate over all pairs of l's elements including new      # elements from merged pairs. We use len_l because len(l)      # may change as we iterate      i_set = set(l[i])      j_set = set(l[j])      if len(i_set.intersection(j_set)) > 0:        # Remove these two from list        l.pop(j)        l.pop(i)        # Merge them and append to the orig. list        ij_union = list(i_set.union(j_set))        l.append(ij_union)        # len(l) has changed        len_l -= 1        # adjust 'i' because elements shifted        i -= 1        # abort inner loop, continue with next l[i]        break      i += 1  return l### @pillmuncher##def pillmuncher(l):  l = deepcopy(l)  def connected_components(lists):    neighbors = defaultdict(set)    seen = set()    for each in lists:        for item in each:            neighbors[item].update(each)    def component(node, neighbors=neighbors, seen=seen, see=seen.add):        nodes = set([node])        next_node = nodes.pop        while nodes:            node = next_node()            see(node)            nodes |= neighbors[node] - seen            yield node    for node in neighbors:        if node not in seen:            yield sorted(component(node))  return list(connected_components(l))### @NicholasBraaksma##def braaksma(l):  l = deepcopy(l)  lists = sorted([sorted(x) for x in l]) #Sorts lists in place so you dont miss things. Trust me, needs to be done.  resultslist = [] #Create the empty result list.  if len(lists) >= 1: # If your list is empty then you dont need to do anything.      resultlist = [lists[0]] #Add the first item to your resultset      if len(lists) > 1: #If there is only one list in your list then you dont need to do anything.          for l in lists[1:]: #Loop through lists starting at list 1              listset = set(l) #Turn you list into a set              merged = False #Trigger              for index in range(len(resultlist)): #Use indexes of the list for speed.                  rset = set(resultlist[index]) #Get list from you resultset as a set                  if len(listset & rset) != 0: #If listset and rset have a common value then the len will be greater than 1                      resultlist[index] = list(listset | rset) #Update the resultlist with the updated union of listset and rset                      merged = True #Turn trigger to True                      break #Because you found a match there is no need to continue the for loop.              if not merged: #If there was no match then add the list to the resultset, so it doesnt get left out.                  resultlist.append(l)  return resultlist### @Rumple Stiltskin##def stiltskin(l):  l = deepcopy(l)  hashdict = defaultdict(int)  def hashit(x, y):      for i in y: x[i] += 1      return x  def merge(x, y):      sums = sum([hashdict[i] for i in y])      if sums > len(y):          x[0] = x[0].union(y)      else:          x[1] = x[1].union(y)      return x  hashdict = reduce(hashit, l, hashdict)  sets = reduce(merge, l, [set(),set()])  return list(sets)### @Asterisk##def asterisk(l):  l = deepcopy(l)  results = {}  for sm in ['min', 'max']:    sort_method = min if sm == 'min' else max    l = sorted(l, key=lambda x:sort_method(x))    queue = deque(l)    grouped = []    while len(queue) >= 2:      l1 = queue.popleft()      l2 = queue.popleft()      s1 = set(l1)      s2 = set(l2)      if s1 & s2:        queue.appendleft(s1 | s2)      else:        grouped.append(s1)        queue.appendleft(s2)    if queue:      grouped.append(queue.pop())    results[sm] = grouped  if len(results['min']) < len(results['max']):    return results['min']  return results['max']### Validate no more clusters can be merged##def validate(output, L):  # validate all sublists are maximally merged  d = defaultdict(list)  for idx, i in enumerate(output):    for j in i:      d[j].append(i)  if any([len(i) > 1 for i in d.values()]):    return 'not maximally merged'  # validate all items in L are accounted for  all_items = set(chain.from_iterable(L))  accounted_items = set(chain.from_iterable(output))  if all_items != accounted_items:    return 'missing items'  # validate results are good  return 'true'### Timers##def time(func, L):  start = datetime.datetime.now()  result = func(L)  delta = datetime.datetime.now() - start  return result, delta### Function runner##def run_func(args):  func, L, input_size = args  results, elapsed = time(func, L)  validation_result = validate(results, L)  return func.__name__, input_size, elapsed, validation_result### Main##all_results = defaultdict(lambda: defaultdict(list))funcs = [mimomu, howard, jochen, mak, cmangla, braaksma, asterisk]args = []for trial in range(10):  for s in range(10):    input_size = 2**s    # get some random inputs to use for all trials at this size    L = []    for i in range(input_size):      sublist = []      for j in range(randint(5, 10)):        sublist.append(randint(0, 2**24))      L.append(sublist)    for i in funcs:      args.append([i, L, input_size])pool = Pool()for result in pool.imap(run_func, args):  func_name, input_size, elapsed, validation_result = result  all_results[func_name][input_size].append({    'time': elapsed,    'validation': validation_result,  })  # show the running time for the function at this input size  print(input_size, func_name, elapsed, validation_result)pool.close()pool.join()# write the average of time trials at each size for each functionwith open('times.tsv', 'w') as out:  for func in all_results:    validations = [i['validation'] for j in all_results[func] for i in all_results[func][j]]    linetype = 'incorrect results' if any([i != 'true' for i in validations]) else 'correct results'    for input_size in all_results[func]:      all_times = [i['time'].microseconds for i in all_results[func][input_size]]      avg_time = sum(all_times) / len(all_times)      out.write(func + '\t' + str(input_size) + '\t' + \        str(avg_time) + '\t' + linetype + '\n')

And for plotting:

library(ggplot2)df <- read.table('times.tsv', sep='\t')p <- ggplot(df, aes(x=V2, y=V3, color=as.factor(V1))) +  geom_line() +  xlab('number of input lists') +  ylab('runtime (ms)') +  labs(color='') +  scale_x_continuous(trans='log10') +  facet_wrap(~V4, ncol=1)ggsave('runtimes.png')