Recreating figures in the manuscript

import os
import pandas as pd
import subprocess
from io import StringIO
import matplotlib.pyplot as plt
import seaborn as sns
import string
import matplotlib.gridspec as gridspec
import geopandas as gpd
import cartopy.crs as ccrs
import numpy as np
import pycoda
from matplotlib.patches import Patch

sns.set_context('paper')
FIG_PATH = '/Users/hanmar/OneDrive - Danmarks Tekniske Universitet/PhD/manuscripts/db_paper/figures'
DATA_PATH = '../data/'
alphabet = list(string.ascii_lowercase)

def query_db(cmd):
    p = subprocess.run(f"mysql -e \"{cmd}\"", shell=True, stderr=subprocess.PIPE, stdout=subprocess.PIPE)
    return  pd.read_csv(StringIO(p.stdout.decode()), sep='\t')

world = gpd.read_file(gpd.datasets.get_path("naturalearth_lowres"))
def figure_saver(fig, prefix, contexts=['paper'], formats=['.png', '.pdf', '.tiff']):
    
    for context in contexts:
        
        for fformat in formats:
            fname = f"{prefix}_{context}{fformat}"
            fig.savefig(fname)
    
    sns.set_context('paper')

def get_cmap(labels, name='viridis'):
    '''Returns a function that maps each index in 0, 1, ..., n-1 to a distinct 
    RGB color; the keyword argument name must be a standard mpl colormap name.'''
    cmap = plt.get_cmap(name, len(labels))#.colors
    
    return {l: cmap(i) for i, l in enumerate(labels)}
# retrieve metadata
pull = False
metadataFile = os.path.join(DATA_PATH, 'metadata.h5')
if pull:
    metadata = query_db(
        "select run_accession, host, IF( (collection_date <= '2020-01-01') AND (collection_date >= '1800-01-01'), YEAR(collection_date), NULL)  as year , CONCAT(IF(SPLIT_STRING(location, ' ', 2) LIKE 'S', '-', ''), SPLIT_STRING(location, ' ', 1)) as latitude, CONCAT(IF(SPLIT_STRING(location, ' ', 4) LIKE 'W', '-', ''), SPLIT_STRING(location, ' ', 3)) as longitude, raw_reads, raw_bases, trimmed_fragments, instrument_platform, country, continent from Meta_public inner join run_overview using(run_accession, sample_accession, project_accession) where mapped_ResFinder=1 and mapped_Silva=1"
    )
    
    metadata.to_hdf(metadataFile, key='metadata')
else:
    metadata = pd.read_hdf(metadataFile, key='metadata')
continent_order = ['Africa', 'Antarctica', 'Asia', 'Europe', 'North America', 'Oceania', 'South America', 'Arctic Ocean', 'Atlantic Ocean', 'Indian Ocean', 'Norwegian Sea', 'Pacific Ocean', 'Red Sea']

hostCounts = metadata.groupby('host').agg({'run_accession': 'count', 'trimmed_fragments': 'sum'}).sort_values(by='run_accession')
host1k = hostCounts.query("run_accession >= 1000").index
host_index = host1k.str.lower().argsort()
host_order = host1k[host_index]
host_axlabels = [" ".join([h.split(' ')[0].title()] + h.split(' ')[1:]) for h in host1k]

Figure 1

fig = plt.figure(figsize=(14,7))

gs = gridspec.GridSpec(ncols=2, nrows=2, height_ratios=[1, .2], width_ratios=[.2, 1])


ax_map = fig.add_subplot(gs[0, 1])
ax_host = fig.add_subplot(gs[0, 0])
#ax_host_b = fig.add_subplot(gs[0, 1], sharey=ax_host)
ax_time = fig.add_subplot(gs[-1, :])

sns.barplot(data=hostCounts.loc[hostCounts.run_accession >= 1000].reset_index(), x='run_accession', y='host', orient='h', ax=ax_host, color='lightgrey') 
ax_host.set_xscale('log')
ax_host.set_xlabel('Samples (log)')
ax_host.set_ylabel('Sampling host')
ax_host.set_yticklabels(host_axlabels)

