Source code for gempipe.interface.clusters

import pandas as pnd
import seaborn as sb
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from matplotlib.colors import LinearSegmentedColormap, ListedColormap
from scipy.spatial.distance import pdist
from scipy.cluster.hierarchy import linkage, cut_tree, dendrogram, leaves_list
from sklearn.metrics import silhouette_score, silhouette_samples
import numpy as np    


from gempipe.interface.clusters_utils import merge_tables, sort_by_leaves, make_dendrogram, make_colorbar_clusters, make_colorbar_metadata, make_legends, subset_k_best

    
    
[docs] def silhouette_analysis( tables, figsize = (10,5), drop_const=True, ctotest=None, forcen=None, derive_report=None, report_key='species', excludekeys=[], legend_ratio=0.7, outfile=None, verbose=False, anchor=[None, None, None], key_to_color=None): """Perform a silhuette analysis to detect the optimal number of clusters. Args: tables (pnd.DataFrame): feature tables with genome accessions are in columns and features are in rows. Can also be a dictionary of feature tables (example: ``{'auxotrophies': aux_df, 'substrates': sub_df})``. In this case, any number of tables (pandas.DataFrame) can be used. For each table, genome accessions are in columns, features are in rows. Directly compatible tables are: `rpam.csv`, `cnps.csv`, and `aux.csv` (all produced by `gempipe derive`). figsize (int, int): width and height of the figure. drop_const (bool): if `True`, remove constant features. ctotest (list): number of clusters to test (example: ``[5,7,10]`` to test five, seven and ten clusters). If `None`, all the combinations from 2 to the number of accessions -1 will be used. forcen (int): force the number of cluster, otherwise the optimal number will picked up according to the sihouette value. derive_report (pandas.DataFrame): report table for the generation of strain-specific GSMMs, made by `gempipe derive` in the output directory (`derive_strains.csv`). excludekeys (list): keys (iches/species) not to show in legend. Bug: no more than 1 key is allowed. report_key (str): name of the attribute (column) appearing in `derive_report`, to be compared to the metabolilc clusters. Usually it is 'species' or 'niche'. legend_ratio (float): space reserved for the legend. outfile (str): filepath to be used to save the image. If `None` it will not be saved. verbose (bool): if `True`, print more log messages. anchor (list): list of tuples (X,Y) for customixing the position of legends. ``None`` will leave default positioning. key_to_color (dict): dict mapping each category in `report_key` to a color in the format ([0:1],[0:1],[0:1]). ``None`` will leave default color and order in the legend. Returns: tuple: A tuple containing: - matplotlib.figure.Figure: figure representing the sinhouette analysis. - dict: genome-to-cluster associations. - dict: an RGB color for each cluster. """ def create_silhuette_frame(figsize): # create the subplots: fig, axs = plt.subplots( nrows=1, ncols=10, figsize=figsize, # global dimensions. gridspec_kw={'width_ratios': [0.46, 0.02, 0.46, 0.02, 0.3, 0.04, 0.02, 0.04, 0.02, legend_ratio]}) # suplots width proportions. # adjust the space between subplots: plt.subplots_adjust(wspace=0, hspace=0) axs[1].axis('off') # remove frame and axis axs[3].axis('off') # remove frame and axis axs[6].axis('off') # remove frame and axis axs[8].axis('off') # remove frame and axis return fig, axs def make_plot_ncluster_comparison(ax, num_clusters_vector, silhouette_avg_scores, opt_n_clusters, forcen, verbose): # Plot the silhouette scores against the number of clusters (threshold values) ax.plot(num_clusters_vector, silhouette_avg_scores, marker='o') ax.set_xlabel('N clusters') ax.set_ylabel('Average Silhouette Score') ax.grid(True) if verbose: print(f"Optimal number of clusters: {opt_n_clusters}") ax.axvline(x=opt_n_clusters if forcen==None else forcen, color='red', linestyle='--') def make_plot_silhouette_coeff(ax, opt_n_clusters, silhouette_scores, clusters): # Given a fixed number of cluster (ie the optimal number of clusters), # extract the datapoint belonging to each of the clusters and show its associates silhouette score. y_lower = 0 cluster_to_color = {} for i in range(opt_n_clusters): # scores of the datapoint inside the cluster i: cluster_i_scores = silhouette_scores[clusters == i] cluster_i_scores.sort() # sort from smallest to biggest size_cluster_i = len(cluster_i_scores) # get the limits for this polygon: y_upper = y_lower + size_cluster_i # get the color for this cluster/polygon: color = plt.cm.Spectral(float(i) / opt_n_clusters) cluster_to_color[i+1] = color ax.fill_betweenx(np.arange(y_lower, y_upper), 0, cluster_i_scores, facecolor=color, edgecolor=color, alpha=1.0) ax.text(0, (y_lower + y_upper -1)/2, f'Cluster_{i+1}', va='center', ha='left') y_lower = y_upper + -1 # no space between clusters ax.set_xlabel('Silhouette Coefficient') ax.set_ylabel('') ax.set_yticks([]) ax.set_yticklabels([]) ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) ax.spines['left'].set_visible(False) #ax.set_title('Silhouette Plot for {} Clusters'.format(opt_n_clusters)) ax.set_facecolor('#f8f8f8') # Light gray background return cluster_to_color # START: # format input tables: data, dict_tables = merge_tables(tables) data_bool = data.astype(bool) # convert multi-layer (0, 1, 2, 3, ...)into binary: # the user may want to drop constant columns: if drop_const: constant_columns = [col for col in data.columns if data[col].nunique() == 1] if verbose: print(f"WARNING: removing {len(constant_columns)} constant features.") data_bool = data_bool.drop(columns=constant_columns) # pdist() / linkage() will loose the accession information. So here we save a dict: index_to_acc = {i: accession for i, accession in enumerate(data_bool.index)} # Calculate the linkage matrix using Ward clustering and Jaccard dissimilarity distances = pdist(data_bool, 'jaccard') linkage_matrix = linkage(distances, method='ward') # creates empty plots: fig, axs = create_silhuette_frame(figsize) # get the vector of number of clusters to test: num_clusters_vector = np.arange(2, len(data_bool)-1, 1) if ctotest != None: num_clusters_vector = ctotest #print("Testing the following number of clusters:", num_clusters_vector) # Initialize lists to store silhouette scores and cluster assignments silhouette_avg_scores = [] cluster_assignments = [] # Iterate over a range of threshold values for num_clusters in num_clusters_vector: # Extract clusters based on the current threshold clusters = cut_tree(linkage_matrix, n_clusters=num_clusters) clusters = clusters.flatten() # 'clusters' is now a list of int, representing the cluster to which the i-element belongs to. # create a conversion dictionary: acc_to_cluster = {index_to_acc[index]: clusters[index] for index in index_to_acc.keys()} # Calculate the silhouette score for the current set of clusters. # The Silhouette Score can be used for both K-means clustering and hierarchical clustering, # as well as other clustering algorithms. It's a general-purpose metric for evaluating the # quality of clusters, and it does not depend on the specific clustering algorithm being used. silhouette_avg = silhouette_score(data_bool, clusters) # Store the silhouette score and cluster assignments silhouette_avg_scores.append(silhouette_avg) cluster_assignments.append(clusters) # get the max average sillhouette (optimal value) max_value = max(silhouette_avg_scores) max_index = silhouette_avg_scores.index(max_value) opt_n_clusters = max_index + 2 # '+2' because num_clsuters starts from 2 # Plot the average sihoutte (average on each datapoint). make_plot_ncluster_comparison(axs[0], num_clusters_vector, silhouette_avg_scores, opt_n_clusters, forcen, verbose) # Given the optimal number of clusters, visualizze the silhouette score for each data point. if forcen != None: opt_n_clusters = forcen clusters = cut_tree(linkage_matrix, n_clusters=opt_n_clusters) clusters = clusters.flatten() acc_to_cluster = {index_to_acc[index]: clusters[index]+1 for index in index_to_acc.keys()} silhouette_avg = silhouette_score(data_bool, clusters) if verbose: print(f'Avg silhouette score when {opt_n_clusters} clusters:', silhouette_avg) silhouette_scores = silhouette_samples(data_bool, clusters) # Now 'silhouette_scores' is just a list of values. But the index correspond to a specific accession, that is # associated to a specific cluster. Thus, later we obtain the scores for a specific cluster # simply with a 'silhouette_scores[clusters == i]'. # Show silhouette scores for each datapoint (given the opimal number of clusters) cluster_to_color = make_plot_silhouette_coeff(axs[2], opt_n_clusters, silhouette_scores, clusters) # Plot the dendrogram make_dendrogram(axs[4], linkage_matrix) # order the dataframe following the leaves of the tree: ord_data_bool = sort_by_leaves(data_bool, linkage_matrix, index_to_acc) # add colorbar for the dendrogram make_colorbar_clusters(axs[5], ord_data_bool, acc_to_cluster, cluster_to_color) # add colorbar for the species/niches make_colorbar_metadata(axs[7], ord_data_bool, derive_report, report_key, excludekeys, key_to_color) # make legeneds make_legends(axs[9], derive_report, report_key, excludekeys, cluster_to_color, None, anchor, key_to_color) # save to disk; bbox_inches='tight' removes white spaces around the figure. if outfile != None: plt.savefig(outfile, dpi=300, bbox_inches='tight') fig.set_dpi(300) fig.tight_layout() return (fig, acc_to_cluster, cluster_to_color)
[docs] def heatmap_multilayer( tables, figsize = (10,5), drop_const=True, derive_report=None, report_key='species', excludekeys=[], acc_to_cluster=None, cluster_to_color=None, legend_ratio=0.7, label_ratio=0.02, outfile=None, verbose=False, anchor=[None, None, None], key_to_color=None, xlabels=False): """Create a phylo-metabolic dendrogram. Args: tables (pnd.DataFrame): feature tables with genome accessions are in columns and features are in rows. Can also be a dictionary of feature tables (example: ``{'auxotrophies': aux_df, 'substrates': sub_df})``. In this case, any number of tables (pandas.DataFrame) can be used. For each table, genome accessions are in columns, features are in rows. Directly compatible tables are: `rpam.csv`, `cnps.csv`, and `aux.csv` (all produced by `gempipe derive`). figsize (int, int): width and height of the figure. drop_const (bool): if `True`, remove constant features. derive_report (pandas.DataFrame): report table for the generation of strain-specific GSMMs, made by `gempipe derive` in the output directory (`derive_strains.csv`). report_key (str): name of the attribute (column) appearing in `derive_report`, to be compared to the metabolilc clusters. Usually it is 'species' or 'niche'. excludekeys (list): keys (iches/species) not to show in legend. Bug: no more than 1 key is allowed. acc_to_cluster (dict): genome-to-cluster associations produced by `silhouette_analysis()`. cluster_to_color (dict): cluster-to-RGB color associations produced by `silhouette_analysis()`. legend_ratio (float): space reserved for the legend. label_ratio (float): space reserved for the Y-axis labels. outfile (str): filepath to be used to save the image. If `None` it will not be saved. verbose (bool): if `True`, print more log messages anchor (list): list of tuples (X,Y) for customixing the position of legends. ``None`` will leave default positioning. key_to_color (dict): dict mapping each category in `report_key` to a color in the format ([0:1],[0:1],[0:1]). ``None`` will leave default color and order in the legend. xlabels (bool): if `True`, show x-axis labels (feature IDs). Returns: tuple: A tuple containing: - matplotlib.figure.Figure: figure representing the phylometabolic tree and associated heatmap. - pnd.DataFrame: table representing the binary features contained in the heatmap. """ def create_heatmap_frame(figsize): # create the subplots: fig, axs = plt.subplots( nrows=1, ncols=8, figsize=figsize, # global dimensions. gridspec_kw={'width_ratios': [0.3, 0.04, 0.02, 0.04, 0.02, 0.94, label_ratio, legend_ratio ]}) # suplots width proportions. # adjust the space between subplots: plt.subplots_adjust(wspace=0, hspace=0) axs[2].axis('off') # remove frame and axis axs[4].axis('off') # remove frame and axis axs[6].axis('off') # remove frame and axis return fig, axs def make_plot_heatmap_multilayer(ax, ord_data, xlabels ): ax.matshow( ord_data, cmap='viridis', vmin=ord_data.min().min(), vmax=ord_data.max().max(), # define ranges for the colormap. aspect='auto', # fixed axes and aspect adjusted to fit data. interpolation='none') # no interp. performed on Agg-ps-pdf-svg backends. # set x labels ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) ax.spines['top'].set_visible(False) ax.spines['bottom'].set_visible(False) ax.spines['right'].set_visible(False) ax.spines['left'].set_visible(False) if xlabels: # show x-axis labels (feats IDs) ax.get_xaxis().set_visible(True) ax.set_xticks(range(len(ord_data.columns))) ax.set_xticklabels(list(ord_data.columns)) ax.xaxis.set_ticks_position('bottom') # START # format input tables: data, dict_tables = merge_tables(tables) data_bool = data.astype(bool) # convert multi-layer (0, 1, 2, 3, ...)into binary: # the user may want to drop constant columns: if drop_const: constant_columns = [col for col in data.columns if data[col].nunique() == 1] if verbose: print(f"WARNING: removing {len(constant_columns)} constant features.") data = data.drop(columns=constant_columns) data_bool = data_bool.drop(columns=constant_columns) # pdist() / linkage() will loose the accession information. So here we save a dict: index_to_acc = {i: accession for i, accession in enumerate(data.index)} # create a dendrogram based on the jaccard distancies (dissimilarities): distances = pdist(data_bool, 'jaccard') linkage_matrix = linkage(distances, method='ward') # create the empty figure frame: fig, axs = create_heatmap_frame(figsize) # plot the dendrogram make_dendrogram(axs[0], linkage_matrix) # order the dataframe following the leaves of the tree: ord_data = sort_by_leaves(data, linkage_matrix, index_to_acc) ord_data_bool = sort_by_leaves(data_bool, linkage_matrix, index_to_acc) # only to return # plot the heatmap: make_plot_heatmap_multilayer(axs[5], ord_data, xlabels) # add the cluster information (coming from the silhouette analysis); make_colorbar_clusters(axs[1], ord_data, acc_to_cluster, cluster_to_color) # colorbar for the species/niche make_colorbar_metadata(axs[3], ord_data, derive_report, report_key, excludekeys, key_to_color) # make legends make_legends(axs[7], derive_report, report_key, excludekeys, cluster_to_color, dict_tables, anchor, key_to_color) # save to disk; bbox_inches='tight' removes white spaces around the figure. if outfile != None: plt.savefig(outfile, dpi=300, bbox_inches='tight') fig.set_dpi(300) fig.tight_layout() return fig, ord_data_bool
[docs] def discriminant_feat(binary_feats, acc_to_cluster, cluster_to_color, threshold=0.90): """Extract discriminant features from cluster of strains. Args: tables (pnd.DataFrame): binary features table such as the one produced by `heatmap_multilayer` (genomes in row, binary featuresin column). acc_to_cluster (dict): dictionary such as the one produced by `silhouette_analysis`` (accessions as keys, cluster assignment as value). cluster_to_color (dict): dictionary such as the one produced by `silhouette_analysis`` (clusters as keys, colors as value). threshold (float): features are shown if at least one cluster has relative frequency >= `threshold` and, at the same time, at least another cluster has relative frequency <= 1-`threshold`. Returns: tuple: A tuple containing: - matplotlib.figure.Figure: figure representing the discriminative binary features. """ def get_contingency(binary_feats, feat_id): contingency_table = pnd.crosstab( binary_feats[feat_id], binary_feats['y'], margins = False) # limit case: '0' or '1' is mising: if 0 not in contingency_table.index: new_row = pnd.DataFrame([[0] * contingency_table.shape[1]], columns=contingency_table.columns, index=[0]) contingency_table = pnd.concat([new_row, contingency_table]) if 1 not in contingency_table.index: new_row = pnd.DataFrame([[0] * contingency_table.shape[1]], columns=contingency_table.columns, index=[1]) contingency_table = pnd.concat([new_row, contingency_table]) # the resulting pnd.DataFrame will be similar to (e.g. for the binary feat "[aux]his__L"): # y Cluster_1 Cluster_2 Cluster_3 Cluster_4 Cluster_5 # [aux]his__L # 0 12 3 6 12 0 # 1 0 0 0 0 3 return contingency_table # START # convert to int: binary_feats = binary_feats.copy().astype(int) # .copy() will defragment the dataframe. # add the classification column (usually caled 'y'): binary_feats['y'] = "Cluster_" + str(0) for accession, row in binary_feats.iterrows(): binary_feats.loc[accession, 'y'] = "Cluster_" + str(acc_to_cluster[accession]) # str to avoid ambiguity # get dataframe of relative frequencies: df_relfreq = pnd.DataFrame(index=list(set(list(binary_feats.columns))-set(['y'])), columns=binary_feats['y'].unique()) for feat_id in df_relfreq.index: cont = get_contingency(binary_feats, feat_id) for cluster in df_relfreq.columns: df_relfreq.loc[feat_id, cluster] = cont.loc[1, cluster] / (cont.loc[1, cluster] + cont.loc[0, cluster]) df_relfreq = df_relfreq.astype(float) # filter the dataframe: df_relfreq = df_relfreq[(df_relfreq >= threshold).any(axis=1)] df_relfreq = df_relfreq[(df_relfreq <= 1-threshold).any(axis=1)] # invert column order to match that of the heatmap df_relfreq = df_relfreq[reversed(df_relfreq.columns)] # sort features aalphabetically df_relfreq = df_relfreq.sort_index(ascending=False) # resort index in order to have similar rows (similar features) close together. index_to_featid = {i: feat_id for i, feat_id in enumerate(df_relfreq.index)} distances = pdist(df_relfreq, 'jaccard') linkage_matrix = linkage(distances, method='ward') df_relfreq = sort_by_leaves(df_relfreq, linkage_matrix, index_to_featid) # only to return # create the subplots: fig, axs = plt.subplots( nrows=2, ncols=1, figsize=(0.5 * len(df_relfreq.columns), 0.3 * len(df_relfreq)), # global dimensions. gridspec_kw={'width_ratios': [1], 'height_ratios': [1, 1*len(df_relfreq)]}) # suplots width proportions. # adjust the space between subplots: plt.subplots_adjust(wspace=0, hspace=0) axs[0].set_frame_on(False) # remove squared border but not ticks axs[1].set_frame_on(False) # remove squared border but not ticks # create matshow (clusters) if type(list(cluster_to_color.keys())[0]) == str: cmap = LinearSegmentedColormap.from_list('', [cluster_to_color[cluster.replace('Cluster_', '')] for cluster in df_relfreq.columns]) if type(list(cluster_to_color.keys())[0]) == int: # add the 'eval' cmap = LinearSegmentedColormap.from_list('', [cluster_to_color[eval(cluster.replace('Cluster_', ''))] for cluster in df_relfreq.columns]) df_clusters = pnd.DataFrame({cluster: [i] for i, cluster in enumerate(df_relfreq.columns)}) axs[0].matshow( df_clusters, cmap=cmap, vmin=df_clusters.min().min(), vmax=df_clusters.max().max(), # define ranges for the colormap. aspect='auto', # fixed axes and aspect adjusted to fit data. interpolation='none') # no interp. performed on Agg-ps-pdf-svg backends. # create matshow (heatmap) cmap = LinearSegmentedColormap.from_list('', ["#DDDDDD", "#8888DD"]) axs[1].matshow( df_relfreq, cmap=cmap, vmin=df_relfreq.min().min(), vmax=df_relfreq.max().max(), # define ranges for the colormap. aspect='auto', # fixed axes and aspect adjusted to fit data. interpolation='none') # no interp. performed on Agg-ps-pdf-svg backends. # add x/y axis labels (clusters): axs[0].set_xticks(range(len(df_relfreq.columns))) # Position of x-ticks axs[0].set_xticklabels(df_relfreq.columns, ha='left', rotation=30) # Replace with your desired labels axs[0].set_yticks([]) # remove x ticks # add x/y axis labels (heatmap): axs[1].set_xticks([]) # remove x ticks axs[1].set_yticks(range(len(df_relfreq.index))) # Position of y-ticks axs[1].set_yticklabels(df_relfreq.index) # add annotations (rel frequencies): for i in range(len(df_relfreq.index)): for j in range(len(df_relfreq.columns)): freq_to_show = round(df_relfreq.iloc[i,j], 2) annot_color = 'white' if freq_to_show >= 0.60 else 'black' axs[1].text(j, i, f'{freq_to_show}', ha='center', va='center', color=annot_color) fig.set_dpi(300) fig.tight_layout() return fig, df_relfreq