Skip to content
Snippets Groups Projects
Commit 9595ac25 authored by schereme's avatar schereme
Browse files

fix, tests: Add raise exceptions for model and FASTA file download issues -...

parent 4acf69a9
Branches
Tags feature_importance
2 merge requests!21Merge develop branch into main,!20fix, tests: Add raise exceptions for model and FASTA file download issues -...
Pipeline #6692 passed with stage
in 3 minutes and 59 seconds
......@@ -12,3 +12,41 @@ class NoGoAnnotationFoundError(Exception):
f"flag <is_go> to False) for the score calculation."
)
super(Exception, self).__init__(self.message)
class CouldNotDownloadAlphaFoldModelError(Exception):
"""Exception for cases when AlphaFold cannot be downloaded from API."""
def __init__(self, uniprot_id, err_msg):
self.uniprot_id = uniprot_id
self.err_msg = err_msg
self.message = f"Could NOT download AlphaFold model file for UniProt Id: {self.uniprot_id}! Received the following server message: {self.err_msg}."
super(Exception, self).__init__(self.message)
class CouldNotDownloadFASTAFileError(Exception):
"""Exception for cases when FASTA file cannot be downloaded from API."""
def __init__(self, uniprot_id, err_msg):
self.uniprot_id = uniprot_id
self.err_msg = err_msg
self.message = f"Could NOT download FASTA file for UniProt Id: {self.uniprot_id}! Received the following server message: {self.err_msg}."
super(Exception, self).__init__(self.message)
class CouldNotParseFASTAFileError(Exception):
"""Exception for cases when FASTA file cannot be parsed."""
def __init__(self, fasta_file_path):
self.fasta_file_path = fasta_file_path
self.message = f"Could NOT parse FASTA file: {self.fasta_file_path}!"
super(Exception, self).__init__(self.message)
class InvalidUniProtIdProvidedError(Exception):
"""Exception for cases when an invalid UniProt Id has been provided."""
def __init__(self, uniprot_id):
self.uniprot_id = uniprot_id
self.message = f"Invalid UniProt Id provided: {uniprot_id}! Please make sure that the provided UniProt Id is not empty | None."
super(Exception, self).__init__(self.message)
......@@ -4,6 +4,8 @@ import os
import prody as pr
import requests
from ..exceptions import CouldNotDownloadAlphaFoldModelError, CouldNotDownloadFASTAFileError
def get_dict():
aadict = {
......@@ -251,31 +253,32 @@ def calculate_sequence_structure_af_pipe_one_af(af_models_path, uniprot_id):
return af_fea
def get_af_model_from_api(uniprot_id: str, path_file) -> str:
def send_request_and_write_content_to_file(url: str, file_path: str):
with open(file_path, "wb") as f:
response = requests.get(url)
response.raise_for_status()
f.write(response.content)
try:
url = f"https://alphafold.ebi.ac.uk/files/AF-{uniprot_id}-F1-model_v4.pdb"
full_path = f"{path_file}AF-{uniprot_id}-F1-model_v4.pdb"
with open(full_path, "wb") as f:
f.write(requests.get(url).content)
def get_af_model_from_api(uniprot_id: str, path_file) -> str:
try:
url: str = f"https://alphafold.ebi.ac.uk/files/AF-{uniprot_id}-F1-model_v4.pdb"
file_path: str = f"{path_file}AF-{uniprot_id}-F1-model_v4.pdb"
send_request_and_write_content_to_file(url, file_path)
except requests.exceptions.RequestException as err:
raise SystemExit(err)
return full_path
raise CouldNotDownloadAlphaFoldModelError(uniprot_id, err)
return file_path
def get_fasta_from_api(uniprot_id: str, path_file) -> str:
try:
url = f"https://rest.uniprot.org/uniprotkb/{uniprot_id}.fasta"
full_path = f"{path_file}{uniprot_id}.fasta"
with open(full_path, "wb") as f:
f.write(requests.get(url).content)
file_path = f"{path_file}{uniprot_id}.fasta"
send_request_and_write_content_to_file(url, file_path)
except requests.exceptions.RequestException as err:
raise SystemExit(err)
return full_path
raise CouldNotDownloadFASTAFileError(uniprot_id, err)
return file_path
if __name__ == "__main__":
......
......@@ -4,6 +4,7 @@ from typing import Optional
from Bio import SeqIO
from ..exceptions import CouldNotParseFASTAFileError
from ..features.calculation_disorder import calculate_disorder_one
from ..features.goterms import calculate_3dir_go_one
from ..features.sequence_complexity import calculate_k40_k60_complexity_one
......@@ -41,12 +42,8 @@ def calculate_pipeline_automated_one(
path_uni = os.path.join(path_output, name + "/")
if not os.path.isdir(path_uni):
os.mkdir(path_uni)
path_af = get_af_model_from_api(name, path_uni)
get_af_model_from_api(name, path_uni)
input_fasta = get_fasta_from_api(name, path_uni)
if path_af is None:
return None
if input_fasta is None:
return None
fasta_sequences = SeqIO.parse(open(input_fasta), "fasta")
records = []
for fasta in fasta_sequences:
......@@ -65,7 +62,7 @@ def calculate_pipeline_automated_one(
return allfea
else:
return None
raise CouldNotParseFASTAFileError(input_fasta)
if __name__ == "__main__":
......
......@@ -6,6 +6,7 @@ import catboost
import numpy as np
import pandas as pd
from ..exceptions import InvalidUniProtIdProvidedError
from ..files import get_go_dir_path, get_model_dir_path
from .calculation_pipeline import calculate_pipeline_automated_one, calculate_pipeline_one
......@@ -111,8 +112,6 @@ def _calculate_inference_probability(
)
if is_automated:
rr = calculate_pipeline_automated_one(output_path, uniprot_id, go_flag, get_go_dir_path())
if rr is None:
return -1
else:
rr = calculate_pipeline_one(output_path, uniprot_id, fasta_file_dir, go_flag, get_go_dir_path())
......@@ -129,6 +128,8 @@ def inference_model_with_go_one(
fasta_dir, path_af, uniprot_id, is_automated: bool = True, fold: int = 10
) -> [np.float64, pd.DataFrame]:
"""Calculates a PICNIC score for a single protein sequence using a model trained with the GO annotation feature."""
if not uniprot_id:
raise InvalidUniProtIdProvidedError(uniprot_id)
picnic_score, feat_importance = _calculate_inference_probability(
model_file_name="modelpipe_depth6class1_id_2_llps_withgonocc_retrained_newgo18.sav",
feature_keys_file_name="keys_llps_withgonocc_retrained_newgo_18.txt",
......@@ -146,6 +147,8 @@ def inference_model_with_go_one(
def inference_model_without_go_one(fasta_dir, path_af, uniprot_id, is_automated: bool = True, fold: int = 10):
"""Calculates a PICNIC score for a single protein sequence using a model trained without the GO annotation feature."""
if not uniprot_id:
raise InvalidUniProtIdProvidedError(uniprot_id)
picnic_score, feat_importance = _calculate_inference_probability(
model_file_name="modelpipe_depth7class1_id_92_llps_withoutgo_24-02.sav",
feature_keys_file_name="keys_llps_withoutgocattrue_92.txt",
......
......@@ -3,7 +3,7 @@ from unittest.mock import call
import pandas as pd
import pytest
from src.exceptions import NoGoAnnotationFoundError
from src.exceptions import InvalidUniProtIdProvidedError, NoGoAnnotationFoundError
from src.prediction.inference_model import inference_model_with_go_one, inference_model_without_go_one
from tests.unit import EXPECTED_FEATURE_IMPORTANCE_PROTEIN_Q99720
......@@ -156,3 +156,24 @@ def test_inference_model_without_go_one_should_succeed(
# and the correct result should be returned
assert actual_score == expected_score_result
pd.testing.assert_frame_equal(actual_feat_importance, expected_feat_importance, check_names=True)
@pytest.mark.parametrize(
"invalid_uniprot_id",
["", None],
)
def test_inference_model_without_go_one_should_raise_exception(tmpdir, invalid_uniprot_id):
"""Test 'inference_model_without_go_one' method should succeed raise exception."""
# ... given
# ... a tmpdir fixture
# ... and a UniProt ID
# when ... we call `inference_model_without_go_one()`
output_dir = tmpdir.mkdir("picnic_output")
with pytest.raises(InvalidUniProtIdProvidedError) as err:
inference_model_without_go_one(None, output_dir, invalid_uniprot_id)
# and ... the correct error message
assert err.value.message == (
f"Invalid UniProt Id provided: {invalid_uniprot_id}! Please make sure that the provided UniProt Id is not empty | None."
)
import pytest
from src.exceptions import CouldNotDownloadAlphaFoldModelError, CouldNotDownloadFASTAFileError
from src.features.sequence_structure_AF2 import get_af_model_from_api, get_fasta_from_api
@pytest.mark.parametrize(
"uniprot_id",
[
"Q99720",
],
)
def test_get_af_model_from_api_should_succeed(tmpdir, uniprot_id):
"""Test 'get_af_model_from_api' method should succeed and return expected results."""
# ... given
# ... a tmpdir fixture
# ... and a UniProt ID
# when ... we call `get_af_model_from_api()`
output_dir = tmpdir.mkdir("picnic_output")
actual_result = get_af_model_from_api(uniprot_id=uniprot_id, path_file=output_dir)
# then the correct result should be returned
assert actual_result == f"{output_dir}AF-{uniprot_id}-F1-model_v4.pdb"
@pytest.mark.parametrize(
"invalid_uniprot_id",
["", None],
)
def test_get_af_model_from_api_should_raise_exception(tmpdir, invalid_uniprot_id):
"""Test 'get_af_model_from_api' method raise exception."""
# ... given
# ... a tmpdir fixture
# ... and a UniProt ID
# when ... we call `get_af_model_from_api()`
output_dir = tmpdir.mkdir("picnic_output")
with pytest.raises(CouldNotDownloadAlphaFoldModelError) as err:
get_af_model_from_api(uniprot_id=invalid_uniprot_id, path_file=output_dir)
# and ... the correct error message
assert err.value.message == (
f"Could NOT download AlphaFold model file for UniProt Id: {invalid_uniprot_id}! Received the following server message: 404 Client Error: Not Found for url: https://alphafold.ebi.ac.uk/files/AF-{invalid_uniprot_id}-F1-model_v4.pdb."
)
@pytest.mark.parametrize(
"uniprot_id",
[
"Q99720",
],
)
def test_get_fasta_from_api_should_succeed(tmpdir, uniprot_id):
"""Test 'get_fasta_from_api' method should succeed and return expected results."""
# ... given
# ... a tmpdir fixture
# ... and a UniProt ID
# when ... we call `get_af_model_from_api()`
output_dir = tmpdir.mkdir("picnic_output")
actual_result = get_fasta_from_api(uniprot_id=uniprot_id, path_file=output_dir)
# then the correct result should be returned
assert actual_result == f"{output_dir}{uniprot_id}.fasta"
@pytest.mark.parametrize(
"invalid_uniprot_id, expected_error_msg",
[
("", "404 Client Error: Not Found"),
(None, "400 Client Error: Bad Request"),
],
)
def test_get_fasta_from_api_should_raise_exception(tmpdir, invalid_uniprot_id, expected_error_msg):
"""Test 'get_fasta_from_api' method raise exception."""
# ... given
# ... a tmpdir fixture
# ... and a UniProt ID
# when ... we call `get_af_model_from_api()`
output_dir = tmpdir.mkdir("picnic_output")
with pytest.raises(CouldNotDownloadFASTAFileError) as err:
get_fasta_from_api(uniprot_id=invalid_uniprot_id, path_file=output_dir)
# and ... the correct error message
assert err.value.message == (
f"Could NOT download FASTA file for UniProt Id: {invalid_uniprot_id}! Received the following server message: {expected_error_msg} for url: https://rest.uniprot.org/uniprotkb/{invalid_uniprot_id}.fasta."
)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment