import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import warnings
warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*")
import umap
from copy import deepcopy
from matplotlib.colors import CenteredNorm
import matplotlib.patches as patches
from scipy.stats import chi2
from sklearn.preprocessing import StandardScaler
from antipasti.utils.biology_utils import remove_nanobodies, extract_mean_region_lengths
def get_maps_of_interest(preprocessed_data, learnt_filter, affinity_thr=-8):
r"""Post-processes both raw data and results to obtain maps of interest.
preprocessed_data: antipasti.preprocessing.preprocessing.Preprocessing
The ``Preprocessing`` class.
learnt_filter: numpy.ndarray
Filters that express the learnt features during training.
affinity_thr: float
Affinity value separating antibodies considered to have high affinity from those considered to have low affinity.
mean_learnt: numpy.ndarray
A resized version of ``learnt_filter`` to match the shape of the input normal mode correlation maps.
mean_image: numpy.ndarray
The mean of all the input normal mode correlation maps.
mean_diff_image: numpy.ndarray
Map resulting from the subtraction of the mean of the high affinity correlation maps and the mean of the low affinity correlation maps.
high_aff = []
low_aff = []
train_x = preprocessed_data.train_x
train_y = preprocessed_data.train_y
input_shape = train_x.shape[-1]
for i in range(train_y.shape[0]):
if train_y[i] < affinity_thr:
elif train_y[i] > affinity_thr:
# Obtaining the maps
mean_learnt = cv2.resize(-learnt_filter, dsize=(input_shape, input_shape))
mean_image = np.mean(train_x, axis=0).reshape(input_shape, input_shape)
mean_diff_image = np.mean(high_aff, axis=0) - np.mean(low_aff, axis=0)
return mean_learnt, mean_image, mean_diff_image
def get_output_representations(preprocessed_data, model):
r"""Returns maps that reveal the important residue interactions for the binding affinity. We call them 'output layer representations'.
preprocessed_data: antipasti.preprocessing.preprocessing.Preprocessing
The ``Preprocessing`` class.
model: antipasti.model.model.ANTIPASTI
The model class, i.e., ``ANTIPASTI``.
input_shape = preprocessed_data.test_x.shape[-1]
each_img_enl = np.zeros((preprocessed_data.train_x.shape[0], input_shape**2))
size_le = int(np.sqrt([-1] / model.n_filters))
offset = np.zeros((input_shape**2))
# Border effects
preprocessed_data.train_x[:,:,0] = 0
preprocessed_data.train_x[:,0,:] = 0
preprocessed_data.train_x[:,:,len(preprocessed_data.max_res_list_h)-1:len(preprocessed_data.max_res_list_h)+1] = 0
preprocessed_data.train_x[:,len(preprocessed_data.max_res_list_h)-1:len(preprocessed_data.max_res_list_h)+1, :] = 0
preprocessed_data.train_x[:,:,-1] = 0
preprocessed_data.train_x[:,-1,:] = 0
inter_filter_off = model(torch.from_numpy(np.zeros((input_shape, input_shape)).reshape(1, 1, input_shape, input_shape).astype(np.float32)))[1].detach().numpy()
for i in range(model.n_filters):
offset += cv2.resize(np.multiply(inter_filter_off[0,i],, size_le**2)[i].reshape(size_le, size_le)), dsize=(input_shape, input_shape)).reshape((input_shape**2))
for j in range(preprocessed_data.train_x.shape[0]):
inter_filter_item = model(torch.from_numpy(preprocessed_data.train_x[j].reshape(1, 1, input_shape, input_shape).astype(np.float32)))[1].detach().numpy()
for i in range(model.n_filters):
each_img_enl[j] += (size_le**2/input_shape**2) * cv2.resize(np.multiply(inter_filter_item[0,i],, size_le**2)[i].reshape(size_le, size_le)), dsize=(input_shape, input_shape)).reshape((input_shape**2))
each_img_enl[j] -= offset
return each_img_enl
def get_test_contribution(preprocessed_data, model):
r"""Returns a map that reveals the important residue interactions for the binding affinity.
preprocessed_data: antipasti.preprocessing.preprocessing.Preprocessing
The ``Preprocessing`` class.
model: antipasti.model.model.ANTIPASTI
The model class, i.e., ``ANTIPASTI``.
test_x = preprocessed_data.test_x
input_shape = preprocessed_data.test_x.shape[-1]
n_filters = model.n_filters
each_img_enl = np.zeros((input_shape, input_shape))
size_le = int(np.sqrt([-1] / n_filters))
inter_filter_item = model(torch.from_numpy(test_x.reshape(1, 1, input_shape, input_shape).astype(np.float32)))[1].detach().numpy()
for i in range(n_filters):
each_img_enl -= cv2.resize(np.multiply(inter_filter_item[0,i],, size_le**2)[i].reshape(size_le, size_le)), dsize=(input_shape, input_shape))
return each_img_enl
def plot_map_with_regions(preprocessed_data, map, title='Normal mode correlation map', interactive=False):
r"""Maps the residues to the antibody regions and plots the normal mode correlation map.
preprocessed_data: antipasti.preprocessing.preprocessing.Preprocessing
The ``Preprocessing`` class.
map: numpy.ndarray
A normal mode correlation map.
title: str
The image title.
interactive: bool
Set to ``True`` when running a script or ``pytest``.
# Font sizes
title_size = 42
font_size = 32
# Defining the region boundaries
mrlh = preprocessed_data.max_res_list_h
mrll = preprocessed_data.max_res_list_l
subgroup_boundaries_h = [mrlh.index('1'), mrlh.index('26'), mrlh.index('33'), mrlh.index('52'), mrlh.index('57'), mrlh.index('95'), mrlh.index('103'), mrlh.index('113')+1]
subgroup_boundaries_l = [mrll.index('1'), mrll.index('24'), mrll.index('35'), mrll.index('50'), mrll.index('57'), mrll.index('89'), mrll.index('98'), mrll.index('107')+1]
labels_h = ['FR-H1', 'CDR-H1', 'FR-H2', 'CDR-H2', 'FR-H3', 'CDR-H3', 'FR-H4']
labels_l = ['FR-L1', 'CDR-L1', 'FR-L2', 'CDR-L2', 'FR-L3', 'CDR-L3', 'FR-L4']
subgroup_boundaries = subgroup_boundaries_h[:-1] + [x+mrlh.index('113')+1 for x in subgroup_boundaries_l]
labels = labels_h + labels_l
fig = plt.figure(figsize=(20, 20))
# Plotting the normal mode correlation map
plt.imshow(map, origin='lower', cmap='seismic', norm=CenteredNorm())
cbar = plt.colorbar(fraction=0.045)
# Set the font size of the colorbar
for i in range(len(subgroup_boundaries) - 1):
start_index = subgroup_boundaries[i]
end_index = subgroup_boundaries[i+1]
label_position = (start_index + end_index) / 2 - 0.5
# Choosing the colours
if labels[i].startswith('CDR'):
c = 'deeppink'
c = 'orange'
# Adding rectangles for the regions
rect = plt.Rectangle((start_index - 0.5, -6.5), end_index - start_index, 6, edgecolor='black', facecolor=c, alpha=0.7)
rect = plt.Rectangle((-12.5, start_index - 0.5), 12, end_index - start_index, edgecolor='black', facecolor=c, alpha=0.7)
# Add labels for the regions on the y-axis
plt.text(-6, label_position-0.25, labels[i], ha='center', va='center', color='black', size=10)
# Add labels for the regions on the x-axis
plt.text(label_position, -3.7, labels[i], ha='center', va='center', color='black', size=9)
# Adding rectangles for the chains
rect = plt.Rectangle((-0.5, -10.5), subgroup_boundaries_h[-1], 4, edgecolor='black', facecolor='white')
rect = plt.Rectangle((subgroup_boundaries_h[-1]-0.5, -10.5), subgroup_boundaries_l[-1], 4, edgecolor='black', facecolor='white')
rect = plt.Rectangle((-16.5, -0.5), 4, subgroup_boundaries_h[-1], edgecolor='black', facecolor='white')
rect = plt.Rectangle((-16.5, subgroup_boundaries_h[-1]-0.5), 4, subgroup_boundaries_l[-1], edgecolor='black', facecolor='white')
# Adding labels for the chains on the y-axis
plt.text(-14, subgroup_boundaries_h[-1]/2, 'Heavy chain', ha='center', va='center', color='black', rotation='vertical', size=14)
plt.text(-14, subgroup_boundaries_h[-1]+subgroup_boundaries_l[-1]/2, 'Light chain', ha='center', va='center', color='black', rotation='vertical', size=14)
# Adding labels for the chains on the x-axis
plt.text(subgroup_boundaries_h[-1]/2, -9, 'Heavy chain', ha='center', va='center', color='black', size=12)
plt.text(subgroup_boundaries_h[-1]+subgroup_boundaries_l[-1]/2, -9, 'Light chain', ha='center', va='center', color='black', size=12)
# Adjusting the axis limits and labels
plt.xlim(-16.5, 290.5)
plt.ylim(-10.5, 290.5)
# Adding labels and title
plt.xlabel('Residues', size=font_size)
plt.ylabel('Residues', size=font_size)
plt.title(title, size=title_size)
if interactive:
def compute_umap(preprocessed_data, model, scheme='heavy_species', categorical=True, include_ellipses=False, numerical_values=None, external_cdict=None, interactive=False, exclude_nanobodies=False):
r"""Performs UMAP dimensionality reduction calculations.
preprocessed_data: antipasti.model.model.Preprocessing
The ``Preprocessing`` class.
model: antipasti.model.model.ANTIPASTI
The model class, i.e., ``ANTIPASTI``.
scheme: str
Category of the labels or values appearing in the UMAP representation.
categorical: bool
``True`` if ``scheme`` is categorical.
include_ellipses: bool
``True`` if ellipses comprising three quarters of the points of a given class are included.
numerical_values: list
A list of values or entries should be provided if data external to SAbDab is used.
external_cdict: dictionary
Option to provide an external dictionary of the UMAP labels.
interactive: bool
Set to ``True`` when running a script or ``pytest``.
exclude_nanobodies: bool
Set to ``True`` to exclude nanobodies from the UMAP plot.
train_x = preprocessed_data.train_x
input_shape = preprocessed_data.test_x.shape[-1]
reducer = umap.UMAP(random_state=32, min_dist=0.1, n_neighbors=90) # Paired-HL
labels = []
colours = []
pdb_codes = preprocessed_data.labels
db = pd.read_csv(preprocessed_data.data_path+'sabdab_summary_all.tsv', sep='\t')
if scheme in db.columns:
db = db.loc[:,['pdb', scheme]]
# Obtaining the labels and the output layer representations
for j in range(train_x.shape[0]):
if scheme in db.columns:
labels.append(str(db[db['pdb'] == pdb_codes[j]].iloc[-1][scheme]))
# if pdb_codes[j] in nanobodies:
# labels.pop()
# labels.append('nanobodies')
each_img_enl = get_output_representations(preprocessed_data, model)
# each_img_enl = each_img_enl.reshape(-1, train_x.shape[-1], train_x.shape[-1])[:, :len(preprocessed_data.max_res_list_h), :len(preprocessed_data.max_res_list_h)].reshape(-1, len(preprocessed_data.max_res_list_h)**2)
# UMAP fitting
scaled_each_img = StandardScaler().fit_transform(each_img_enl)
embedding = reducer.fit_transform(scaled_each_img)
if exclude_nanobodies:
pdb_codes, _, embedding, labels, numerical_values = remove_nanobodies(pdb_codes, train_x, embedding, labels, numerical_values)
if categorical:
if scheme == 'light_subclass':
cdict = {'IGKV1': 0,
'IGKV2': 1,
'IGKV3': 2,
'IGKV4': 3,
'IGKV5': 4,
'IGKV6': 5,
'IGKV7': 6,
'IGKV8': 7,
'IGKV9': 8,
'IGKV10': 9,
'IGKV14': 10,
'IGLV1': 11,
'IGLV2': 12,
'IGLV6': 13,
'Other': 14,}
scheme = 'Light chain V gene family'
elif scheme == 'heavy_subclass':
cdict = {'IGHV1': 0,
'IGHV2': 1,
'IGHV3': 2,
'IGHV4': 3,
'IGHV5': 4,
'IGHV6': 5,
'IGHV7': 6,
'Other': 7,}
scheme = 'Heavy chain V gene family'
elif scheme == 'heavy_species' or scheme == 'light_species':
cdict = {'homo sapiens': 0,
'mus musculus': 1,
'Other': 2}
scheme = 'Antibody species'
elif scheme == 'light_ctype':
cdict = {'Kappa': 0,
'Lambda': 1,
'unknown': 2,
'NA': 3,
'Other': 4,}
scheme = 'Type of light chain'
elif scheme == 'antigen_type':
cdict = {'protein': 0,
'peptide': 1,
'Hapten': 2,
'protein | protein': 3,
'carbohydrate': 4,
# 'nanobodies': 5,
'Other': 5}
scheme = 'Type of antigen'
cdict = external_cdict
for i in range(len(labels)):
if labels[i] in cdict:
labels[i] = 'Other'
cdict = None
deleted_items = 0
for i, item in enumerate(numerical_values):
if isinstance(item, (int, float, np.int64, np.float32)):
elif item.replace('.', '').isnumeric():
embedding = np.delete(embedding, i-deleted_items, axis=0)
pdb_codes = np.delete(pdb_codes, i-deleted_items, axis=0)
deleted_items += 1
plot_umap(embedding=embedding, colours=colours, scheme=scheme, pdb_codes=pdb_codes, categorical=categorical, include_ellipses=include_ellipses, cdict=cdict, interactive=interactive)
return colours, pdb_codes, embedding
def plot_umap(embedding, colours, scheme, pdb_codes, categorical=True, include_ellipses=False, cdict=None, interactive=False):
r"""Plots UMAP maps.
embedding: numpy.ndarray
The output layer representations after dimensionality reduction.
colours: list
The data points labels or values.
scheme: str
Category of the labels or values appearing in the UMAP representation.
pdb_codes: list
The PDB codes of the antibodies.
categorical: bool
``True`` if ``scheme`` is categorical.
include_ellipses: bool
``True`` to include ellipses comprising 85% of the points of a given class.
cdict: dictionary
External dictionary of the UMAP labels.
interactive: bool
Set to ``True`` when running a script or ``pytest``.
fig = plt.figure(figsize=(20,20))
ax = fig.add_subplot()
if categorical:
cmap = plt.get_cmap('tab10')
cmap = plt.get_cmap('Purples')
unique_colours = list(set(colours))
norm = plt.Normalize(np.min(colours), np.max(colours))
legend_patches = [patches.Patch(color=cmap(norm(color))) for color in unique_colours]
im = ax.scatter(embedding[:, 0], embedding[:, 1] , s=50, c=colours, cmap=cmap)
for i in range(len(pdb_codes)):
if i % 1 == 0:
ax.annotate(pdb_codes[i], (embedding[i, 0], embedding[i, 1]), size=8)
if include_ellipses:
# Inverse of the chi-squared CDF
conf_level = 0.85
inv_chi2 = chi2.ppf(conf_level, df=2)
ellipses = [] # Store ellipse information
for label in unique_colours:
label_points = embedding[np.array(colours) == label] # Subset of UMAP points for a specific label
n_points = len(label_points)
# Centroid and then sort
center = np.mean(label_points, axis=0)
covariance = np.cov(label_points.T)
dist = np.sum(np.square(label_points - center), axis=1)
sorted_indices = np.argsort(dist)
# Calculate the number of points to include within the ellipse
n_inside = int(np.ceil(n_points * conf_level))
inside_points = label_points[sorted_indices[:n_inside]]
# Recalculate the mean and covariance using only the inside points
center = np.mean(inside_points, axis=0)
covariance = np.cov(inside_points.T)
# Calculate the eigenvalues and eigenvectors of the cov matrix (again)
eigenvalues, eigenvectors = np.linalg.eig(covariance)
angle = np.degrees(np.arctan2(*eigenvectors[:, 0][::-1]))
scale_factor = np.sqrt(inv_chi2)
radius = np.sqrt(eigenvalues) * scale_factor
ellipse = patches.Ellipse(xy=center, width=2 * radius[0], height=2 * radius[1],
angle=angle, fill=False, linewidth=3, alpha=0.7, color=cmap(norm(label)))
for i, ellipse in enumerate(ellipses):
if list(cdict.keys())[i] not in ['unknown', 'Other']:
if categorical:
legend1 = ax.legend(legend_patches, cdict.keys(), loc='best')
legend1 = ax.legend(legend_patches[:10], set(colours), loc='best')
ax.set_title(scheme, size=18)
ax.set_xlabel('UMAP 1', size=16)
ax.set_ylabel('UMAP 2', size=16)
if interactive:
def plot_region_importance(importance_factor, importance_factor_ob, antigen_type, mode='region', interactive=False):
r"""Plots ranking of important regions.
importance_factor: list
Measure of importance (0-100) for each antibody region.
importance_factor_ob: list
Measure of importance (0-100) for each antibody region attributable to off-block correlations. This can be inter-region or inter-chain depending on the selected ``mode``.
antigen_type: int
Plot corresponding to antigens of a given type. These can be proteins (0), haptens (1), peptides (2) or carbohydrates (3).
mode: str
``region`` to explicitely show which correlations are inter/intra-region (likewise for ``chain``).
interactive: bool
Set to ``True`` when running a script or ``pytest``.
labels = ['FR-H1', 'CDR-H1', 'FR-H2', 'CDR-H2', 'FR-H3', 'CDR-H3', 'FR-H4', 'FR-L1', 'CDR-L1', 'FR-L2', 'CDR-L2', 'FR-L3', 'CDR-L3', 'FR-L4']
mapping_dict = {0: 0, 1: 2, 2: 1, 3: 5}
sorted_indices = np.argsort(importance_factor)[::-1] # Reverse order
cmap = plt.get_cmap('tab10')
# Create bars for each class
fig, ax = plt.subplots()
y_pos = np.arange(len(labels))
bar1 = ax.barh(y_pos, np.array(importance_factor_ob)[sorted_indices], align='center', color=cmap(mapping_dict[antigen_type]), label=f'Inter-{mode}')
bar2 = ax.barh(y_pos, np.array([importance_factor[i]-importance_factor_ob[i] for i in range(len(labels))])[sorted_indices], align='center', alpha=0.6, left=np.array(importance_factor_ob)[sorted_indices], color=cmap(mapping_dict[antigen_type]), label=f'Intra-{mode}')
ax.set_xlabel('Importance (%)')
ax.set_yticklabels([labels[np.argsort(importance_factor)[::-1][i]] for i in range(len(labels))])
plt.tick_params(axis='y', which='both', left=False, right=False)
plt.tick_params(axis='x', which='major', bottom=True, right=True, size=3.5)
for i, label in enumerate(ax.get_yticklabels()):
if np.argsort(importance_factor)[::-1][i] < 7:
color = 'green'
color = '#333333'
if interactive:
def add_region_based_on_range(list_residues):
r"""Given a list of residues in Chothia numbering, this function adds the corresponding regions in brackets for each of them."""
output_list_residues = []
new_coord = np.array([range(0, 26), range(26, 38), range(38, 57), range(57, 67), range(67, 116), range(116, 142),
range(142, 153), range(153, 176), range(176, 195), range(195, 210), range(210, 225),
range(225, 265), range(265, 279), range(279, 292)], dtype=object)
regions = ['FR-H1', 'CDR-H1', 'FR-H2', 'CDR-H2', 'FR-H3', 'CDR-H3', 'FR-H4', 'FR-L1', 'CDR-L1', 'FR-L2', 'CDR-L2', 'FR-L3', 'CDR-L3', 'FR-L4']
for index, element in enumerate(list_residues):
for i, r in enumerate(new_coord):
if index in r:
output_list_residues.append(element+' (' + regions[i] + ')')
return output_list_residues
def plot_residue_importance(preprocessed_data, importance_factor, antigen_type, interactive=False):
r"""Plots ranking of important residues.
preprocessed_data: antipasti.preprocessing.preprocessing.Preprocessing
The ``Preprocessing`` class.
importance_factor: list
Measure of importance (0-100) for each antibody residue.
antigen_type: int
Plot corresponding to antigens of a given type. These can be proteins (0), haptens (1), peptides (2) or carbohydrates (3).
interactive: bool
Set to ``True`` when running a script or ``pytest``.
res_labels = add_region_based_on_range(preprocessed_data.max_res_list_h+preprocessed_data.max_res_list_l)
mapping_dict = {0: 0, 1: 2, 2: 1, 3: 5}
cmap = plt.get_cmap('tab10')
fig, ax = plt.subplots()
y_pos = np.arange(len(res_labels[:30]))
bar1 = ax.barh(y_pos, sorted(importance_factor, reverse=True)[:30], align='center', alpha=0.9, color=cmap(mapping_dict[antigen_type]))
# Show top 30
ax.set_xlabel('Importance (%)')
ax.set_yticklabels([res_labels[np.argsort(importance_factor)[::-1][i]][:30] for i in range(len(res_labels[:30]))], fontsize=9.5)
plt.tick_params(axis='y', which='both', left=False, right=False)
plt.tick_params(axis='x', which='major', bottom=True, right=True, size=3.5)
for i, label in enumerate(ax.get_yticklabels()):
if np.argsort(importance_factor)[::-1][i] < len(preprocessed_data.max_res_list_h):
color = 'green'
color = '#333333'
if interactive:
def get_colours_ag_type(preprocessed_data):
r"""Returns a different colour according to the antigen type.
preprocessed_data: antipasti.preprocessing.preprocessing.Preprocessing
The ``Preprocessing`` class.
cluster_according_to = 'antigen_type'
db = pd.read_csv(preprocessed_data.data_path+'sabdab_summary_all.tsv', sep='\t')
clusters = []
for i in range(len(preprocessed_data.labels)):
clusters.append(str(db[db['pdb'] == preprocessed_data.labels[i]].iloc[0][cluster_according_to]))
cdict = {'protein': 0,
'Hapten': 1,
'peptide': 2,
'carbohydrate': 3,
'nucleic-acid': 4,
'protein | protein': 5,
'Other': 6}
colours = []
for i in range(len(clusters)):
if clusters[i] in cdict:
clusters[i] = 'Other'
return colours
def compute_region_importance(preprocessed_data, model, type_of_antigen, nanobodies, mode='region', interactive=False):
r"""Computes the importance factors (0-100) of all the Fv antibody regions. Returns the importance for each region.
preprocessed_data: antipasti.preprocessing.preprocessing.Preprocessing
The ``Preprocessing`` class.
model: antipasti.model.model.ANTIPASTI
The model class, i.e., ``ANTIPASTI``.
type_of_antigen: int
Choose between: proteins (0), haptens (1), peptides (2) or carbohydrates (3).
nanobodies: list
PDB codes of nanobodies in the dataset.
mode: str
``region`` to explicitely calculate which correlations are inter/intra-region (likewise for ``chain``).
interactive: bool
Set to ``True`` when running a script or ``pytest``.
colours = get_colours_ag_type(preprocessed_data)
each_img_enl = get_output_representations(preprocessed_data, model)
train_x = preprocessed_data.train_x
input_shape = preprocessed_data.test_x.shape[-1]
labels = preprocessed_data.labels
train_y = preprocessed_data.train_y
colours = [0 if c == 5 else c for c in colours]
all_mse_without_region = []
all_mse_without_region_intra = []
all_mse_without_region_ob = []
new_coord = np.array([range(0, 26), range(26, 38), range(38, 57), range(57, 67), range(67, 116), range(116, 142), range(142, 153),
range(153, 176), range(176, 195), range(195, 210), range(210, 225), range(225, 265), range(265, 279), range(279, 292)], dtype=object)
for j in range(len(new_coord)+1):
train_y_ = np.array([train_y[i] for i in range(each_img_enl.shape[0]) if colours[i] == type_of_antigen and labels[i] not in nanobodies])
if j != len(new_coord):
sums_without_region = np.array([
each_img_enl[i].reshape((input_shape, input_shape)).sum()-(each_img_enl[i].reshape((input_shape, input_shape))[new_coord[j][0]:new_coord[j][-1] + 1, :]).sum()
for i in range(each_img_enl.shape[0]) if colours[i] == type_of_antigen and labels[i] not in nanobodies])
if mode == 'region':
sums_without_region_divided = np.array([
each_img_enl[i].reshape((input_shape, input_shape)).sum()-np.array([(each_img_enl[i].reshape((input_shape, input_shape))[new_coord[j][0]:new_coord[j][-1] + 1, new_coord[0][0]:new_coord[j][0]]).sum()
+(each_img_enl[i].reshape((input_shape, input_shape))[new_coord[j][0]:new_coord[j][-1] + 1, new_coord[j][-1] + 1:new_coord[-1][-1] + 1]).sum(),
(each_img_enl[i].reshape((input_shape, input_shape))[new_coord[j][0]:new_coord[j][-1] + 1, new_coord[j][0]:new_coord[j][-1] + 1]).sum()])
for i in range(each_img_enl.shape[0]) if colours[i] == type_of_antigen and labels[i] not in nanobodies])
all_mse_without_region_intra.append(np.mean((sums_without_region_divided[:,1] - train_y_)**2))
all_mse_without_region_ob.append(np.mean((sums_without_region_divided[:,0] - train_y_)**2))
sums_without_region_divided = np.array([
each_img_enl[i].reshape((input_shape, input_shape)).sum()-np.array([(each_img_enl[i].reshape((input_shape, input_shape))[new_coord[j][0]:new_coord[j][-1] + 1, :len(preprocessed_data.max_res_list_h)]).sum(),
(each_img_enl[i].reshape((input_shape, input_shape))[new_coord[j][0]:new_coord[j][-1] + 1, len(preprocessed_data.max_res_list_h):]).sum()])
for i in range(each_img_enl.shape[0]) if colours[i] == type_of_antigen and labels[i] not in nanobodies])
index = 0 if j < 7 else 1
all_mse_without_region_intra.append(np.mean((sums_without_region_divided[:, index] - train_y_)**2))
all_mse_without_region_ob.append(np.mean((sums_without_region_divided[:, 1 - index] - train_y_)**2))
sums_without_region = np.array([
each_img_enl[i].reshape((input_shape, input_shape)).sum()
for i in range(each_img_enl.shape[0]) if colours[i] == type_of_antigen and labels[i] not in nanobodies])
all_mse_without_region.append(np.mean((sums_without_region - train_y_)**2))
total_mse = all_mse_without_region[-1]
region_mean_lengths = np.array([24, 7.1, 19, 6, 41, 11.3, 10, 21.7, 12.7, 14.9, 7, 32.9, 9.2, 9.1])
idx_best_normalised_mean_length = np.argmax(abs(all_mse_without_region-total_mse)/region_mean_lengths)
tot = 100*region_mean_lengths[idx_best_normalised_mean_length] * abs(all_mse_without_region-total_mse) / abs(all_mse_without_region[idx_best_normalised_mean_length]-total_mse)/region_mean_lengths
ob = tot - tot * abs(all_mse_without_region_intra-total_mse) / (abs(all_mse_without_region_intra-total_mse)+abs(all_mse_without_region_ob-total_mse))
plot_region_importance(tot, ob, type_of_antigen, mode, interactive=interactive)
return tot, ob
def compute_residue_importance(preprocessed_data, model, type_of_antigen, nanobodies, interactive=False):
r"""Computes the importance factors (0-100) of all the amino acids of the antibody variable region.
preprocessed_data: antipasti.preprocessing.preprocessing.Preprocessing
The ``Preprocessing`` class.
model: antipasti.model.model.ANTIPASTI
The model class, i.e., ``ANTIPASTI``.
type_of_antigen: int
Choose between: proteins (0), haptens (1), peptides (2) or carbohydrates (3).
nanobodies: list
PDB codes of nanobodies in the dataset.
interactive: bool
Set to ``True`` when running a script or ``pytest``.
colours = get_colours_ag_type(preprocessed_data)
each_img_enl = get_output_representations(preprocessed_data, model)
train_x = preprocessed_data.train_x
input_shape = preprocessed_data.test_x.shape[-1]
colours = [0 if c == 5 else c for c in colours]
all_mse_without_region = []
for j in range(train_x.shape[-1]+1):
if j != train_x.shape[-1]:
sums_without_region = np.array([
each_img_enl[i].reshape((input_shape, input_shape)).sum()-(each_img_enl[i].reshape((input_shape, input_shape))[j:j+1, :]).sum()
for i in range(train_x.shape[0]) if colours[i] == type_of_antigen and preprocessed_data.labels[i] not in nanobodies])
sums_without_region = np.array([
each_img_enl[i].reshape((input_shape, input_shape)).sum()
for i in range(train_x.shape[0]) if colours[i] == type_of_antigen and preprocessed_data.labels[i] not in nanobodies])
train_y_ = np.array([preprocessed_data.train_y[i] for i in range(train_x.shape[0]) if colours[i] == type_of_antigen and preprocessed_data.labels[i] not in nanobodies])
mse_without_region = np.mean((sums_without_region - train_y_)**2)
total_mse = all_mse_without_region[-1]
plot_residue_importance(preprocessed_data, 100*abs(all_mse_without_region-total_mse)/abs(max(all_mse_without_region)-total_mse), type_of_antigen, interactive=interactive)