#sns.barplot(data=hostCounts.loc[hostCounts.run_accession >= 1000].reset_index(), x='trimmed_fragments', y='host', orient='h', ax=ax_host_b, color='grey') 
#ax_host_b.set_xscale('log')
#ax_host_b.set_xlabel('Fragments (log)')
#

# map
world.plot(ax=ax_map, color="lightgrey")
metadata.plot(x="longitude", y="latitude", ax=ax_map, kind="scatter", color="blue", alpha=.25, s=1)
ax_map.axis('off')
ax_map.set_xlabel('Sampling location')
ax_map.margins(0)
ax_map.apply_aspect()

# year
timeCounts = metadata.groupby('year').agg({'run_accession': 'count', 'trimmed_fragments': 'sum'}).sort_values(by='run_accession').reset_index()
timeCounts['year'] = timeCounts['year'].astype('int')

sns.barplot(
    data=timeCounts.reset_index(),
    x='year',
    y = 'run_accession',
    ax=ax_time,
    color='lightgrey'
)
ax_time.set_yscale('log')
ax_time.set_xlabel('Sampling year')
ax_time.set_ylabel('Samples (log)')


for ax, label in zip([ax_host, ax_map, ax_time], alphabet):
    ax.text(-0.05, 1.05, label + '.', transform=ax.transAxes, fontsize=14, va='top', ha='right', fontweight='bold')

fig.tight_layout()

figure_saver(fig=fig, prefix=os.path.join(FIG_PATH, 'sample_origins'), contexts=['paper', 'talk'])

# save data
hostCounts.to_csv(os.path.join(DATA_PATH, 'figure-data', 'fig1A_host_data.csv'))
metadata[['run_accession', 'longitude', 'latitude']].to_csv(os.path.join(DATA_PATH, 'figure-data', 'fig1B_loc_data.csv'))
timeCounts.to_csv(os.path.join(DATA_PATH, 'figure-data', 'fig1C_year_data.csv'))
_images/02_figures_6_0.png

Figure S1

fig, axes = plt.subplots(1,2, figsize=(10,7), sharey=True, gridspec_kw={'width_ratios': [.3, 1]})

readPlatformSum = metadata.groupby('instrument_platform').agg({'run_accession': 'count'}).reset_index().sort_values(by='instrument_platform')
sns.barplot(
    data = readPlatformSum,
    y = 'instrument_platform',
    x = 'run_accession',
    orient='h',
    ax=axes[0]
)
axes[0].set_xscale('log')

for index, row in readPlatformSum.iterrows():
    if row.run_accession > 20:
        axes[0].text(row.run_accession-10, index, row.run_accession, color='black', ha="right")
    else:
        axes[0].text(row.run_accession, index, row.run_accession, color='black', ha="left")



sns.boxenplot(
    data=metadata.sort_values(by='instrument_platform'),
    x='raw_reads',
    y='instrument_platform',
    ax=axes[1],
    orient='h'
),
axes[1].set_xscale('log')

# labels
axes[0].set_ylabel("Instrument platform")
axes[1].set_ylabel('')

axes[0].set_xlabel("Samples (log-scale)")
axes[1].set_xlabel("Raw reads (log-scale)")

axes[0].text(-0.05, 1, 'a.', transform=axes[0].transAxes, fontsize=14, va='top', ha='right', fontweight='bold')
axes[1].text(-0.01, 1, 'b.', transform=axes[1].transAxes, fontsize=14, va='top', ha='right', fontweight='bold')

fig.tight_layout()
figure_saver(fig=fig, prefix=os.path.join(FIG_PATH, 'platform_overview'))

# save data
readPlatformSum.to_csv(os.path.join(DATA_PATH, 'figure-data', 'figS1A_platform_count_data.csv'))
metadata[['run_accession', 'instrument_platform', 'raw_reads']].to_csv(os.path.join(DATA_PATH, 'figure-data', 'figS1B_platform_frag_data.csv'))
_images/02_figures_8_0.png

Figure S2

found_stats = query_db(
    "select ResFinder_anno.anno_value as ResFinder_class, count(gene) as Total, sum(IF(gene in (SELECT distinct(refSequence) from AVA_public), 1, 0)) as Found, count(gene) - sum(IF(gene in (SELECT distinct(refSequence) from AVA_public), 1, 0))  as Missing from ResFinder_anno where anno_type='Class' group by anno_value"
)
found_total = query_db(
    "CREATE TEMPORARY TABLE gene_found AS (select gene, IF(gene in (SELECT distinct(refSequence) from AVA_public), 'Found', 'Missing') as found from ResFinder_anno where anno_type='Class'); select found, count(distinct(gene)) as gcount from gene_found group by found"
)

n = found_stats.shape[0]
found_stats.loc[n, ['ResFinder_class', 'Found', 'Missing', 'Total']] = ['All'] + found_total.set_index('found').T.values.tolist()[0] + [sum(found_total.set_index('found').T.values.tolist()[0])]
for c in ['Missing', 'Found', 'Total']:
    found_stats[c] = found_stats[c].astype('int')

found_stats['label'] = found_stats.apply(lambda x: f"{x['ResFinder_class']} ({x['Found']}/{x['Total']})", axis=1, result_type='expand')
fig, ax = plt.subplots(figsize=(6,7))
found_stats.set_index('label')[['Missing', 'Found']].sort_index(ascending=False).coda.closure(100).plot.barh(stacked=True, color=['grey', 'lightgray'], ax=ax)
ax.legend(loc='upper right')
ax.set_ylabel('')
ax.set_xlabel('Percentage')
fig.tight_layout()
figure_saver(fig=fig, prefix=os.path.join(FIG_PATH, 'ResFinder_found'), contexts=['paper', 'talk'])

# save data
found_stats.to_csv(os.path.join(DATA_PATH, 'figure-data', 'figS2_found-missing_data.csv'))
_images/02_figures_12_0.png

Figure 2

pull = False
fragDataFile = os.path.join(DATA_PATH, 'fragData.h5')

if pull:
    fragData=query_db("select run_accession, category, total_fragments, category_fragments from Diversity")
    fragData.to_hdf(fragDataFile, key='fragData')

else:
    fragData = pd.read_hdf(fragDataFile, key='fragData')
# summary statistics for how many fragments could be mapped
print(
    fragData.drop_duplicates(subset=['run_accession', 'total_fragments'])['total_fragments'].describe(), 
    fragData.drop_duplicates(subset=['run_accession', 'total_fragments'])['total_fragments'].sum(), 
    sep='\n'
)
count    2.132560e+05
mean     7.927448e+06
std      2.150121e+07
min      1.000000e+00
25%      5.570600e+04
50%      7.847480e+05
75%      6.711620e+06
max      9.169014e+08
Name: total_fragments, dtype: float64
1690575810608
# mapping percentages
(fragData.groupby('category').agg({'category_fragments': 'sum', 'total_fragments': 'sum'})  / 1690575810608)*100
category_fragments total_fragments
category
ARG 0.043015 97.130376
Genera 0.191964 99.844671
Phyla 0.191964 99.844671
fig = plt.figure(figsize=(7,7), constrained_layout=True)
spec = gridspec.GridSpec(ncols=1, nrows=2, figure=fig, height_ratios=[.2,  1])

ax_density = fig.add_subplot(spec[0, 0])
ax_scatter = fig.add_subplot(spec[1, 0])

sns.kdeplot(
    data=fragData.drop_duplicates(subset=['run_accession', 'total_fragments']),
    x='total_fragments',
    fill=True,
    ax=ax_density,
    log_scale=True
)
ax_density.set_xlabel("Fragment per metagenome (log)")

sns.scatterplot(
    data=fragData.pivot_table(index='run_accession', columns='category', values='category_fragments').dropna(subset=['Phyla']).fillna(0),
    x='Phyla',
    y='ARG',
    alpha=.15,
    ax=ax_scatter,
    s=4,
    marker='.'
)

ax_scatter.set_xscale('symlog')
ax_scatter.set_yscale('symlog')

