diff --git a/annif_client.py b/annif_client.py index c610441..f07b9b3 100755 --- a/annif_client.py +++ b/annif_client.py @@ -2,6 +2,8 @@ """Module for accessing Annif REST API""" import requests +import importlib + # Default API base URL API_BASE = 'https://api.annif.org/v1/' @@ -12,24 +14,31 @@ class AnnifClient: def __init__(self, api_base=API_BASE): self.api_base = api_base + version = importlib.metadata.version("annif-client") + self._headers = { + "User-Agent": f"Annif-client/{version}", + } @property def api_info(self): """Get basic information of the API endpoint""" - req = requests.get(self.api_base) + req = requests.get(self.api_base, headers=self._headers) req.raise_for_status() return req.json() @property def projects(self): """Get a list of projects available on the API endpoint""" - req = requests.get(self.api_base + 'projects') + req = requests.get(self.api_base + 'projects', headers=self._headers) req.raise_for_status() return req.json()['projects'] def get_project(self, project_id): """Get a single project by project ID""" - req = requests.get(self.api_base + 'projects/{}'.format(project_id)) + req = requests.get( + self.api_base + 'projects/{}'.format(project_id), + headers=self._headers + ) if req.status_code == 404: raise ValueError(req.json()['detail']) req.raise_for_status() @@ -51,7 +60,7 @@ def suggest(self, project_id, text, limit=None, threshold=None): payload['threshold'] = threshold url = self.api_base + 'projects/{}/suggest'.format(project_id) - req = requests.post(url, data=payload) + req = requests.post(url, data=payload, headers=self._headers) if req.status_code == 404: raise ValueError(req.json()['detail']) req.raise_for_status() @@ -61,7 +70,7 @@ def learn(self, project_id, documents): """Further train an existing project on a text with given subjects.""" url = self.api_base + 'projects/{}/learn'.format(project_id) - req = requests.post(url, json=documents) + req = requests.post(url, json=documents, headers=self._headers) if req.status_code == 404: raise ValueError(req.json()['detail']) req.raise_for_status() diff --git a/tests/test_annif_client.py b/tests/test_annif_client.py index a9c830c..da80010 100644 --- a/tests/test_annif_client.py +++ b/tests/test_annif_client.py @@ -4,6 +4,9 @@ import os.path import pytest import responses +import requests +import importlib +import unittest @pytest.fixture(scope='module') @@ -42,3 +45,13 @@ def test_projects(client): body=open(datafile).read()) result = client.projects assert len(result) == 2 + + +def test_headers(client): + with unittest.mock.patch("requests.get"): + client.api_info + + version = importlib.metadata.version("annif-client") + assert requests.get.call_args.kwargs["headers"] == { + "User-Agent": f"Annif-client/{version}" + }