Coverage for oc_ds_converter / datasource / orcid_index.py: 81%

149 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-06-12 21:23 +0000

1# SPDX-FileCopyrightText: 2026 Arcangelo Massari <arcangelo.massari@unibo.it> 

2# 

3# SPDX-License-Identifier: ISC 

4 

5from __future__ import annotations 

6 

7import json 

8from collections import defaultdict 

9from concurrent.futures import ProcessPoolExecutor, as_completed 

10from csv import DictReader 

11from multiprocessing import get_context 

12from os import cpu_count, sep, walk 

13from os.path import exists 

14from typing import Protocol, cast 

15 

16import fakeredis 

17 

18from oc_ds_converter.datasource.redis import RedisDataSource 

19from oc_ds_converter.lib.console import create_progress 

20from oc_ds_converter.oc_idmanager import DOIManager 

21 

22 

23class OrcidIndexInterface(Protocol): 

24 def get_value(self, id_string: str, /) -> set[str] | None: ... 

25 def get_values_batch(self, ids: list[str], /) -> dict[str, set[str]]: ... 

26 

27 

28class OrcidIndexRedis: 

29 def __init__(self, testing: bool = False) -> None: 

30 if testing: 

31 self._r = fakeredis.FakeStrictRedis(decode_responses=True) 

32 else: 

33 self._redis = RedisDataSource("DOI-ORCID-INDEX") 

34 self._r = self._redis._r 

35 

36 def get_value(self, doi: str) -> set[str] | None: 

37 result = cast(set[str], self._r.smembers(doi)) 

38 if result: 

39 return result 

40 return None 

41 

42 def get_values_batch(self, dois: list[str]) -> dict[str, set[str]]: 

43 if not dois: 

44 return {} 

45 pipe = self._r.pipeline() 

46 for doi in dois: 

47 pipe.smembers(doi) 

48 results = pipe.execute() 

49 return { 

50 doi: cast(set[str], members) 

51 for doi, members in zip(dois, results) 

52 if members 

53 } 

54 

55 def add_values_batch(self, data: dict[str, set[str]]) -> None: 

56 pipe = self._r.pipeline() 

57 for doi, values in data.items(): 

58 if values: 

59 pipe.sadd(doi, *values) 

60 pipe.execute() 

61 

62 def has_data(self) -> bool: 

63 return cast(int, self._r.dbsize()) > 0 

64 

65 def clear(self) -> None: 

66 self._r.flushdb() 

67 

68 

69def _process_csv_file(csv_path: str) -> dict[str, set[str]]: 

70 """Process a single CSV file and return DOI -> ORCID mappings. 

71 

72 This function runs in a separate process for parallelization. 

73 """ 

74 doi_manager = DOIManager() 

75 result: dict[str, set[str]] = defaultdict(set) 

76 

77 with open(csv_path, 'r', encoding='utf-8') as f: 

78 reader = DictReader(f) 

79 for row in reader: 

80 raw_doi = row['id'] 

81 doi = doi_manager.normalise(raw_doi, include_prefix=True) 

82 if doi: 

83 result[doi].add(row['value']) 

84 

85 return dict(result) 

86 

87 

88def load_orcid_index_to_redis( 

89 orcid_index_dir: str, 

90 orcid_index_redis: OrcidIndexRedis, 

91 batch_size: int = 1000000, 

92 max_workers: int | None = None, 

93) -> None: 

94 if not exists(orcid_index_dir): 

95 return 

96 

97 files_to_process: list[str] = [] 

98 for cur_dir, _, cur_files in walk(orcid_index_dir): 

99 for cur_file in cur_files: 

100 if cur_file.endswith('.csv'): 

101 files_to_process.append(cur_dir + sep + cur_file) 

102 

103 if not files_to_process: 

104 return 

105 

106 if max_workers is None: 

107 max_workers = min(cpu_count() or 4, len(files_to_process)) 

108 

109 batch: dict[str, set[str]] = {} 

110 count = 0 

111 

112 with create_progress() as progress: 

113 task = progress.add_task( 

114 "[green]Loading DOI-ORCID index files", total=len(files_to_process) 

115 ) 

116 

117 with ProcessPoolExecutor( 

118 max_workers=max_workers, mp_context=get_context("forkserver") 

119 ) as executor: 