ax_scatter.set_xlim(-1, ax_scatter.get_xlim()[1])
ax_scatter.set_ylim(-1, ax_scatter.get_ylim()[1])

ax_scatter.set_xlabel('rRNA fragments')
ax_scatter.set_ylabel('ARG fragments')

for ax, label in zip([ax_density, ax_scatter], alphabet):
    ax.text(-0.0, 1.05, label + '.', transform=ax.transAxes, fontsize=14, va='top', ha='right', fontweight='bold')


figure_saver(fig=fig, prefix=os.path.join(FIG_PATH, 'overview_fragments_scatter'), contexts=['paper', 'talk'])

# save data
fragData.to_csv(os.path.join(DATA_PATH, 'figure-data', 'fig2AB_fragments_data.csv'))
_images/02_figures_17_0.png

Figure S3

fig, axes = plt.subplots(2, 1, figsize=(10,10), sharex=True)

sns.scatterplot(
    data = metadata[['run_accession','raw_reads', 'raw_bases']].merge(fragData.pivot_table(index='run_accession', columns='category', values='category_fragments'), left_on='run_accession', right_index=True),
    x = 'raw_bases',
    y = 'ARG',
    marker = '.',
    ax = axes[0],
    alpha=.15,
    s=4
)

sns.scatterplot(
    data = metadata[['run_accession','raw_reads', 'raw_bases']].merge(fragData.pivot_table(index='run_accession', columns='category', values='category_fragments'), left_on='run_accession', right_index=True),
    x = 'raw_bases',
    y = 'Phyla',
    marker = '.',
    ax = axes[1],
    alpha=.15,
    s=4
)

for ax, label in zip(axes.flatten(), alphabet):
    ax.set_xscale('symlog')
    ax.set_yscale('symlog')

    ax.set_xlim(-1, ax.get_xlim()[1])
    ax.set_ylim(-1, ax.get_ylim()[1])
    
    ax.text(-0.0, 1.05, label + '.', transform=ax.transAxes, fontsize=14, va='top', ha='right', fontweight='bold')
    
axes[0].set_ylabel('ARG fragments')
axes[1].set_ylabel('rRNA fragments')
axes[1].set_xlabel('Raw base count')
    
fig.tight_layout()
figure_saver(fig=fig, prefix=os.path.join(FIG_PATH, 'overview_bases_scatter'), contexts=['paper', 'talk'])
_images/02_figures_19_0.png

Figure S4

selected_hosts = {
    'Homo sapiens': 'Homo sapiens',
    'marine metagenome': 'marine metagenome',
    'soil metagenome': 'soil metagenome',
    'freshwater metagenome': 'freshwater metagenome',
    'wastewater metagenome': 'wastewater metagenome',
    'Canis lupus familiaris': 'Canis lupus familiaris',
    'pig gut metagenome': 'Sus scrofa',
    'Sus scrofa': 'Sus scrofa',
    'Sus scrofa domesticus': 'Sus scrofa',
    'pig metagenome': 'Sus scrofa',
    'chicken gut metagenome': 'Gallus gallus',
    'Gallus gallus': 'Gallus gallus',
    'cow dung metagenome': 'Bos taurus',
    'Bos taurus': 'Bos taurus',
    'mouse metagenome': 'Mus musculus',
    'mouse gut metagenome': 'Mus musculus',
    'Mus musculus': 'Mus musculus',
    
}

t = fragData.pivot_table(index=['run_accession', 'total_fragments'], columns='category', values='category_fragments').fillna(0).reset_index().merge(metadata, on='run_accession')
t = t.query("Phyla > 0")
t['selected_host'] = t['host'].map(selected_hosts)
fig, axes = plt.subplots(5, 2, figsize=(3*2, 3*5), sharex=True, sharey=True)


colors = list(sns.palettes.color_palette('tab20').as_hex())
colors.pop(14)
colors.pop(14)


for ax, h, c, letter in zip(axes.flatten(), sorted(set(selected_hosts.values())), colors, alphabet):
    
    td = t.loc[t.selected_host == h]
    tn = t.loc[t.selected_host != h]
    
    ax.scatter(
        x = tn['Phyla'],
        y=tn['ARG'],
        color='lightgray',
        alpha = .05,
        marker='.',
        s=.1
    )
    
    ax.scatter(
        x = td['Phyla'],
        y=td['ARG'],
        color=c,
        alpha = .1,
        marker='.',
        s=1
    )
    
    ax.set_xscale('symlog')
    ax.set_yscale('symlog')
    ax.set_xlim(-1, ax.get_xlim()[1])
    ax.set_ylim(-1, ax.get_ylim()[1])
    ax.set_title(h.title())
    



for ax in axes[:, 0].flatten():
    ax.set_ylabel('ARG fragments')

for ax in axes[-1, :].flatten():
    ax.set_xlabel('rRNA fragments')

fig.tight_layout()
figure_saver(fig=fig, prefix=os.path.join(FIG_PATH, 'overview_fragments_host'), contexts=['paper', 'talk'])

Figure 3

pull = True
classAMRfile = os.path.join(DATA_PATH, 'classAMR.h5')
if pull: 
    classAMRdf = query_db("select run_accession, bacterial_fragment, anno_value as ResFinder_class, sum(fragmentCountAln/(refSequence_length / 1e3)) as fragmentCountAln from AVA_public right join ResFinder_anno on refSequence=gene where anno_type='Class' group by run_accession, bacterial_fragment, anno_value")
    classAMRdf = classAMRdf.merge(metadata, on='run_accession', how='left')
    classAMRdf.replace(['Missing', 'Not available'], np.nan, inplace=True)
    
    classAMRdf.to_hdf(classAMRfile, key='classAMR')
else:
    classAMRdf = pd.read_hdf(classAMRfile, key='classAMR')
/Users/hanmar/.pyenv/versions/3.8.0/lib/python3.8/site-packages/pandas/core/generic.py:2703: PerformanceWarning: 
your performance may suffer as PyTables will pickle object types that it cannot
map directly to c-types [inferred_type->mixed,key->block1_values] [items->Index(['run_accession', 'ResFinder_class', 'host', 'instrument_platform',
       'country', 'continent'],
      dtype='object')]

  pytables.to_hdf(
ARGabundances = classAMRdf.replace(['Missing', 'Not available'], np.nan).groupby(
    ['run_accession', 'year', 'continent', 'host']
).agg(
    {'bacterial_fragment': 'max', 'fragmentCountAln': 'sum'}
).apply(lambda x: np.log(x['fragmentCountAln'] / (x['bacterial_fragment'] / 1e6)), axis=1).replace([np.inf, -np.inf], np.nan).dropna()
<ipython-input-21-e7deea0d47fb>:5: RuntimeWarning: divide by zero encountered in double_scalars
  ).apply(lambda x: np.log(x['fragmentCountAln'] / (x['bacterial_fragment'] / 1e6)), axis=1).replace([np.inf, -np.inf], np.nan).dropna()
<ipython-input-21-e7deea0d47fb>:5: RuntimeWarning: divide by zero encountered in log
  ).apply(lambda x: np.log(x['fragmentCountAln'] / (x['bacterial_fragment'] / 1e6)), axis=1).replace([np.inf, -np.inf], np.nan).dropna()
fig = plt.figure(figsize=(15,10), constrained_layout=True)
spec = gridspec.GridSpec(ncols=3, nrows=2, figure=fig, height_ratios=[.15, 1], hspace=.02)

ax_bar = fig.add_subplot(spec[0,:])
ax_year = fig.add_subplot(spec[1,0])
ax_continent = fig.add_subplot(spec[1,1])
ax_host = fig.add_subplot(spec[1, 2])

sns.boxplot(data=ARGabundances.dropna(), orient='h', ax=ax_bar)#, scale='count')
ax_bar.set_xlabel('Abundance (log)')
ax_bar.set_yticks(ticks=[])
ax_bar.set_title("Sample-wise ARG abundance")

