Coverage for oc_ds_converter / datasource / orcid_index.py: 80%
148 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-25 18:06 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-25 18:06 +0000
1# SPDX-FileCopyrightText: 2026 Arcangelo Massari <arcangelo.massari@unibo.it>
2#
3# SPDX-License-Identifier: ISC
5from __future__ import annotations
7import json
8from collections import defaultdict
9from concurrent.futures import ProcessPoolExecutor, as_completed
10from csv import DictReader
11from os import cpu_count, sep, walk
12from os.path import exists
13from typing import Protocol, cast
15import fakeredis
17from oc_ds_converter.datasource.redis import RedisDataSource
18from oc_ds_converter.lib.console import create_progress
19from oc_ds_converter.oc_idmanager import DOIManager
22class OrcidIndexInterface(Protocol):
23 def get_value(self, id_string: str, /) -> set[str] | None: ...
24 def get_values_batch(self, ids: list[str], /) -> dict[str, set[str]]: ...
27class OrcidIndexRedis:
28 def __init__(self, testing: bool = False) -> None:
29 if testing:
30 self._r = fakeredis.FakeStrictRedis(decode_responses=True)
31 else:
32 self._redis = RedisDataSource("DOI-ORCID-INDEX")
33 self._r = self._redis._r
35 def get_value(self, doi: str) -> set[str] | None:
36 result = cast(set[str], self._r.smembers(doi))
37 if result:
38 return result
39 return None
41 def get_values_batch(self, dois: list[str]) -> dict[str, set[str]]:
42 if not dois:
43 return {}
44 pipe = self._r.pipeline()
45 for doi in dois:
46 pipe.smembers(doi)
47 results = pipe.execute()
48 return {
49 doi: cast(set[str], members)
50 for doi, members in zip(dois, results)
51 if members
52 }
54 def add_values_batch(self, data: dict[str, set[str]]) -> None:
55 pipe = self._r.pipeline()
56 for doi, values in data.items():
57 if values:
58 pipe.sadd(doi, *values)
59 pipe.execute()
61 def has_data(self) -> bool:
62 return cast(int, self._r.dbsize()) > 0
64 def clear(self) -> None:
65 self._r.flushdb()
68def _process_csv_file(csv_path: str) -> dict[str, set[str]]:
69 """Process a single CSV file and return DOI -> ORCID mappings.
71 This function runs in a separate process for parallelization.
72 """
73 doi_manager = DOIManager()
74 result: dict[str, set[str]] = defaultdict(set)
76 with open(csv_path, 'r', encoding='utf-8') as f:
77 reader = DictReader(f)
78 for row in reader:
79 raw_doi = row['id']
80 doi = doi_manager.normalise(raw_doi, include_prefix=True)
81 if doi:
82 result[doi].add(row['value'])
84 return dict(result)
87def load_orcid_index_to_redis(
88 orcid_index_dir: str,
89 orcid_index_redis: OrcidIndexRedis,
90 batch_size: int = 1000000,
91 max_workers: int | None = None,
92) -> None:
93 if not exists(orcid_index_dir):
94 return
96 files_to_process: list[str] = []
97 for cur_dir, _, cur_files in walk(orcid_index_dir):
98 for cur_file in cur_files:
99 if cur_file.endswith('.csv'):
100 files_to_process.append(cur_dir + sep + cur_file)
102 if not files_to_process:
103 return
105 if max_workers is None:
106 max_workers = min(cpu_count() or 4, len(files_to_process))
108 batch: dict[str, set[str]] = {}
109 count = 0
111 with create_progress() as progress:
112 task = progress.add_task(
113 "[green]Loading DOI-ORCID index files", total=len(files_to_process)
114 )
116 with ProcessPoolExecutor(max_workers=max_workers) as executor:
117 futures = {
118 executor.submit(_process_csv_file, path): path
119 for path in files_to_process
120 }
122 for future in as_completed(futures):
123 file_data = future.result()
125 # Merge results into batch
126 for doi, values in file_data.items():
127 if doi in batch:
128 batch[doi].update(values)
129 else:
130 batch[doi] = values
131 count += len(values)
133 # Flush to Redis when batch is large enough
134 if count >= batch_size:
135 orcid_index_redis.add_values_batch(batch)
136 batch = {}
137 count = 0
139 progress.update(task, advance=1)
141 if batch:
142 orcid_index_redis.add_values_batch(batch)
145class PublishersRedis:
146 MEMBER_PREFIX = "member:"
147 DOI_PREFIX_KEY = "prefix:"
149 def __init__(self, testing: bool = False) -> None:
150 if testing:
151 self._r = fakeredis.FakeStrictRedis(decode_responses=True)
152 else:
153 self._redis = RedisDataSource("PUBLISHERS-INDEX")
154 self._r = self._redis._r
156 def get_by_member(self, member_id: str) -> dict[str, str | set[str]] | None:
157 key = f"{self.MEMBER_PREFIX}{member_id}"
158 data = self._r.get(key)
159 if data:
160 result = json.loads(str(data))
161 result["prefixes"] = set(result["prefixes"])
162 return result
163 return None
165 def get_by_prefix(self, prefix: str) -> dict[str, str | set[str]] | None:
166 key = f"{self.DOI_PREFIX_KEY}{prefix}"
167 member_id = self._r.get(key)
168 if member_id:
169 return self.get_by_member(str(member_id))
170 return None
172 def set_publisher(self, member_id: str, name: str, prefixes: set[str]) -> None:
173 member_key = f"{self.MEMBER_PREFIX}{member_id}"
174 data = {"name": name, "prefixes": list(prefixes)}
175 self._r.set(member_key, json.dumps(data))
176 for prefix in prefixes:
177 prefix_key = f"{self.DOI_PREFIX_KEY}{prefix}"
178 self._r.set(prefix_key, member_id)
180 def set_publishers_batch(self, publishers: dict[str, dict[str, str | set[str]]]) -> None:
181 pipe = self._r.pipeline()
182 for member_id, data in publishers.items():
183 member_key = f"{self.MEMBER_PREFIX}{member_id}"
184 prefixes_list = list(data["prefixes"])
185 pipe.set(member_key, json.dumps({"name": data["name"], "prefixes": prefixes_list}))
186 for prefix in prefixes_list:
187 prefix_key = f"{self.DOI_PREFIX_KEY}{prefix}"
188 pipe.set(prefix_key, member_id)
189 pipe.execute()
191 def has_data(self) -> bool:
192 return cast(int, self._r.dbsize()) > 0
194 def clear(self) -> None:
195 self._r.flushdb()
198def load_publishers_to_redis(
199 publishers_filepath: str,
200 publishers_redis: PublishersRedis,
201 batch_size: int = 5000,
202) -> None:
203 if not exists(publishers_filepath):
204 return
206 batch: dict[str, dict[str, str | set[str]]] = {}
207 count = 0
209 with open(publishers_filepath, 'r', encoding='utf-8') as f:
210 reader = DictReader(f)
211 for row in reader:
212 pub_id = row['id']
213 if pub_id not in batch:
214 batch[pub_id] = {'name': row['name'], 'prefixes': set()}
215 prefixes = batch[pub_id]['prefixes']
216 prefixes.add(row['prefix']) # type: ignore[union-attr]
217 count += 1
219 if count >= batch_size:
220 publishers_redis.set_publishers_batch(batch)
221 batch = {}
222 count = 0
224 if batch:
225 publishers_redis.set_publishers_batch(batch)