Coverage for oc_ds_converter / datasource / orcid_index.py: 81%
149 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-12 21:23 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-12 21:23 +0000
1# SPDX-FileCopyrightText: 2026 Arcangelo Massari <arcangelo.massari@unibo.it>
2#
3# SPDX-License-Identifier: ISC
5from __future__ import annotations
7import json
8from collections import defaultdict
9from concurrent.futures import ProcessPoolExecutor, as_completed
10from csv import DictReader
11from multiprocessing import get_context
12from os import cpu_count, sep, walk
13from os.path import exists
14from typing import Protocol, cast
16import fakeredis
18from oc_ds_converter.datasource.redis import RedisDataSource
19from oc_ds_converter.lib.console import create_progress
20from oc_ds_converter.oc_idmanager import DOIManager
23class OrcidIndexInterface(Protocol):
24 def get_value(self, id_string: str, /) -> set[str] | None: ...
25 def get_values_batch(self, ids: list[str], /) -> dict[str, set[str]]: ...
28class OrcidIndexRedis:
29 def __init__(self, testing: bool = False) -> None:
30 if testing:
31 self._r = fakeredis.FakeStrictRedis(decode_responses=True)
32 else:
33 self._redis = RedisDataSource("DOI-ORCID-INDEX")
34 self._r = self._redis._r
36 def get_value(self, doi: str) -> set[str] | None:
37 result = cast(set[str], self._r.smembers(doi))
38 if result:
39 return result
40 return None
42 def get_values_batch(self, dois: list[str]) -> dict[str, set[str]]:
43 if not dois:
44 return {}
45 pipe = self._r.pipeline()
46 for doi in dois:
47 pipe.smembers(doi)
48 results = pipe.execute()
49 return {
50 doi: cast(set[str], members)
51 for doi, members in zip(dois, results)
52 if members
53 }
55 def add_values_batch(self, data: dict[str, set[str]]) -> None:
56 pipe = self._r.pipeline()
57 for doi, values in data.items():
58 if values:
59 pipe.sadd(doi, *values)
60 pipe.execute()
62 def has_data(self) -> bool:
63 return cast(int, self._r.dbsize()) > 0
65 def clear(self) -> None:
66 self._r.flushdb()
69def _process_csv_file(csv_path: str) -> dict[str, set[str]]:
70 """Process a single CSV file and return DOI -> ORCID mappings.
72 This function runs in a separate process for parallelization.
73 """
74 doi_manager = DOIManager()
75 result: dict[str, set[str]] = defaultdict(set)
77 with open(csv_path, 'r', encoding='utf-8') as f:
78 reader = DictReader(f)
79 for row in reader:
80 raw_doi = row['id']
81 doi = doi_manager.normalise(raw_doi, include_prefix=True)
82 if doi:
83 result[doi].add(row['value'])
85 return dict(result)
88def load_orcid_index_to_redis(
89 orcid_index_dir: str,
90 orcid_index_redis: OrcidIndexRedis,
91 batch_size: int = 1000000,
92 max_workers: int | None = None,
93) -> None:
94 if not exists(orcid_index_dir):
95 return
97 files_to_process: list[str] = []
98 for cur_dir, _, cur_files in walk(orcid_index_dir):
99 for cur_file in cur_files:
100 if cur_file.endswith('.csv'):
101 files_to_process.append(cur_dir + sep + cur_file)
103 if not files_to_process:
104 return
106 if max_workers is None:
107 max_workers = min(cpu_count() or 4, len(files_to_process))
109 batch: dict[str, set[str]] = {}
110 count = 0
112 with create_progress() as progress:
113 task = progress.add_task(
114 "[green]Loading DOI-ORCID index files", total=len(files_to_process)
115 )
117 with ProcessPoolExecutor(
118 max_workers=max_workers, mp_context=get_context("forkserver")
119 ) as executor:
120 futures = {
121 executor.submit(_process_csv_file, path): path
122 for path in files_to_process
123 }
125 for future in as_completed(futures):
126 file_data = future.result()
128 # Merge results into batch
129 for doi, values in file_data.items():
130 if doi in batch:
131 batch[doi].update(values)
132 else:
133 batch[doi] = values
134 count += len(values)
136 # Flush to Redis when batch is large enough
137 if count >= batch_size:
138 orcid_index_redis.add_values_batch(batch)
139 batch = {}
140 count = 0
142 progress.update(task, advance=1)
144 if batch:
145 orcid_index_redis.add_values_batch(batch)
148class PublishersRedis:
149 MEMBER_PREFIX = "member:"
150 DOI_PREFIX_KEY = "prefix:"
152 def __init__(self, testing: bool = False) -> None:
153 if testing:
154 self._r = fakeredis.FakeStrictRedis(decode_responses=True)
155 else:
156 self._redis = RedisDataSource("PUBLISHERS-INDEX")
157 self._r = self._redis._r
159 def get_by_member(self, member_id: str) -> dict[str, str | set[str]] | None:
160 key = f"{self.MEMBER_PREFIX}{member_id}"
161 data = self._r.get(key)
162 if data:
163 result = json.loads(str(data))
164 result["prefixes"] = set(result["prefixes"])
165 return result
166 return None
168 def get_by_prefix(self, prefix: str) -> dict[str, str | set[str]] | None:
169 key = f"{self.DOI_PREFIX_KEY}{prefix}"
170 member_id = self._r.get(key)
171 if member_id:
172 return self.get_by_member(str(member_id))
173 return None
175 def set_publisher(self, member_id: str, name: str, prefixes: set[str]) -> None:
176 member_key = f"{self.MEMBER_PREFIX}{member_id}"
177 data = {"name": name, "prefixes": list(prefixes)}
178 self._r.set(member_key, json.dumps(data))
179 for prefix in prefixes:
180 prefix_key = f"{self.DOI_PREFIX_KEY}{prefix}"
181 self._r.set(prefix_key, member_id)
183 def set_publishers_batch(self, publishers: dict[str, dict[str, str | set[str]]]) -> None:
184 pipe = self._r.pipeline()
185 for member_id, data in publishers.items():
186 member_key = f"{self.MEMBER_PREFIX}{member_id}"
187 prefixes_list = list(data["prefixes"])
188 pipe.set(member_key, json.dumps({"name": data["name"], "prefixes": prefixes_list}))
189 for prefix in prefixes_list:
190 prefix_key = f"{self.DOI_PREFIX_KEY}{prefix}"
191 pipe.set(prefix_key, member_id)
192 pipe.execute()
194 def has_data(self) -> bool:
195 return cast(int, self._r.dbsize()) > 0
197 def clear(self) -> None:
198 self._r.flushdb()
201def load_publishers_to_redis(
202 publishers_filepath: str,
203 publishers_redis: PublishersRedis,
204 batch_size: int = 5000,
205) -> None:
206 if not exists(publishers_filepath):
207 return
209 batch: dict[str, dict[str, str | set[str]]] = {}
210 count = 0
212 with open(publishers_filepath, 'r', encoding='utf-8') as f:
213 reader = DictReader(f)
214 for row in reader:
215 pub_id = row['id']
216 if pub_id not in batch:
217 batch[pub_id] = {'name': row['name'], 'prefixes': set()}
218 prefixes = batch[pub_id]['prefixes']
219 prefixes.add(row['prefix']) # type: ignore[union-attr]
220 count += 1
222 if count >= batch_size:
223 publishers_redis.set_publishers_batch(batch)
224 batch = {}
225 count = 0
227 if batch:
228 publishers_redis.set_publishers_batch(batch)