sns.boxplot(data=ARGabundances.reset_index(level=1), x=0, y='year', ax=ax_year, orient='h', palette="Blues")
sns.boxplot(data=ARGabundances.reset_index(level=2).sort_values(by='continent'), x=0, y='continent', ax=ax_continent, orient='h', color="lightgrey", palette="Blues", order=continent_order)
sns.boxplot(data=ARGabundances.loc[ARGabundances.index.get_level_values('host').isin(host1k),].reset_index(level=3).sort_values(by='host'), x=0, y='host', ax=ax_host, orient='h', palette="Blues", order=host_order)

ax_host.set_yticklabels([host_axlabels[i] for i in host_index])
ax_year.set_yticklabels([t.get_text().replace('.0', '') for t in ax_year.get_yticklabels()])

for ax, t, label in zip([ax_bar, ax_year, ax_continent, ax_host], ['Sample-wise ARG abundance','Sampling host', 'Sampling location', 'Sampling host'], alphabet):
    ax.set_xlabel('Abundance (log)', fontsize=plt.rcParams['axes.labelsize']-.5)
    ax.set_ylabel('')
    ax.set_title(t, fontsize=plt.rcParams['axes.titlesize']+.5)
    ax.text(-0.005, 1.025, label + '.', transform=ax.transAxes, fontsize=14, va='top', ha='right', fontweight='bold')

figure_saver(fig=fig, prefix=os.path.join(FIG_PATH, 'ARG_abundances'), contexts=['paper', 'talk'])

# save data
ARGabundances.to_csv(os.path.join(DATA_PATH, 'figure-data', 'fig3_ARG_abundances_data.csv'))
_images/02_figures_26_0.png

Figure 4

ResFinder_classes = sorted(classAMRdf.ResFinder_class.dropna().unique())
AMR_colors = get_cmap(ResFinder_classes, 'tab20')

continentRel = classAMRdf.groupby(['continent', 'ResFinder_class']).agg({'fragmentCountAln': 'sum'}).reset_index().pivot('continent', 'ResFinder_class', 'fragmentCountAln').fillna(0).coda.closure(1)
yearRel = classAMRdf.groupby(['year', 'ResFinder_class']).agg({'fragmentCountAln': 'sum'}).reset_index().pivot('year', 'ResFinder_class', 'fragmentCountAln').fillna(0).coda.closure(1)
hostRel = classAMRdf.groupby(['host', 'ResFinder_class']).agg({'fragmentCountAln': 'sum'}).reset_index().pivot('host', 'ResFinder_class', 'fragmentCountAln').fillna(0).coda.closure(1)
fig = plt.figure(figsize=(15,10), constrained_layout=True)
spec = gridspec.GridSpec(ncols=3, nrows=2, figure=fig, height_ratios=[1, .15], hspace=.05)

ax_year = fig.add_subplot(spec[0,0])
ax_continent = fig.add_subplot(spec[0,1])
ax_host = fig.add_subplot(spec[0, 2])
ax_leg = fig.add_subplot(spec[1,:])

xlabel = 'Relative proportion'

yearRel.sort_index(ascending=False).plot.barh(stacked=True, color=AMR_colors, ax=ax_year, legend=False)
ax_year.set_xlabel(xlabel)
ax_year.set_ylabel('')
ax_year.set_title("Sampling year")
ax_year.set_xlim(0,1)
ax_year.set_yticklabels([t.get_text().replace('.0', '') for t in ax_year.get_yticklabels()])

continentRel.loc[continent_order].loc[::-1].plot.barh(stacked=True, color=AMR_colors, ax=ax_continent, legend=False, sort_columns=False)
ax_continent.set_xlabel(xlabel)
ax_continent.set_ylabel('')
ax_continent.set_title("Sampling location")
ax_continent.set_xlim(0,1)

hostRel.loc[host_order[::-1]].plot.barh(stacked=True, color=AMR_colors, ax=ax_host, legend=False)
ax_host.set_xlabel(xlabel)
ax_host.set_ylabel('')
ax_host.set_title("Sampling host")
ax_host.set_xlim(0, 1)
ax_host.set_yticklabels([t.get_text().title() for t in ax_host.get_yticklabels()])

for ax, label in zip([ax_year, ax_continent, ax_host], alphabet):
    ax.text(-0.005, 1.025, label + '.', transform=ax.transAxes, fontsize=14, va='top', ha='right', fontweight='bold')