120 futures = { 

121 executor.submit(_process_csv_file, path): path 

122 for path in files_to_process 

123 } 

124 

125 for future in as_completed(futures): 

126 file_data = future.result() 

127 

128 # Merge results into batch 

129 for doi, values in file_data.items(): 

130 if doi in batch: 

131 batch[doi].update(values) 

132 else: 

133 batch[doi] = values 

134 count += len(values) 

135 

136 # Flush to Redis when batch is large enough 

137 if count >= batch_size: 

138 orcid_index_redis.add_values_batch(batch) 

139 batch = {} 

140 count = 0 

141 

142 progress.update(task, advance=1) 

143 

144 if batch: 

145 orcid_index_redis.add_values_batch(batch) 

146 

147 

148class PublishersRedis: 

149 MEMBER_PREFIX = "member:" 

150 DOI_PREFIX_KEY = "prefix:" 

151 

152 def __init__(self, testing: bool = False) -> None: 

153 if testing: 

154 self._r = fakeredis.FakeStrictRedis(decode_responses=True) 

155 else: 

156 self._redis = RedisDataSource("PUBLISHERS-INDEX") 

157 self._r = self._redis._r 

158 

159 def get_by_member(self, member_id: str) -> dict[str, str | set[str]] | None: 

160 key = f"{self.MEMBER_PREFIX}{member_id}" 

161 data = self._r.get(key) 

162 if data: 

163 result = json.loads(str(data)) 

164 result["prefixes"] = set(result["prefixes"]) 

165 return result 

166 return None 

167 

168 def get_by_prefix(self, prefix: str) -> dict[str, str | set[str]] | None: 

169 key = f"{self.DOI_PREFIX_KEY}{prefix}" 

170 member_id = self._r.get(key) 

171 if member_id: 

172 return self.get_by_member(str(member_id)) 

173 return None 

174 

175 def set_publisher(self, member_id: str, name: str, prefixes: set[str]) -> None: 

176 member_key = f"{self.MEMBER_PREFIX}{member_id}" 

177 data = {"name": name, "prefixes": list(prefixes)} 

178 self._r.set(member_key, json.dumps(data)) 

179 for prefix in prefixes: 

180 prefix_key = f"{self.DOI_PREFIX_KEY}{prefix}" 

181 self._r.set(prefix_key, member_id) 

182 

183 def set_publishers_batch(self, publishers: dict[str, dict[str, str | set[str]]]) -> None: 

184 pipe = self._r.pipeline() 

185 for member_id, data in publishers.items(): 

186 member_key = f"{self.MEMBER_PREFIX}{member_id}" 

187 prefixes_list = list(data["prefixes"]) 

188 pipe.set(member_key, json.dumps({"name": data["name"], "prefixes": prefixes_list})) 

189 for prefix in prefixes_list: 

190 prefix_key = f"{self.DOI_PREFIX_KEY}{prefix}" 

191 pipe.set(prefix_key, member_id) 

192 pipe.execute() 

193 

194 def has_data(self) -> bool: 

195 return cast(int, self._r.dbsize()) > 0 

196 

197 def clear(self) -> None: 

198 self._r.flushdb() 

199 

200 

201def load_publishers_to_redis( 

202 publishers_filepath: str, 

203 publishers_redis: PublishersRedis, 

204 batch_size: int = 5000, 

205) -> None: 

206 if not exists(publishers_filepath): 

207 return 

208 

209 batch: dict[str, dict[str, str | set[str]]] = {} 

210 count = 0 

211 

212 with open(publishers_filepath, 'r', encoding='utf-8') as f: 

213 reader = DictReader(f) 

214 for row in reader: 

215 pub_id = row['id'] 

216 if pub_id not in batch: 

217 batch[pub_id] = {'name': row['name'], 'prefixes': set()} 

218 prefixes = batch[pub_id]['prefixes'] 

219 prefixes.add(row['prefix']) # type: ignore[union-attr] 

220 count += 1 

221 

222 if count >= batch_size: 

223 publishers_redis.set_publishers_batch(batch) 

224 batch = {} 

225 count = 0 

226 

227 if batch: 

228 publishers_redis.set_publishers_batch(batch)