handles = [Patch(facecolor=v, edgecolor=v, label=k) for k, v in AMR_colors.items()]
ax_leg.legend(handles=handles, ncol=6, loc="center")
ax_leg.axis('off')
figure_saver(fig=fig, prefix=os.path.join(FIG_PATH, 'ARG_distribution'), contexts=['paper', 'talk'])

# save data
yearRel.to_csv(os.path.join(DATA_PATH, 'figure-data', 'fig4A_year_relARG_data.csv'))
continentRel.to_csv(os.path.join(DATA_PATH, 'figure-data', 'fig4B_contintent_relARG_data.csv'))
hostRel.to_csv(os.path.join(DATA_PATH, 'figure-data', 'fig4C_host_relARG_data.csv'))
_images/02_figures_29_0.png

Figure S5

pull = True

diversityFile = os.path.join(DATA_PATH, 'classAMR.h5')
if pull:
    divtest = query_db("select run_accession, kma_version, category, n, Shannon, IF(Simpson<0, NULL, Simpson) as Simpson from Diversity;")
    div = divtest.pivot_table(index='run_accession', columns=['category'], values=['n', 'Shannon', 'Simpson'])
    
    div.to_hdf(diversityFile, key='diversity')
else:
    div = pd.read_hdf(diversityFile, key='diversity')
shannondiv = div['Shannon'].merge(metadata, left_index=True, right_on='run_accession', how='left')
shannondiv['selected_host'] = shannondiv['host'].map(selected_hosts)

ndiv = div['n'].merge(metadata, left_index=True, right_on='run_accession', how='left')
ndiv['selected_host'] = ndiv['host'].map(selected_hosts)

fig = plt.figure(figsize=(5 + (4*2), 7*2))

spec = gridspec.GridSpec(ncols=5, nrows=5*2, figure=fig, wspace=0.1, hspace=0.9, width_ratios=[1, 1, 0.01, 0.5, 0.5])


ax_main_n = fig.add_subplot(spec[:5, :2])
sns.scatterplot(
    data = ndiv,
    x = 'Genera',
    y = 'ARG',
    s = 6,
    alpha = .2,
    ax = ax_main_n
)
ax_main_n.set_xscale('symlog')
ax_main_n.set_yscale('symlog')

ax_main_n.set_xlim(0, ax_main_n.get_xlim()[1])
ax_main_n.set_ylim(0, ax_main_n.get_ylim()[1])

ax_main_n.set_xlabel('')
ax_main_n.set_ylabel('ARG')

specs = [(i, j) for i in range(5) for j in range(3, 5)]
for i, (s, h, c, letter) in enumerate(zip(specs, sorted(set(selected_hosts.values())), colors, alphabet)):
    ax = fig.add_subplot(spec[s[0], s[1]])    
    sns.scatterplot(
        data = ndiv.loc[ndiv.selected_host != h],
        x = 'Genera',
        y = 'ARG',
        s = 1,
        alpha = .1,
        ax = ax,
        color='lightgray'
    )
      
    sns.scatterplot(
        data = ndiv.loc[ndiv.selected_host == h],
        x = 'Genera',
        y = 'ARG',
        s = 4,
        alpha = .2,
        ax = ax,
        color=c
    )
    
    ax.set_xscale('symlog')
    ax.set_yscale('symlog')
    ax.set_xlim(0, ax.get_xlim()[1])
    ax.set_ylim(0, ax.get_ylim()[1])
    
    ax.set_title(h.title())
    
    if  i < 8: 
        ax.set_xticklabels([])
    ax.set_xlabel('')

        
    if (i+1) % 2 == 0:
        ax.set_yticklabels([])
    ax.set_ylabel('')


ax_main_shannon = fig.add_subplot(spec[-5:, :2])
sns.scatterplot(
    data = shannondiv,
    x = 'Genera',
    y = 'ARG',
    s = 6,
    alpha = .2,
    ax = ax_main_shannon
)


ax_main_shannon.set_xlabel('Genera')
ax_main_shannon.set_ylabel('ARG')

specs = [(i, j) for i in range(5, 10) for j in range(3, 5)]
sub_axes = []
for i, (s, h, c, letter) in enumerate(zip(specs, sorted(set(selected_hosts.values())), colors, alphabet)):
    ax = fig.add_subplot(spec[s[0], s[1]])    
    sns.scatterplot(
        data = shannondiv.loc[shannondiv.selected_host != h],
        x = 'Genera',
        y = 'ARG',
        s = 1,
        alpha = .1,
        ax = ax,
        color='lightgray'
    )
      
    sns.scatterplot(
        data = shannondiv.loc[shannondiv.selected_host == h],
        x = 'Genera',
        y = 'ARG',
        s = 4,
        alpha = .2,
        ax = ax,
        color=c
    )
    
    ax.set_title(h.title())
    
    if  i < 8: 
        ax.set_xticklabels([])
    ax.set_xlabel('')

        
    if (i+1) % 2 == 0:
        ax.set_yticklabels([])
    ax.set_ylabel('')

ax_main_n.text(-0.005, 1.025, 'a.', transform=ax_main_n.transAxes, fontsize=14, va='top', ha='right', fontweight='bold')
ax_main_n.set_title('Observed richness', fontsize=14)
ax_main_shannon.text(-0.005, 1.025, 'b.', transform=ax_main_shannon.transAxes, fontsize=14, va='top', ha='right', fontweight='bold')
ax_main_shannon.set_title('Shannon diversity', fontsize=14)

figure_saver(fig=fig, prefix=os.path.join(FIG_PATH, 'overview_richness_shannon_host'), contexts=['paper', 'talk'])

# save data
ndiv.to_csv(os.path.join(DATA_PATH, 'figure-data', 'figS4A_observed_richness_data.csv'))
shannondiv.to_csv(os.path.join(DATA_PATH, 'figure-data', 'figS4B_shannon_div_data.csv'))
_images/02_figures_32_0.png

Figure 5

simpsondiv = div['Simpson'].merge(metadata, left_index=True, right_on='run_accession', how='left')
simpsondiv['selected_host'] = simpsondiv['host'].map(selected_hosts)

fig = plt.figure(figsize=(5 + (4*2), 7))

spec = gridspec.GridSpec(ncols=5, nrows=5, figure=fig, wspace=0.1, hspace=0.4, width_ratios=[1, 1, 0.01, 0.5, 0.5])

ax_main = fig.add_subplot(spec[:, :2])
sns.scatterplot(
    data = simpsondiv,
    x = 'Genera',
    y = 'ARG',
    s = 6,
    alpha = .2,
    ax = ax_main
)


ax_main.set_xlabel('Genera Gini-Simpson')
ax_main.set_ylabel('ARG Gini-Simpson')

specs = [(i, j) for i in range(5) for j in range(3, 5)]
sub_axes = []
for i, (s, h, c, letter) in enumerate(zip(specs, sorted(set(selected_hosts.values())), colors, alphabet)):
    ax = fig.add_subplot(spec[s[0], s[1]])    
    sns.scatterplot(
        data = simpsondiv.loc[simpsondiv.selected_host != h],
        x = 'Genera',
        y = 'ARG',
        s = 1,
        alpha = .1,
        ax = ax,
        color='lightgray'
    )
      
    sns.scatterplot(
        data = simpsondiv.loc[simpsondiv.selected_host == h],
        x = 'Genera',
        y = 'ARG',
        s = 4,
        alpha = .2,
        ax = ax,
        color=c
    )
    
    ax.set_title(h.title())
    
    if  i < 8: 
        ax.set_xticklabels([])
    ax.set_xlabel('')

        
    if (i+1) % 2 == 0:
        ax.set_yticklabels([])
    ax.set_ylabel('')
figure_saver(fig=fig, prefix=os.path.join(FIG_PATH, 'overview_simpson_host'), contexts=['paper', 'talk'])

# save data
simpsondiv.to_csv(os.path.join(DATA_PATH, 'figure-data', 'fig5_simpson_div_data.csv'))
_images/02_figures_34_0.png